How to use the sagemaker.predictor.RealTimePredictor function in sagemaker

To help you get started, we’ve selected a few sagemaker examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github aws / sagemaker-python-sdk / tests / integ / test_ipinsights.py View on Github external
num_entity_vectors=10,
                vector_dim=100,
                sagemaker_session=sagemaker_session,
            )

        record_set = prepare_record_set_from_local_files(
            data_path, ipinsights.data_location, num_records, FEATURE_DIM, sagemaker_session
        )
        ipinsights.fit(records=record_set, job_name=job_name)

    with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session):
        model = IPInsightsModel(
            ipinsights.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session
        )
        predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
        assert isinstance(predictor, RealTimePredictor)

        predict_input = [["user_1", "1.1.1.1"]]
        result = predictor.predict(predict_input)

        assert len(result["predictions"]) == 1
        assert 0 > result["predictions"][0]["dot_product"] > -1  # We expect ~ -0.22
github aws / sagemaker-python-sdk / tests / unit / test_tf_predictor.py View on Github external
def test_predict_tensor_request_csv(sagemaker_session):
    data = [6.4, 3.2, 0.5, 1.5]
    tensor_proto = tf.make_tensor_proto(
        values=np.asarray(data), shape=[1, len(data)], dtype=tf.float32
    )
    predictor = RealTimePredictor(
        serializer=tf_csv_serializer,
        deserializer=tf_json_deserializer,
        sagemaker_session=sagemaker_session,
        endpoint=ENDPOINT,
    )

    mock_response(
        json.dumps(CLASSIFICATION_RESPONSE).encode("utf-8"), sagemaker_session, JSON_CONTENT_TYPE
    )

    result = predictor.predict(tensor_proto)

    sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with(
        Accept=JSON_CONTENT_TYPE,
        Body="6.4,3.2,0.5,1.5",
        ContentType=CSV_CONTENT_TYPE,
github aws / sagemaker-python-sdk / tests / unit / test_predictor.py View on Github external
def test_delete_endpoint_only():
    sagemaker_session = empty_sagemaker_session()
    predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
    predictor.delete_endpoint(delete_endpoint_config=False)

    sagemaker_session.delete_endpoint.assert_called_with(ENDPOINT)
    sagemaker_session.delete_endpoint_config.assert_not_called()
github aws / sagemaker-python-sdk / tests / unit / test_predictor.py View on Github external
def test_predict_call_with_headers_and_csv():
    sagemaker_session = ret_csv_sagemaker_session()
    predictor = RealTimePredictor(
        ENDPOINT, sagemaker_session, accept=CSV_CONTENT_TYPE, serializer=csv_serializer
    )

    data = [1, 2]
    result = predictor.predict(data)

    assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called

    expected_request_args = {
        "Accept": CSV_CONTENT_TYPE,
        "Body": "1,2",
        "ContentType": CSV_CONTENT_TYPE,
        "EndpointName": ENDPOINT,
    }
    call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
    assert kwargs == expected_request_args
github aws / sagemaker-python-sdk / tests / unit / test_predictor.py View on Github external
def test_predict_call_pass_through():
    sagemaker_session = empty_sagemaker_session()
    predictor = RealTimePredictor(ENDPOINT, sagemaker_session)

    data = "untouched"
    result = predictor.predict(data)

    assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called

    expected_request_args = {"Body": data, "EndpointName": ENDPOINT}
    call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
    assert kwargs == expected_request_args

    assert result == RETURN_VALUE
github aws / sagemaker-python-sdk / tests / integ / test_inference_pipeline.py View on Github external
model_data=sparkml_model_data,
            env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
            sagemaker_session=sagemaker_session,
        )
        xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
        xgb_model = Model(
            model_data=xgb_model_data, image=xgb_image, sagemaker_session=sagemaker_session
        )
        model = PipelineModel(
            models=[sparkml_model, xgb_model],
            role="SageMakerRole",
            sagemaker_session=sagemaker_session,
            name=endpoint_name,
        )
        model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
        predictor = RealTimePredictor(
            endpoint=endpoint_name,
            sagemaker_session=sagemaker_session,
            serializer=json_serializer,
            content_type=CONTENT_TYPE_CSV,
            accept=CONTENT_TYPE_CSV,
        )

        with open(VALID_DATA_PATH, "r") as f:
            valid_data = f.read()
            assert predictor.predict(valid_data) == "0.714013934135"

        with open(INVALID_DATA_PATH, "r") as f:
            invalid_data = f.read()
            assert predictor.predict(invalid_data) is None

    model.delete_model()
github aws / sagemaker-python-sdk / tests / unit / test_predictor.py View on Github external
def test_delete_model_fail():
    sagemaker_session = empty_sagemaker_session()
    sagemaker_session.sagemaker_client.delete_model = Mock(
        side_effect=Exception("Could not find model.")
    )
    expected_error_message = "One or more models cannot be deleted, please retry."

    predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)

    with pytest.raises(Exception) as exception:
        predictor.delete_model()
        assert expected_error_message in str(exception.val)
github aws / sagemaker-python-sdk / tests / unit / test_predictor.py View on Github external
def test_predict_call_with_headers_and_json():
    sagemaker_session = json_sagemaker_session()
    predictor = RealTimePredictor(
        ENDPOINT,
        sagemaker_session,
        content_type="not/json",
        accept="also/not-json",
        serializer=json_serializer,
    )

    data = [1, 2]
    result = predictor.predict(data)

    assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called

    expected_request_args = {
        "Accept": "also/not-json",
        "Body": json.dumps(data),
        "ContentType": "not/json",
github aws / sagemaker-python-sdk / src / sagemaker / mxnet / model.py View on Github external
from __future__ import absolute_import

import logging

from pkg_resources import parse_version

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.mxnet.defaults import MXNET_VERSION
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer

logger = logging.getLogger("sagemaker")


class MXNetPredictor(RealTimePredictor):
    """A RealTimePredictor for inference against MXNet Endpoints.

    This is able to serialize Python lists, dictionaries, and numpy arrays to
    multidimensional tensors for MXNet inference.
    """

    def __init__(self, endpoint_name, sagemaker_session=None):
        """Initialize an ``MXNetPredictor``.

        Args:
            endpoint_name (str): The name of the endpoint to perform inference
                on.
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed. If not specified, the estimator creates one
                using the default AWS configuration chain.
github aws / sagemaker-python-sdk / src / sagemaker / tensorflow / model.py View on Github external
"""Placeholder docstring"""
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import RealTimePredictor
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer

logger = logging.getLogger("sagemaker")


class TensorFlowPredictor(RealTimePredictor):
    """A ``RealTimePredictor`` for inference against TensorFlow endpoint.

    This is able to serialize Python lists, dictionaries, and numpy arrays to
    multidimensional tensors for inference
    """

    def __init__(self, endpoint_name, sagemaker_session=None):
        """Initialize an ``TensorFlowPredictor``.

        Args:
            endpoint_name (str): The name of the endpoint to perform inference
                on.
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed. If not specified, the estimator creates one
                using the default AWS configuration chain.