Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
with pytest.raises(TypeError, match='parameter `m` must be a instance of '):
trainer.model = {}
trainer.model = data[0]
assert isinstance(trainer.model, torch.nn.Module)
with pytest.raises(RuntimeError, match='no loss function for training'):
trainer.fit(*data[1])
trainer.loss_func = MSELoss()
assert trainer.loss_type == 'train_mse_loss'
assert trainer.loss_func.__class__ == MSELoss
with pytest.raises(RuntimeError, match='no optimizer for training'):
trainer.fit(*data[1])
trainer.optimizer = Adam()
assert isinstance(trainer.optimizer, torch.optim.Adam)
assert isinstance(trainer._optimizer_state, dict)
assert isinstance(trainer._init_states, dict)
trainer.lr_scheduler = ExponentialLR(gamma=0.99)
assert isinstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ExponentialLR)
def test_persist_1(data):
model = deepcopy(data[0])
trainer = Trainer(model=model, optimizer=Adam(lr=0.1), loss_func=MSELoss(), epochs=200)
trainer.extend(TensorConverter(), Persist('model_dir'))
trainer.fit(*data[1], *data[1])
persist = trainer['persist']
checker = persist._checker
assert isinstance(persist, Persist)
assert isinstance(checker.model, torch.nn.Module)
assert isinstance(checker.describe, dict)
assert isinstance(checker.files, list)
assert set(checker.files) == {'model', 'init_state', 'model_structure', 'describe', 'training_info', 'final_state'}
trainer = Trainer.load(checker)
assert isinstance(trainer.training_info, pd.DataFrame)
assert isinstance(trainer.model, torch.nn.Module)
assert isinstance(trainer._training_info, list)
assert trainer.optimizer is None
assert isinstance(trainer.training_info, pd.DataFrame)
assert isinstance(trainer.model, torch.nn.Module)
assert isinstance(trainer._training_info, list)
assert trainer.optimizer is None
assert trainer.lr_scheduler is None
assert trainer.x_val is None
assert trainer.y_val is None
assert trainer.validate_dataset is None
assert trainer._optimizer_state is None
assert trainer.total_epochs == 0
assert trainer.total_iterations == 0
assert trainer.loss_type is None
assert trainer.loss_func is None
trainer = Trainer.load(from_=checker.path,
optimizer=Adam(),
loss_func=MSELoss(),
lr_scheduler=ExponentialLR(gamma=0.99),
clip_grad=ClipValue(clip_value=0.1))
assert isinstance(trainer._scheduler, ExponentialLR)
assert isinstance(trainer._optim, Adam)
assert isinstance(trainer.clip_grad, ClipValue)
assert isinstance(trainer.loss_func, MSELoss)
assert trainer.x_val is None
assert trainer.y_val is None
assert trainer.validate_dataset is None
assert trainer._optimizer_state is None
assert trainer.total_epochs == 0
assert trainer.total_iterations == 0
assert trainer.loss_type is None
assert trainer.loss_func is None
trainer = Trainer.load(from_=checker.path,
optimizer=Adam(),
loss_func=MSELoss(),
lr_scheduler=ExponentialLR(gamma=0.99),
clip_grad=ClipValue(clip_value=0.1))
assert isinstance(trainer._scheduler, ExponentialLR)
assert isinstance(trainer._optim, Adam)
assert isinstance(trainer.clip_grad, ClipValue)
assert isinstance(trainer.loss_func, MSELoss)
def test_trainer_fit_4(data):
model = deepcopy(data[0])
trainer = Trainer(model=model,
optimizer=Adam(),
loss_func=MSELoss(),
clip_grad=ClipValue(0.1),
lr_scheduler=ReduceLROnPlateau(),
epochs=10)
count = 1
for i in trainer(*data[1]):
assert isinstance(i, dict)
assert i['i_epoch'] == count
if count == 3:
trainer.early_stop('stop')
count += 1
assert trainer.total_epochs == 3
assert trainer._early_stopping == (True, 'stop')