How to use the tf2onnx.graph_matcher.GraphMatcher function in tf2onnx

To help you get started, we’ve selected a few tf2onnx examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github onnx / tensorflow-onnx / tests / test_internals.py View on Github external
def test_rewrite_subgraph(self):
        graph_proto = self.sample_net()
        g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
        pattern = \
            OpTypePattern('Abs', name='output', inputs=[
                OpTypePattern('Add', name='input')
            ])
        ops = g.get_nodes()
        matcher = GraphMatcher(pattern)
        match_results = list(matcher.match_ops(ops))
        for match in match_results:
            input_node = match.get_op('input')
            output_node = match.get_op('output')
            op_name = utils.make_name("ReplacedOp")
            out_name = utils.port_name(op_name)
            new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
            g.replace_all_inputs(ops, output_node.output[0], new_node.output[0])
            for n in set(match.get_nodes()):
                g.remove_node(n.name)
        g.topological_sort(ops)
        result = onnx_to_graphviz(g)
        expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
                   'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__5 [op_type=Sub] ' \
                   'n6 [op_type=Identity] n5_graph_outputs_Identity__3 [op_type=Identity] ' \
                   'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 n3:0 -> ReplacedOp__5 ' \
github onnx / tensorflow-onnx / tf2onnx / rewriter / flatten_rewriter.py View on Github external
OpTypePattern('Reshape', name='reshape', inputs=[
            OpTypePattern("*", name="input"),
            OpTypePattern('Pack', name="pack", inputs=[
                OpTypePattern('StridedSlice', name="slice", inputs=[
                    OpTypePattern('Shape', inputs=[
                        OpTypePattern("*", name="input2")
                    ]),
                    "*", "*", "*",
                ]),
                "*",
            ]),
        ])
    matcher = GraphMatcher(pattern_fixed_shape_input)
    match_results_1 = list(matcher.match_ops(ops))

    matcher = GraphMatcher(pattern_non_fixed_shape_input)
    match_results_2 = list(matcher.match_ops(ops))

    match_results = [(match_results_1, True), (match_results_2, False)]
    for match_results, check_fixed_input_shape in match_results:
        for match in match_results:
            input_node = match.get_op('input')
            reshape_node = match.get_op('reshape')
            pack_node = match.get_op('pack')
            slice_node = match.get_op('slice')
            need_rewrite = pack_node.inputs[1].is_const() and pack_node.inputs[1].get_tensor_value() == -1
            if not need_rewrite:
                continue

            input_shape = g.get_shape(reshape_node.input[0])
            need_rewrite = input_shape is not None
            if not need_rewrite:
github onnx / tensorflow-onnx / tf2onnx / rewriter / unit_rnn_rewriter_base.py View on Github external
def find_sequence_length_node(self, context):
        # get any state variable
        state_variable = list(context.state_variables.values())[0]
        next_iter_input_node = self.g.get_node_by_output(state_variable.next_iteration_input.id)
        if not is_tf_select_op(next_iter_input_node):
            logger.debug("no sequence length node is given")
            return None
        matcher = GraphMatcher(seq_len_pattern)
        match_result = matcher.match_op(next_iter_input_node)
        if not match_result:
            raise RuntimeError("failed to find sequence length.")
        return match_result.get_op("seq_len_node")
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / random_uniform.py View on Github external
def rewrite_random_uniform(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', inputs=[
                OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
                OpTypePattern('Sub', name='input2', inputs=["*", "*"]),
            ]), None
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        input2 = match.get_op('input2')
        output = match.get_op('output')
        ru_op = match.get_op('input1')
        # max is on input 0
        tmax = input2.inputs[0].get_tensor_value()
        tmin = input2.inputs[1].get_tensor_value()
        to_delete = list(set(match.get_nodes()))
        new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
        g.replace_all_inputs(ops, output.output[0], new_node.output[0])
        g.safe_remove_nodes(to_delete)

    return ops
github onnx / tensorflow-onnx / tf2onnx / rewriter / loop_rewriter_base.py View on Github external
def _parse_input_ta(self, context):
        graph_inputs = [v.switch_true_identity_output.id for v in context.loop_properties.all_variables.values()
                        if v.switch_true_identity_output.id]
        matcher = GraphMatcher(self.ta_read_input_pattern, allow_reorder=False)
        match_results = matcher.match_ops(self.g.get_nodes())
        match_results = [r for r in match_results if r.get_op("ta_index").output[0] in graph_inputs]
        for match in match_results:
            ta_input_scatter = match.get_op("ta_input_scatter")
            # the 3rd input of scatter is the value
            data_input_id = ta_input_scatter.input[2]
            ta_read_node = match.get_op("ta_read")

            # todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration.
            # then we can be sure this is equivalent to scan input behavior.
            index_input_id = ta_read_node.input[1]
            unstacked_ta_consumer = match.get_op("ta_read").output[0]
            ta = InputTensorArray(data_input_id, index_input_id, unstacked_ta_consumer, self.g)
            context.loop_properties.add_scan_input(ta)
github onnx / tensorflow-onnx / tf2onnx / rewriter / unit_rewriter_base.py View on Github external
2 input_x
            3 weight
            4 sequence node
            5 initializer
            6 state output & hidden output
        3 process found info according to ONNX requirement

        remember: op pattern and scope name are useful
                  they are used to get needed info from tensorflow graph
                  raw found info need to be formatted according to ONNX requirement
        """
        # allow_reorder must be true. because LSTMCell and BasicLSTMCell's call function
        # are defining the calculation with different orders. Then we can share the same
        # pattern.
        cell_pattern = get_pattern(unit_type)
        matcher = GraphMatcher(cell_pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(self.g.get_nodes()))

        if match_results:
            for match in match_results:
                self.run_single_match(match)

            self.g.delete_unused_nodes(self.g.outputs)
            self.print_step("finish handling")

        return self.g.get_nodes()
github onnx / tensorflow-onnx / tf2onnx / rewriter / random_uniform.py View on Github external
def rewrite_random_uniform(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', inputs=[
                OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
                OpTypePattern('Sub', name='input2', inputs=["*", "*"]),
            ]), None
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        input2 = match.get_op('input2')
        output = match.get_op('output')
        ru_op = match.get_op('input1')
        # max is on input 0
        tmax = input2.inputs[0].get_tensor_value()
        tmin = input2.inputs[1].get_tensor_value()
        to_delete = list(set(match.get_nodes()))
        new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
        g.replace_all_inputs(ops, output.output[0], new_node.output[0])
        for n in to_delete:
            g.remove_node(n.name)

    return ops
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / loop_rewriter_base.py View on Github external
def _parse_input_ta(self, context):
        graph_inputs = [v.switch_true_identity_output.id for v in context.loop_properties.all_variables.values()
                        if v.switch_true_identity_output.id]
        matcher = GraphMatcher(self.ta_read_input_pattern, allow_reorder=False)
        match_results = matcher.match_ops(self.g.get_nodes())
        match_results = [r for r in match_results if r.get_op("ta_index").output[0] in graph_inputs]
        for match in match_results:
            ta_input_scatter = match.get_op("ta_input_scatter")
            # the 3rd input of scatter is the value
            data_input_id = ta_input_scatter.input[2]
            ta_read_node = match.get_op("ta_read")

            # todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration.
            # then we can be sure this is equivalent to scan input behavior.
            index_input_id = ta_read_node.input[1]
            unstacked_ta_consumer = match.get_op("ta_read").output[0]
            ta = InputTensorArray(data_input_id, index_input_id, unstacked_ta_consumer, self.g)
            context.loop_properties.add_scan_input(ta)
github onnx / tensorflow-onnx / tf2onnx / rewriter / gemm_rewriter.py View on Github external
OpTypePattern('Const', name='beta'),
                OpTypePattern('*', name='C')
            ])
        ])

    # pattern3: A*B + C
    pattern3 = \
        OpTypePattern('Add|AddV2', name='add', inputs=[
            OpTypePattern('MatMul', name='matmul'),
            OpTypePattern('*', name='C'),
        ])

    pattern_list = [pattern0, pattern1, pattern2, pattern3]

    for pattern in pattern_list:
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        if match_results:
            for match in match_results:
                matmul_node = match.get_op("matmul")

                if g.get_dtype(matmul_node.input[0]) != onnx_pb.TensorProto.FLOAT:
                    logging.warning(u"For now, onnxruntime only support float32 type for Gemm rewriter")
                    continue

                attr, is_valid = get_gemm_attr(match)
                if not is_valid:
                    continue

                add_node = match.get_op('add')
                input_c_node = match.get_op("C")
                a_edge_name = matmul_node.input[0]
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / unit_rnn_rewriter_base.py View on Github external
def find_sequence_length_node(self, context):
        # get any state variable
        state_variable = list(context.state_variables.values())[0]
        next_iter_input_node = self.g.get_node_by_output(state_variable.next_iteration_input.id)
        if not is_tf_select_op(next_iter_input_node):
            logger.debug("no sequence length node is given")
            return None
        matcher = GraphMatcher(seq_len_pattern)
        match_result = matcher.match_op(next_iter_input_node)
        if not match_result:
            raise RuntimeError("failed to find sequence length.")
        return match_result.get_op("seq_len_node")