How to use the sagemaker.amazon.amazon_estimator.RecordSet function in sagemaker

To help you get started, we’ve selected a few sagemaker examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github aws / sagemaker-python-sdk / tests / unit / test_pca.py View on Github external
def test_call_fit(base_fit, sagemaker_session):
    pca = PCA(base_job_name="pca", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )

    pca.fit(data, MINI_BATCH_SIZE)

    base_fit.assert_called_once()
    assert len(base_fit.call_args[0]) == 2
    assert base_fit.call_args[0][0] == data
    assert base_fit.call_args[0][1] == MINI_BATCH_SIZE
github aws / sagemaker-python-sdk / tests / unit / test_randomcutforest.py View on Github external
def test_model_image(sagemaker_session):
    randomcutforest = RandomCutForest(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )
    randomcutforest.fit(data, MINI_BATCH_SIZE)

    model = randomcutforest.create_model()
    assert model.image == registry(REGION, "randomcutforest") + "/randomcutforest:1"
github aws / sagemaker-python-sdk / tests / unit / test_randomcutforest.py View on Github external
def test_predictor_type(sagemaker_session):
    randomcutforest = RandomCutForest(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )
    randomcutforest.fit(data, MINI_BATCH_SIZE)
    model = randomcutforest.create_model()
    predictor = model.deploy(1, TRAIN_INSTANCE_TYPE)

    assert isinstance(predictor, RandomCutForestPredictor)
github aws / sagemaker-python-sdk / tests / unit / test_ntm.py View on Github external
def test_call_fit_none_mini_batch_size(sagemaker_session):
    ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )
    ntm.fit(data)
github aws / sagemaker-python-sdk / tests / unit / test_kmeans.py View on Github external
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
    kmeans = KMeans(base_job_name="kmeans", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )
    kmeans._prepare_for_training(data)

    assert kmeans.mini_batch_size == 5000
github aws / sagemaker-python-sdk / tests / unit / test_ntm.py View on Github external
def test_prepare_for_training_wrong_value_upper_mini_batch_size(sagemaker_session):
    ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)

    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )
    with pytest.raises(ValueError):
        ntm._prepare_for_training(data, 10001)
github aws / sagemaker-python-sdk / tests / unit / test_ipinsights.py View on Github external
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session):
    ipinsights = IPInsights(
        base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS
    )

    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )

    with pytest.raises((TypeError, ValueError)):
        ipinsights._prepare_for_training(data, "some")
github aws / sagemaker-python-sdk / tests / unit / test_randomcutforest.py View on Github external
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
    randomcutforest = RandomCutForest(
        base_job_name="randomcutforest", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS
    )

    data = RecordSet(
        "s3://{}/{}".format(BUCKET_NAME, PREFIX),
        num_records=1,
        feature_dim=FEATURE_DIM,
        channel="train",
    )
    randomcutforest._prepare_for_training(data)

    assert randomcutforest.mini_batch_size == MINI_BATCH_SIZE
github aws / sagemaker-python-sdk / tests / unit / test_job.py View on Github external
def test_format_inputs_to_input_config_list_not_all_records():
    records = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1)
    inputs = [records, "mock"]

    with pytest.raises(ValueError) as ex:
        _Job._format_inputs_to_input_config(inputs)

    assert "List compatible only with RecordSets or FileSystemRecordSets." in str(ex)
github aws / sagemaker-python-sdk / src / sagemaker / workflow / airflow.py View on Github external
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of

                Amazon :class:~`Record` objects serialized and stored in S3. For
                use with an estimator for an Amazon algorithm.

            * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
                  :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
                  where each instance is a different channel of training data.
        mini_batch_size:
    """
    if isinstance(inputs, list):
        for record in inputs:
            if isinstance(record, amazon_estimator.RecordSet) and record.channel == "train":
                estimator.feature_dim = record.feature_dim
                break
    elif isinstance(inputs, amazon_estimator.RecordSet):
        estimator.feature_dim = inputs.feature_dim
    else:
        raise TypeError("Training data must be represented in RecordSet or list of RecordSets")
    estimator.mini_batch_size = mini_batch_size