How to use the fastai.basic_train.LearnerCallback function in fastai

To help you get started, we’ve selected a few fastai examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github fastai / fastai / fastai / callbacks / oversampling.py View on Github external
from ..torch_core import *
from ..basic_data import DataBunch
from ..callback import *
from ..basic_train import Learner,LearnerCallback
from torch.utils.data.sampler import WeightedRandomSampler

__all__ = ['OverSamplingCallback']

class OverSamplingCallback(LearnerCallback):
    def __init__(self,learn:Learner,weights:torch.Tensor=None):
        super().__init__(learn)
        self.weights = weights

    def on_train_begin(self, **kwargs):
        ds,dl = self.data.train_ds,self.data.train_dl
        self.labels = ds.y.items
        assert np.issubdtype(self.labels.dtype, np.integer), "Can only oversample integer values"
        _,self.label_counts = np.unique(self.labels,return_counts=True)
        if self.weights is None: self.weights = torch.DoubleTensor((1/self.label_counts)[self.labels])
        self.total_len_oversample = int(self.data.c*np.max(self.label_counts))
        sampler = WeightedRandomSampler(self.weights, self.total_len_oversample)
        self.data.train_dl = dl.new(shuffle=False, sampler=sampler)
github fastai / fastai / fastai / distributed.py View on Github external
from .torch_core import *
from .basic_train import Learner,LearnerCallback
from torch.nn.parallel import DistributedDataParallel, DataParallel
from torch.utils.data.distributed import DistributedSampler

from fastai.text import TextLMDataBunch

__all__ = ['DistributedRecorder', 'DistributedTrainer', 'read_metrics', 'setup_distrib']

def rnn_reset(self):
    if hasattr(self.module, 'reset'): self.module.reset()
DistributedDataParallel.reset = rnn_reset

class ParallelTrainer(LearnerCallback):
    _order = -20
    def on_train_begin(self, **kwargs): self.learn.model = DataParallel(self.learn.model)
    def on_train_end  (self, **kwargs): self.learn.model = self.learn.model.module

class DistributedTrainer(LearnerCallback):
    _order = -20 # Needs to run before the recorder
    def __init__(self, learn:Learner, cuda_id:int=0):
        super().__init__(learn)
        self.cuda_id,self.train_sampler = cuda_id,None

    def _change_dl(self, dl, shuffle):
        old_dl = dl
        sampler = OurDistributedSampler(dl.dataset, shuffle=shuffle)
        new_dl = dl.new(shuffle=False, sampler=sampler)
        return old_dl,new_dl,sampler
github jantic / DeOldify / deoldify / save.py View on Github external
from fastai.basic_train import Learner, LearnerCallback
from fastai.vision.gan import GANLearner


class GANSaveCallback(LearnerCallback):
    """A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""

    def __init__(
        self,
        learn: GANLearner,
        learn_gen: Learner,
        filename: str,
        save_iters: int = 1000,
    ):
        super().__init__(learn)
        self.learn_gen = learn_gen
        self.filename = filename
        self.save_iters = save_iters

    def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
        if iteration == 0:
github fastai / fastai / fastai / callbacks / general_sched.py View on Github external
from ..basic_train import Learner, LearnerCallback

__all__ = ['GeneralScheduler', 'TrainingPhase']

@dataclass
class TrainingPhase():
    "Schedule hyper-parameters for a phase of `length` iterations."
    length:int
    
    def __post_init__(self): self.scheds = dict()
    def schedule_hp(self, name, vals, anneal=None):
        "Adds a schedule for `name` between `vals` using `anneal`."
        self.scheds[name] = Scheduler(vals, self.length, anneal)
        return self

class GeneralScheduler(LearnerCallback):
    "Schedule multiple `TrainingPhase` for a `Learner`."
    def __init__(self, learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None):
        super().__init__(learn)
        self.phases,self.start_epoch = phases,start_epoch

    def on_train_begin(self, epoch:int, **kwargs:Any)->None:
        "Initialize the schedulers for training."
        res = {'epoch':self.start_epoch} if self.start_epoch is not None else None
        self.start_epoch = ifnone(self.start_epoch, epoch)
        self.scheds = [p.scheds for p in self.phases]
        self.opt = self.learn.opt
        for k,v in self.scheds[0].items(): 
            v.restart()
            self.opt.set_stat(k, v.start)
        self.idx_s = 0
        return res
github fastai / fastai / fastai / callbacks / mixup.py View on Github external
"Implements [mixup](https://arxiv.org/abs/1710.09412) training method"
from ..torch_core import *
from ..callback import *
from ..basic_train import Learner, LearnerCallback

__all__ = ["MixUpCallback", "MixUpLoss"]

class MixUpCallback(LearnerCallback):
    "Callback that creates the mixed-up input and target."
    def __init__(self, learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True):
        super().__init__(learn)
        self.alpha,self.stack_x,self.stack_y = alpha,stack_x,stack_y
    
    def on_train_begin(self, **kwargs):
        if self.stack_y: self.learn.loss_func = MixUpLoss(self.learn.loss_func)
        
    def on_batch_begin(self, last_input, last_target, train, **kwargs):
        "Applies mixup to `last_input` and `last_target` if `train`."
        if not train: return
        lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))
        lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
        lambd = last_input.new(lambd)
        shuffle = torch.randperm(last_target.size(0)).to(last_input.device)
        x1, y1 = last_input[shuffle], last_target[shuffle]
github fastai / fastai / fastai / vision / gan.py View on Github external
def __init__(self, loss_funcG:Callable, loss_funcC:Callable, gan_model:GANModule):
        super().__init__()
        self.loss_funcG,self.loss_funcC,self.gan_model = loss_funcG,loss_funcC,gan_model

    def generator(self, output, target):
        "Evaluate the `output` with the critic then uses `self.loss_funcG` to combine it with `target`."
        fake_pred = self.gan_model.critic(output)
        return self.loss_funcG(fake_pred, target, output)

    def critic(self, real_pred, input):
        "Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.loss_funcD`."
        fake = self.gan_model.generator(input.requires_grad_(False)).requires_grad_(True)
        fake_pred = self.gan_model.critic(fake)
        return self.loss_funcC(real_pred, fake_pred)

class GANTrainer(LearnerCallback):
    "Handles GAN Training."
    _order=-20
    def __init__(self, learn:Learner, switch_eval:bool=False, clip:float=None, beta:float=0.98, gen_first:bool=False,
                 show_img:bool=True):
        super().__init__(learn)
        self.switch_eval,self.clip,self.beta,self.gen_first,self.show_img = switch_eval,clip,beta,gen_first,show_img
        self.generator,self.critic = self.model.generator,self.model.critic

    def _set_trainable(self):
        train_model = self.generator if     self.gen_mode else self.critic
        loss_model  = self.generator if not self.gen_mode else self.critic
        requires_grad(train_model, True)
        requires_grad(loss_model, False)
        if self.switch_eval:
            train_model.train()
            loss_model.eval()
github fastai / fastai / fastai / callbacks / tensorboard.py View on Github external
from queue import Queue
import statistics
import torchvision.utils as vutils
from abc import ABC
#This is an optional dependency in fastai.  Must install separately.
try: from tensorboardX import SummaryWriter
except: print("To use this tracker, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results")

__all__=['LearnerTensorboardWriter', 'GANTensorboardWriter', 'ImageGenTensorboardWriter']

#---Example usage (applies to any of the callbacks)--- 
# proj_id = 'Colorize'
# tboard_path = Path('data/tensorboard/' + proj_id)
# learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=tboard_path, name='GanLearner'))

class LearnerTensorboardWriter(LearnerCallback):
    "Broadly useful callback for Learners that writes to Tensorboard.  Writes model histograms, losses/metrics, and gradient stats."
    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
        super().__init__(learn=learn)
        self.base_dir,self.name,self.loss_iters,self.hist_iters,self.stats_iters  = base_dir,name,loss_iters,hist_iters,stats_iters
        log_dir = base_dir/name
        self.tbwriter = SummaryWriter(str(log_dir))
        self.hist_writer = HistogramTBWriter()
        self.stats_writer = ModelStatsTBWriter()
        self.graph_writer = GraphTBWriter()
        self.data = None
        self.metrics_root = '/metrics/'
        self._update_batches_if_needed()

    def _get_new_batch(self, ds_type:DatasetType)->Collection[Tensor]:
        "Retrieves new batch of DatasetType, and detaches it."
        return self.learn.data.one_batch(ds_type=ds_type, detach=True, denorm=False, cpu=False)