How to use the tf2onnx.utils.get_tf_tensor_shape function in tf2onnx

To help you get started, we’ve selected a few tf2onnx examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github onnx / tensorflow-onnx / tf2onnx / shape_inference.py View on Github external
return False

    if op.type == "TensorArrayReadV3":
        # TensorArrayRead reads an element from the TensorArray into output value.
        # The TensorArray's shape can be got from TensorArrayScatter.
        # So the process is: first find TensorArrayScatter's shape and then TensorArray's
        # and finally take its last n-1 dim.
        flow_in_op = op.inputs[2].op
        if flow_in_op.type != "Enter":
            return False

        scatter_op = flow_in_op.inputs[0].op
        if scatter_op.type != "TensorArrayScatterV3":
            return False

        value_shape_before_scatter = utils.get_tf_tensor_shape(scatter_op.inputs[2])
        if value_shape_before_scatter is None:
            return False

        new_shape = value_shape_before_scatter[1:]
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
            return True
        return False

    return False
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / shape_inference.py View on Github external
logger.warning("Shapes of Merge %s have different ranks: %s, %s", op.name, len(s1), len(s2))
                return False

            logger.debug("Inputs of Merge %s have different shapes: %s, %s, but the same rank", op.name, s1, s2)
            new_shape = _merge_shapes_for_tf(s1, s2)
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
        else:
            new_shape = s1
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)

        return True

    if op.type == "Switch":
        new_shape = utils.get_tf_tensor_shape(op.inputs[0])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            op.outputs[1].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[1].name, new_shape)
            return True
        return False

    if op.type == "Enter":
        new_shape = utils.get_tf_tensor_shape(op.inputs[0])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
            return True
        return False
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / shape_inference.py View on Github external
return False

    if op.type == "TensorArrayReadV3":
        # TensorArrayRead reads an element from the TensorArray into output value.
        # The TensorArray's shape can be got from TensorArrayScatter.
        # So the process is: first find TensorArrayScatter's shape and then TensorArray's
        # and finally take its last n-1 dim.
        flow_in_op = op.inputs[2].op
        if flow_in_op.type != "Enter":
            return False

        scatter_op = flow_in_op.inputs[0].op
        if scatter_op.type != "TensorArrayScatterV3":
            return False

        value_shape_before_scatter = utils.get_tf_tensor_shape(scatter_op.inputs[2])
        if value_shape_before_scatter is None:
            return False

        new_shape = value_shape_before_scatter[1:]
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
            return True
        return False

    return False
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / shape_inference.py View on Github external
return False

    if op.type == "Placeholder":
        # if placeholder shape is not found, try to get it from "shape" attribute.
        attr_shape = utils.get_tf_shape_attr(op)
        if attr_shape is not None:
            new_shape = list(attr_shape)
            op.outputs[0].set_shape(new_shape)
            logger.debug("set placeholder op [%s] with new shape %s", op.outputs[0].name, new_shape)
            return True
        logger.warning("Shape of placeholder %s is unknown, treated it as a scalar", op.name)
        op.outputs[0].set_shape([])
        return True

    if op.type == "Merge":
        s1 = utils.get_tf_tensor_shape(op.inputs[0])
        s2 = utils.get_tf_tensor_shape(op.inputs[1])
        new_shape = None
        if s1 is None and s2 is None:
            return False
        if s1 is None and s2 is not None:
            new_shape = s2
        if s1 is not None and s2 is None:
            new_shape = s1

        if new_shape is not None:
            op.inputs[0].set_shape(new_shape)
            op.inputs[1].set_shape(new_shape)
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
            return True
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / shape_inference.py View on Github external
else:
            axis = 0

        shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
        op.outputs[0].set_shape(shape)
        return True

    if op.type in ["All", "Any", "Max", "Min"]:
        axis_op = op.inputs[1].op
        if not utils.is_tf_const_op(axis_op):
            return False
        axis = utils.get_tf_const_value(axis_op)
        if not isinstance(axis, list):
            axis = [axis]
        keep_dims = op.get_attr("keep_dims")
        shape = utils.get_tf_tensor_shape(op.inputs[0])
        for i, _ in enumerate(axis):
            if axis[i] < 0:
                axis[i] += len(shape)

        new_shape = []
        for i, _ in enumerate(shape):
            if i in axis:
                if keep_dims:
                    new_shape.append(1)
            else:
                new_shape.append(shape[i])

        op.outputs[0].set_shape(new_shape)
        logger.debug("set %s op [%s] with new shape %s", op.type, op.outputs[0].name, new_shape)
        return True
github onnx / tensorflow-onnx / tf2onnx / shape_inference.py View on Github external
def set_shape_from_inputs_broadcast(input_tensors, output_tensor):
    s1 = utils.get_tf_tensor_shape(input_tensors[0])
    s2 = utils.get_tf_tensor_shape(input_tensors[1])
    new_shape = broadcast_shape_inference(s1, s2)
    if new_shape is not None:
        output_tensor.set_shape(new_shape)
        logger.debug("set [%s] with new shape %s", output_tensor.name, new_shape)
        return True
    return False
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / shape_inference.py View on Github external
return False

        axis = op.get_attr("axis")
        axis = axis if axis >= 0 else axis + len(input_shape)
        # the link below says that the rank of output is "rank(input) -1",
        # from this statement "num" must equal to input_shape[axis], and if not tf will throw a runtime error
        # https://www.tensorflow.org/api_docs/python/tf/unstack
        new_shape = input_shape[:axis] + input_shape[axis + 1:]
        for output in op.outputs:
            output.set_shape(new_shape)
            logger.debug("set %s op [%s] with new shape %s", op.type, output.name, new_shape)
        return True

    if op.type in ["Minimum", "Maximum"]:
        # ops that are elementwise and support broadcasting
        input_shapes = [utils.get_tf_tensor_shape(op) for op in op.inputs]
        new_shape = broadcast_shape_inference(*input_shapes)
        op.outputs[0].set_shape(new_shape)
        return True

    return False
github onnx / tensorflow-onnx / tf2onnx / shape_inference.py View on Github external
if infer_input_shapes(op):
            return True

    if not has_unknown_output_shape:
        return False

    # for those ops, we don't expect all input shapes available to infer output shapes.
    ret = infer_output_shapes_with_partial_inputs(op)
    if ret is not None:
        return ret

    # for ops, we need all input shapes ready to infer output shapes.
    are_all_input_shape_ready = True
    no_shape = []
    for i in op.inputs:
        if utils.get_tf_tensor_shape(i) is None:
            are_all_input_shape_ready = False
            no_shape.append(i.name)

    if not are_all_input_shape_ready:
        logger.debug("op %s has inputs don't have shape specified, they are: %s", op.name, no_shape)
        return False

    if op.type in direct_ops:
        return set_shape_from_input(op.inputs[0], op.outputs[0])

    if op.type in broadcast_ops:
        return set_shape_from_inputs_broadcast(op.inputs, op.outputs[0])

    if op.type == "RandomUniform":
        shape_op = op.inputs[0].op
        if not shape_op or shape_op.type != "Shape":
github onnx / tensorflow-onnx / tf2onnx / shape_inference.py View on Github external
return set_shape_from_input(op.inputs[0], op.outputs[0])

    if op.type in broadcast_ops:
        return set_shape_from_inputs_broadcast(op.inputs, op.outputs[0])

    if op.type == "RandomUniform":
        shape_op = op.inputs[0].op
        if not shape_op or shape_op.type != "Shape":
            return False
        return set_shape_from_input(shape_op.inputs[0], op.outputs[0])

    if op.type == "Gather":
        # uses the follwing link to know how to infer shape of output
        # https://www.tensorflow.org/api_docs/python/tf/gather
        shape_params = utils.get_tf_tensor_shape(op.inputs[0])
        shape_indices = utils.get_tf_tensor_shape(op.inputs[1])
        # gather can only have 2 inputs
        # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/gather.html
        if len(op.inputs) == 3:
            axis_op = op.inputs[2].op
            if not utils.is_tf_const_op(axis_op):
                return False
            axis = utils.get_tf_const_value(axis_op)
        else:
            axis = 0

        shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
        op.outputs[0].set_shape(shape)
        return True

    if op.type in ["All", "Any", "Max", "Min"]:
        axis_op = op.inputs[1].op