How to use the tf2onnx.rewriter.rnn_utils.REWRITER_RESULT function in tf2onnx

To help you get started, we’ve selected a few tf2onnx examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github onnx / tensorflow-onnx / tf2onnx / rewriter / custom_rnn_rewriter.py View on Github external
for input_tensor_info in scan_props.state_inputs:
                scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape)

            for input_tensor_info in scan_props.scan_inputs:
                scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape)

            scan_node = self._create_scan_node(context, scan_props,
                                               state_inputs_initial_values + scan_inputs_initial_values)
            if not scan_node:
                log.error("failed to create scan node during rewrite")
                return REWRITER_RESULT.FAIL

            scan_node.set_body_graph_as_attr("body", scan_body_g)
            self._connect_scan_with_output(context, scan_node)

            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            log.error("custom rnn rewrite failed, due to exception: %s, details:%s", ex, tb)
            return REWRITER_RESULT.FAIL
github onnx / tensorflow-onnx / tf2onnx / rewriter / unit_rewriter_base.py View on Github external
if not rnn_scope_name:
            log.debug("unable to find rnn scope name, skip")
            return REWRITER_RESULT.SKIP
        log.debug("rnn scope name is %s", rnn_scope_name)

        self.print_step("get_weight_and_bias starts")
        rnn_weights = self.get_weight_and_bias(match)
        if not rnn_weights:
            log.debug("rnn weights check failed, skip")
            return REWRITER_RESULT.SKIP

        rnn_props = RnnProperties()
        res = self.get_var_initializers(match, rnn_props, rnn_scope_name)
        if not res or not rnn_props.var_initializers.keys:
            log.debug("no cell variable initializers found, skip")
            return REWRITER_RESULT.SKIP

        seq_len_input_node = self.find_sequence_length_node(rnn_scope_name)
        input_filter = self.get_rnn_input_blacklist(rnn_weights, rnn_props)
        if seq_len_input_node:
            input_filter.append(seq_len_input_node)

        self.find_inputs(rnn_scope_name, rnn_props, match, input_filter)
        if not rnn_props.is_valid():
            log.debug("rnn properties are not valid, skip")
            return REWRITER_RESULT.SKIP

        if not self.process_input_x(rnn_props, rnn_scope_name):
            log.debug("rnn input x not found, skip")
            return REWRITER_RESULT.SKIP

        self.print_step("process the weights/bias/ft_bias, to fit onnx weights/bias requirements")
github onnx / tensorflow-onnx / tf2onnx / rewriter / loop_rewriter.py View on Github external
loop_body_g.replace_all_inputs(loop_body_g.get_nodes(), input_ta.consumer.id, data_node.output[0])

            ## create Loop node
            loop_node = self._create_loop_node(context, loop_props, init_cond_output)
            if not loop_node:
                logger.error("failed to create loop node during rewrite")
                return REWRITER_RESULT.FAIL
            loop_node.set_body_graph_as_attr("body", loop_body_g)

            logger.debug("rewrite successfully")
            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            logger.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb)
            return REWRITER_RESULT.FAIL
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / custom_rnn_rewriter.py View on Github external
for input_tensor_info in scan_props.state_inputs:
                scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape)

            for input_tensor_info in scan_props.scan_inputs:
                scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape)

            scan_node = self._create_scan_node(context, scan_props,
                                               state_inputs_initial_values + scan_inputs_initial_values)
            if not scan_node:
                logger.error("failed to create scan node during rewrite")
                return REWRITER_RESULT.FAIL

            scan_node.set_body_graph_as_attr("body", scan_body_g)
            self._connect_scan_with_output(context, scan_node)

            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            logger.error("custom rnn rewrite failed, due to exception: %s, details:%s", ex, tb)
            return REWRITER_RESULT.FAIL
github onnx / tensorflow-onnx / tf2onnx / rewriter / loop_rewriter_base.py View on Github external
def rewrite(self, context):
        return REWRITER_RESULT.FAIL
github onnx / tensorflow-onnx / tf2onnx / rewriter / unit_rewriter_base.py View on Github external
most found info is stored in "rnn_props"
        """
        log.debug("=========================")
        self.print_step("start handling a new potential rnn cell")
        self.all_nodes = self.g.get_nodes()
        # FIXME:
        # pylint: disable=assignment-from-none,assignment-from-no-return

        # when bi-directional, node in while will be rnnxx/fw/fw/while/... >> scope name is rnnxx/fw/fw
        # when single direction, node in while will be rnnxx/while/... >> scope name is rnnxx
        # and rnnxx can be assigned by users but not "fw", though maybe "FW" in another tf version
        rnn_scope_name = self.get_rnn_scope_name(match)
        if not rnn_scope_name:
            log.debug("unable to find rnn scope name, skip")
            return REWRITER_RESULT.SKIP
        log.debug("rnn scope name is %s", rnn_scope_name)

        self.print_step("get_weight_and_bias starts")
        rnn_weights = self.get_weight_and_bias(match)
        if not rnn_weights:
            log.debug("rnn weights check failed, skip")
            return REWRITER_RESULT.SKIP

        rnn_props = RnnProperties()
        res = self.get_var_initializers(match, rnn_props, rnn_scope_name)
        if not res or not rnn_props.var_initializers.keys:
            log.debug("no cell variable initializers found, skip")
            return REWRITER_RESULT.SKIP

        seq_len_input_node = self.find_sequence_length_node(rnn_scope_name)
        input_filter = self.get_rnn_input_blacklist(rnn_weights, rnn_props)
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / loop_rewriter_base.py View on Github external
# self.g.get_nodes may change inside this loop so that we parse all LoopCond first
        for op in loopcond_ops:
            logger.debug("======================\n handling loop cond node called %s", op.name)
            context = self.create_context()
            context.loop_cond = op

            self._check_in_read_only_mode(context)

            if self.need_rewrite(context):
                # cut off connection between cell/cond graphs and useless nodes like Merge, NextIteration.
                self._cut_off_connection_for_cell(context)
                context.cell_graph = self._crop_loop_body_sub_graph(context)
                context.cond_graph = self._crop_loop_condition_sub_graph(context)

                _result = self.rewrite(context)
                if _result == REWRITER_RESULT.OK:
                    logger.debug("rewrite successfully")
                elif _result == REWRITER_RESULT.SKIP:
                    logger.debug("rewrite skipped for LoopCond called %s", op.name)
                    continue
                elif _result == REWRITER_RESULT.FAIL:
                    raise ValueError("rewrite failed, so just fast fail it")

        if self.g.outputs:
            # clean the graph based on output names.
            self.g.delete_unused_nodes(self.g.outputs)
        return self.g.get_nodes()
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / loop_rewriter.py View on Github external
for input_ta in loop_props.tensor_array_inputs:
                # Loop does not have scan inputs, so we use Gather to get data for each iteration.
                index_node = loop_body_g.make_node("Unsqueeze", [input_ta.index_input_id], attr={"axes": [0]})
                gather_node = loop_body_g.make_node("Gather", [input_ta.data_input_id, index_node.output[0]])
                data_node = loop_body_g.make_node("Squeeze", [gather_node.output[0]], attr={"axes": [0]})
                loop_body_g.replace_all_inputs(loop_body_g.get_nodes(), input_ta.consumer.id, data_node.output[0])

            ## create Loop node
            loop_node = self._create_loop_node(context, loop_props)
            if not loop_node:
                logger.error("failed to create loop node during rewrite")
                return REWRITER_RESULT.FAIL
            loop_node.set_body_graph_as_attr("body", loop_body_g)

            logger.debug("rewrite successfully")
            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            logger.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb)
            return REWRITER_RESULT.FAIL