Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_base_runner_3(data):
x, y = 1, 2
ext1, ext2 = data[0](), data[1]()
runner = BaseRunner()
runner.extend(ext1, ext2)
assert {'ext1', 'ext2'} == set(runner._extensions.keys())
runner._before_proc()
assert ext1.before == 'ext1'
assert ext2.before == 'ext1_ext2'
runner._after_proc()
assert ext1.after == 'ext1'
assert ext2.after == 'ext1_ext2'
step_info = OrderedDict()
runner._step_forward(step_info=step_info)
assert step_info['ext1'] == 'ext1'
assert step_info['ext2'] == 'ext1_ext2'
assert step_info['non_exist'] == 'you can not see me!'
def test_validator_1():
x = np.random.randn(100) # input
y = x + np.random.rand() * 0.001 # true values
class _Trainer(BaseRunner):
def __init__(self):
super().__init__()
self.x_val = x
self.y_val = y
self.loss_type = 'train_loss'
def predict(self, x_, y_):
return x_, y_
val = Validator('regress', each_iteration=False)
step_info = OrderedDict(train_loss=0, i_epoch=0)
val.step_forward(trainer=_Trainer(), step_info=step_info) # noqa
assert 'val_mae' not in step_info
def test_base_runner_2(data):
x, y = 1, 2
runner = BaseRunner()
assert runner.input_proc(x, y, ) == (x, y)
def __init__(self, cuda: Union[bool, str, torch.device] = False):
self._device = self.check_device(cuda)
self._extensions: BaseRunner.T_Extension_Dict = {}
import pandas as pd
import torch
from torch.nn import Module
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from torch.utils.data import DataLoader
from deprecated import deprecated
from xenonpy.model.training import ClipValue, ClipNorm, Checker
from xenonpy.model.training.base import BaseOptimizer, BaseLRScheduler, BaseRunner
from xenonpy.utils import camel_to_snake
__all__ = ['Trainer']
class Trainer(BaseRunner):
checkpoint_tuple = namedtuple('checkpoint', 'id iterations model_state')
results_tuple = namedtuple('results', 'total_epochs device training_info checkpoints model')
def __init__(
self,
*,
loss_func: torch.nn.Module = None,
optimizer: BaseOptimizer = None,
model: Module = None,
lr_scheduler: BaseLRScheduler = None,
clip_grad: Union[ClipNorm, ClipValue] = None,
epochs: int = 200,
cuda: Union[bool, str, torch.device] = False,
non_blocking: bool = False,
):
"""
# license that can be found in the LICENSE file.
from copy import deepcopy
from typing import Union, Tuple, Any, Dict
import numpy as np
import torch
from torch.nn import Module
from torch.utils.data import DataLoader
from xenonpy.model.training.base import BaseRunner
__all__ = ['Predictor']
class Predictor(BaseRunner):
def __init__(self,
model: Module,
*,
cuda: Union[bool, str, torch.device] = False,
check_points: Dict[int, Dict] = None,
verbose: bool = True,
):
"""
Parameters
----------
model
cuda
verbose
"""
super().__init__(cuda=cuda)
raise RuntimeError('Number of keys not equal values\' number')
types_ = kwargs.values()
if not all([isinstance(v, t) for v, t in zip(ret, types_)]):
raise TypeError('Returns\' type not match')
names = kwargs.keys()
pair = zip(names, ret)
checker = getattr(self._checker, fn_.__name__)
checker.save(**{k: v for k, v in pair})
return ret if len(ret) > 1 else ret[0]
return _func_2
return _deco
class RegressionRunner(BaseRunner, RegressorMixin):
"""
Run model.
"""
def __init__(self,
epochs=2000,
*,
cuda=False,
check_step=100,
log_step=0,
work_dir=None,
verbose=True,
describe=None):
"""
Parameters
----------
def _checked(o):
if not isinstance(o, BaseRunner):
raise TypeError('persistence only decorate inherent object\'s method')
return o