Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def build(self):
for layer in self.model.node:
self.node_map[layer.name.replace('/', '_').replace(
'-', '_')] = TFGraphNode(layer, data_format=self.tf_data_format)
for layer_name, node in self.node_map.items():
for in_node in node.layer.input:
in_node = in_node.replace('/',
'_').replace('-',
'_').replace('^', '')
if in_node not in self.node_map:
if in_node.strip().split(':')[0] in self.node_map:
self.connect(in_node.strip().split(':')[0], layer_name)
else:
raise Exception(
'input[{}] of node[{}] does not exist in node_map'.
format(in_node, layer_name))
else:
self.connect(in_node, layer_name)
def __init__(self, layer, layer_name=None, data_format="NHWC"):
if layer_name is None:
super(TFGraphNode, self).__init__(
layer,
layer.name.replace('/', '_').replace('-', '_').replace('^', ''))
else:
super(TFGraphNode, self).__init__(
layer,
layer_name.replace('/', '_').replace('-', '_').replace('^', ''))
self.layer_type = layer.op
self.tf_data_format = data_format
self.pd_data_format = "NCHW"
self.fluid_code = FluidCode()
self.dtype_map = {
1: "float32",
3: "int32",
4: "uint8",
9: "int64",
10: "bool"
}
def _check_input_shape(self, graph_def):
numpy.random.seed(13)
graph_def = cp.deepcopy(graph_def)
input_map = dict()
for layer in graph_def.node:
if layer.op != "Placeholder":
continue
graph_node = TFGraphNode(layer)
dtype = graph_node.layer.attr['dtype'].type
need_define_shape = 0
if self.define_input_shape:
need_define_shape = 3
elif graph_node.layer.attr[
'shape'].shape.unknown_rank or not graph_node.get_attr(
"shape"):
need_define_shape = 1
else:
value = graph_node.layer.attr["shape"].shape
shape = [dim.size for dim in value.dim]
if shape.count(-1) > 1:
need_define_shape = 2
if need_define_shape > 0: