How to use the tf2onnx.graph_matcher.OpTypePattern function in tf2onnx

To help you get started, we’ve selected a few tf2onnx examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github onnx / tensorflow-onnx / tests / test_internals.py View on Github external
def test_match_flipped(self):
        n1 = helper.make_node("Sub", ["i1", "i1"], ["n1:0"], name="n1")
        n2 = helper.make_node("Add", ["i2", "i2"], ["n2:0"], name="n2")
        n3 = helper.make_node("Mul", ["n1:0", "n2:0"], ["n3:0"], name="n3")

        graph_proto = helper.make_graph(
            nodes=[n1, n2, n3],
            name="test",
            inputs=[helper.make_tensor_value_info("i1", TensorProto.FLOAT, [2, 2]),
                    helper.make_tensor_value_info("i2", TensorProto.FLOAT, [2, 2])],
            outputs=[helper.make_tensor_value_info("n2:0", TensorProto.FLOAT, [2, 2])],
            initializer=[]
        )
        g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
        pattern = OpTypePattern('Mul', inputs=[
            OpTypePattern('Add'),
            OpTypePattern('Sub')
        ])
        ops = g.get_nodes()
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        self.assertEqual(1, len(match_results))
github onnx / tensorflow-onnx / tests / test_internals.py View on Github external
def test_rewrite_subgraph(self):
        graph_proto = self.sample_net()
        g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
        pattern = \
            OpTypePattern('Abs', name='output', inputs=[
                OpTypePattern('Add', name='input')
            ])
        ops = g.get_nodes()
        matcher = GraphMatcher(pattern)
        match_results = list(matcher.match_ops(ops))
        for match in match_results:
            input_node = match.get_op('input')
            output_node = match.get_op('output')
            op_name = utils.make_name("ReplacedOp")
            out_name = utils.port_name(op_name)
            new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
            g.replace_all_inputs(ops, output_node.output[0], new_node.output[0])
            for n in set(match.get_nodes()):
                g.remove_node(n.name)
        g.topological_sort(ops)
        result = onnx_to_graphviz(g)
        expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
github onnx / tensorflow-onnx / tf2onnx / rewriter / rnn_utils.py View on Github external
OpTypePattern("Enter", inputs=[
                                OpTypePattern("*", name="hidden_state_bias")
                            ]),
                            OpTypePattern("MatMul", inputs=[
                                OpTypePattern("Enter", inputs=[
                                    OpTypePattern("*", name="hidden_state_kernel"),
                                ]),
                                OpTypePattern("Identity")
                            ])
                        ])
                    ]),
                    OpTypePattern("BiasAdd", inputs=[
                        OpTypePattern("Enter", inputs=[
                            OpTypePattern("*", name="hidden_input_bias")
                        ]),
                        OpTypePattern("MatMul", inputs=[
                            OpTypePattern("Enter", inputs=[
                                OpTypePattern("*", name="hidden_input_kernel"),
                            ]),
                            OpTypePattern("*")
                        ])
                    ])
                ])
            ])
        ]),
        OpTypePattern("Mul", inputs=[
            gru_split_pattern,
            OpTypePattern("Identity")
        ])
    ])
github onnx / tensorflow-onnx / tf2onnx / rewriter / random_uniform.py View on Github external
def rewrite_random_uniform_fold_const(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='mul', inputs=[
                OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
                None,
            ]),
            None,
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        output = match.get_op('output')
        mul = match.get_op('mul')
        ru_op = match.get_op('input1')

        tmax_minus_tmin = mul.inputs[1].get_tensor_value()
        tmin = output.inputs[1].get_tensor_value()
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / rnn_utils.py View on Github external
logger = logging.getLogger(__name__)


class REWRITER_RESULT(Enum):
    SKIP = 1
    OK = 2
    FAIL = 3


# TensorFlow LSTMCell/BasicLSTMCell computation graph matching
xc_pattern = OpTypePattern('Split', inputs=[
    OpTypePattern("Const"), # axis for split
    OpTypePattern("BiasAdd", name="bias_add", inputs=[
        OpTypePattern("MatMul", inputs=[
            OpTypePattern("ConcatV2|Concat", name="xh"),
            OpTypePattern("Enter", inputs=[
                OpTypePattern("*", name="cell_kernel"),
            ]),
        ]),
        OpTypePattern("Enter", inputs=[
            OpTypePattern("*", name="cell_bias"),
        ]),
    ]),
])


lstmcell_pattern = \
    OpTypePattern('Mul', name='ht', inputs=[
        OpTypePattern("Sigmoid", name="ot", inputs=[xc_pattern]),
        OpTypePattern('Tanh', inputs=[
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / rnn_utils.py View on Github external
OpTypePattern("Enter", inputs=[
                    OpTypePattern("*", name="gate_bias")
                ]),
                OpTypePattern("MatMul", name="update_reset_gate", inputs=[
                    OpTypePattern("Enter", inputs=[
                        OpTypePattern("*", name="gate_kernel")
                    ]),
                    OpTypePattern("ConcatV2|Concat", name="cell_inputs")
                ])
            ])
        ])
    ])


grucell_pattern = \
    OpTypePattern("Add", name="cell_output", inputs=[
        OpTypePattern("Mul", inputs=[
            gru_split_pattern,
            OpTypePattern("Identity")
        ]),
        OpTypePattern("Mul", inputs=[
            OpTypePattern("Sub", inputs=[
                OpTypePattern("Const"),  # 1-u
                gru_split_pattern
            ]),
            OpTypePattern("*", name="optional_activation", inputs=[
                OpTypePattern("BiasAdd", inputs=[
                    OpTypePattern("Enter", inputs=[
                        OpTypePattern("*", name="hidden_bias")
                    ]),
                    OpTypePattern("MatMul", inputs=[
                        OpTypePattern("Enter", inputs=[
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / rnn_utils.py View on Github external
])


lstmcell_pattern = \
    OpTypePattern('Mul', name='ht', inputs=[
        OpTypePattern("Sigmoid", name="ot", inputs=[xc_pattern]),
        OpTypePattern('Tanh', inputs=[
            OpTypePattern("Add", name="ct", inputs=[
                OpTypePattern("Mul", name="ct_identity_consumer", inputs=[
                    OpTypePattern("Sigmoid", name="ft", inputs=[
                        OpTypePattern("Add", inputs=[
                            xc_pattern,
                            OpTypePattern("*", name="ft_bias"),
                        ]),
                    ]),
                    OpTypePattern("*"),
                ]),
                OpTypePattern("Mul", inputs=[
                    OpTypePattern("Sigmoid", name="it", inputs=[xc_pattern]),
                    OpTypePattern("Tanh", name="gt", inputs=[xc_pattern]),
                ]),
            ]),
        ]),
    ])

# input sequence: top to down, left to right
# split into update gate and reset gate
gru_split_pattern = \
    OpTypePattern("Split", inputs=[
        OpTypePattern("Const"),  # split dim, a constant
        OpTypePattern("Sigmoid", inputs=[
            OpTypePattern("BiasAdd", inputs=[
github onnx / tensorflow-onnx / tf2onnx / rewriter / rnn_utils.py View on Github external
OpTypePattern("*"),
                ]),
                OpTypePattern("Mul", inputs=[
                    OpTypePattern("Sigmoid", name="it", inputs=[xc_pattern]),
                    OpTypePattern("Tanh", name="gt", inputs=[xc_pattern]),
                ]),
            ]),
        ]),
    ])

# input sequence: top to down, left to right
# split into update gate and reset gate
gru_split_pattern = \
    OpTypePattern("Split", inputs=[
        OpTypePattern("Const"),  # split dim, a constant
        OpTypePattern("Sigmoid", inputs=[
            OpTypePattern("BiasAdd", inputs=[
                OpTypePattern("Enter", inputs=[
                    OpTypePattern("*", name="gate_bias")
                ]),
                OpTypePattern("MatMul", name="update_reset_gate", inputs=[
                    OpTypePattern("Enter", inputs=[
                        OpTypePattern("*", name="gate_kernel")
                    ]),
                    OpTypePattern("ConcatV2|Concat", name="cell_inputs")
                ])
            ])
        ])
    ])


grucell_pattern = \
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / loop_rewriter_base.py View on Github external
def __init__(self, g):
        self.g = g
        self.ta_read_input_pattern = \
            OpTypePattern("TensorArrayReadV3", name="ta_read", inputs=[
                OpTypePattern("Enter", name="ta_enter", inputs=[
                    OpTypePattern("TensorArrayV3")
                ]),
                OpTypePattern("Identity", name="ta_index"),
                OpTypePattern("Enter", name="ta_scatter_enter", inputs=[
                    OpTypePattern("TensorArrayScatterV3", name="ta_input_scatter")
                ]),
github onnx / tensorflow-onnx / tf2onnx / rewriter / transpose_rewriter.py View on Github external
def rewrite_transpose(g, ops):
    pattern = \
        OpTypePattern('Transpose', name='output', inputs=[
            OpTypePattern(None),
            OpTypePattern('Sub', inputs=[
                OpTypePattern('Sub', inputs=["*", "*"]),
                OpTypePattern('Range', inputs=["*", "*", "*"]),
            ]),
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        output = match.get_op('output')
        shape = g.get_shape(output.input[0])
        dims = range(len(shape) - 1, -1, -1)
        output.set_attr("perm", dims)
        g.remove_input(output, output.input[1])
        to_delete = [n for n in match.get_nodes() if n != output]
        g.safe_remove_nodes(to_delete)
    return ops