Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_sdk_launch_plan_node():
@_tasks.inputs(a=_types.Types.Integer)
@_tasks.outputs(b=_types.Types.Integer)
@_tasks.python_task()
def testy_test(wf_params, a, b):
pass
@_workflow.workflow_class
class test_workflow(object):
a = _workflow.Input(_types.Types.Integer)
test = testy_test(a=1)
b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer)
lp = test_workflow.create_launch_plan()
lp._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'name', 'version')
n = _component_nodes.SdkWorkflowNode(sdk_launch_plan=lp)
assert n.launchplan_ref.project == 'project'
assert n.launchplan_ref.domain == 'domain'
assert n.launchplan_ref.name == 'name'
assert n.launchplan_ref.version == 'version'
# Test floating ID
lp._id = _identifier.Identifier(
_identifier.ResourceType.TASK,
'new_project',
'new_domain',
'new_name',
'new_version'
)
assert n.launchplan_ref.project == 'new_project'
assert n.launchplan_ref.domain == 'new_domain'
{},
{}
)
)
task_node = _workflow.TaskNode(task.id)
node = _workflow.Node(
id='my_node',
metadata=node_metadata,
inputs=[b0],
upstream_node_ids=[],
output_aliases=[],
task_node=task_node)
template = _workflow.WorkflowTemplate(
id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version"),
metadata=_workflow.WorkflowMetadata(),
interface=typed_interface,
nodes=[node],
outputs=[b1, b2],
)
obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
assert len(obj.tasks) == 1
obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
assert obj == obj2
)
task_metadata = _task.TaskMetadata(
True,
_task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
_literals.RetryStrategy(3),
"0.1.1b0",
"This is deprecated!"
)
cpu_resource = _task.Resources.ResourceEntry(_task.Resources.ResourceName.CPU, "1")
resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])
task = _task.TaskTemplate(
_identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version"),
"python",
task_metadata,
typed_interface,
{'a': 1, 'b': {'c': 2, 'd': 3}},
container=_task.Container(
"my_image",
["this", "is", "a", "cmd"],
["this", "is", "an", "arg"],
resources,
{},
{}
)
)
task_node = _workflow.TaskNode(task.id)
node = _workflow.Node(
def test_serialize():
workflow_to_test = _workflow.workflow(
{},
inputs={
'required_input': _workflow.Input(_types.Types.Integer),
'default_input': _workflow.Input(_types.Types.Integer, default=5)
}
)
workflow_to_test._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v")
lp = workflow_to_test.create_launch_plan(
fixed_inputs={'required_input': 5},
role='iam_role',
)
with _configuration.TemporaryConfiguration(
_os.path.join(_os.path.dirname(_os.path.realpath(__file__)), '../../common/configs/local.config'),
internal_overrides={
'image': 'myflyteimage:v123',
'project': 'myflyteproject',
'domain': 'development'
}
):
s = lp.serialize()
assert s.workflow_id == _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v").to_flyte_idl()
assert s.auth_role.assumable_iam_role == 'iam_role'
def test_non_system_nodes():
@inputs(a=primitives.Integer)
@outputs(b=primitives.Integer)
@python_task()
def my_task(wf_params, a, b):
b.set(a + 1)
my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version')
required_input = promise.Input('required', primitives.Integer)
n1 = my_task(a=required_input).assign_id_and_return('n1')
n_start = nodes.SdkNode(
'start-node',
[],
[
_literals.Binding(
'a',
interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3)
)
],
None,
sdk_task=my_task,
def test_get_task_execution_inputs(mock_client_factory, execution_data_locations):
mock_client = MagicMock()
mock_client.get_task_execution_data = MagicMock(
return_value=_execution_models.TaskExecutionGetDataResponse(
execution_data_locations[0],
execution_data_locations[1]
)
)
mock_client_factory.return_value = mock_client
m = MagicMock()
type(m).id = PropertyMock(
return_value=identifier.TaskExecutionIdentifier(
identifier.Identifier(
identifier.ResourceType.TASK,
'project',
'domain',
'task-name',
'version'
),
identifier.NodeExecutionIdentifier(
"node-a",
identifier.WorkflowExecutionIdentifier(
"project",
"domain",
"name",
)
),
0
)
def test_workflow_node():
@inputs(a=primitives.Integer)
@outputs(b=primitives.Integer)
@python_task()
def my_task(wf_params, a, b):
b.set(a + 1)
my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version')
@inputs(a=[primitives.Integer])
@outputs(b=[primitives.Integer])
@python_task
def my_list_task(wf_params, a, b):
b.set([v + 1 for v in a])
my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_list_task',
'version')
input_list = [
promise.Input('required', primitives.Integer),
promise.Input('not_required', primitives.Integer, default=5, help='Not required.')
]
n1 = my_task(a=input_list[0]).assign_id_and_return('n1')
def from_flyte_idl(cls, proto):
"""
:param flyteidl.core.identifier_pb2.TaskExecutionIdentifier proto:
:rtype: TaskExecutionIdentifier
"""
return cls(
task_id=Identifier.from_flyte_idl(proto.task_id),
node_execution_id=NodeExecutionIdentifier.from_flyte_idl(proto.node_execution_id),
retry_attempt=proto.retry_attempt
)