How to use the flyteidl.plugins.sagemaker.hpo_job_pb2.HPOJobConfig function in flyteidl

To help you get started, we’ve selected a few flyteidl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github lyft / flytekit / tests / flytekit / unit / sdk / tasks / test_sagemaker_tasks.py View on Github external
assert simple_training_job_task.interface.inputs['train'].description == ''
    assert simple_training_job_task.interface.inputs['train'].type == \
        _sdk_types.Types.MultiPartCSV.to_flyte_literal_type()
    assert simple_training_job_task.interface.inputs['validation'].description == ''
    assert simple_training_job_task.interface.inputs['validation'].type == \
        _sdk_types.Types.MultiPartCSV.to_flyte_literal_type()
    assert simple_training_job_task.interface.inputs['static_hyperparameters'].description == ''
    assert simple_training_job_task.interface.inputs['static_hyperparameters'].type == \
        _sdk_types.Types.Generic.to_flyte_literal_type()
    assert simple_training_job_task.interface.inputs['stopping_condition'].type == \
        _sdk_types.Types.Proto(_pb2_StoppingCondition).to_flyte_literal_type()

    # Checking if the hpo-specific input is defined
    assert simple_xgboost_hpo_job_task.interface.inputs['hpo_job_config'].description == ''
    assert simple_xgboost_hpo_job_task.interface.inputs['hpo_job_config'].type == \
           _sdk_types.Types.Proto(_pb2_HPOJobConfig).to_flyte_literal_type()
    assert simple_xgboost_hpo_job_task.interface.outputs['model'].description == ''
    assert simple_xgboost_hpo_job_task.interface.outputs['model'].type == \
           _sdk_types.Types.Blob.to_flyte_literal_type()
    assert simple_xgboost_hpo_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_HPO_JOB_TASK

    # Checking if the spec of the TrainingJob is embedded into the custom field of this SdkSimpleHPOJobTask
    assert simple_xgboost_hpo_job_task.to_flyte_idl().custom["trainingJob"] == \
           simple_training_job_task.to_flyte_idl().custom

    assert simple_xgboost_hpo_job_task.metadata.timeout == _datetime.timedelta(seconds=0)
    assert simple_xgboost_hpo_job_task.metadata.discoverable is True
    assert simple_xgboost_hpo_job_task.metadata.discovery_version == '1'
    assert simple_xgboost_hpo_job_task.metadata.retries.retries == 2
    """
    assert simple_xgboost_hpo_job_task.task_module == __name__
github lyft / flytekit / flytekit / models / sagemaker / hpo_job.py View on Github external
elif self._tuning_strategy == _sdk_sagemaker_types.HyperparameterTuningStrategy.RANDOM:
            idl_strategy = _idl_hpo_job.HPOJobConfig.HyperparameterTuningStrategy.RANDOM
        else:
            raise _user_exceptions.FlyteValidationException(
                "Invalid Hyperparameter Tuning Strategy: {}".format(self._tuning_strategy))

        if self._training_job_early_stopping_type == _sdk_sagemaker_types.TrainingJobEarlyStoppingType.OFF:
            idl_training_early_stopping_type = _idl_hpo_job.HPOJobConfig.TrainingJobEarlyStoppingType.OFF
        elif self._training_job_early_stopping_type == _sdk_sagemaker_types.TrainingJobEarlyStoppingType.AUTO:
            idl_training_early_stopping_type = _idl_hpo_job.HPOJobConfig.TrainingJobEarlyStoppingType.AUTO
        else:
            raise _user_exceptions.FlyteValidationException(
                "Invalid Training Job Early Stopping Type (in HPO Config): {}".format(
                    self._training_job_early_stopping_type))

        return _idl_hpo_job.HPOJobConfig(
            hyperparameter_ranges=self._hyperparameter_ranges.to_flyte_idl(),
            tuning_strategy=idl_strategy,
            tuning_objective=self._tuning_objective.to_flyte_idl(),
            training_job_early_stopping_type=idl_training_early_stopping_type,
        )
github lyft / flytekit / flytekit / models / sagemaker / hpo_job.py View on Github external
def to_flyte_idl(self):

        if self._tuning_strategy == _sdk_sagemaker_types.HyperparameterTuningStrategy.BAYESIAN:
            idl_strategy = _idl_hpo_job.HPOJobConfig.HyperparameterTuningStrategy.BAYESIAN
        elif self._tuning_strategy == _sdk_sagemaker_types.HyperparameterTuningStrategy.RANDOM:
            idl_strategy = _idl_hpo_job.HPOJobConfig.HyperparameterTuningStrategy.RANDOM
        else:
            raise _user_exceptions.FlyteValidationException(
                "Invalid Hyperparameter Tuning Strategy: {}".format(self._tuning_strategy))

        if self._training_job_early_stopping_type == _sdk_sagemaker_types.TrainingJobEarlyStoppingType.OFF:
            idl_training_early_stopping_type = _idl_hpo_job.HPOJobConfig.TrainingJobEarlyStoppingType.OFF
        elif self._training_job_early_stopping_type == _sdk_sagemaker_types.TrainingJobEarlyStoppingType.AUTO:
            idl_training_early_stopping_type = _idl_hpo_job.HPOJobConfig.TrainingJobEarlyStoppingType.AUTO
        else:
            raise _user_exceptions.FlyteValidationException(
                "Invalid Training Job Early Stopping Type (in HPO Config): {}".format(
                    self._training_job_early_stopping_type))

        return _idl_hpo_job.HPOJobConfig(
            hyperparameter_ranges=self._hyperparameter_ranges.to_flyte_idl(),
            tuning_strategy=idl_strategy,