Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_sdk_launch_plan_node():
@_tasks.inputs(a=_types.Types.Integer)
@_tasks.outputs(b=_types.Types.Integer)
@_tasks.python_task()
def testy_test(wf_params, a, b):
pass
@_workflow.workflow_class
class test_workflow(object):
a = _workflow.Input(_types.Types.Integer)
test = testy_test(a=1)
b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer)
lp = test_workflow.create_launch_plan()
lp._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'name', 'version')
n = _component_nodes.SdkWorkflowNode(sdk_launch_plan=lp)
assert n.launchplan_ref.project == 'project'
assert n.launchplan_ref.domain == 'domain'
assert n.launchplan_ref.name == 'name'
assert n.launchplan_ref.version == 'version'
# Test floating ID
lp._id = _identifier.Identifier(
_identifier.ResourceType.TASK,
'new_project',
@_sdk_tasks.inputs(a=_Types.Integer)
@_sdk_tasks.outputs(b=_Types.Integer, c=_Types.Integer)
@_sdk_tasks.python_task()
def demo_task_for_promote(wf_params, a, b, c):
b.set(a + 1)
c.set(a + 2)
from __future__ import absolute_import
from flytekit.sdk.tasks import inputs
from flytekit.sdk.types import Types
from flytekit.sdk.workflow import workflow_class, Input, Output
from flytekit.common.tasks.presto_task import SdkPrestoTask
schema = Types.Schema([("a", Types.String), ("b", Types.Integer)])
presto_task = SdkPrestoTask(
task_inputs=inputs(ds=Types.String, rg=Types.String),
statement="SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds}}' LIMIT 10",
output_schema=schema,
routing_group="{{ .Inputs.rg }}",
# catalog="hive",
# schema="city",
)
@workflow_class()
class PrestoWorkflow(object):
ds = Input(Types.String, required=True, help="Test string with no default")
# routing_group = Input(Types.String, required=True, help="Test string with no default")
@inputs(num=Types.Integer)
@outputs(out=Types.Integer)
@python_task
def inner_task(wf_params, num, out):
wf_params.logging.info("Running inner task... setting output to input")
out.set(num)
@outputs(out1=Types.String)
@sidecar_task(
cpu_request='10',
gpu_limit='2',
environment={"foo": "bar"},
pod_spec=get_pod_spec(),
primary_container_name="a container",
)
def simple_sidecar_task(wf_params, in1, out1):
pass
from __future__ import absolute_import
from flytekit.common import constants
from flytekit.common.tasks import sdk_runnable
from flytekit.common.types import helpers
from flytekit.models import interface
from flytekit.sdk import types
import six
GOOD_INPUTS = {
'a': types.Types.Integer,
'name': types.Types.String,
}
GOOD_OUTPUTS = {
'x': types.Types.Integer,
}
GOOD_NOTEBOOK = sdk_runnable.RunnableNotebookTask(
notebook_path="notebooks/good.ipynb",
inputs={
k: interface.Variable(
helpers.python_std_to_sdk_type(v).to_flyte_literal_type(),
''
)
for k, v in six.iteritems(GOOD_INPUTS)
},
outputs={
k: interface.Variable(
def test_raw_container_task_definition_no_outputs():
tk = SdkRawContainerTask(
inputs={"x": Types.Integer},
image="my-image",
command=["echo", "hello, world!"],
gpu_limit="1",
gpu_request="1",
)
assert not tk.serialize() is None
task_instance = tk(x=3)
assert task_instance.inputs[0].binding.scalar.primitive.integer == 3
@_tasks.inputs(num=_Types.Integer)
@_tasks.outputs(out=_Types.Integer)
@_tasks.python_task
def inner_task(wf_params, num, out):
wf_params.logging.info("Running inner task... setting output to input")
out.set(num)
def test_simple_training_job_task():
assert isinstance(simple_training_job_task, SdkSimpleTrainingJobTask)
assert isinstance(simple_training_job_task, _sdk_task.SdkTask)
assert simple_training_job_task.interface.inputs['train'].description == ''
assert simple_training_job_task.interface.inputs['train'].type == \
_sdk_types.Types.MultiPartCSV.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['validation'].description == ''
assert simple_training_job_task.interface.inputs['validation'].type == \
_sdk_types.Types.MultiPartCSV.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['static_hyperparameters'].description == ''
assert simple_training_job_task.interface.inputs['static_hyperparameters'].type == \
_sdk_types.Types.Generic.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['stopping_condition'].type == \
_sdk_types.Types.Proto(_pb2_StoppingCondition).to_flyte_literal_type()
assert simple_training_job_task.interface.outputs['model'].description == ''
assert simple_training_job_task.interface.outputs['model'].type == \
_sdk_types.Types.Blob.to_flyte_literal_type()
assert simple_training_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK
assert simple_training_job_task.metadata.timeout == _datetime.timedelta(seconds=0)
assert simple_training_job_task.metadata.deprecated_error_message == ''
assert simple_training_job_task.metadata.discoverable is False
assert simple_training_job_task.metadata.discovery_version == ''
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
),
),
description="",
),
"validation": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="csv",
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
),
),
description="",
),
"stopping_condition": _interface_model.Variable(
_sdk_types.Types.Proto(_training_job_pb2.StoppingCondition).to_flyte_literal_type(), ""
)
},
outputs={
"model": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="",
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
description=""
)
}
),
custom=MessageToDict(self._training_job_model.to_flyte_idl()),
)