How to use the tf2onnx.graph_builder.GraphBuilder function in tf2onnx

To help you get started, we’ve selected a few tf2onnx examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github onnx / tensorflow-onnx / tf2onnx / onnx_opset / nn.py View on Github external
# the detail math process of this op is: a = onehot(labels), b = logsoftmax(features), reduce_sum(mul(a, b))
        logit_node = node.inputs[0]
        logit_shape = ctx.get_shape(node.input[0])
        logit_dtype = ctx.get_dtype(node.input[0])

        label_name = node.input[1]

        if logit_shape is not None and logit_shape[-1] != -1:
            num_class = logit_shape[-1]
            node_nme = utils.make_name("onehot_depth")
            depth_node = ctx.make_const(node_nme, np.array([num_class]).astype(np.int64)).output[0]
        else:
            logit_shape = ctx.make_node("Shape", [node.input[0]]).output[0]
            slice_args = {"data": logit_shape,
                          "starts": [-1], "ends": [int(utils.get_max_value(np.int32))]}
            num_class = GraphBuilder(ctx).make_slice(kwargs=slice_args)
            depth_node = num_class
        values_node = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1]).astype(np.int64)).output[0]
        label_dtype = ctx.get_dtype(label_name)
        if label_dtype != TensorProto.INT64:
            onehot_indice = ctx.make_node("Cast", [label_name], attr={"to": TensorProto.INT64}).output[0]
        else:
            onehot_indice = label_name
        label_node = ctx.make_node(op_type="OneHot",
                                   inputs=[onehot_indice, depth_node, values_node])
        # the above logic makes output dtype of label_node now always int64
        # make sure label has same dtype as logit
        if logit_dtype != TensorProto.INT64:
            label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])

        _make_sparse_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / lstm_rewriter.py View on Github external
def _process_non_tuple_ch_init_nodes(self, context):
        input_id = context.state_variables["ct_ht"].enter_input_id
        hidden_size = context.hidden_size

        attr = {"axes": [1], "starts": [0], "ends": [hidden_size]}
        inputs_map = {"data": input_id, **attr}
        slice_node1 = GraphBuilder(self.g).make_slice(inputs_map)
        unsqueeze_node_1 = self.g.make_node("Unsqueeze", [slice_node1], attr={"axes": [0]})

        attr = {"axes": [1], "starts": [hidden_size], "ends": [hidden_size*2]}
        inputs_map = {"data": input_id, **attr}
        slice_node2 = GraphBuilder(self.g).make_slice(inputs_map)
        unsqueeze_node_2 = self.g.make_node("Unsqueeze", [slice_node2], attr={"axes": [0]})

        return unsqueeze_node_1.output[0], unsqueeze_node_2.output[0]
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / tensor.py View on Github external
# create axes input
        axes_const = ctx.make_const(
            utils.make_name("slice_axes"),
            np.array(axes, dtype=np_dtype)
        )
        axes_output = axes_const.output[0]

        inputs_map = {
            "data": node.input[0],
            "starts": begin.output[0],
            "ends": end_output,
            "steps": strides_output,
            "axes": axes_output
        }
        kwargs = {**inputs_map, "outputs": node.output}
        node = GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=out_dtypes, shapes=out_shapes)
        node = ctx.get_node_by_output(node)
        if needs_squeeze:
            name = utils.make_name(node.name)
            squeeze_node = ctx.insert_new_node_on_output("Squeeze", node.output[0], name)
            squeeze_node.set_attr("axes", needs_squeeze)
            input_dtype = ctx.get_dtype(node.output[0])
            ctx.set_dtype(squeeze_node.output[0], input_dtype)
            ctx.copy_shape(node.output[0], squeeze_node.output[0])
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / rnn_utils.py View on Github external
# remove reverse op for rnn_bw
        reverse_nodes = get_reverse_nodes_after_y_output(g, rnn_bw)

        for r_op in reverse_nodes:
            logger.debug("remove reverse op %s", r_op.name)
            g.replace_all_inputs(all_nodes, r_op.output[0], r_op.input[0])
            to_remove.append(r_op.name)
    elif rnn_output_index in [1, 2]:
        axis = 0
    else:
        raise ValueError("rnn only should has 3 outputs.")

    if fw_consumers:
        attr = {"axes": [axis], "starts": [0], "ends": [1]}
        inputs_map = {"data": bi_rnn.output[rnn_output_index], **attr}
        slice_node_fw = GraphBuilder(g).make_slice(inputs_map)
        all_nodes.append(g.get_node_by_output(slice_node_fw))
        g.replace_all_inputs(fw_consumers, rnn_fw.output[rnn_output_index], slice_node_fw)

    if bw_consumers:
        attr = {"axes": [axis], "starts": [1], "ends": [2]}
        inputs_map = {"data": bi_rnn.output[rnn_output_index], **attr}
        slice_node_bw = GraphBuilder(g).make_slice(inputs_map)
        all_nodes.append(g.get_node_by_output(slice_node_bw))
        g.replace_all_inputs(bw_consumers, rnn_bw.output[rnn_output_index], slice_node_bw)
github onnx / tensorflow-onnx / tf2onnx / rewriter / unit_rnn_rewriter_base.py View on Github external
def process_seq_length(self, context):
        # output: [time step, batch size, input size]
        seq_len_node = context.seq_len_node
        shape_node = self.g.make_node("Shape", [context.onnx_input_ids["X"]])
        # LSTMCell only allow inputs of [batch size, input_size], so we assume dynamic_rnn has 3 dims.
        # Slice cannot support Int64 in OPSET 7, so we cast here.
        cast_shape_node = self.g.make_node(
            "Cast", [shape_node.output[0]],
            attr={"to": TensorProto.FLOAT},
            shapes=[self.g.get_shape(shape_node.output[0])]
        )

        attr = {"axes": [0], "starts": [1], "ends": [2]}
        inputs_map = {"data": cast_shape_node.output[0], **attr}
        batchsize_node = GraphBuilder(self.g).make_slice(inputs_map)
        if not seq_len_node:
            # Tile's repeats must be INT64
            repeat_node = self.g.make_node(
                "Cast", [batchsize_node],
                attr={"to": TensorProto.INT64}
            )

            attr = {"axes": [0], "starts": [0], "ends": [1]}
            inputs_map = {"data": cast_shape_node.output[0], **attr}
            timestep_node = GraphBuilder(self.g).make_slice(inputs_map)
            tile_node = self.g.make_node("Tile", [timestep_node, repeat_node.output[0]])

            # LSTM sequence_lens needs to be int32
            seq_len_node = self.g.make_node(
                "Cast", [tile_node.output[0]],
                attr={"to": TensorProto.INT32}
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / tensor.py View on Github external
# reshape to target shape
    # output shape of gathernd: indices.shape[:-1] + gathernd_output.shape[1:]
    inner_loop_shape = ctx.make_node("Shape", [gathernd_loop.output[1]], dtypes=[TensorProto.INT64])
    # workaround in case gathernd_loop is 1-dimensional
    one_const = ctx.make_const(utils.make_name("one"), np.array([1], dtype=np.int64))
    inner_loop_shape_ = ctx.make_node("Concat",
                                      [inner_loop_shape.output[0], one_const.output[0]],
                                      attr={"axis": 0},
                                      dtypes=[TensorProto.INT64])
    attr = {"axes": [0], "ends": [sys.maxsize], "starts": [1]}
    inputs_map = {"data": inner_loop_shape_.output[0], **attr}
    output_inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
    attr = {"axes": [0], "ends": [-1], "starts": [0]}
    inputs_map = {"data": indices_shape.output[0], **attr}
    indices_outter_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
    output_shape_ = ctx.make_node("Concat",
                                  [indices_outter_shape, output_inner_shape],
                                  attr={"axis": 0},
                                  dtypes=[TensorProto.INT64])
    attr = {"axes": [0], "ends": [-1], "starts": [0]}
    inputs_map = {"data": output_shape_.output[0], **attr}
    output_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
    ctx.make_node("Reshape",
                  [gathernd_loop.output[1], output_shape],
                  outputs=[output],
                  shapes=shapes,
                  dtypes=dtypes)
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / onnx_opset / tensor.py View on Github external
[inner_loop_shape.output[0], one_const.output[0]],
                                      attr={"axis": 0},
                                      dtypes=[TensorProto.INT64])
    attr = {"axes": [0], "ends": [sys.maxsize], "starts": [1]}
    inputs_map = {"data": inner_loop_shape_.output[0], **attr}
    output_inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
    attr = {"axes": [0], "ends": [-1], "starts": [0]}
    inputs_map = {"data": indices_shape.output[0], **attr}
    indices_outter_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
    output_shape_ = ctx.make_node("Concat",
                                  [indices_outter_shape, output_inner_shape],
                                  attr={"axis": 0},
                                  dtypes=[TensorProto.INT64])
    attr = {"axes": [0], "ends": [-1], "starts": [0]}
    inputs_map = {"data": output_shape_.output[0], **attr}
    output_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
    ctx.make_node("Reshape",
                  [gathernd_loop.output[1], output_shape],
                  outputs=[output],
                  shapes=shapes,
                  dtypes=dtypes)
github onnx / keras-onnx / keras2onnx / wrapper.py View on Github external
# first create "scales" info for onnx upsample
    # if shape of input and output known then  "scale" is calculated statically and set as a const node
    shape = ctx.get_shape(node.input[0])
    if shape and shape[2] != -1 and shape[1] != -1 and node.inputs[1].is_const():
        target_shape = node.inputs[1].get_tensor_value()
        n, h, w, c = shape
        nh, nw = target_shape
        # scales is nchw
        # the reason not storing data at raw field is because of the bug: https://github.com/onnx/onnx/issues/1852
        scale_val = np.array([1.0, 1.0, float(nh) / h, float(nw) / w]).astype(np.float32)
        scales = ctx.make_const(tf2onnx.utils.make_name("scales"), scale_val, raw=False)
    else:
        ori_shape = ctx.make_node("Shape", [node.input[0]])
        attr = {"axes": [0], "starts": [1], "ends": [3]}
        inputs_map = {"data": ori_shape.output[0], **attr}
        ori_shape_hw = GraphBuilder(ctx).make_slice(inputs_map)
        ori_shape_hw_float = ctx.make_node("Cast", [ori_shape_hw], attr={"to": onnx_pb.TensorProto.FLOAT})

        target_hw = node.inputs[1]
        target_hw_float = ctx.make_node("Cast", target_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT})

        scales_hw = ctx.make_node("Div", [target_hw_float.output[0], ori_shape_hw_float.output[0]])

        const_one_array = ctx.make_const(tf2onnx.utils.make_name("one"), np.array([1.0, 1.0]).astype(np.float32))
        # scales is nchw
        scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
    # because onnxruntime only supports to scale the last two dims so transpose is inserted
    input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
    roi = ctx.make_const(tf2onnx.utils.make_name("roi"), np.array([]).astype(np.float32))
    attrs = {"mode": mode}
    attrs['coordinate_transformation_mode'] = 'asymmetric'
    if attrs['mode'] == 'nearest':
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / nn.py View on Github external
# first create "scales" info for onnx upsample
        # if shape of input and output known then  "scale" is calculated statically and set as a const node
        shape = ctx.get_shape(node.input[0])
        if shape and shape[2] != -1 and shape[1] != -1 and node.inputs[1].is_const():
            target_shape = node.inputs[1].get_tensor_value()
            n, h, w, c = shape
            nh, nw = target_shape
            # scales is nchw
            # the reason not storing data at raw field is because of the bug: https://github.com/onnx/onnx/issues/1852
            scale_val = np.array([1.0, 1.0, float(nh) / h, float(nw) / w]).astype(np.float32)
            scales = ctx.make_const(utils.make_name("scales"), scale_val, raw=False)
        else:
            ori_shape = ctx.make_node("Shape", [node.input[0]])
            attr = {"axes": [0], "starts": [1], "ends": [3]}
            inputs_map = {"data": ori_shape.output[0], **attr}
            ori_shape_hw = GraphBuilder(ctx).make_slice(inputs_map)
            ori_shape_hw_float = ctx.make_node("Cast", [ori_shape_hw], attr={"to": onnx_pb.TensorProto.FLOAT})

            target_hw = node.inputs[1]
            target_hw_float = ctx.make_node("Cast", target_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT})

            scales_hw = ctx.make_node("Div", [target_hw_float.output[0], ori_shape_hw_float.output[0]])

            const_one_array = ctx.make_const(utils.make_name("one"), np.array([1.0, 1.0]).astype(np.float32))
            # scales is nchw
            scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
        # because onnxruntime only supports to scale the last two dims so transpose is inserted
        input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
        if roi_required:
            roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32))
            upsample = ctx.make_node("Resize", [input_nchw.output[0], roi.output[0], scales.output[0]],
                                     attr={"mode": mode, "nearest_mode": "floor",