How to use the nncf.layer_utils.COMPRESSION_MODULES function in nncf

To help you get started, we’ve selected a few nncf examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github opencv / openvino_training_extensions / pytorch_toolkit / nncf / nncf / sparsity / layers.py View on Github external
Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import torch
import torch.nn as nn

from nncf.layer_utils import COMPRESSION_MODULES
from nncf.sparsity.functions import apply_binary_mask as apply_binary_mask_impl
from nncf.utils import is_tracing_state, no_jit_trace


@COMPRESSION_MODULES.register()
class BinaryMask(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.register_buffer("_binary_mask", torch.ones(size))

    @property
    def binary_mask(self):
        return self._binary_mask

    @binary_mask.setter
    def binary_mask(self, tensor):
        with torch.no_grad():
            self._binary_mask.set_(tensor)

    def forward(self, weight):
        if is_tracing_state():