Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_fetch_latest(mock_get_engine):
admin_task = _task_models.Task(
_identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"),
_MagicMock(),
)
mock_engine = _MagicMock()
mock_engine.fetch_latest_task = _MagicMock(
return_value=admin_task
)
mock_get_engine.return_value = mock_engine
task = _task.SdkTask.fetch_latest("p1", "d1", "n1")
assert task.id == admin_task.id
LIST_OF_ALL_LITERAL_TYPES = \
LIST_OF_SCALAR_LITERAL_TYPES + \
LIST_OF_COLLECTION_LITERAL_TYPES + \
LIST_OF_NESTED_COLLECTION_LITERAL_TYPES
LIST_OF_INTERFACES = [
interface.TypedInterface(
{'a': interface.Variable(t, "description 1")},
{'b': interface.Variable(t, "description 2")}
)
for t in LIST_OF_ALL_LITERAL_TYPES
]
LIST_OF_RESOURCE_ENTRIES = [
task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1"),
task.Resources.ResourceEntry(task.Resources.ResourceName.GPU, "1"),
task.Resources.ResourceEntry(task.Resources.ResourceName.MEMORY, "1G"),
task.Resources.ResourceEntry(task.Resources.ResourceName.STORAGE, "1G")
]
LIST_OF_RESOURCE_ENTRY_LISTS = [
LIST_OF_RESOURCE_ENTRIES
]
LIST_OF_RESOURCES = [
task.Resources(request, limit)
for request, limit in product(LIST_OF_RESOURCE_ENTRY_LISTS, LIST_OF_RESOURCE_ENTRY_LISTS)
]
_task_models.Resources.ResourceEntry(
_task_models.Resources.ResourceName.CPU,
cpu_limit
)
)
if gpu_limit:
limits.append(
_task_models.Resources.ResourceEntry(
_task_models.Resources.ResourceName.GPU,
gpu_limit
)
)
if memory_limit:
limits.append(
_task_models.Resources.ResourceEntry(
_task_models.Resources.ResourceName.MEMORY,
memory_limit
)
)
if environment is None:
environment = {}
return _task_models.Container(
image=image,
command=command,
args=args,
resources=_task_models.Resources(limits=limits, requests=requests),
env=environment,
config={},
data_loading_config=data_loading_config,
)
:param Text gpu_request:
:param Text memory_request:
:param Text storage_limit:
:param Text cpu_limit:
:param Text gpu_limit:
:param Text memory_limit:
:param bool discoverable:
:param datetime.timedelta timeout:
:param dict[Text, Text] environment:
:param dict[Text, T] custom:
"""
self._task_function = task_function
super(SdkRunnableTask, self).__init__(
task_type,
_task_models.TaskMetadata(
discoverable,
_task_models.RuntimeMetadata(
_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
__version__,
'python'
),
timeout,
_literal_models.RetryStrategy(retries),
interruptible,
discovery_version,
deprecated
),
_interface.TypedInterface({}, {}),
custom,
container=self._get_container_definition(
storage_request=storage_request,
def serialize(self):
"""
:rtype: flyteidl.admin.task_pb2.TaskSpec
"""
return _task_model.TaskSpec(self).to_flyte_idl()
:param Text catalog: The catalog to set for the given Presto query
:param Text schema: The schema to set for the given Presto query
:param dict[Text,flytekit.common.types.base_sdk_types.FlyteSdkType] task_inputs: Optional inputs to the Presto task
:param bool discoverable:
:param Text discovery_version: String describing the version for task discovery purposes
:param int retries: Number of retries to attempt
:param datetime.timedelta timeout:
"""
# Set as class fields which are used down below to configure implicit
# parameters
self._routing_group = routing_group or ""
self._catalog = catalog or ""
self._schema = schema or ""
metadata = _task_model.TaskMetadata(
discoverable,
# This needs to have the proper version reflected in it
_task_model.RuntimeMetadata(
_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__,
"python"),
timeout or _datetime.timedelta(seconds=0),
_literals.RetryStrategy(retries),
interruptible,
discovery_version,
"This is deprecated!"
)
presto_query = _presto_models.PrestoQuery(
routing_group=routing_group or "",
catalog=catalog or "",
schema=schema or "",
def get_task(self, id):
"""
This returns a single task for a given identifier.
:param flytekit.models.core.identifier.Identifier id: The ID representing a given task.
:raises: TODO
:rtype: flytekit.models.task.Task
"""
return _task.Task.from_flyte_idl(
super(SynchronousFlyteClient, self).get_task(
_common_pb2.ObjectGetRequest(
id=id.to_flyte_idl()
)
def execution_id(self):
"""
This is the identifier of the workflow execution within the underlying engine. It will be consistent across all
task executions in a workflow or sub-workflow execution.
.. note::
Do NOT use this execution_id to drive any production logic. This execution ID should only be used as a tag
on output data to link back to the workflow run that created it.
:rtype: Text
"""
return self._execution_id
class SdkRunnableContainer(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _task_models.Container)):
def __init__(
self,
command,
args,
resources,
env,
config,
):
super(SdkRunnableContainer, self).__init__(
"",
command,
args,
resources,
env or {},
config
# Just saving everything as a hash for now, will figure out what to do with this in the future.
task_obj = {}
task_obj['task_type'] = _common_constants.SdkTaskType.PYTHON_TASK,
task_obj['retries'] = retries,
task_obj['storage_request'] = storage_request,
task_obj['cpu_request'] = cpu_request,
task_obj['gpu_request'] = gpu_request,
task_obj['memory_request'] = memory_request,
task_obj['storage_limit'] = storage_limit,
task_obj['cpu_limit'] = cpu_limit,
task_obj['gpu_limit'] = gpu_limit,
task_obj['memory_limit'] = memory_limit,
task_obj['environment'] = environment,
task_obj['custom'] = {}
metadata = _task_model.TaskMetadata(
cache,
_task_model.RuntimeMetadata(
_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK,
'1.2.3',
'python'
),
timeout or _datetime.timedelta(seconds=0),
_literal_models.RetryStrategy(retries),
interruptible,
cache_version,
deprecated
)
interface = get_interface_from_task_info(fn.__annotations__, outputs or [])
task_instance = PythonTask(fn, interface, metadata, outputs, task_obj)
from flytekit.common import interface as _interfaces, nodes as _nodes, sdk_bases as _sdk_bases
from flytekit.common.core import identifier as _identifier
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.mixins import registerable as _registerable, hash as _hash_mixin
from flytekit.configuration import internal as _internal_config
from flytekit.engines import loader as _engine_loader
from flytekit.models import common as _common_model, task as _task_model
from flytekit.models.core import workflow as _workflow_model, identifier as _identifier_model
from flytekit.common.exceptions import user as _user_exceptions
class SdkTask(
_six.with_metaclass(
_sdk_bases.ExtendedSdkType,
_hash_mixin.HashOnReferenceMixin,
_task_model.TaskTemplate,
_registerable.RegisterableEntity,
)
):
def __init__(self, type, metadata, interface, custom, container=None):
"""
:param Text type: This is used to define additional extensions for use by Propeller or SDK.
:param TaskMetadata metadata: This contains information needed at runtime to determine behavior such as
whether or not outputs are discoverable, timeouts, and retries.
:param flytekit.common.interface.TypedInterface interface: The interface definition for this task.
:param dict[Text, T] custom: Arbitrary type for use by plugins.
:param Container container: Provides the necessary entrypoint information for execution. For instance,
a Container might be specified with the necessary command line arguments.
"""
super(SdkTask, self).__init__(
_identifier.Identifier(