Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
assert simple_training_job_task.interface.inputs['train'].description == ''
assert simple_training_job_task.interface.inputs['train'].type == \
_sdk_types.Types.MultiPartCSV.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['validation'].description == ''
assert simple_training_job_task.interface.inputs['validation'].type == \
_sdk_types.Types.MultiPartCSV.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['static_hyperparameters'].description == ''
assert simple_training_job_task.interface.inputs['static_hyperparameters'].type == \
_sdk_types.Types.Generic.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['stopping_condition'].type == \
_sdk_types.Types.Proto(_pb2_StoppingCondition).to_flyte_literal_type()
# Checking if the hpo-specific input is defined
assert simple_xgboost_hpo_job_task.interface.inputs['hpo_job_config'].description == ''
assert simple_xgboost_hpo_job_task.interface.inputs['hpo_job_config'].type == \
_sdk_types.Types.Proto(_pb2_HPOJobConfig).to_flyte_literal_type()
assert simple_xgboost_hpo_job_task.interface.outputs['model'].description == ''
assert simple_xgboost_hpo_job_task.interface.outputs['model'].type == \
_sdk_types.Types.Blob.to_flyte_literal_type()
assert simple_xgboost_hpo_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_HPO_JOB_TASK
# Checking if the spec of the TrainingJob is embedded into the custom field of this SdkSimpleHPOJobTask
assert simple_xgboost_hpo_job_task.to_flyte_idl().custom["trainingJob"] == \
simple_training_job_task.to_flyte_idl().custom
assert simple_xgboost_hpo_job_task.metadata.timeout == _datetime.timedelta(seconds=0)
assert simple_xgboost_hpo_job_task.metadata.discoverable is True
assert simple_xgboost_hpo_job_task.metadata.discovery_version == '1'
assert simple_xgboost_hpo_job_task.metadata.retries.retries == 2
"""
assert simple_xgboost_hpo_job_task.task_module == __name__
elif self._tuning_strategy == _sdk_sagemaker_types.HyperparameterTuningStrategy.RANDOM:
idl_strategy = _idl_hpo_job.HPOJobConfig.HyperparameterTuningStrategy.RANDOM
else:
raise _user_exceptions.FlyteValidationException(
"Invalid Hyperparameter Tuning Strategy: {}".format(self._tuning_strategy))
if self._training_job_early_stopping_type == _sdk_sagemaker_types.TrainingJobEarlyStoppingType.OFF:
idl_training_early_stopping_type = _idl_hpo_job.HPOJobConfig.TrainingJobEarlyStoppingType.OFF
elif self._training_job_early_stopping_type == _sdk_sagemaker_types.TrainingJobEarlyStoppingType.AUTO:
idl_training_early_stopping_type = _idl_hpo_job.HPOJobConfig.TrainingJobEarlyStoppingType.AUTO
else:
raise _user_exceptions.FlyteValidationException(
"Invalid Training Job Early Stopping Type (in HPO Config): {}".format(
self._training_job_early_stopping_type))
return _idl_hpo_job.HPOJobConfig(
hyperparameter_ranges=self._hyperparameter_ranges.to_flyte_idl(),
tuning_strategy=idl_strategy,
tuning_objective=self._tuning_objective.to_flyte_idl(),
training_job_early_stopping_type=idl_training_early_stopping_type,
)
def to_flyte_idl(self):
if self._tuning_strategy == _sdk_sagemaker_types.HyperparameterTuningStrategy.BAYESIAN:
idl_strategy = _idl_hpo_job.HPOJobConfig.HyperparameterTuningStrategy.BAYESIAN
elif self._tuning_strategy == _sdk_sagemaker_types.HyperparameterTuningStrategy.RANDOM:
idl_strategy = _idl_hpo_job.HPOJobConfig.HyperparameterTuningStrategy.RANDOM
else:
raise _user_exceptions.FlyteValidationException(
"Invalid Hyperparameter Tuning Strategy: {}".format(self._tuning_strategy))
if self._training_job_early_stopping_type == _sdk_sagemaker_types.TrainingJobEarlyStoppingType.OFF:
idl_training_early_stopping_type = _idl_hpo_job.HPOJobConfig.TrainingJobEarlyStoppingType.OFF
elif self._training_job_early_stopping_type == _sdk_sagemaker_types.TrainingJobEarlyStoppingType.AUTO:
idl_training_early_stopping_type = _idl_hpo_job.HPOJobConfig.TrainingJobEarlyStoppingType.AUTO
else:
raise _user_exceptions.FlyteValidationException(
"Invalid Training Job Early Stopping Type (in HPO Config): {}".format(
self._training_job_early_stopping_type))
return _idl_hpo_job.HPOJobConfig(
hyperparameter_ranges=self._hyperparameter_ranges.to_flyte_idl(),
tuning_strategy=idl_strategy,