How to use the skorch.callbacks.Callback function in skorch

To help you get started, weā€™ve selected a few skorch examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github skorch-dev / skorch / skorch / callbacks / training.py View on Github external
--------

    Use ``Initializer`` to initialize all dense layer weights with
    values sampled from an uniform distribution on the beginning of
    the first epoch:

    >>> init_fn = partial(torch.nn.init.uniform_, a=-1e-3, b=1e-3)
    >>> cb = Initializer('dense*.weight', fn=init_fn)
    >>> net = Net(myModule, callbacks=[cb])
    """
    def __init__(self, *args, **kwargs):
        kwargs['at'] = kwargs.get('at', 1)
        super().__init__(*args, **kwargs)


class LoadInitState(Callback):
    """Loads the model, optimizer, and history from a checkpoint into a
    :class:`.NeuralNet` when training begins.

    Examples
    --------

    Consider running the following example multiple times:

    >>> cp = Checkpoint(monitor='valid_loss_best')
    >>> load_state = LoadInitState(cp)
    >>> net = NeuralNet(..., callbacks=[cp, load_state])
    >>> net.fit(X, y)

    On the first run, the :class:`.Checkpoint` saves the model, optimizer, and
    history when the validation loss is minimized. During the first run,
    there are no files on disk, thus :class:`.LoadInitState` will
github skorch-dev / skorch / skorch / callbacks / logging.py View on Github external
data = net.history[-1]
        verbose = net.verbose
        tabulated = self.table(data)

        if self.first_iteration_:
            header, lines = tabulated.split('\n', 2)[:2]
            self._sink(header, verbose)
            self._sink(lines, verbose)
            self.first_iteration_ = False

        self._sink(tabulated.rsplit('\n', 1)[-1], verbose)
        if self.sink is print:
            sys.stdout.flush()


class ProgressBar(Callback):
    """Display a progress bar for each epoch.

    The progress bar includes elapsed and estimated remaining time for
    the current epoch, the number of batches processed, and other
    user-defined metrics. The progress bar is erased once the epoch is
    completed.

    ``ProgressBar`` needs to know the total number of batches per
    epoch in order to display a meaningful progress bar. By default,
    this number is determined automatically using the dataset length
    and the batch size. If this heuristic does not work for some
    reason, you may either specify the number of batches explicitly
    or let the ``ProgressBar`` count the actual number of batches in
    the previous epoch.

    For jupyter notebooks a non-ASCII progress bar can be printed
github skorch-dev / skorch / skorch / callbacks / training.py View on Github external
from itertools import product

import numpy as np
from skorch.callbacks import Callback
from skorch.exceptions import SkorchException
from skorch.utils import noop
from skorch.utils import open_file_like
from skorch.utils import freeze_parameter
from skorch.utils import unfreeze_parameter


__all__ = ['Checkpoint', 'EarlyStopping', 'ParamMapper', 'Freezer',
           'Unfreezer', 'Initializer', 'LoadInitState', 'TrainEndCheckpoint']


class Checkpoint(Callback):
    """Save the model during training if the given metric improved.

    This callback works by default in conjunction with the validation
    scoring callback since it creates a ``valid_loss_best`` value
    in the history which the callback uses to determine if this
    epoch is save-worthy.

    You can also specify your own metric to monitor or supply a
    callback that dynamically evaluates whether the model should
    be saved in this epoch.

    Some or all of the following can be saved:

      - model parameters (see ``f_params`` parameter);
      - optimizer state (see ``f_optimizer`` parameter);
      - training history (see ``f_history`` parameter);
github skorch-dev / skorch / skorch / callbacks / lr_scheduler.py View on Github external
def _check_lr(name, optimizer, lr):
    """Return one learning rate for each param group."""
    n = len(optimizer.param_groups)
    if not isinstance(lr, (list, tuple)):
        return lr * np.ones(n)

    if len(lr) != n:
        raise ValueError("{} lr values were passed for {} but there are "
                         "{} param groups.".format(n, name, len(lr)))
    return np.array(lr)


class LRScheduler(Callback):
    """Callback that sets the learning rate of each
    parameter group according to some policy.

    Parameters
    ----------

    policy : str or _LRScheduler class (default='WarmRestartLR')
      Learning rate policy name or scheduler to be used.

    monitor : str or callable (default=None)
      Value of the history to monitor or function/callable. In
      the latter case, the callable receives the net instance as
      argument and is expected to return the score (float) used to
      determine the learning rate adjustment.

    kwargs
github skorch-dev / skorch / skorch / callbacks / logging.py View on Github external
self.pbar.close()


def rename_tensorboard_key(key):
    """Rename keys from history to keys in TensorBoard

    Specifically, prefixes all names with "Loss/" if they seem to be
    losses.

    """
    if key.startswith('train') or key.startswith('valid'):
        key = 'Loss/' + key
    return key


class TensorBoard(Callback):
    """Logs results from history to TensorBoard

    "TensorBoard provides the visualization and tooling needed for
    machine learning experimentation" (tensorboard_)

    Use this callback to automatically log all interesting values from
    your net's history to tensorboard after each epoch.

    The best way to log additional information is to subclass this
    callback and add your code to one of the ``on_*`` methods.

    Examples
    --------
    >>> # Example to log the bias parameter as a histogram
    >>> def extract_bias(module):
    ...     return module.hidden.bias
github skorch-dev / skorch / skorch / callbacks.py View on Github external
"validation scores for checkpointing.".format(e.args[0]))

        if do_checkpoint:
            target = self.target
            if isinstance(self.target, str):
                target = self.target.format(
                    net=net,
                    last_epoch=net.history[-1],
                    last_batch=net.history[-1, 'batches', -1],
                )
            if net.verbose > 0:
                print("Checkpoint! Saving model to {}.".format(target))
            net.save_params(target)


class ProgressBar(Callback):
    """Display a progress bar for each epoch including duration, estimated
    remaining time and user-defined metrics.

    For jupyter notebooks a non-ASCII progress bar is printed instead.
    To use this feature, you need to have `ipywidgets
    `
    installed.

    Parameters:
    -----------

    batches_per_epoch : int (default=None)
      The progress bar determines the number of batches per epoch
      automatically after one epoch but you can also specify this
      number yourself using this parameter.
github skorch-dev / skorch / skorch / callbacks.py View on Github external
"""Called at the end of each batch."""
        pass

    def _get_param_names(self):
        return (key for key in self.__dict__ if not key.endswith('_'))

    def get_params(self, deep=True):
        return BaseEstimator.get_params(self, deep=deep)

    def set_params(self, **params):
        for key, val in params.items():
            setattr(self, key, val)
        return self


class EpochTimer(Callback):
    """Measures the duration of each epoch and writes it to the
    history with the name ``dur``.

    """
    def __init__(self, **kwargs):
        super(EpochTimer, self).__init__(**kwargs)

        self.epoch_start_time_ = None

    def on_epoch_begin(self, net, **kwargs):
        self.epoch_start_time_ = time.time()

    def on_epoch_end(self, net, **kwargs):
        net.history.record('dur', time.time() - self.epoch_start_time_)
github skorch-dev / skorch / skorch / callbacks.py View on Github external
history with the name ``dur``.

    """
    def __init__(self, **kwargs):
        super(EpochTimer, self).__init__(**kwargs)

        self.epoch_start_time_ = None

    def on_epoch_begin(self, net, **kwargs):
        self.epoch_start_time_ = time.time()

    def on_epoch_end(self, net, **kwargs):
        net.history.record('dur', time.time() - self.epoch_start_time_)


class ScoringBase(Callback):
    """Base class for scoring.

    Subclass and implement an ``on_*`` method before using.
    """
    def __init__(
            self,
            scoring,
            lower_is_better=True,
            on_train=False,
            name=None,
            target_extractor=to_numpy,
    ):
        self.scoring = scoring
        self.lower_is_better = lower_is_better
        self.on_train = on_train
        self.name = name
github skorch-dev / skorch / skorch / callbacks.py View on Github external
elif self.batches_per_epoch == 'count':
            # No limit is known until the end of the first epoch.
            batches_per_epoch = None

        if self._use_notebook():
            self.pbar = tqdm.tqdm_notebook(total=batches_per_epoch)
        else:
            self.pbar = tqdm.tqdm(total=batches_per_epoch)

    def on_epoch_end(self, net, **kwargs):
        if self.batches_per_epoch == 'count':
            self.batches_per_epoch = self.pbar.n
        self.pbar.close()


class GradientNormClipping(Callback):
    """Clips gradient norm of a module's parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified
    in-place.

    See ``torch.nn.utils.clip_grad_norm`` for more information.

    Parameters
    ----------
    gradient_clip_value : float (default=None)
      If not None, clip the norm of all model parameter gradients to this
      value. The type of the norm is determined by the
      ``gradient_clip_norm_type`` parameter and defaults to L2.

    gradient_clip_norm_type : float (default=2)