Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, search_space, skopt_kwargs):
# type: (Dict[str, BaseDistribution], Dict[str, Any]) -> None
self._search_space = search_space
dimensions = []
for name, distribution in sorted(self._search_space.items()):
if isinstance(distribution, distributions.UniformDistribution):
# Convert the upper bound from exclusive (optuna) to inclusive (skopt).
high = np.nextafter(distribution.high, float('-inf'))
dimension = space.Real(distribution.low, high)
elif isinstance(distribution, distributions.LogUniformDistribution):
# Convert the upper bound from exclusive (optuna) to inclusive (skopt).
high = np.nextafter(distribution.high, float('-inf'))
dimension = space.Real(distribution.low, high, prior='log-uniform')
elif isinstance(distribution, distributions.IntUniformDistribution):
dimension = space.Integer(distribution.low, distribution.high)
elif isinstance(distribution, distributions.DiscreteUniformDistribution):
count = (distribution.high - distribution.low) // distribution.q
dimension = space.Integer(0, count)
elif isinstance(distribution, distributions.CategoricalDistribution):
dimension = space.Categorical(distribution.choices)
else:
raise NotImplementedError(
def set_trial_param(self, trial_id, param_name, param_value_internal, distribution):
# type: (int, str, float, distributions.BaseDistribution) -> bool
with self._lock:
self.check_trial_is_updatable(trial_id, self.trials[trial_id].state)
# Check param distribution compatibility with previous trial(s).
if param_name in self.param_distribution:
distributions.check_distribution_compatibility(self.param_distribution[param_name],
distribution)
# Check param has not been set; otherwise, return False.
if param_name in self.trials[trial_id].params:
return False
# Set param distribution.
self.param_distribution[param_name] = distribution
# Set param.
self.trials[trial_id].params[param_name] = distribution.to_external_repr(
param_value_internal)
self.trials[trial_id].distributions[param_name] = distribution
return True
return self._random_sampler.sample_independent(
study, trial, param_name, param_distribution)
below_param_values, above_param_values = self._split_observation_pairs(values, scores)
if isinstance(param_distribution, distributions.UniformDistribution):
return self._sample_uniform(param_distribution, below_param_values, above_param_values)
elif isinstance(param_distribution, distributions.LogUniformDistribution):
return self._sample_loguniform(param_distribution, below_param_values,
above_param_values)
elif isinstance(param_distribution, distributions.DiscreteUniformDistribution):
return self._sample_discrete_uniform(param_distribution, below_param_values,
above_param_values)
elif isinstance(param_distribution, distributions.IntUniformDistribution):
return self._sample_int(param_distribution, below_param_values, above_param_values)
elif isinstance(param_distribution, distributions.CategoricalDistribution):
index = self._sample_categorical_index(param_distribution, below_param_values,
above_param_values)
return param_distribution.choices[index]
else:
distribution_list = [
distributions.UniformDistribution.__name__,
distributions.LogUniformDistribution.__name__,
distributions.DiscreteUniformDistribution.__name__,
distributions.IntUniformDistribution.__name__,
distributions.CategoricalDistribution.__name__
]
raise NotImplementedError("The distribution {} is not implemented. "
"The parameter distribution should be one of the {}".format(
param_distribution, distribution_list))
def _is_relative_param(self, name, distribution):
# type: (str, BaseDistribution) -> bool
if name not in self.relative_params:
return False
if name not in self.relative_search_space:
raise ValueError("The parameter '{}' was sampled by `sample_relative` method "
"but it is not contained in the relative search space.".format(name))
relative_distribution = self.relative_search_space[name]
distributions.check_distribution_compatibility(relative_distribution, distribution)
param_value = self.relative_params[name]
param_value_in_internal_repr = distribution.to_internal_repr(param_value)
return distribution._contains(param_value_in_internal_repr)
def suggest_uniform(self, name, low, high):
# type: (str, float, float) -> float
return self._suggest(name, distributions.UniformDistribution(low=low, high=high))
def suggest_discrete_uniform(self, name, low, high, q):
# type: (str, float, float, float) -> float
high = _adjust_discrete_uniform_high(name, low, high, q)
discrete = distributions.DiscreteUniformDistribution(low=low, high=high, q=q)
return self._suggest(name, discrete)
def infer_relative_search_space(self, study, trial):
# type: (Study, FrozenTrial) -> Dict[str, BaseDistribution]
search_space = {}
for name, distribution in samplers.intersection_search_space(study).items():
if distribution.single():
if not isinstance(distribution, distributions.CategoricalDistribution):
# `skopt` cannot handle non-categorical distributions that contain just
# a single value, so we skip this distribution.
#
# Note that `Trial` takes care of this distribution during suggestion.
continue
search_space[name] = distribution
return search_space
def sample_independent(self, study, trial, param_name, param_distribution):
# type: (Study, FrozenTrial, str, BaseDistribution) -> Any
values, scores = _get_observation_pairs(study, param_name)
n = len(values)
if n < self._n_startup_trials:
return self._random_sampler.sample_independent(
study, trial, param_name, param_distribution)
below_param_values, above_param_values = self._split_observation_pairs(values, scores)
if isinstance(param_distribution, distributions.UniformDistribution):
return self._sample_uniform(param_distribution, below_param_values, above_param_values)
elif isinstance(param_distribution, distributions.LogUniformDistribution):
return self._sample_loguniform(param_distribution, below_param_values,
above_param_values)
elif isinstance(param_distribution, distributions.DiscreteUniformDistribution):
return self._sample_discrete_uniform(param_distribution, below_param_values,
above_param_values)
elif isinstance(param_distribution, distributions.IntUniformDistribution):
return self._sample_int(param_distribution, below_param_values, above_param_values)
elif isinstance(param_distribution, distributions.CategoricalDistribution):
index = self._sample_categorical_index(param_distribution, below_param_values,
above_param_values)
return param_distribution.choices[index]
else:
distribution_list = [
distributions.UniformDistribution.__name__,
below_param_values, above_param_values = self._split_observation_pairs(
list(range(n)), [p[0] for p in observation_pairs], list(range(n)),
[p[1] for p in observation_pairs])
if isinstance(param_distribution, distributions.UniformDistribution):
return self._sample_uniform(param_distribution, below_param_values, above_param_values)
elif isinstance(param_distribution, distributions.LogUniformDistribution):
return self._sample_loguniform(param_distribution, below_param_values,
above_param_values)
elif isinstance(param_distribution, distributions.DiscreteUniformDistribution):
return self._sample_discrete_uniform(param_distribution, below_param_values,
above_param_values)
elif isinstance(param_distribution, distributions.IntUniformDistribution):
return self._sample_int(param_distribution, below_param_values, above_param_values)
elif isinstance(param_distribution, distributions.CategoricalDistribution):
return self._sample_categorical(param_distribution, below_param_values,
above_param_values)
else:
distribution_list = [
distributions.UniformDistribution.__name__,
distributions.LogUniformDistribution.__name__,
distributions.DiscreteUniformDistribution.__name__,
distributions.IntUniformDistribution.__name__,
distributions.CategoricalDistribution.__name__
]
raise NotImplementedError("The distribution {} is not implemented. "
"The parameter distribution should be one of the {}".format(
param_distribution, distribution_list))
def _check_compatibility_with_previous_trial_param_distributions(self, session):
# type: (orm.Session) -> None
trial = TrialModel.find_or_raise_by_id(self.trial_id, session)
previous_record = session.query(TrialParamModel).join(TrialModel). \
filter(TrialModel.study_id == trial.study_id). \
filter(TrialParamModel.param_name == self.param_name).first()
if previous_record is not None:
distributions.check_distribution_compatibility(
distributions.json_to_distribution(previous_record.distribution_json),
distributions.json_to_distribution(self.distribution_json))