How to use the optuna.type_checking.TYPE_CHECKING function in optuna

To help you get started, we’ve selected a few optuna examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github optuna / optuna / tests / test_cli.py View on Github external
import pytest
import re
import subprocess
from subprocess import CalledProcessError
import tempfile

import optuna
from optuna.cli import Studies
from optuna.storages.base import DEFAULT_STUDY_NAME_PREFIX
from optuna.storages import RDBStorage
from optuna.structs import CLIUsageError
from optuna.testing.storage import StorageSupplier
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from typing import List  # NOQA

    from optuna.trial import Trial  # NOQA


def test_create_study_command():
    # type: () -> None

    with StorageSupplier('new') as storage:
        assert isinstance(storage, RDBStorage)
        storage_url = str(storage.engine.url)

        # Create study.
        command = ['optuna', 'create-study', '--storage', storage_url]
        subprocess.check_call(command)
github optuna / optuna / optuna / integration / __init__.py View on Github external
'pytorch_ignite': ["PyTorchIgnitePruningHandler"],
    'pytorch_lightning': ['PyTorchLightningPruningCallback'],
    'sklearn': ['OptunaSearchCV'],
    'mxnet': ['MXNetPruningCallback'],
    'skopt': ['SkoptSampler'],
    'tensorflow': ['TensorFlowPruningHook'],
    'tfkeras': ['TFKerasPruningCallback'],
    'xgboost': ['XGBoostPruningCallback'],
    'fastai': ['FastAIPruningCallback'],
}


__all__ = list(_import_structure.keys()) + sum(_import_structure.values(), [])


if TYPE_CHECKING:
    from optuna.integration.chainer import ChainerPruningExtension  # NOQA
    from optuna.integration.chainermn import ChainerMNStudy  # NOQA
    from optuna.integration.cma import CmaEsSampler  # NOQA
    from optuna.integration.fastai import FastAIPruningCallback  # NOQA
    from optuna.integration.keras import KerasPruningCallback  # NOQA
    from optuna.integration.lightgbm import LightGBMPruningCallback  # NOQA
    from optuna.integration.lightgbm import LightGBMTuner  # NOQA
    from optuna.integration.mxnet import MXNetPruningCallback  # NOQA
    from optuna.integration.pytorch_ignite import PyTorchIgnitePruningHandler  # NOQA
    from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback  # NOQA
    from optuna.integration.sklearn import OptunaSearchCV  # NOQA
    from optuna.integration.skopt import SkoptSampler  # NOQA
    from optuna.integration.tensorflow import TensorFlowPruningHook  # NOQA
    from optuna.integration.tfkeras import TFKerasPruningCallback  # NOQA
    from optuna.integration.xgboost import XGBoostPruningCallback  # NOQA
else:
github optuna / optuna / optuna / distributions.py View on Github external
import abc
import json
import six

from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from typing import Any  # NOQA
    from typing import Dict  # NOQA
    from typing import Tuple  # NOQA
    from typing import Union  # NOQA


@six.add_metaclass(abc.ABCMeta)
class BaseDistribution(object):
    """Base class for distributions.

    Note that distribution classes are not supposed to be called by library users.
    They are used by :class:`~optuna.trial.Trial` and :class:`~optuna.samplers` internally.
    """

    def to_external_repr(self, param_value_in_internal_repr):
        # type: (float) -> Any
github optuna / optuna / optuna / integration / lightgbm_tuner / alias.py View on Github external
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from typing import Any  # NOQA
    from typing import Dict  # NOQA
    from typing import List  # NOQA


ALIAS_GROUP_LIST = [
    {
        'param_name': 'bagging_fraction',
        'alias_names': ['sub_row', 'subsample', 'bagging'],
        'default_value': None,
    },
    {
        'param_name': 'learning_rate',
        'alias_names': ['shrinkage_rate', 'eta'],
        'default_value': 0.1,  # Start from large `learning_rate` value.
    },
github optuna / optuna / optuna / integration / pytorch_ignite.py View on Github external
import optuna
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from optuna.trial import Trial  # NOQA

try:
    from ignite.engine import Engine  # NOQA
    _available = True
except ImportError as e:
    _import_error = e
    # PyTorchIgnitePruningHandler is disabled because pytorch-ignite is not available.
    _available = False


class PyTorchIgnitePruningHandler(object):
    """PyTorch Ignite handler to prune unpromising trials.

    Example:
github optuna / optuna / optuna / storages / in_memory.py View on Github external
import copy
from datetime import datetime
import threading

from optuna import distributions  # NOQA
from optuna.storages import base
from optuna.storages.base import DEFAULT_STUDY_NAME_PREFIX
from optuna import structs
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from typing import Any  # NOQA
    from typing import Dict  # NOQA
    from typing import List  # NOQA
    from typing import Optional  # NOQA

IN_MEMORY_STORAGE_STUDY_ID = 0
IN_MEMORY_STORAGE_STUDY_UUID = '00000000-0000-0000-0000-000000000000'


class InMemoryStorage(base.BaseStorage):
    """Storage class that stores data in memory of the Python process.

    This class is not supposed to be directly accessed by library users.
    """

    def __init__(self):
github optuna / optuna / optuna / integration / lightgbm_tuner / __init__.py View on Github external
from optuna.integration.lightgbm_tuner.sklearn import LGBMClassifier, LGBMModel, LGBMRegressor  # NOQA
from optuna.integration.lightgbm_tuner.optimize import LightGBMTuner
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from type_checking import Any  # NOQA
    from type_checking import Dict  # NOQA
    from type_checking import List  # NOQA
    from type_checking import Optional  # NOQA


def train(*args, **kwargs):
    # type: (List[Any], Optional[Dict[Any, Any]]) -> Any
    """Wrapper function of LightGBM API: train()

    Arguments and keyword arguments for `lightgbm.train()` can be passed.
    """

    auto_booster = LightGBMTuner(*args, **kwargs)
    booster = auto_booster.run()
    return booster