Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
'Activation' : Activation,
'LeakyReLU' : ParametrizedActivation,
'ThresholdedReLU' : ParametrizedActivation,
'ELU' : ParametrizedActivation,
'PReLU' : PReLU,
'Dense' : Dense,
'BinaryDense' : Dense,
'TernaryDense' : Dense,
'Conv1D' : Conv1D,
'Conv2D' : Conv2D,
'BatchNormalization' : BatchNormalization,
'MaxPooling1D' : Pooling1D,
'AveragePooling1D' : Pooling1D,
'MaxPooling2D' : Pooling2D,
'AveragePooling2D' : Pooling2D,
'Merge' : Merge,
'Concatenate' : Concatenate,
}
def register_layer(name, clazz):
global layer_map
layer_map[name] = clazz
params['input1_t'] = self.get_input_variable(self.inputs[0]).type.name
params['input2_t'] = self.get_input_variable(self.inputs[1]).type.name
params['output_t'] = self.get_output_variable().type.name
params['input1'] = self.get_input_variable(self.inputs[0]).name
params['input2'] = self.get_input_variable(self.inputs[1]).name
params['output'] = self.get_output_variable().name
return [self._function_template.format(**params)]
def config_cpp(self):
params = self._default_config_params()
params['n_elem'] = self.get_input_variable(self.inputs[0]).size_cpp()
return self._config_template.format(**params)
class Concatenate(Merge):
def initialize(self):
assert(len(self.inputs) == 2)
inp1 = self.get_input_variable(self.inputs[0])
inp2 = self.get_input_variable(self.inputs[1])
shape = [sum(x) for x in zip(inp1.shape, inp2.shape)]
rank = len(shape)
if rank > 1:
dims = ['OUT_CONCAT_{}_{}'.format(i, self.index) for i in range(rank)]
else:
dims = ['OUT_CONCAT_{}'.format(self.index)]
self.add_output_variable(shape, dims)
def config_cpp(self):
params = self._default_config_params()
for i in range(3):
params.setdefault('n_elem1_{}'.format(i), 0)