Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def forward(self, weight):
if is_tracing_state():
with no_jit_trace():
return weight.mul_(self.binary_mask)
tmp_tensor = self._calc_training_binary_mask(weight)
return apply_binary_mask_impl(tmp_tensor, weight)
def apply_binary_mask(self, weight):
return apply_binary_mask_impl(self.binary_mask, weight)