How to use the batchflow.models.torch.utils.unpack_fn_from_config function in batchflow

To help you get started, we’ve selected a few batchflow examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github analysiscenter / batchflow / batchflow / models / torch / base.py View on Github external
def _make_loss(self, config):
        res = unpack_fn_from_config('loss', config)
        res = res if isinstance(res, list) else [res]

        losses = []
        for loss, args in res:
            loss_fn = None
            if isinstance(loss, str):
                if hasattr(nn, loss):
                    loss = getattr(nn, loss)
                elif hasattr(nn, loss + "Loss"):
                    loss = getattr(nn, loss + "Loss")
                else:
                    loss = LOSSES.get(re.sub('[-_ ]', '', loss).lower(), None)
            elif isinstance(loss, type):
                pass
            elif isinstance(loss, nn.Module):
                loss_fn = loss
github analysiscenter / batchflow / batchflow / models / torch / base.py View on Github external
def _make_loss(self, config):
        res = unpack_fn_from_config('loss', config)
        res = res if isinstance(res, list) else [res]

        losses = []
        for loss, args in res:
            loss_fn = None
            if isinstance(loss, str):
                if hasattr(nn, loss):
                    loss = getattr(nn, loss)
                elif hasattr(nn, loss + "Loss"):
                    loss = getattr(nn, loss + "Loss")
                else:
                    loss = LOSSES.get(re.sub('[-_ ]', '', loss).lower(), None)
            elif isinstance(loss, type):
                pass
            elif isinstance(loss, nn.Module):
                loss_fn = loss
github analysiscenter / batchflow / batchflow / models / torch / base.py View on Github external
def _make_optimizer(self, config):
        optimizer, optimizer_args = unpack_fn_from_config('optimizer', config)

        if callable(optimizer) or isinstance(optimizer, type):
            pass
        elif isinstance(optimizer, str) and hasattr(torch.optim, optimizer):
            optimizer = getattr(torch.optim, optimizer)
        else:
            raise ValueError("Unknown optimizer", optimizer)

        if optimizer:
            optimizer = optimizer(self.model.parameters(), **optimizer_args)
        else:
            raise ValueError("Optimizer is not defined", optimizer)

        decay, decay_args = self._make_decay(config)
        if decay is not None:
            decay = decay(optimizer, **decay_args)
github analysiscenter / batchflow / batchflow / models / torch / base.py View on Github external
def _make_optimizer(self, config):
        optimizer, optimizer_args = unpack_fn_from_config('optimizer', config)

        if callable(optimizer) or isinstance(optimizer, type):
            pass
        elif isinstance(optimizer, str) and hasattr(torch.optim, optimizer):
            optimizer = getattr(torch.optim, optimizer)
        else:
            raise ValueError("Unknown optimizer", optimizer)

        if optimizer:
            optimizer = optimizer(self.model.parameters(), **optimizer_args)
        else:
            raise ValueError("Optimizer is not defined", optimizer)

        decay, decay_args = self._make_decay(config)
        if decay is not None:
            decay = decay(optimizer, **decay_args)
github analysiscenter / batchflow / batchflow / models / torch / base.py View on Github external
def _make_decay(self, config):
        decay, decay_args = unpack_fn_from_config('decay', config)
        n_iters = config.get('n_iters')

        if decay is None:
            return decay, decay_args
        if 'n_iters' not in config:
            raise ValueError("Missing required key ``'n_iters'`` in the cofiguration dict.")

        if callable(decay) or isinstance(decay, type):
            pass
        elif isinstance(decay, str) and hasattr(torch.optim.lr_scheduler, decay):
            decay = getattr(torch.optim.lr_scheduler, decay)
        elif decay in DECAYS:
            decay = DECAYS.get(decay)
        else:
            raise ValueError("Unknown learning rate decay method", decay)
github analysiscenter / batchflow / batchflow / models / torch / base.py View on Github external
def _make_loss(self, config):
        res = unpack_fn_from_config('loss', config)
        res = res if isinstance(res, list) else [res]

        losses = []
        for loss, args in res:
            loss_fn = None
            if isinstance(loss, str):
                if hasattr(nn, loss):
                    loss = getattr(nn, loss)
                elif hasattr(nn, loss + "Loss"):
                    loss = getattr(nn, loss + "Loss")
                else:
                    loss = LOSSES.get(re.sub('[-_ ]', '', loss).lower(), None)
            elif isinstance(loss, type):
                pass
            elif isinstance(loss, nn.Module):
                loss_fn = loss