How to use the sagemaker.session.s3_input function in sagemaker

To help you get started, we’ve selected a few sagemaker examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github aws / sagemaker-python-sdk / tests / unit / test_tf_estimator.py View on Github external
def test_tf_deploy_model_server_workers_unset(sagemaker_session):
    tf = _build_tf(sagemaker_session)
    tf.fit(inputs=s3_input("s3://mybucket/train"))

    tf.deploy(initial_instance_count=1, instance_type="ml.c2.2xlarge")

    assert (
        MODEL_SERVER_WORKERS_PARAM_NAME.upper()
        not in sagemaker_session.method_calls[3][1][2]["Environment"]
    )
github aws / sagemaker-python-sdk / tests / unit / test_estimator.py View on Github external
def test_fit_verify_job_name(strftime, sagemaker_session):
    fw = DummyFramework(
        entry_point=SCRIPT_PATH,
        role="DummyRole",
        sagemaker_session=sagemaker_session,
        train_instance_count=INSTANCE_COUNT,
        train_instance_type=INSTANCE_TYPE,
        enable_cloudwatch_metrics=True,
        tags=TAGS,
        encrypt_inter_container_traffic=True,
    )
    fw.fit(inputs=s3_input("s3://mybucket/train"))

    _, _, train_kwargs = sagemaker_session.train.mock_calls[0]

    assert train_kwargs["hyperparameters"]["sagemaker_enable_cloudwatch_metrics"]
    assert train_kwargs["image"] == IMAGE_NAME
    assert train_kwargs["input_mode"] == "File"
    assert train_kwargs["tags"] == TAGS
    assert train_kwargs["job_name"] == JOB_NAME
    assert train_kwargs["encrypt_inter_container_traffic"] is True
    assert fw.latest_training_job.name == JOB_NAME
github aws / sagemaker-python-sdk / tests / unit / test_job.py View on Github external
def test_format_input_s3_input():
    input_dict = _Job._format_inputs_to_input_config(
        s3_input(
            "s3://foo/bar",
            distribution="ShardedByS3Key",
            compression="gzip",
            content_type="whizz",
            record_wrapping="bang",
        )
    )
    assert input_dict == [
        {
            "CompressionType": "gzip",
            "ChannelName": "training",
            "ContentType": "whizz",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3DataDistributionType": "ShardedByS3Key",
github aws / sagemaker-python-sdk / tests / unit / test_job.py View on Github external
def test_load_config(estimator):
    inputs = s3_input(BUCKET_NAME)

    config = _Job._load_config(inputs, estimator)

    assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == BUCKET_NAME
    assert config["role"] == ROLE
    assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH
    assert "KmsKeyId" not in config["output_config"]
    assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT
    assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE
    assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE
    assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME
github aws / sagemaker-python-sdk / tests / unit / test_job.py View on Github external
def test_dict_of_mixed_input_types():
    input_list = _Job._format_inputs_to_input_config(
        {"a": "s3://foo/bar", "b": s3_input("s3://whizz/bang")}
    )

    expected = [
        {
            "ChannelName": "a",
            "DataSource": {
                "S3DataSource": {
                    "S3DataDistributionType": "FullyReplicated",
                    "S3DataType": "S3Prefix",
                    "S3Uri": "s3://foo/bar",
                }
            },
        },
        {
            "ChannelName": "b",
            "DataSource": {
github aws / sagemaker-python-sdk / src / sagemaker / job.py View on Github external
uri_input,
                content_type=content_type,
                input_mode=input_mode,
                compression=compression,
                target_attribute_name=target_attribute_name,
            )
            return s3_input_result
        if isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"):
            return file_input(uri_input)
        if isinstance(uri_input, str) and validate_uri:
            raise ValueError(
                'URI input {} must be a valid S3 or FILE URI: must start with "s3://" or '
                '"file://"'.format(uri_input)
            )
        if isinstance(uri_input, str):
            s3_input_result = s3_input(
                uri_input,
                content_type=content_type,
                input_mode=input_mode,
                compression=compression,
                target_attribute_name=target_attribute_name,
            )
            return s3_input_result
        if isinstance(uri_input, (s3_input, file_input, FileSystemInput)):
            return uri_input

        raise ValueError(
            "Cannot format input {}. Expecting one of str, s3_input, file_input or "
            "FileSystemInput".format(uri_input)
        )
github aws / sagemaker-python-sdk / src / sagemaker / tuner.py View on Github external
inputs,
        estimator,
        static_hyperparameters,
        metric_definitions,
        estimator_name=None,
        objective_type=None,
        objective_metric_name=None,
        parameter_ranges=None,
    ):
        """Prepare training config for one estimator"""
        training_config = _Job._load_config(inputs, estimator)

        training_config["input_mode"] = estimator.input_mode
        training_config["metric_definitions"] = metric_definitions

        if isinstance(inputs, s3_input):
            if "InputMode" in inputs.config:
                logging.debug(
                    "Selecting s3_input's input_mode (%s) for TrainingInputMode.",
                    inputs.config["InputMode"],
                )
                training_config["input_mode"] = inputs.config["InputMode"]

        if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
            training_config["algorithm_arn"] = estimator.algorithm_arn
        else:
            training_config["image"] = estimator.train_image()

        training_config["enable_network_isolation"] = estimator.enable_network_isolation()
        training_config[
            "encrypt_inter_container_traffic"
        ] = estimator.encrypt_inter_container_traffic
github aws / sagemaker-python-sdk / src / sagemaker / job.py View on Github external
"""
        if isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("s3://"):
            return s3_input(
                model_uri,
                input_mode="File",
                distribution="FullyReplicated",
                content_type="application/x-sagemaker-model",
            )
        if isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("file://"):
            return file_input(model_uri)
        if isinstance(model_uri, string_types) and validate_uri:
            raise ValueError(
                'Model URI must be a valid S3 or FILE URI: must start with "s3://" or ' '"file://'
            )
        if isinstance(model_uri, string_types):
            return s3_input(
                model_uri,
                input_mode="File",
                distribution="FullyReplicated",
                content_type="application/x-sagemaker-model",
            )
        raise ValueError("Cannot format model URI {}. Expecting str".format(model_uri))
github aws / sagemaker-python-sdk / src / sagemaker / amazon / amazon_estimator.py View on Github external
def records_s3_input(self):
        """Return a s3_input to represent the training data"""
        return s3_input(self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type)
github aws / sagemaker-python-sdk / src / sagemaker / job.py View on Github external
content_type=None,
        input_mode=None,
        compression=None,
        target_attribute_name=None,
    ):
        """
        Args:
            uri_input:
            validate_uri:
            content_type:
            input_mode:
            compression:
            target_attribute_name:
        """
        if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
            s3_input_result = s3_input(
                uri_input,
                content_type=content_type,
                input_mode=input_mode,
                compression=compression,
                target_attribute_name=target_attribute_name,
            )
            return s3_input_result
        if isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"):
            return file_input(uri_input)
        if isinstance(uri_input, str) and validate_uri:
            raise ValueError(
                'URI input {} must be a valid S3 or FILE URI: must start with "s3://" or '
                '"file://"'.format(uri_input)
            )
        if isinstance(uri_input, str):
            s3_input_result = s3_input(