Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# connect fused node to entry and output nodes
connect_edge(graph, current_node.name, fused_bn_node.name)
connect_dests(graph, fused_bn_node.name, bn_outputs)
# correct output's inputs order
for out in bn_outputs:
if len(graph[out].inputs) < 2:
continue
out_inputs = graph[out].inputs
a = out_inputs.index(out_node.name)
b = out_inputs.index(fused_bn_node.name)
out_inputs[a], out_inputs[b] = out_inputs[b], out_inputs[a]
# delete merged nodes
for name in bn_node_names:
delete_node(graph, name)
cropping_values[1] = croppings[1, 1] # right
needs_cropping_after = False
border_mode = n.attr.get('padding', '').lower()
if sum(cropping_values) != 0:
if border_mode != 'valid':
needs_cropping_after = True
else:
raise NotImplementedError('unhandled BatchToSpaceND case.')
if needs_cropping_after:
graph[output_node].attr.update({'_cropping_after': cropping_values})
# adjust type inference
shape = list(graph[previous_node].datatype.get_shape())
graph[output_node].datatype = builtins.tensor(graph[output_node].datatype.get_primitive(), tuple(shape))
delete_node(graph, n.name)
count += 1
if count > 0:
print('[Op Fusion] Skipped {} BatchToSpaceND / SpaceToBatchND nodes.'.format(count))
out_node = gelu_nodes[2]
gelu_outputs = out_node.outputs[:]
# Instantiate a new fused node in the graph
fused_gelu_node = ParsedNode()
fused_gelu_node.op = 'GeLU'
fused_gelu_node.name = out_node.name + '_gelu'
fused_gelu_node.attr = {}
fused_gelu_node.datatype = current_node.datatype
graph[fused_gelu_node.name] = fused_gelu_node
# Delete nodes
gelu_node_names = [x.name for x in gelu_nodes]
for name in gelu_node_names:
delete_node(graph, name)
# Connect fused node to entry and output nodes
connect_edge(graph, current_node.name, fused_gelu_node.name)
connect_dests(graph, fused_gelu_node.name, gelu_outputs)
count += 1
if count > 0:
print('[Op Fusion] Fused {} GeLUs.'.format(count))
keys = list(f.graph.keys())
for k in keys:
if k not in f.graph:
continue
node = f.graph[k]
if len(node.inputs) != 1 or len(node.outputs) != 1:
continue
inp_node = f.graph[node.inputs[0]]
if node.op == 'Identity' and inp_node.op != 'get_tuple':
delete_count += 1
parent_name = f.graph[k].inputs[0]
disconnect_edge(f.graph, parent_name, k)
for control_input in f.graph[k].control_inputs:
replace_control_dest(f.graph, control_input, k, parent_name)
replace_node(f.graph, k, parent_name) # join parent to children
delete_node(f.graph, k)
return delete_count
delete_node(graph, node.name)
builder = GraphBuilder(graph, node.name + '/', ParsedTFNode)
x_center = builder.add_elementwise("Sub", [x, estimated_mean])
scaling_factor = builder.add_elementwise(
"Mul", [scale, builder.add_elementwise("Rsqrt", [estimated_variance])])
x_scaled = builder.add_elementwise("Mul", [x_center, scaling_factor])
x_shifted = builder.add_elementwise("Add", [x_scaled, offset])
x_final = GraphBuilder(graph, '', ParsedTFNode).add_identity(x_shifted, node.name)
outputs = [x_final]
for o in original_node_outputs:
replace_node(graph, o, outputs[graph[o].attr['index']])
delete_node(graph, o)
self.node_value_trace[cur_node_key] == source_const_key:
delete_nodes.append(nodename)
elif node.op == 'get_tuple':
my_trace = self.node_value_trace[(
fname,
nodename,
)]
parent_trace = self.node_value_trace[(
fname,
node.inputs[0],
)]
if type(parent_trace) is list:
node.attr['index'] = parent_trace.index(my_trace)
if fname == source_fname:
delete_node(source_fn.graph, source_node_name)
target_fn.graph[source_node_name] = source_const_node
for d in delete_nodes:
if d != source_node_name:
delete_node(fn.graph, d)
elif fname == target_fname:
# if this is the target function, we rewrite.
for d in delete_nodes:
assert (len(fn.graph[d].inputs) == 1)
fn.graph[d].op = 'Identity'
disconnect_edge(fn.graph, fn.graph[d].inputs[0], d)
connect_edge(fn.graph, source_const_node.name, d)
self.node_value_trace[(
fname,
d,
)] = (
fname,
fused_ln_node.op = 'LayerNormalization'
fused_ln_node.name = out_node.name + '_layernorm'
fused_ln_node.attr = ln_params
fused_ln_node.datatype = current_node.datatype
graph[fused_ln_node.name] = fused_ln_node
# Connect fused node to entry and output nodes
connect_edge(graph, current_node.name, fused_ln_node.name)
replace_node(graph, out_node.name, fused_ln_node.name)
# connect_dests(graph, fused_ln_node.name, ln_outputs)
# Delete nodes
ln_node_names = [x.name for x in ln_nodes]
for name in ln_node_names:
delete_node(graph, name)
count += 1
if count > 0:
print('[Op Fusion] Fused {} layer normalizations.'.format(count))