Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
f = f + forget_bias
if not use_peephole:
wci = wcf = wco = 0
i = sigmoid(cs_prev * wci + i)
f = sigmoid(cs_prev * wcf + f)
ci = tanh(ci)
cs = ci .* i + cs_prev .* f
cs = clip(cs, cell_clip)
o = sigmoid(cs * wco + o)
co = tanh(cs)
h = co .* o
"""
builder = GraphBuilder(graph, node.name + '/', ParsedTFNode)
zero = builtins.int32()
zero.val = 0
one = builtins.int32()
one.val = 1
concat_axis = builder.add_const(one, name='concat_axis')
expand_axis = builder.add_const(zero, name='expand_axis')
h_prev_expand = builder.add_expanddims(h_prev, expand_axis)
xh = builder.add_concat([x, h_prev_expand], concat_axis)
icifo_presplit = linear(builder, xh, w, b)
icifo = builder.add_split(value=icifo_presplit, split_dim=concat_axis, num_split=4)
i = builder.add_get_tuple(icifo, index=0)
ci = builder.add_get_tuple(icifo, index=1)
f = builder.add_get_tuple(icifo, index=2)
o = builder.add_get_tuple(icifo, index=3)
if forget_bias is not None and forget_bias != 0.0:
fb = builtins.fp32()
fb.val = forget_bias
bias = builder.add_const(fb, name='forget_bias')
f = builder.add_elementwise("Add", [f, bias])
def visit_iff(self, node):
# an op we inserted. equivalent to the functional IF
# IF cond: true: false
assert (len(node.inputs) == 3)
typecond = self.visit(node.inputs[0])
# assert (builtins.is_tensor(typecond) == False)
typea = self.visit(node.inputs[1])
typeb = self.visit(node.inputs[2])
if typea is not None and typeb is not None:
compatible, restype = builtins.is_tensor_and_is_compatible_general_shape(typea, typeb)
if compatible:
return restype
elif typea == typeb:
return typea
else:
logging.warning(
"In an IFF node %s != %s", builtins.get_type_info(typea),
builtins.get_type_info(typeb))
return typea
if typea is not None:
return typea
else:
return typeb
def _visit_broadcast(self, node, is_predicate=False):
# this is broadcast mul
assert (len(node.inputs) == 2)
typea = self.visit(node.inputs[0])
typeb = self.visit(node.inputs[1])
if typea is not None and typeb is not None:
primitive_type = builtins.bool if is_predicate else self._promoted_primitive_type(
typea, typeb)
if primitive_type is None:
raise ValueError('Incompatible primitive types in broadcast operation')
if builtins.is_tensor(typea):
if builtins.is_tensor(typeb):
retshape = self._broadcast_shape(node, typea.get_shape(), typeb.get_shape())
retval = builtins.tensor(primitive_type, retshape)
else:
# a is tensor, b is not
retval = builtins.tensor(primitive_type, typea.get_shape())
elif builtins.is_tensor(typeb):
# b is tensor, a is not
retval = builtins.tensor(primitive_type, typeb.get_shape())
else:
# both typea and typeb are not tensors
retval = primitive_type
def recursive_replace_symbols_in_type_with_unknown(dtype):
if builtins.is_list(dtype):
return builtins.list(recursive_replace_symbols_in_type_with_unknown(dtype.T[0]))
elif builtins.is_tuple(dtype):
return builtins.tuple(
tuple(recursive_replace_symbols_in_type_with_unknown(t) for t in dtype.T))
elif builtins.is_tensor(dtype):
return builtins.tensor(
dtype.get_primitive(),
tuple(-1 if issubclass(type(t), sm.Basic) else int(t) for t in dtype.get_shape()))
else:
return dtype
ranka = len(typea.get_shape())
rankb = len(typeb.get_shape())
assert (ranka == rankb)
if rankcond == 1 and ranka > 1:
node.attr['expand_dims'] = [-i - 1 for i in range(ranka - rankcond)]
if typea is not None and typeb is not None:
compatible, restype = builtins.is_tensor_and_is_compatible_general_shape(typea, typeb)
if compatible:
return restype
elif typea == typeb:
return typea
else:
logging.error(
"%s != %s", builtins.get_type_info(typea), builtins.get_type_info(typeb))
if typea is not None:
return typea
else:
return typeb
def type_is_unknown(t):
if builtins.is_tuple(t):
return any(type_is_unknown(a) for a in t.T)
elif builtins.is_tensor(t):
return type_is_unknown(t.get_primitive()) or \
t.get_shape() is None or \
len(t.get_shape()) == 0 or \
any_symbolic_or_unknown(t.get_shape())
elif builtins.is_list(t):
return type_is_unknown(t.T[0])
elif t is builtins.unknown:
return True
else:
return t is None
def parse_from_attr(self):
if 'value' in self.attr:
self.datatype = self.attr['value'].__class__
elif '_output_shapes' in self.attr:
output_shapes = self.attr['_output_shapes']
if output_shapes[0] is not None and len(output_shapes[0]) > 0:
if 'dtype' in self.attr:
rettype = builtins.tensor(self.attr['dtype'], tuple(output_shapes[0]))
elif 'T' in self.attr:
rettype = builtins.tensor(self.attr['T'], tuple(output_shapes[0]))
elif 'Tparams' in self.attr:
rettype = builtins.tensor(self.attr['Tparams'], tuple(output_shapes[0]))
else:
raise NotImplementedError(
"Op-(%s) %s not implemented\nWith attribute:" + str(self.attr) %
(self.op, self.name))
self.datatype = rettype
elif 'dtype' in self.attr:
self.datatype = self.attr['dtype']
elif 'shape' in self.attr:
shape = self.attr['shape']
assert ('dtype' in self.attr)
if len(shape) == 0:
self.datatype = self.attr['dtype']
def type_is_unknown(t):
if builtins.is_tuple(t):
return any(type_is_unknown(a) for a in t.T)
elif builtins.is_tensor(t):
return type_is_unknown(t.get_primitive()) or \
t.get_shape() is None or \
len(t.get_shape()) == 0 or \
any_symbolic_or_unknown(t.get_shape())
elif builtins.is_list(t):
return type_is_unknown(t.T[0])
elif t is builtins.unknown:
return True
else:
return t is None
def recursive_replace_symbols_in_type_with_unknown(dtype):
if builtins.is_list(dtype):
return builtins.list(recursive_replace_symbols_in_type_with_unknown(dtype.T[0]))
elif builtins.is_tuple(dtype):
return builtins.tuple(
tuple(recursive_replace_symbols_in_type_with_unknown(t) for t in dtype.T))
elif builtins.is_tensor(dtype):
return builtins.tensor(
dtype.get_primitive(),
tuple(-1 if issubclass(type(t), sm.Basic) else int(t) for t in dtype.get_shape()))
else:
return dtype
if transpose_node_name in graph:
tp_node = graph[transpose_node_name]
if dst.name not in tp_node.outputs:
tp_node.outputs.append(dst.name)
else:
# the node does not exist, so create a fresh one
tp_node = ParsedNode()
tp_node.op = 'Transpose'
tp_node.name = transpose_node_name
# Adjust type inference
if builtins.is_tensor(src.datatype):
s = src.datatype.get_shape()
if len(s) == 4:
tp_shape = tuple([s[transpose_params[0]], s[transpose_params[1]], s[transpose_params[2]], s[transpose_params[3]]])
tp_node.datatype = builtins.tensor(src.datatype.get_primitive(), tp_shape)
else:
tp_node.datatype = src.datatype
tp_node.inputs = [src.name]
tp_node.outputs = [dst.name]
tp_node.attr['dim'] = transpose_params
graph[transpose_node_name] = tp_node
# Rename dst's input 'src' to 'transpose_node_name'
for idx, inp in enumerate(dst.inputs):
if inp == src.name:
dst.inputs[idx] = transpose_node_name
break
# Rename src's output from 'dst' to 'transpose_node_name'
if transpose_node_name in src.outputs: