How to use the tf2onnx.utils.get_tf_const_value function in tf2onnx

To help you get started, we’ve selected a few tf2onnx examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / shape_inference.py View on Github external
axis_op = op.inputs[2].op
            if not utils.is_tf_const_op(axis_op):
                return False
            axis = utils.get_tf_const_value(axis_op)
        else:
            axis = 0

        shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
        op.outputs[0].set_shape(shape)
        return True

    if op.type in ["All", "Any", "Max", "Min"]:
        axis_op = op.inputs[1].op
        if not utils.is_tf_const_op(axis_op):
            return False
        axis = utils.get_tf_const_value(axis_op)
        if not isinstance(axis, list):
            axis = [axis]
        keep_dims = op.get_attr("keep_dims")
        shape = utils.get_tf_tensor_shape(op.inputs[0])
        for i, _ in enumerate(axis):
            if axis[i] < 0:
                axis[i] += len(shape)

        new_shape = []
        for i, _ in enumerate(shape):
            if i in axis:
                if keep_dims:
                    new_shape.append(1)
            else:
                new_shape.append(shape[i])
github onnx / tensorflow-onnx / tf2onnx / shape_inference.py View on Github external
axis_op = op.inputs[2].op
            if not utils.is_tf_const_op(axis_op):
                return False
            axis = utils.get_tf_const_value(axis_op)
        else:
            axis = 0

        shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
        op.outputs[0].set_shape(shape)
        return True

    if op.type in ["All", "Any", "Max", "Min"]:
        axis_op = op.inputs[1].op
        if not utils.is_tf_const_op(axis_op):
            return False
        axis = utils.get_tf_const_value(axis_op)
        if not isinstance(axis, list):
            axis = [axis]
        keep_dims = op.get_attr("keep_dims")
        shape = utils.get_tf_tensor_shape(op.inputs[0])
        for i, _ in enumerate(axis):
            if axis[i] < 0:
                axis[i] += len(shape)

        new_shape = []
        for i, _ in enumerate(shape):
            if i in axis:
                if keep_dims:
                    new_shape.append(1)
            else:
                new_shape.append(shape[i])
github onnx / tensorflow-onnx / tf2onnx / shape_inference.py View on Github external
new_shape.append(1)
            else:
                new_shape.append(shape[i])

        op.outputs[0].set_shape(new_shape)
        logger.debug("set %s op [%s] with new shape %s", op.type, op.outputs[0].name, new_shape)
        return True

    if op.type == "ExpandDims":
        # https://www.tensorflow.org/api_docs/python/tf/expand_dims
        input_shape = utils.get_tf_tensor_shape(op.inputs[0])
        dim_op = op.inputs[1].op
        if input_shape is None or not utils.is_tf_const_op(dim_op):
            return False

        dim = utils.get_tf_const_value(dim_op)
        if dim < 0:
            dim = dim + len(input_shape) + 1

        new_shape = input_shape[:dim] + [1] + input_shape[dim:]
        op.outputs[0].set_shape(new_shape)
        logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
        return True

    if op.type == "Unpack":
        input_shape = utils.get_tf_tensor_shape(op.inputs[0])
        if input_shape is None:
            return False

        axis = op.get_attr("axis")
        axis = axis if axis >= 0 else axis + len(input_shape)
        # the link below says that the rank of output is "rank(input) -1",
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / shape_inference.py View on Github external
if utils.is_tf_concat_op(op):
        data_inputs = op.inputs[:-1]
        input_shapes = [utils.get_tf_tensor_shape(inp) for inp in data_inputs]
        input_shapes = [shape for shape in input_shapes if shape is not None]
        if not input_shapes:
            logger.debug("all input shapes of concat op %s are None, can't infer its output shape", op.name)
            return False

        new_shape = input_shapes[0]
        axis_op = op.inputs[-1]
        rank = len(new_shape)
        if not utils.is_tf_const_op(axis_op):
            op.outputs[0].set_shape([-1] * rank)
            return True

        axis = utils.get_tf_const_value(axis_op)
        axis = axis if axis >= 0 else axis + rank
        new_shape[axis] = -1
        if len(input_shapes) == len(data_inputs):  # all input shapes are known
            concat_dim_vals = list(np.array(input_shapes)[:, axis])
            # only when inputs' shape are known, then val of concat dim can be calculated
            if concat_dim_vals.count(-1) == 0:
                new_shape[axis] = sum(concat_dim_vals)

        op.outputs[0].set_shape(new_shape)
        logger.debug("set Concat op [%s] with new shape %s", op.outputs[0].name, new_shape)
        return True

    if op.type == "Select":
        new_shape = utils.get_tf_tensor_shape(op.inputs[1])
        if new_shape is None:
            new_shape = utils.get_tf_tensor_shape(op.inputs[2])
github onnx / tensorflow-onnx / tf2onnx / shape_inference.py View on Github external
if utils.is_tf_concat_op(op):
        data_inputs = op.inputs[:-1]
        input_shapes = [utils.get_tf_tensor_shape(inp) for inp in data_inputs]
        input_shapes = [shape for shape in input_shapes if shape is not None]
        if not input_shapes:
            logger.debug("all input shapes of concat op %s are None, can't infer its output shape", op.name)
            return False

        new_shape = input_shapes[0]
        axis_op = op.inputs[-1]
        rank = len(new_shape)
        if not utils.is_tf_const_op(axis_op):
            op.outputs[0].set_shape([-1] * rank)
            return True

        axis = utils.get_tf_const_value(axis_op)
        axis = axis if axis >= 0 else axis + rank
        new_shape[axis] = -1
        if len(input_shapes) == len(data_inputs):  # all input shapes are known
            concat_dim_vals = list(np.array(input_shapes)[:, axis])
            # only when inputs' shape are known, then val of concat dim can be calculated
            if concat_dim_vals.count(-1) == 0:
                new_shape[axis] = sum(concat_dim_vals)

        op.outputs[0].set_shape(new_shape)
        logger.debug("set Concat op [%s] with new shape %s", op.outputs[0].name, new_shape)
        return True

    if op.type == "Select":
        new_shape = utils.get_tf_tensor_shape(op.inputs[1])
        if new_shape is None:
            new_shape = utils.get_tf_tensor_shape(op.inputs[2])
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / shape_inference.py View on Github external
new_shape.append(1)
            else:
                new_shape.append(shape[i])

        op.outputs[0].set_shape(new_shape)
        logger.debug("set %s op [%s] with new shape %s", op.type, op.outputs[0].name, new_shape)
        return True

    if op.type == "ExpandDims":
        # https://www.tensorflow.org/api_docs/python/tf/expand_dims
        input_shape = utils.get_tf_tensor_shape(op.inputs[0])
        dim_op = op.inputs[1].op
        if input_shape is None or not utils.is_tf_const_op(dim_op):
            return False

        dim = utils.get_tf_const_value(dim_op)
        if dim < 0:
            dim = dim + len(input_shape) + 1

        new_shape = input_shape[:dim] + [1] + input_shape[dim:]
        op.outputs[0].set_shape(new_shape)
        logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
        return True

    if op.type == "Unpack":
        input_shape = utils.get_tf_tensor_shape(op.inputs[0])
        if input_shape is None:
            return False

        axis = op.get_attr("axis")
        axis = axis if axis >= 0 else axis + len(input_shape)
        # the link below says that the rank of output is "rank(input) -1",
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / shape_inference.py View on Github external
if not shape_op or shape_op.type != "Shape":
            return False
        return set_shape_from_input(shape_op.inputs[0], op.outputs[0])

    if op.type == "Gather":
        # uses the follwing link to know how to infer shape of output
        # https://www.tensorflow.org/api_docs/python/tf/gather
        shape_params = utils.get_tf_tensor_shape(op.inputs[0])
        shape_indices = utils.get_tf_tensor_shape(op.inputs[1])
        # gather can only have 2 inputs
        # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/gather.html
        if len(op.inputs) == 3:
            axis_op = op.inputs[2].op
            if not utils.is_tf_const_op(axis_op):
                return False
            axis = utils.get_tf_const_value(axis_op)
        else:
            axis = 0

        shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
        op.outputs[0].set_shape(shape)
        return True

    if op.type in ["All", "Any", "Max", "Min"]:
        axis_op = op.inputs[1].op
        if not utils.is_tf_const_op(axis_op):
            return False
        axis = utils.get_tf_const_value(axis_op)
        if not isinstance(axis, list):
            axis = [axis]
        keep_dims = op.get_attr("keep_dims")
        shape = utils.get_tf_tensor_shape(op.inputs[0])