How to use the x2paddle.decoder.tf_decoder.TFGraphNode function in x2paddle

To help you get started, we’ve selected a few x2paddle examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github PaddlePaddle / X2Paddle / x2paddle / decoder / tf_decoder.py View on Github external
def build(self):
        for layer in self.model.node:
            self.node_map[layer.name.replace('/', '_').replace(
                '-', '_')] = TFGraphNode(layer, data_format=self.tf_data_format)

        for layer_name, node in self.node_map.items():
            for in_node in node.layer.input:
                in_node = in_node.replace('/',
                                          '_').replace('-',
                                                       '_').replace('^', '')
                if in_node not in self.node_map:
                    if in_node.strip().split(':')[0] in self.node_map:
                        self.connect(in_node.strip().split(':')[0], layer_name)
                    else:
                        raise Exception(
                            'input[{}] of node[{}] does not exist in node_map'.
                            format(in_node, layer_name))
                else:
                    self.connect(in_node, layer_name)
github PaddlePaddle / X2Paddle / x2paddle / decoder / tf_decoder.py View on Github external
def __init__(self, layer, layer_name=None, data_format="NHWC"):
        if layer_name is None:
            super(TFGraphNode, self).__init__(
                layer,
                layer.name.replace('/', '_').replace('-', '_').replace('^', ''))
        else:
            super(TFGraphNode, self).__init__(
                layer,
                layer_name.replace('/', '_').replace('-', '_').replace('^', ''))

        self.layer_type = layer.op
        self.tf_data_format = data_format
        self.pd_data_format = "NCHW"
        self.fluid_code = FluidCode()

        self.dtype_map = {
            1: "float32",
            3: "int32",
            4: "uint8",
            9: "int64",
            10: "bool"
        }
github PaddlePaddle / X2Paddle / x2paddle / decoder / tf_decoder.py View on Github external
def _check_input_shape(self, graph_def):
        numpy.random.seed(13)
        graph_def = cp.deepcopy(graph_def)
        input_map = dict()
        for layer in graph_def.node:
            if layer.op != "Placeholder":
                continue
            graph_node = TFGraphNode(layer)
            dtype = graph_node.layer.attr['dtype'].type

            need_define_shape = 0
            if self.define_input_shape:
                need_define_shape = 3
            elif graph_node.layer.attr[
                    'shape'].shape.unknown_rank or not graph_node.get_attr(
                        "shape"):
                need_define_shape = 1
            else:
                value = graph_node.layer.attr["shape"].shape
                shape = [dim.size for dim in value.dim]
                if shape.count(-1) > 1:
                    need_define_shape = 2

            if need_define_shape > 0: