How to use the tf2onnx.constants function in tf2onnx

To help you get started, we’ve selected a few tf2onnx examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / convert.py View on Github external
def main():
    args = get_args()
    logging.basicConfig(level=logging.get_verbosity_level(args.verbose))
    if args.debug:
        utils.set_debug_mode(True)

    logger = logging.getLogger(constants.TF2ONNX_PACKAGE_NAME)

    extra_opset = args.extra_opset or []
    custom_ops = {}
    if args.custom_ops:
        # default custom ops for tensorflow-onnx are in the "tf" namespace
        custom_ops = {op: (default_custom_op_handler, []) for op in args.custom_ops.split(",")}
        extra_opset.append(constants.TENSORFLOW_OPSET)

    # get the frozen tensorflow model from graphdef, checkpoint or saved_model.
    if args.graphdef:
        graph_def, inputs, outputs = loader.from_graphdef(args.graphdef, args.inputs, args.outputs)
        model_path = args.graphdef
    if args.checkpoint:
        graph_def, inputs, outputs = loader.from_checkpoint(args.checkpoint, args.inputs, args.outputs)
        model_path = args.checkpoint
    if args.saved_model:
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / onnx_opset / tensor.py View on Github external
def version_9(cls, ctx, node, **kwargs):
        # T output = OneHot(uint8/int32/int64 input, T depth, T on-value, T off-value, @int axis, @dtype)
        # tf requires that dtype is same as on-value's and off-value's dtype
        # in ONNX, op's schema is (input, depth, value, @int axis), meaning of "value" is [off-value, on-value]
        # onnxruntime only supports int64
        output_dtype = ctx.get_dtype(node.input[2])
        if ctx.is_target(constants.TARGET_RS6) \
                and output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]:
            logger.warning("unsupported dtype in onnxruntime, onehot-9 can't be used directly")
            cls.version_1(ctx, node, **kwargs)
            return

        depth = node.input[1]
        depth = ctx.make_node("Unsqueeze", [depth], attr={"axes": [0]}).output[0]

        on_value = node.input[2]
        off_value = node.input[3]
        on_value = ctx.make_node("Unsqueeze", [on_value], attr={"axes": [0]}).output[0]
        off_value = ctx.make_node("Unsqueeze", [off_value], attr={"axes": [0]}).output[0]
        off_on_value = ctx.make_node("Concat", [off_value, on_value], attr={"axis": 0}).output[0]

        indices = node.input[0]
        if ctx.is_target(constants.TARGET_RS6) \
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / nn.py View on Github external
new_kernel_shape: reshape the kernel
    """

    if input_indices is None:
        input_indices = [0]
    if output_indices is None:
        output_indices = [0]

    if node.is_nhwc():
        # transpose input if needed, no need to record shapes on input
        for idx in input_indices:
            parent = node.inputs[idx]
            if node.inputs[idx].is_const() and len(ctx.find_output_consumers(node.input[1])) == 1:
                # if input is a constant, transpose that one if we are the only consumer
                val = parent.get_tensor_value(as_list=False)
                parent.set_tensor_value(val.transpose(constants.NHWC_TO_NCHW))
            else:
                # if input comes from a op, insert transpose op
                input_name = node.input[idx]
                transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
                transpose.set_attr("perm", constants.NHWC_TO_NCHW)
                transpose.skip_conversion = True
                shape = ctx.get_shape(input_name)
                if shape is not None:
                    new_shape = spatial_map(shape, constants.NHWC_TO_NCHW)
                    ctx.set_shape(transpose.output[0], new_shape)

    # kernel must to be transposed
    if with_kernel:
        parent = node.inputs[1]
        need_transpose = True
        if node.inputs[1].is_const():
github onnx / tensorflow-onnx / tf2onnx / tfonnx.py View on Github external
"""Insert a transpose from NHWC to NCHW on model input on users request."""
    ops = []
    for node in ctx.get_nodes():
        for idx, output_name in enumerate(node.output):
            if output_name in inputs_as_nchw:
                shape = ctx.get_shape(output_name)
                if len(shape) != len(constants.NCHW_TO_NHWC):
                    logger.warning("transpose_input for %s: shape must be rank 4, ignored" % output_name)
                    ops.append(node)
                    continue
                # insert transpose
                op_name = utils.make_name(node.name)
                transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name)
                transpose.set_attr("perm", constants.NCHW_TO_NHWC)
                ctx.copy_shape(output_name, transpose.output[0])
                ctx.set_shape(output_name, np.array(shape)[constants.NHWC_TO_NCHW])
                ops.append(transpose)
                ops.append(node)
                continue
        ops.append(node)
    ctx.reset_nodes(ops)
github onnx / tensorflow-onnx / tf2onnx / utils.py View on Github external
return helper.make_opsetid(domain, version)


def is_onnx_domain(domain):
    if domain is None or domain == "":
        return True
    return False


def parse_bool(val):
    if val is None:
        return False
    return val.lower() in ("yes", "true", "t", "y", "1")


_is_debug_mode = parse_bool(os.environ.get(constants.ENV_TF2ONNX_DEBUG_MODE))


def is_debug_mode():
    return _is_debug_mode


def set_debug_mode(enabled):
    global _is_debug_mode
    _is_debug_mode = enabled


def get_max_value(np_dtype):
    return np.iinfo(np_dtype).max


def get_min_value(np_dtype):
github onnx / tensorflow-onnx / tf2onnx / verbose_logging.py View on Github external
def set_level(level):
    """ Set logging level for tf2onnx package. tf verbosity is updated accordingly. """
    _logging.getLogger(constants.TF2ONNX_PACKAGE_NAME).setLevel(level)
    set_tf_verbosity(level)
github onnx / keras-onnx / keras2onnx / wrapper.py View on Github external
scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
    # because onnxruntime only supports to scale the last two dims so transpose is inserted
    input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
    roi = ctx.make_const(tf2onnx.utils.make_name("roi"), np.array([]).astype(np.float32))
    attrs = {"mode": mode}
    attrs['coordinate_transformation_mode'] = 'asymmetric'
    if attrs['mode'] == 'nearest':
        attrs['nearest_mode'] = 'floor'

    upsample = ctx.make_node("Resize", [input_nchw.output[0], roi.output[0], scales.output[0]],
                             attr=attrs)

    shapes = node.output_shapes
    dtypes = node.output_dtypes
    ctx.remove_node(node.name)
    ctx.make_node("Transpose", upsample.output, {"perm": constants.NCHW_TO_NHWC},
                  name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
github sony / nnabla / python / src / nnabla / utils / converter / tensorflow / importer.py View on Github external
def convert_to_onnx(self, graph_def, inputs, outputs):

        # FIXME: folding const = False
        graph_def = tf2onnx.tfonnx.tf_optimize(
            inputs, outputs, graph_def, False)
        with tf.Graph().as_default() as tf_graph:
            tf.import_graph_def(graph_def, name='')
        with tf.Session(graph=tf_graph):
            onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph,
                                                         continue_on_error=False,
                                                         verbose=False,
                                                         target=",".join(
                                                             constants.DEFAULT_TARGET),
                                                         opset=9,
                                                         input_names=inputs,
                                                         output_names=outputs,
                                                         inputs_as_nchw=None)
        model_proto = onnx_graph.make_model(
            "converted from {}".format(self._tf_file))
        new_model_proto = GraphUtil.optimize_model_proto(model_proto)
        if new_model_proto:
            model_proto = new_model_proto
        return model_proto
github onnx / tensorflow-onnx / tf2onnx / utils.py View on Github external
def find_opset(opset):
    """Find opset."""
    if opset is None or opset == 0:
        opset = defs.onnx_opset_version()
        if opset > constants.PREFERRED_OPSET:
            # if we use a newer onnx opset than most runtimes support, default to the one most supported
            opset = constants.PREFERRED_OPSET
    return opset
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / common.py View on Github external
def version_6(cls, ctx, node, **kwargs):
        """Elementwise Ops with broadcast flag."""
        if node.type == "AddV2":
            node.type = "Add"
        shape0 = ctx.get_shape(node.input[0])
        shape1 = ctx.get_shape(node.input[1])
        if shape0 != shape1:
            # this works around shortcomings in the broadcasting code
            # of caffe2 and winml/rs4.
            if ctx.is_target(constants.TARGET_RS4):
                # in rs4 mul and add do not support scalar correctly
                if not shape0:
                    if node.inputs[0].is_const():
                        shape0 = node.inputs[0].scalar_to_dim1()
                if not shape1:
                    if node.inputs[1].is_const():
                        shape1 = node.inputs[1].scalar_to_dim1()
            if shape0 and shape1 and len(shape0) < len(shape1) and node.type in ["Mul", "Add", "AddV2"]:
                tmp = node.input[0]
                node.input[0] = node.input[1]
                node.input[1] = tmp