Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, var_name, type_name, precision, data, reuse_factor, **kwargs):
super(CompressedWeightVariable, self).__init__(var_name, type_name, precision, data, **kwargs)
self.extra_zeros = 0
self.data_length = np.prod(data.shape) - self.nzeros
while self.data_length % reuse_factor != 0:
self.extra_zeros += 1
self.data_length += 1
self.nonzeros = np.prod(data.shape) - self.nzeros + self.extra_zeros
# Compress the array
weights = []
extra_nzero_cnt = self.extra_zeros
it = np.nditer(data, order='C', flags=['multi_index'])
max_idx = 0
while not it.finished:
val = it[0]
if not (val == 0 and extra_nzero_cnt < 1):
if val == 0:
data = self.model.get_weights_data(self.name, name)
elif isinstance(data, six.string_types):
data = self.model.get_weights_data(self.name, data)
if quantize > 0:
data = self.model.quantize_data(data, quantize)
if quantize == 1:
precision = 'ap_uint<1>'
type_name = name + '{index}_t'
elif quantize == 2 or quantize == 3:
precision = 'ap_int<2>'
type_name = name + '{index}_t'
if compression:
rf = self.model.config.get_reuse_factor(self)
var = CompressedWeightVariable(var_name, type_name=type_name, precision=precision, data=data, reuse_factor=rf, index=self.index)
else:
var = WeightVariable(var_name, type_name=type_name, precision=precision, data=data, index=self.index)
self.weights[name] = var
self.precision[var.type.name] = var.type