Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def run_experiment(experiment_tags, data_dir, results_dir, start_fresh=False, use_cuda=False, workers=None,
experiments_file=None, *args, **kwargs):
if not os.path.exists(data_dir):
raise RuntimeError('Cannot find data_dir directory: {}'.format(data_dir))
if not os.path.exists(results_dir):
raise RuntimeError('Cannot find results_dir directory: {}'.format(results_dir))
cfg = load_experiment_config(experiments_file, experiment_tags)
logger.info(cfg)
model, optimizer, trainer, trainer_params = experiment_config_parser(cfg, workers=workers, data_dir=data_dir)
experiment_dir = os.path.join(results_dir, '_'.join(experiment_tags))
manager = ExperimentManager(experiment_dir, model, optimizer)
if start_fresh:
logger.info('Starting fresh option enabled. Clearing all previous results...')
manager.delete_dirs()
manager.make_dirs()
if use_cuda:
manager.model = manager.model.cuda()
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
last_iter = manager.get_last_model_iteration()
if last_iter > 0:
logger.info('Continue experiment from iteration: {}'.format(last_iter))
manager.load_train_state(last_iter)
trainer_params.update(kwargs)