Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
log.info("Model: \n{:s}".format(str(model)))
dummy_input = np_to_var(train_set.X[:1, :, :, None])
if cuda:
dummy_input = dummy_input.cuda()
out = model(dummy_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
optimizer = optim.Adam(model.parameters())
iterator = CropsFromTrialsIterator(batch_size=batch_size,
input_time_length=input_time_length,
n_preds_per_input=n_preds_per_input)
stop_criterion = Or([MaxEpochs(max_epochs),
NoDecrease('valid_misclass', max_increase_epochs)])
monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
CroppedTrialMisclassMonitor(
input_time_length=input_time_length), RuntimeMonitor()]
model_constraint = MaxNormDefaultConstraint()
loss_function = lambda preds, targets: F.nll_loss(
th.mean(preds, dim=2, keepdim=False), targets)
exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator,
loss_function=loss_function, optimizer=optimizer,
model_constraint=model_constraint,
monitors=monitors,
stop_criterion=stop_criterion,
def setup_after_stop_training(self):
"""
Setup training after first stop.
Resets parameters to best parameters and updates stop criterion.
"""
# also remember old monitor chans, will be put back into
# monitor chans after experiment finished
self.before_stop_df = deepcopy(self.epochs_df)
self.rememberer.reset_to_best_model(
self.epochs_df, self.model, self.optimizer
)
loss_to_reach = float(self.epochs_df["train_loss"].iloc[-1])
self.stop_criterion = Or(
stop_criteria=[
MaxEpochs(max_epochs=self.rememberer.best_epoch * 2),
ColumnBelow(
column_name="valid_loss", target_value=loss_to_reach
),
]
)
log.info("Train loss to reach {:.5f}".format(loss_to_reach))
elif model == "deep":
model = Deep4Net(
n_chans,
n_classes,
input_time_length=input_time_length,
final_conv_length="auto",
).create_network()
if cuda:
model.cuda()
log.info("Model: \n{:s}".format(str(model)))
optimizer = optim.Adam(model.parameters())
iterator = BalancedBatchSizeIterator(batch_size=batch_size)
stop_criterion = Or(
[
MaxEpochs(max_epochs),
NoDecrease("valid_misclass", max_increase_epochs),
]
)
monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()]
model_constraint = MaxNormDefaultConstraint()
exp = Experiment(
model,
train_set,
valid_set,
test_set,
iterator=iterator,