How to use the optuna.type_checking function in optuna

To help you get started, we’ve selected a few optuna examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github optuna / optuna / tests / samplers_tests / tpe_tests / test_parzen_estimator.py View on Github external
import itertools

import numpy as np
import pytest

from optuna.samplers.tpe.parzen_estimator import _ParzenEstimator
from optuna.samplers.tpe.sampler import default_weights
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from typing import Dict  # NOQA
    from typing import List  # NOQA


class TestParzenEstimator(object):
    @staticmethod
    @pytest.mark.parametrize(
        'mus, prior, magic_clip, endpoints',
        itertools.product(
            ([], [0.4], [-0.4, 0.4]),  # mus
            (True, False),  # prior
            (True, False),  # magic_clip
            (True, False),  # endpoints
        ))
    def test_calculate_shape_check(mus, prior, magic_clip, endpoints):
        # type: (List[float], bool, bool, bool) -> None
github optuna / optuna / tests / test_distributions.py View on Github external
import copy
import json
import pytest

from optuna import distributions
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from typing import Any  # NOQA
    from typing import Dict  # NOQA
    from typing import List  # NOQA

EXAMPLE_DISTRIBUTIONS = {
    'u': distributions.UniformDistribution(low=1., high=2.),
    'l': distributions.LogUniformDistribution(low=0.001, high=100),
    'du': distributions.DiscreteUniformDistribution(low=1., high=10., q=2.),
    'iu': distributions.IntUniformDistribution(low=1, high=10),
    'c1': distributions.CategoricalDistribution(choices=(2.71, -float('inf'))),
    'c2': distributions.CategoricalDistribution(choices=('Roppongi', 'Azabu'))
}  # type: Dict[str, Any]

EXAMPLE_JSONS = {
    'u': '{"name": "UniformDistribution", "attributes": {"low": 1.0, "high": 2.0}}',
    'l': '{"name": "LogUniformDistribution", "attributes": {"low": 0.001, "high": 100}}',
github optuna / optuna / tests / visualization_tests / test_intermediate_plot.py View on Github external
from optuna.study import create_study
from optuna.testing.visualization import prepare_study_with_trials
from optuna import type_checking
from optuna.visualization.intermediate_values import plot_intermediate_values

if type_checking.TYPE_CHECKING:
    from optuna.trial import Trial  # NOQA


def test_plot_intermediate_values():
    # type: () -> None

    # Test with no trials.
    study = prepare_study_with_trials(no_trials=True)
    figure = plot_intermediate_values(study)
    assert not figure.data

    def objective(trial, report_intermediate_values):
        # type: (Trial, bool) -> float

        if report_intermediate_values:
            trial.report(1.0, step=0)
github optuna / optuna / tests / integration_tests / test_pytorch_ignite.py View on Github external
from ignite.engine import Engine
from mock import Mock
from mock import patch
import pytest

import optuna
from optuna.testing.integration import create_running_trial
from optuna.testing.integration import DeterministicPruner
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from typing import Iterable  # NOQA


def test_pytorch_ignite_pruning_handler():
    # type: () -> None

    def update(engine, batch):
        # type: (Engine, Iterable) -> None

        pass

    trainer = Engine(update)

    # The pruner is activated.
    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
github optuna / optuna / tests / test_trial.py View on Github external
from mock import patch
import numpy as np
import pytest
import warnings

from optuna import distributions
from optuna import samplers
from optuna import storages
from optuna.study import create_study
from optuna.testing.integration import DeterministicPruner
from optuna.testing.sampler import DeterministicRelativeSampler
from optuna.trial import FixedTrial
from optuna.trial import Trial
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from datetime import datetime  # NOQA
    import typing  # NOQA

parametrize_storage = pytest.mark.parametrize(
    'storage_init_func',
    [storages.InMemoryStorage, lambda: storages.RDBStorage('sqlite:///:memory:')])


@parametrize_storage
def test_suggest_uniform(storage_init_func):
    # type: (typing.Callable[[], storages.BaseStorage]) -> None

    mock = Mock()
    mock.side_effect = [1., 2., 3.]
    sampler = samplers.RandomSampler()
github optuna / optuna / tests / visualization_tests / test_contour.py View on Github external
import pytest

from optuna.distributions import LogUniformDistribution
from optuna.structs import StudyDirection
from optuna.study import create_study
from optuna.testing.visualization import prepare_study_with_trials
from optuna import type_checking
from optuna.visualization.contour import _generate_contour_subplot
from optuna.visualization.contour import plot_contour

if type_checking.TYPE_CHECKING:
    from typing import List, Optional  # NOQA

    from optuna.trial import Trial  # NOQA


@pytest.mark.parametrize(
    'params', [
        [],
        ['param_a'],
        ['param_a', 'param_b'],
        ['param_a', 'param_b', 'param_c'],
        ['param_a', 'param_b', 'param_c', 'param_d'],
        None,
    ]
)
def test_plot_contour(params):
github optuna / optuna / optuna / visualization / optimization_history.py View on Github external
from optuna.logging import get_logger
from optuna.structs import StudyDirection
from optuna.structs import TrialState
from optuna import type_checking
from optuna.visualization.utils import _check_plotly_availability
from optuna.visualization.utils import is_available

if type_checking.TYPE_CHECKING:
    from optuna.study import Study  # NOQA

if is_available():
    from optuna.visualization.plotly_imports import go

logger = get_logger(__name__)


def plot_optimization_history(study):
    # type: (Study) -> go.Figure
    """Plot optimization history of all trials in a study.

    Example:

        The following code snippet shows how to plot optimization history.
github optuna / optuna / optuna / storages / base.py View on Github external
import abc

from optuna import structs
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from typing import Any  # NOQA
    from typing import Dict  # NOQA
    from typing import List  # NOQA
    from typing import Optional  # NOQA

    from optuna import distributions  # NOQA

DEFAULT_STUDY_NAME_PREFIX = 'no-name-'


class BaseStorage(object, metaclass=abc.ABCMeta):
    """Base class for storages.

    This class is not supposed to be directly accessed by library users.

    Storage classes abstract a backend database and provide library internal interfaces to
github optuna / optuna / optuna / visualization / slice.py View on Github external
from optuna.logging import get_logger
from optuna.structs import TrialState
from optuna import type_checking
from optuna.visualization.utils import _check_plotly_availability
from optuna.visualization.utils import _is_log_scale
from optuna.visualization.utils import is_available

if type_checking.TYPE_CHECKING:
    from typing import List  # NOQA
    from typing import Optional  # NOQA

    from optuna.structs import FrozenTrial  # NOQA
    from optuna.study import Study  # NOQA
    from optuna.visualization.plotly_imports import Scatter  # NOQA

if is_available():
    from optuna.visualization.plotly_imports import go
    from optuna.visualization.plotly_imports import make_subplots

logger = get_logger(__name__)


def plot_slice(study, params=None):
    # type: (Study, Optional[List[str]]) -> go.Figure
github optuna / optuna / optuna / visualization / intermediate_values.py View on Github external
from optuna.logging import get_logger
from optuna.structs import TrialState
from optuna import type_checking
from optuna.visualization.utils import _check_plotly_availability
from optuna.visualization.utils import is_available

if type_checking.TYPE_CHECKING:
    from optuna.study import Study  # NOQA

if is_available():
    from optuna.visualization.plotly_imports import go

logger = get_logger(__name__)


def plot_intermediate_values(study):
    # type: (Study) -> go.Figure
    """Plot intermediate values of all trials in a study.

    Example:

        The following code snippet shows how to plot intermediate values.