Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _register_weight_sparsifying_operations(self, device, ignored_scopes, target_scopes, logger):
sparsified_modules = get_all_modules_by_type(self._model, NNCF_MODULES)
self.sparsified_module_info = []
for module_name, module in sparsified_modules.items():
if in_scope_list(module_name, ignored_scopes):
logger.info("Ignored adding Weight Sparsifier in scope: {}".format(module_name))
continue
if target_scopes is None or in_scope_list(module_name, target_scopes):
logger.info("Adding Weight Sparsifier in scope: {}".format(module_name))
operation = self.create_weight_sparsifying_operation(module)
opid = module.register_pre_forward_operation(UpdateWeight(operation).to(device))
self.sparsified_module_info.append(
SparseModuleInfo(module_name, module, module.get_pre_op(opid).operand))
def __init__(self, model, num_init_steps):
self.model = model
def apply_collected_fn(initializer, modules_to_init_, distributed_):
for name, module in modules_to_init_.items():
if hasattr(module, 'initialized'):
if module.initialized:
continue
max_value = initializer.get_max_value(module)
min_value = initializer.get_min_value(module)
module_initializer = MIN_MAX_INITIALIZERS.get(type(module).__name__)
module_initializer(module, name, min_value, max_value, distributed_)
self.modules_to_init = OrderedDict()
for module_type, _ in MIN_MAX_INITIALIZERS.registry_dict.items():
self.modules_to_init.update(get_all_modules_by_type(self.model, module_type))
# NOTE: Order of modules must be the same to correctly broadcast parameters (e.g. input_low and input_range)
self.modules_to_init = OrderedDict(sorted(self.modules_to_init.items()))
self.initializer = MinMaxInitializer(self.modules_to_init, apply_collected_fn, num_init_steps)