How to use the flytekit.models.literals function in flytekit

To help you get started, we’ve selected a few flytekit examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github lyft / flytekit / tests / flytekit / unit / bin / test_python_entrypoint.py View on Github external
internal_overrides={
                                     'project': 'test',
                                     'domain': 'development'
                                 }):
        with _utils.AutoDeletingTempDir("dir") as dir:
            literal_map = _type_helpers.pack_python_std_map_to_literal_map(
                {'a': 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs))

            input_dir = os.path.join(dir.name, "1")
            os.mkdir(input_dir)    # auto cleanup will take this subdir into account

            input_file = os.path.join(input_dir, "inputs.pb")
            _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file)

            # construct indexlookup.pb which has array: [1]
            mapped_index = _literals.Literal(_literals.Scalar(primitive=_literals.Primitive(integer=1)))
            index_lookup_collection = _literals.LiteralCollection([mapped_index])
            index_lookup_file = os.path.join(dir.name, "indexlookup.pb")
            _utils.write_proto_to_file(index_lookup_collection.to_flyte_idl(), index_lookup_file)

            # fake arrayjob task by setting environment variables
            orig_env_index_var_name = os.environ.get('BATCH_JOB_ARRAY_INDEX_VAR_NAME')
            orig_env_array_index = os.environ.get('AWS_BATCH_JOB_ARRAY_INDEX')
            os.environ['BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = 'AWS_BATCH_JOB_ARRAY_INDEX'
            os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = '0'

            _execute_task(
                _task_defs.add_one.task_module,
                _task_defs.add_one.task_function_name,
                dir.name,
                dir.name,
                False
github lyft / flytekit / tests / flytekit / unit / engines / flyte / test_engine.py View on Github external
from flytekit.engines.flyte import engine
from flytekit.models import literals, execution as _execution_models, common as _common_models, launch_plan as \
    _launch_plan_models, task as _task_models
from flytekit.models.admin import common as _common
from flytekit.models.core import errors, identifier
from flytekit.sdk import test_utils


_INPUT_MAP = literals.LiteralMap(
    {
        'a': literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1)))
    }
)
_OUTPUT_MAP = literals.LiteralMap(
    {
        'b': literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=2)))
    }
)


@pytest.fixture(scope="function", autouse=True)
def temp_config():
    with TemporaryConfiguration(
            os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../../common/configs/local.config'),
            internal_overrides={
                'image': 'myflyteimage:{}'.format(
                    os.environ.get('IMAGE_VERSION', 'sha')
                ),
                'project': 'myflyteproject',
                'domain': 'development'
            }
    ):
github lyft / flytekit / flytekit / common / workflow.py View on Github external
id = id if id is not None else _identifier.Identifier(
            _identifier_model.ResourceType.WORKFLOW,
            _internal_config.PROJECT.get(),
            _internal_config.DOMAIN.get(),
            _uuid.uuid4().hex,
            _internal_config.VERSION.get()
        )
        metadata = metadata if metadata is not None else _workflow_models.WorkflowMetadata()

        interface = interface if interface is not None else _interface.TypedInterface(
            {v.name: v.var for v in inputs},
            {v.name: v.var for v in outputs}
        )

        output_bindings = output_bindings if output_bindings is not None else \
            [_literal_models.Binding(v.name, v.binding_data) for v in outputs]

        super(SdkWorkflow, self).__init__(
            id=id,
            metadata=metadata,
            metadata_defaults=_workflow_models.WorkflowMetadataDefaults(),
            interface=interface,
            nodes=nodes,
            outputs=output_bindings,
        )
        self._user_inputs = inputs
        self._upstream_entities = set(n.executable_sdk_object for n in nodes)
github lyft / flytekit / flytekit / common / tasks / sagemaker / training_job_task.py View on Github external
# Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training
        # job gracefully
        timeout = _datetime.timedelta(seconds=0)

        super(SdkSimpleTrainingJobTask, self).__init__(
            type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor='sagemaker'
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=interruptible,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs={
                    "static_hyperparameters": _interface_model.Variable(
                        type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT),
                        description="",
                    ),
                    "train": _interface_model.Variable(
                        type=_idl_types.LiteralType(
                            blob=_core_types.BlobType(
                                format="csv",
                                dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
                            ),
github lyft / flytekit / flytekit / common / tasks / raw_container.py View on Github external
self._data_loading_config = _task_models.DataLoadingConfig(
            input_path=input_data_dir,
            output_path=output_data_dir,
            format=metadata_format,
            enabled=True,
            io_strategy=io_strategy,
        )

        metadata = _task_models.TaskMetadata(
            discoverable,
            # This needs to have the proper version reflected in it
            _task_models.RuntimeMetadata(
                _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__,
                "python"),
            timeout or _datetime.timedelta(seconds=0),
            _literals.RetryStrategy(retries),
            interruptible,
            discovery_version,
            None
        )

        # The interface is defined using the inputs and outputs
        i = _interface.TypedInterface(inputs=types_to_variable(inputs), outputs=types_to_variable(outputs))

        # This sets the base SDKTask with container etc
        super(SdkRawContainerTask, self).__init__(
            _constants.SdkTaskType.RAW_CONTAINER_TASK,
            metadata,
            i,
            None,
            container=_get_container_definition(
                image=image,
github lyft / flytekit / flytekit / common / interface.py View on Github external
)

        extra_inputs = set(binding_data.keys()) ^ set(map_of_bindings.keys())
        if len(extra_inputs) > 0:
            raise _user_exceptions.FlyteAssertion(
                "Too many inputs were specified for the interface.  Extra inputs were: {}".format(extra_inputs)
            )

        seen_nodes = set()
        min_upstream = list()
        for n in all_upstream_nodes:
            if n not in seen_nodes:
                seen_nodes.add(n)
                min_upstream.append(n)

        return [_literal_models.Binding(k, bd) for k, bd in _six.iteritems(binding_data)], min_upstream
github lyft / flytekit / flytekit / engines / unit / engine.py View on Github external
[DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for sub_binding_data in
                 binding_data.collection.bindings]))
        elif binding_data.promise:
            if binding_data.promise.node_id not in fulfilled_promises:
                raise _system_exception.FlyteSystemAssertion(
                    "Expecting output of node [{}] but that hasn't been produced.".format(binding_data.promise.node_id))
            node_output = fulfilled_promises[binding_data.promise.node_id]
            if binding_data.promise.var not in node_output:
                raise _system_exception.FlyteSystemAssertion(
                    "Expecting output [{}] of node [{}] but that hasn't been produced.".format(
                        binding_data.promise.var,
                        binding_data.promise.node_id))

            return binding_data.promise.sdk_type.from_python_std(node_output[binding_data.promise.var])
        elif binding_data.map:
            return _literals.Literal(map=_literals.LiteralMap(
                {
                    k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for k, sub_binding_data in
                    _six.iteritems(binding_data.map.bindings)
                }))
github lyft / flytekit / flytekit / common / interface.py View on Github external
from __future__ import absolute_import

import six as _six

from flytekit.common import sdk_bases as _sdk_bases, promise as _promise
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.types import helpers as _type_helpers, containers as _containers, primitives as _primitives
from flytekit.models import interface as _interface_models, literals as _literal_models


class BindingData(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _literal_models.BindingData)):

    @staticmethod
    def _has_sub_bindings(m):
        """
        :param dict[Text,T] or list[T]:
        :rtype: bool
        """
        for v in _six.itervalues(m) if isinstance(m, dict) else m:
            if isinstance(v, (list, dict)) and BindingData._has_sub_bindings(v):
                return True
            elif isinstance(v, (_promise.Input, _promise.NodeOutput)):
                return True
        return False

    @classmethod
    def promote_from_model(cls, model):