How to use the flyteidl.plugins.sagemaker.training_job_pb2 function in flyteidl

To help you get started, we’ve selected a few flyteidl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github lyft / flytekit / flytekit / models / sagemaker / training_job.py View on Github external
def to_flyte_idl(self):
        return _training_job.StoppingCondition(
            max_runtime_in_seconds=self.max_runtime_in_seconds,
            max_wait_time_in_seconds=self.max_wait_time_in_seconds,
        )
github lyft / flytekit / flytekit / models / sagemaker / training_job.py View on Github external
def to_flyte_idl(self):
        return _training_job.TrainingJobConfig(
            instance_count=self.instance_count,
            instance_type=self.instance_type,
            volume_size_in_gb=self.volume_size_in_gb,
        )
github lyft / flytekit / flytekit / models / sagemaker / training_job.py View on Github external
def from_flyte_idl(cls, pb2_object):

        input_mode = _sdk_sagemaker_types.InputMode.FILE
        if pb2_object.input_mode == _training_job.InputMode.PIPE:
            input_mode = _sdk_sagemaker_types.InputMode.PIPE

        algorithm_name = _sdk_sagemaker_types.AlgorithmName.CUSTOM
        if pb2_object.algorithm_name == _training_job.AlgorithmName.XGBOOST:
            algorithm_name = _sdk_sagemaker_types.AlgorithmName.XGBOOST

        return cls(
            input_mode=input_mode,
            algorithm_name=algorithm_name,
            algorithm_version=pb2_object.algorithm_version,
            metric_definitions=[MetricDefinition.from_flyte_idl(m) for m in pb2_object.metric_definitions],
        )
github lyft / flytekit / flytekit / models / sagemaker / training_job.py View on Github external
def to_flyte_idl(self):
        """
        :return: _training_job.TrainingJob
        """

        return _training_job.TrainingJob(
            algorithm_specification=self.algorithm_specification.to_flyte_idl(),
            training_job_config=self.training_job_config.to_flyte_idl(),
        )
github lyft / flytekit / flytekit / models / sagemaker / training_job.py View on Github external
input_mode = _training_job.InputMode.FILE
        elif self.input_mode == _sdk_sagemaker_types.InputMode.PIPE:
            input_mode = _training_job.InputMode.PIPE
        else:
            raise _user_exceptions.FlyteValidationException(
                "Invalid SageMaker Input Mode Specified: [{}]".format(self.input_mode))

        if self.algorithm_name == _sdk_sagemaker_types.AlgorithmName.CUSTOM:
            alg_name = _training_job.AlgorithmName.CUSTOM
        elif self.algorithm_name == _sdk_sagemaker_types.AlgorithmName.XGBOOST:
            alg_name = _training_job.AlgorithmName.XGBOOST
        else:
            raise _user_exceptions.FlyteValidationException(
                "Invalid SageMaker Algorithm Name Specified: [{}]".format(self.algorithm_name))

        return _training_job.AlgorithmSpecification(
            input_mode=input_mode,
            algorithm_name=alg_name,
            algorithm_version=self.algorithm_version,
            metric_definitions=[m.to_flyte_idl() for m in self.metric_definitions],
        )
github lyft / flytekit / flytekit / models / sagemaker / training_job.py View on Github external
def to_flyte_idl(self):

        if self.input_mode == _sdk_sagemaker_types.InputMode.FILE:
            input_mode = _training_job.InputMode.FILE
        elif self.input_mode == _sdk_sagemaker_types.InputMode.PIPE:
            input_mode = _training_job.InputMode.PIPE
        else:
            raise _user_exceptions.FlyteValidationException(
                "Invalid SageMaker Input Mode Specified: [{}]".format(self.input_mode))

        if self.algorithm_name == _sdk_sagemaker_types.AlgorithmName.CUSTOM:
            alg_name = _training_job.AlgorithmName.CUSTOM
        elif self.algorithm_name == _sdk_sagemaker_types.AlgorithmName.XGBOOST:
            alg_name = _training_job.AlgorithmName.XGBOOST
        else:
            raise _user_exceptions.FlyteValidationException(
                "Invalid SageMaker Algorithm Name Specified: [{}]".format(self.algorithm_name))

        return _training_job.AlgorithmSpecification(
            input_mode=input_mode,
            algorithm_name=alg_name,
            algorithm_version=self.algorithm_version,
            metric_definitions=[m.to_flyte_idl() for m in self.metric_definitions],
        )
github lyft / flytekit / flytekit / common / tasks / sagemaker / training_job_task.py View on Github external
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
                            ),
                        ),
                        description="",
                    ),
                    "validation": _interface_model.Variable(
                        type=_idl_types.LiteralType(
                            blob=_core_types.BlobType(
                                format="csv",
                                dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
                            ),
                        ),
                        description="",
                    ),
                    "stopping_condition": _interface_model.Variable(
                        _sdk_types.Types.Proto(_training_job_pb2.StoppingCondition).to_flyte_literal_type(), ""
                    )
                },
                outputs={
                    "model": _interface_model.Variable(
                        type=_idl_types.LiteralType(
                            blob=_core_types.BlobType(
                                format="",
                                dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
                            )
                        ),
                        description=""
                    )
                }
            ),
            custom=MessageToDict(self._training_job_model.to_flyte_idl()),
        )