Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
internal_overrides={
'project': 'test',
'domain': 'development'
}):
with _utils.AutoDeletingTempDir("dir") as dir:
literal_map = _type_helpers.pack_python_std_map_to_literal_map(
{'a': 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs))
input_dir = os.path.join(dir.name, "1")
os.mkdir(input_dir) # auto cleanup will take this subdir into account
input_file = os.path.join(input_dir, "inputs.pb")
_utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file)
# construct indexlookup.pb which has array: [1]
mapped_index = _literals.Literal(_literals.Scalar(primitive=_literals.Primitive(integer=1)))
index_lookup_collection = _literals.LiteralCollection([mapped_index])
index_lookup_file = os.path.join(dir.name, "indexlookup.pb")
_utils.write_proto_to_file(index_lookup_collection.to_flyte_idl(), index_lookup_file)
# fake arrayjob task by setting environment variables
orig_env_index_var_name = os.environ.get('BATCH_JOB_ARRAY_INDEX_VAR_NAME')
orig_env_array_index = os.environ.get('AWS_BATCH_JOB_ARRAY_INDEX')
os.environ['BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = 'AWS_BATCH_JOB_ARRAY_INDEX'
os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = '0'
_execute_task(
_task_defs.add_one.task_module,
_task_defs.add_one.task_function_name,
dir.name,
dir.name,
False
from flytekit.engines.flyte import engine
from flytekit.models import literals, execution as _execution_models, common as _common_models, launch_plan as \
_launch_plan_models, task as _task_models
from flytekit.models.admin import common as _common
from flytekit.models.core import errors, identifier
from flytekit.sdk import test_utils
_INPUT_MAP = literals.LiteralMap(
{
'a': literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1)))
}
)
_OUTPUT_MAP = literals.LiteralMap(
{
'b': literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=2)))
}
)
@pytest.fixture(scope="function", autouse=True)
def temp_config():
with TemporaryConfiguration(
os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../../common/configs/local.config'),
internal_overrides={
'image': 'myflyteimage:{}'.format(
os.environ.get('IMAGE_VERSION', 'sha')
),
'project': 'myflyteproject',
'domain': 'development'
}
):
id = id if id is not None else _identifier.Identifier(
_identifier_model.ResourceType.WORKFLOW,
_internal_config.PROJECT.get(),
_internal_config.DOMAIN.get(),
_uuid.uuid4().hex,
_internal_config.VERSION.get()
)
metadata = metadata if metadata is not None else _workflow_models.WorkflowMetadata()
interface = interface if interface is not None else _interface.TypedInterface(
{v.name: v.var for v in inputs},
{v.name: v.var for v in outputs}
)
output_bindings = output_bindings if output_bindings is not None else \
[_literal_models.Binding(v.name, v.binding_data) for v in outputs]
super(SdkWorkflow, self).__init__(
id=id,
metadata=metadata,
metadata_defaults=_workflow_models.WorkflowMetadataDefaults(),
interface=interface,
nodes=nodes,
outputs=output_bindings,
)
self._user_inputs = inputs
self._upstream_entities = set(n.executable_sdk_object for n in nodes)
# Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training
# job gracefully
timeout = _datetime.timedelta(seconds=0)
super(SdkSimpleTrainingJobTask, self).__init__(
type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK,
metadata=_task_models.TaskMetadata(
runtime=_task_models.RuntimeMetadata(
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
version=__version__,
flavor='sagemaker'
),
discoverable=cacheable,
timeout=timeout,
retries=_literal_models.RetryStrategy(retries=retries),
interruptible=interruptible,
discovery_version=cache_version,
deprecated_error_message="",
),
interface=_interface.TypedInterface(
inputs={
"static_hyperparameters": _interface_model.Variable(
type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT),
description="",
),
"train": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="csv",
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
),
self._data_loading_config = _task_models.DataLoadingConfig(
input_path=input_data_dir,
output_path=output_data_dir,
format=metadata_format,
enabled=True,
io_strategy=io_strategy,
)
metadata = _task_models.TaskMetadata(
discoverable,
# This needs to have the proper version reflected in it
_task_models.RuntimeMetadata(
_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__,
"python"),
timeout or _datetime.timedelta(seconds=0),
_literals.RetryStrategy(retries),
interruptible,
discovery_version,
None
)
# The interface is defined using the inputs and outputs
i = _interface.TypedInterface(inputs=types_to_variable(inputs), outputs=types_to_variable(outputs))
# This sets the base SDKTask with container etc
super(SdkRawContainerTask, self).__init__(
_constants.SdkTaskType.RAW_CONTAINER_TASK,
metadata,
i,
None,
container=_get_container_definition(
image=image,
)
extra_inputs = set(binding_data.keys()) ^ set(map_of_bindings.keys())
if len(extra_inputs) > 0:
raise _user_exceptions.FlyteAssertion(
"Too many inputs were specified for the interface. Extra inputs were: {}".format(extra_inputs)
)
seen_nodes = set()
min_upstream = list()
for n in all_upstream_nodes:
if n not in seen_nodes:
seen_nodes.add(n)
min_upstream.append(n)
return [_literal_models.Binding(k, bd) for k, bd in _six.iteritems(binding_data)], min_upstream
[DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for sub_binding_data in
binding_data.collection.bindings]))
elif binding_data.promise:
if binding_data.promise.node_id not in fulfilled_promises:
raise _system_exception.FlyteSystemAssertion(
"Expecting output of node [{}] but that hasn't been produced.".format(binding_data.promise.node_id))
node_output = fulfilled_promises[binding_data.promise.node_id]
if binding_data.promise.var not in node_output:
raise _system_exception.FlyteSystemAssertion(
"Expecting output [{}] of node [{}] but that hasn't been produced.".format(
binding_data.promise.var,
binding_data.promise.node_id))
return binding_data.promise.sdk_type.from_python_std(node_output[binding_data.promise.var])
elif binding_data.map:
return _literals.Literal(map=_literals.LiteralMap(
{
k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for k, sub_binding_data in
_six.iteritems(binding_data.map.bindings)
}))
from __future__ import absolute_import
import six as _six
from flytekit.common import sdk_bases as _sdk_bases, promise as _promise
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.types import helpers as _type_helpers, containers as _containers, primitives as _primitives
from flytekit.models import interface as _interface_models, literals as _literal_models
class BindingData(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _literal_models.BindingData)):
@staticmethod
def _has_sub_bindings(m):
"""
:param dict[Text,T] or list[T]:
:rtype: bool
"""
for v in _six.itervalues(m) if isinstance(m, dict) else m:
if isinstance(v, (list, dict)) and BindingData._has_sub_bindings(v):
return True
elif isinstance(v, (_promise.Input, _promise.NodeOutput)):
return True
return False
@classmethod
def promote_from_model(cls, model):