Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_median_pruner_n_warmup_steps():
# type: () -> None
pruner = optuna.pruners.MedianPruner(0, 1)
study = optuna.study.create_study()
trial = optuna.trial.Trial(study, study._storage.create_new_trial_id(study.study_id))
trial.report(1, 1)
trial.report(1, 2)
study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)
trial = optuna.trial.Trial(study, study._storage.create_new_trial_id(study.study_id))
trial.report(2, 1)
# A pruner is not activated during warm-up steps.
assert not pruner.prune(
study=study, trial=study._storage.get_trial(trial.trial_id), step=1)
trial.report(2, 2)
# A pruner is activated after warm-up steps.
assert pruner.prune(
def test_median_pruner_intermediate_values(direction_value):
# type: (Tuple[str, float]) -> None
direction, intermediate_value = direction_value
pruner = optuna.pruners.MedianPruner(0, 0)
study = optuna.study.create_study(direction=direction)
trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
trial.report(1, 1)
study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)
trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
# A pruner is not activated if a trial has no intermediate values.
assert not pruner.prune(
study=study, trial=study._storage.get_trial(trial._trial_id))
trial.report(intermediate_value, 1)
# A pruner is activated if a trial has an intermediate value.
assert pruner.prune(
study=study, trial=study._storage.get_trial(trial._trial_id))
def test_median_pruner_interval_steps(
n_warmup_steps, interval_steps, report_steps, expected_prune_steps):
# type: (int, int, int, List[int]) -> None
pruner = optuna.pruners.MedianPruner(0, n_warmup_steps, interval_steps)
study = optuna.study.create_study()
trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
n_steps = max(expected_prune_steps)
base_index = 1
for i in range(base_index, base_index + n_steps):
trial.report(base_index, i)
study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)
trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
for i in range(base_index, base_index + n_steps):
if (i - base_index) % report_steps == 0:
trial.report(2, i)
assert (pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id))
== (i > n_warmup_steps and i in expected_prune_steps))
def test_median_pruner_intermediate_values_nan():
# type: () -> None
pruner = optuna.pruners.MedianPruner(0, 0)
study = optuna.study.create_study()
trial = optuna.trial.Trial(study, study._storage.create_new_trial_id(study.study_id))
trial.report(float('nan'), 1)
# A pruner is not activated if the study does not have any previous trials.
assert not pruner.prune(
study=study, trial=study._storage.get_trial(trial.trial_id), step=1)
study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)
trial = optuna.trial.Trial(study, study._storage.create_new_trial_id(study.study_id))
trial.report(float('nan'), 1)
# A pruner is activated if the best intermediate value of this trial is NaN.
assert pruner.prune(
study=study, trial=study._storage.get_trial(trial.trial_id), step=1)
study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)
log.info(f'Study : {study_name}')
n_trials = 100
n_jobs = psutil.cpu_count()
log.info(f'Number of Trials : {n_trials}')
log.info(f'Number of Parallel Jobs : {n_jobs}')
# sampler = RandomSampler(seed=seed)
sampler = TPESampler(seed=SEED) # Make the sampler behave in a deterministic way.
# study: A study corresponds to an optimization task, i.e., a set of trials.
study = optuna.create_study(study_name=study_name,
direction='maximize',
sampler=sampler,
pruner=optuna.pruners.MedianPruner(),
storage=f'sqlite:///{study_name}.db',
load_if_exists=True)
study.optimize(objective,
n_trials=n_trials,
n_jobs=n_jobs,
show_progress_bar=True)
log.info(f'Best Parameters: {study.best_params}')
log.info(f'Best Value: {study.best_value}')
df = study.trials_dataframe()
df.to_pickle(f'{study_name}_df.bz2')
end_time = dt.datetime.now()
log.info(f'Total time taken for the study: {end_time - start_time}')
param['sample_type'] = trial.suggest_categorical('sample_type', ['uniform', 'weighted'])
param['normalize_type'] = trial.suggest_categorical('normalize_type', ['tree', 'forest'])
param['rate_drop'] = trial.suggest_loguniform('rate_drop', 1e-8, 1.0)
param['skip_drop'] = trial.suggest_loguniform('skip_drop', 1e-8, 1.0)
# Add a callback for pruning.
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'validation-auc')
bst = xgb.train(param, dtrain, evals=[(dtest, 'validation')], callbacks=[pruning_callback])
preds = bst.predict(dtest)
pred_labels = np.rint(preds)
accuracy = sklearn.metrics.accuracy_score(test_y, pred_labels)
return accuracy
if __name__ == '__main__':
study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
direction='maximize')
study.optimize(objective, n_trials=100)
print(study.best_trial)
.format(engine.state.epoch, train_acc, validation_acc)
)
trainer.run(train_loader, max_epochs=EPOCHS)
evaluator.run(val_loader)
return evaluator.state.metrics['accuracy']
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PyTorch Ignite example.')
parser.add_argument('--pruning', '-p', action='store_true',
help='Activate the pruning feature. `MedianPruner` stops unpromising '
'trials at the early stages of training.')
args = parser.parse_args()
pruner = optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
study = optuna.create_study(direction='maximize', pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)
print('Number of finished trials: ', len(study.trials))
print('Best trial:')
trial = study.best_trial
print(' Value: ', trial.value)
print(' Params: ')
for key, value in trial.params.items():
print(' {}: {}'.format(key, value))
def __init__(
self,
study_name, # type: str
storage, # type: Union[str, storages.BaseStorage]
sampler=None, # type: samplers.BaseSampler
pruner=None # type: pruners.BasePruner
):
# type: (...) -> None
self.study_name = study_name
storage = storages.get_storage(storage)
study_id = storage.get_study_id_from_name(study_name)
super(Study, self).__init__(study_id, storage)
self.sampler = sampler or samplers.TPESampler()
self.pruner = pruner or pruners.MedianPruner()
self.logger = logging.get_logger(__name__)
self._optimize_lock = threading.Lock()