Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def handle_common_attributes(node, default_activations):
direction = NodeFactory.get_attribute(node, 'direction')
if direction:
direction = str(direction, 'utf-8')
else:
direction = 'forward'
num_directions = 2 if direction == 'bidirectional' else 1
activations = NodeFactory.get_attribute(node, 'activations')
if activations:
activations = [str(x, 'utf-8').lower().capitalize() for x in activations]
else:
activations = default_activations * num_directions
activation_alpha = NodeFactory.get_attribute(node, 'activation_alpha')
activation_beta = NodeFactory.get_attribute(node, 'activation_beta')
clip_threshold = NodeFactory.get_attribute(node, 'clip')
# TODO: support these activation attributes
for i in inputs:
if type(i) in [onnx.NodeProto, onnx.TensorProto, onnx.ValueInfoProto]:
input_names.append(i.name)
elif type(i) == str:
input_names.append(i)
elif type(i) == np.ndarray:
new_initializer = self.make_initializer(i)
input_names.append(new_initializer.name)
else:
assert False # unexpected type in input
if not node:
node = self.graph_.node.add()
name = self.name_prefix_ + op_type + '_' + str(NodeFactory.node_count_)
NodeFactory.node_count_ = NodeFactory.node_count_ + 1
if not output_names:
output_names = [name]
node.CopyFrom(helper.make_node(op_type, input_names, output_names, name, **attributes))
return node
def convert_lstm_to_scan(node, out_main_graph):
assert node.op_type == 'LSTM'
nf = NodeFactory(out_main_graph)
with nf.scoped_prefix(node.output[0]) as scoped_prefix:
X = node.input[0]
Wa = nf.get_initializer(node.input[1])
Ra = nf.get_initializer(node.input[2])
num_inputs = len(node.input)
Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None
seq_len = node.input[4] if num_inputs > 4 else None
InitHa = node.input[5] if num_inputs > 5 else None
InitCa = node.input[6] if num_inputs > 6 else None
PB = node.input[7] if num_inputs > 7 else None
# TODO: support peephole
assert not PB
direction, num_directions, activations = handle_common_attributes(node, ['Sigmoid', 'Tanh', 'Tanh'])
def scoped_prefix(self, prefix):
return NodeFactory.ScopedPrefix(self, prefix)
input_size = Wa.shape[len(Wa.shape) - 1]
Wt = np.transpose(Wa[direction_index])
Rt = np.transpose(Ra[direction_index])
B = Ba[direction_index].reshape(2, 4*hidden_size).sum(axis=0) # [4*hidden_size]
X_proj = nf.make_node('MatMul', [X, Wt]) #[seq_len, batch_size, 4*hidden_size]
X_proj = nf.make_node('Add', [X_proj, B])
if num_directions == 1:
is_backward = 0 if direction == 'forward' else 1
else:
is_backward = direction_index
scan_body = onnx.GraphProto()
scan_body.name = name_prefix + '_subgraph'
nf_body = NodeFactory(out_main_graph, scan_body)
with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix:
# subgraph inputs
X_proj_subgraph = X_proj.name + '_subgraph'
prev_h_subgraph = name_prefix + '_h_subgraph'
prev_c_subgraph = name_prefix + '_c_subgraph'
seq_len_subgraph = declare_seq_len_in_subgraph(seq_len, nf_body, X_proj.name, batch_size)
for subgraph_i in [prev_h_subgraph, prev_c_subgraph]:
nf_body.make_value_info(subgraph_i,
data_type=onnx.TensorProto.FLOAT,
shape=(batch_size, hidden_size),
usage=NodeFactory.ValueInfoType.input)
nf_body.make_value_info(X_proj_subgraph,
data_type=onnx.TensorProto.FLOAT,
def convert_gru_to_scan(node, out_main_graph):
assert node.op_type == 'GRU'
nf = NodeFactory(out_main_graph)
with nf.scoped_prefix(node.output[0]) as scoped_prefix:
X = node.input[0]
Wa = nf.get_initializer(node.input[1])
Ra = nf.get_initializer(node.input[2])
num_inputs = len(node.input)
Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None
seq_len = node.input[4] if num_inputs > 4 else None
InitHa = node.input[5] if num_inputs > 5 else None
direction, num_directions, activations = handle_common_attributes(node, ['Sigmoid', 'Tanh'])
hidden_size = NodeFactory.get_attribute(node, 'hidden_size')
linear_before_reset = NodeFactory.get_attribute(node, 'linear_before_reset')
InitHa = handle_init_state(InitHa, nf, num_directions)
batch_size, batch_node = handle_batch_size(X, nf, InitHa is None)
# only support 1 scan input
if num_scan_inputs == 1:
optimize_scan = True
# copy the node if it's not the scan node that is supported at the moment
if not optimize_scan:
out_n = out_mp.graph.node.add()
out_n.CopyFrom(in_n)
continue
scan_input_directions = NodeFactory.get_attribute(in_n, 'scan_input_directions')
scan_output_directions = NodeFactory.get_attribute(in_n, 'scan_output_directions')
out_sg = onnx.GraphProto()
out_sg.CopyFrom(in_sg)
out_sg.ClearField('node')
nf_subgraph = NodeFactory(out_mp.graph, out_sg, prefix='opt_inproj_sg_' + in_n.name + '_')
new_inputs = list(in_n.input)
in_sg_inputs = [i.name for i in in_sg.input]
replaced_matmul = None
for in_sn in in_sg.node:
if in_sn.op_type == 'Concat' and len(in_sn.input) == 2 and all([i in in_sg_inputs for i in in_sn.input]):
# make sure the concat's inputs are scan input and scan state
if NodeFactory.get_attribute(in_sn, 'axis') != len(in_sg.input[-1].type.tensor_type.shape.dim) - 1:
continue # must concat last dim
matmul_node = [nn for nn in in_sg.node if nn.op_type == 'MatMul' and in_sn.output[0] in nn.input]
if not matmul_node:
continue
replaced_matmul = matmul_node[0]
assert replaced_matmul.input[1] in initializers
aa = nf.get_initializer(replaced_matmul.input[1])
input_size = in_sg.input[-1].type.tensor_type.shape.dim[-1].dim_value
if in_sg_inputs[-1] == in_sn.input[0]:
def handle_subgraph_outputs(nf_body, seq_len_subgraph, batch_size, hidden_size, subgraph_output_or_default):
final_subgraph_output = []
if seq_len_subgraph:
seq_len_output = nf_body.make_node('Sub', [seq_len_subgraph, np.asarray([1]).astype(np.int32)])
nf_body.make_value_info(seq_len_output,
data_type=onnx.TensorProto.INT32,
shape=(batch_size,),
usage=NodeFactory.ValueInfoType.output)
final_subgraph_output.append(seq_len_output)
# since seq_len is rank-1, need to unsqueeze for Where op on rank-2 states
condition = nf_body.make_node('Unsqueeze', nf_body.make_node('Greater', [seq_len_subgraph, np.zeros(shape=(), dtype=np.int32)]), {'axes':[1]})
for valid, default in subgraph_output_or_default:
final_subgraph_output.append(nf_body.make_node('Where', [condition, valid, default]))
else:
final_subgraph_output.append(None)
for valid, default in subgraph_output_or_default:
final_subgraph_output.append(nf_body.make_node('Identity', valid))
for subgraph_o in final_subgraph_output[1:]:
nf_body.make_value_info(subgraph_o,
data_type=onnx.TensorProto.FLOAT,
shape=(batch_size, hidden_size),
usage=NodeFactory.ValueInfoType.output)
def optimize_input_projection(input_model, output_model):
in_mp = onnx.load(input_model)
out_mp = onnx.ModelProto()
out_mp.CopyFrom(in_mp)
out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input
out_mp.graph.ClearField('node')
nf = NodeFactory(out_mp.graph, prefix='opt_inproj_')
initializers = dict([(i.name, i) for i in in_mp.graph.initializer])
# first find possible fused SVD and do constant folding on MatMul of initializers
const_matmuls = [n for n in in_mp.graph.node if n.op_type == 'MatMul' and all([i in initializers for i in n.input])]
for mm in const_matmuls:
lhs = numpy_helper.to_array(initializers[mm.input[0]])
rhs = numpy_helper.to_array(initializers[mm.input[1]])
val = np.matmul(lhs, rhs)
new_initializer = out_mp.graph.initializer.add()
new_initializer.CopyFrom(numpy_helper.from_array(val, mm.output[0]))
if not [n for n in in_mp.graph.node if n != mm and mm.input[0] in n.input]:
nf.remove_initializer(mm.input[0])
if not [n for n in in_mp.graph.node if n != mm and mm.input[1] in n.input]:
nf.remove_initializer(mm.input[1])
initializers = dict([(i.name,i) for i in out_mp.graph.initializer])