Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
literal_map = _type_helpers.pack_python_std_map_to_literal_map(
{'a': 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs))
input_file = os.path.join(input_dir.name, "inputs.pb")
_utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file)
with _utils.AutoDeletingTempDir("out") as output_dir:
_execute_task(
_task_defs.add_one.task_module,
_task_defs.add_one.task_function_name,
input_file,
output_dir.name,
False
)
p = _utils.load_proto_from_file(
_literals_pb2.LiteralMap,
os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME)
)
raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std(
_literal_models.LiteralMap.from_flyte_idl(p),
_type_map_from_variable_map(_task_defs.add_one.interface.outputs)
)
assert raw_map['b'] == 10
assert len(raw_map) == 1
orig_env_array_index = os.environ.get('AWS_BATCH_JOB_ARRAY_INDEX')
os.environ['BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = 'AWS_BATCH_JOB_ARRAY_INDEX'
os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = '0'
_execute_task(
_task_defs.add_one.task_module,
_task_defs.add_one.task_function_name,
dir.name,
dir.name,
False
)
raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std(
_literal_models.LiteralMap.from_flyte_idl(
_utils.load_proto_from_file(
_literals_pb2.LiteralMap,
os.path.join(input_dir, _constants.OUTPUT_FILE_NAME)
)
),
_type_map_from_variable_map(_task_defs.add_one.interface.outputs)
)
assert raw_map['b'] == 10
assert len(raw_map) == 1
# reset the env vars
if orig_env_index_var_name:
os.environ['BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = orig_env_index_var_name
if orig_env_array_index:
os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = orig_env_array_index
def demo_task_for_promote(wf_params, a, b, c):
b.set(a + 1)
c.set(a + 2)
@workflow_class()
class OneTaskWFForPromote(object):
wf_input = Input(Types.Integer, required=True)
my_task_node = demo_task_for_promote(a=wf_input)
wf_output_b = Output(my_task_node.outputs.b, sdk_type=Integer)
wf_output_c = Output(my_task_node.outputs.c, sdk_type=Integer)
:rtype: flytekit.models.core.workflow.WorkflowTemplate
"""
workflow_template_pb = _workflow_pb2.WorkflowTemplate()
# So that tests that use this work when run from any directory
basepath = _path.dirname(__file__)
filepath = _path.abspath(_path.join(basepath, "resources/protos", "OneTaskWFForPromote.pb"))
with open(filepath, "rb") as fh:
workflow_template_pb.ParseFromString(fh.read())
wt = _workflow_model.WorkflowTemplate.from_flyte_idl(workflow_template_pb)
return wt
def test_task_serialization():
t = get_sample_task()
with TemporaryConfiguration(
_os.path.join(_os.path.dirname(_os.path.realpath(__file__)), '../../../common/configs/local.config'),
internal_overrides={
'image': 'myflyteimage:v123',
'project': 'myflyteproject',
'domain': 'development'
}
):
s = t.serialize()
assert isinstance(s, _admin_task_pb2.TaskSpec)
assert s.template.id.name == 'tests.flytekit.unit.common_tests.tasks.test_task.my_task'
assert s.template.container.image == 'myflyteimage:v123'
n6 = my_list_task(a=n5.outputs.b)
nodes = [n1, n2, n3, n4, n5, n6]
wf_out = [
workflow.Output(
'nested_out',
[n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]],
sdk_type=[[primitives.Integer]]
),
workflow.Output('scalar_out', n1.outputs.b, sdk_type=primitives.Integer)
]
w = workflow.SdkWorkflow(inputs=input_list, outputs=wf_out, nodes=nodes)
serialized = w.serialize()
assert isinstance(serialized, _workflow_pb2.WorkflowSpec)
assert len(serialized.template.nodes) == 6
assert len(serialized.template.interface.inputs.variables.keys()) == 2
assert len(serialized.template.interface.outputs.variables.keys()) == 2
def get_compiled_workflow_closure():
"""
:rtype: flytekit.models.core.compiler.CompiledWorkflowClosure
"""
cwc_pb = _compiler_pb2.CompiledWorkflowClosure()
# So that tests that use this work when run from any directory
basepath = _path.dirname(__file__)
filepath = _path.abspath(_path.join(basepath, "resources/protos", "CompiledWorkflowClosure.pb"))
with open(filepath, "rb") as fh:
cwc_pb.ParseFromString(fh.read())
return _compiler_model.CompiledWorkflowClosure.from_flyte_idl(cwc_pb)
assert simple_training_job_task.interface.inputs['train'].description == ''
assert simple_training_job_task.interface.inputs['train'].type == \
_sdk_types.Types.MultiPartCSV.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['validation'].description == ''
assert simple_training_job_task.interface.inputs['validation'].type == \
_sdk_types.Types.MultiPartCSV.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['static_hyperparameters'].description == ''
assert simple_training_job_task.interface.inputs['static_hyperparameters'].type == \
_sdk_types.Types.Generic.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['stopping_condition'].type == \
_sdk_types.Types.Proto(_pb2_StoppingCondition).to_flyte_literal_type()
# Checking if the hpo-specific input is defined
assert simple_xgboost_hpo_job_task.interface.inputs['hpo_job_config'].description == ''
assert simple_xgboost_hpo_job_task.interface.inputs['hpo_job_config'].type == \
_sdk_types.Types.Proto(_pb2_HPOJobConfig).to_flyte_literal_type()
assert simple_xgboost_hpo_job_task.interface.outputs['model'].description == ''
assert simple_xgboost_hpo_job_task.interface.outputs['model'].type == \
_sdk_types.Types.Blob.to_flyte_literal_type()
assert simple_xgboost_hpo_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_HPO_JOB_TASK
# Checking if the spec of the TrainingJob is embedded into the custom field of this SdkSimpleHPOJobTask
assert simple_xgboost_hpo_job_task.to_flyte_idl().custom["trainingJob"] == \
simple_training_job_task.to_flyte_idl().custom
assert simple_xgboost_hpo_job_task.metadata.timeout == _datetime.timedelta(seconds=0)
assert simple_xgboost_hpo_job_task.metadata.discoverable is True
assert simple_xgboost_hpo_job_task.metadata.discovery_version == '1'
assert simple_xgboost_hpo_job_task.metadata.retries.retries == 2
"""
assert simple_xgboost_hpo_job_task.task_module == __name__
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from flyteidl.core import literals_pb2 as flyteidl_dot_core_dot_literals__pb2
DESCRIPTOR = _descriptor.FileDescriptor(
name='datacatalog/service.proto',
package='datacatalog',
syntax='proto3',
serialized_options=None,
serialized_pb=_b('\n\x19\x64\x61tacatalog/service.proto\x12\x0b\x64\x61tacatalog\x1a\x1c\x66lyteidl/core/literals.proto\"=\n\x14\x43reateDatasetRequest\x12%\n\x07\x64\x61taset\x18\x01 \x01(\x0b\x32\x14.datacatalog.Dataset\"\x17\n\x15\x43reateDatasetResponse\"<\n\x11GetDatasetRequest\x12\'\n\x07\x64\x61taset\x18\x01 \x01(\x0b\x32\x16.datacatalog.DatasetID\";\n\x12GetDatasetResponse\x12%\n\x07\x64\x61taset\x18\x01 \x01(\x0b\x32\x14.datacatalog.Dataset\"x\n\x12GetArtifactRequest\x12\'\n\x07\x64\x61taset\x18\x01 \x01(\x0b\x32\x16.datacatalog.DatasetID\x12\x15\n\x0b\x61rtifact_id\x18\x02 \x01(\tH\x00\x12\x12\n\x08tag_name\x18\x03 \x01(\tH\x00\x42\x0e\n\x0cquery_handle\">\n\x13GetArtifactResponse\x12\'\n\x08\x61rtifact\x18\x01 \x01(\x0b\x32\x15.datacatalog.Artifact\"@\n\x15\x43reateArtifactRequest\x12\'\n\x08\x61rtifact\x18\x01 \x01(\x0b\x32\x15.datacatalog.Artifact\"\x18\n\x16\x43reateArtifactResponse\".\n\rAddTagRequest\x12\x1d\n\x03tag\x18\x01 \x01(\x0b\x32\x10.datacatalog.Tag\"\x10\n\x0e\x41\x64\x64TagResponse\"\xa2\x01\n\x14ListArtifactsRequest\x12\'\n\x07\x64\x61taset\x18\x01 \x01(\x0b\x32\x16.datacatalog.DatasetID\x12-\n\x06\x66ilter\x18\x02 \x01(\x0b\x32\x1d.datacatalog.FilterExpression\x12\x32\n\npagination\x18\x03 \x01(\x0b\x32\x1e.datacatalog.PaginationOptions\"U\n\x15ListArtifactsResponse\x12(\n\tartifacts\x18\x01 \x03(\x0b\x32\x15.datacatalog.Artifact\x12\x12\n\nnext_token\x18\x02 \x01(\t\"x\n\x13ListDatasetsRequest\x12-\n\x06\x66ilter\x18\x01 \x01(\x0b\x32\x1d.datacatalog.FilterExpression\x12\x32\n\npagination\x18\x02 \x01(\x0b\x32\x1e.datacatalog.PaginationOptions\"R\n\x14ListDatasetsResponse\x12&\n\x08\x64\x61tasets\x18\x01 \x03(\x0b\x32\x14.datacatalog.Dataset\x12\x12\n\nnext_token\x18\x02 \x01(\t\"m\n\x07\x44\x61taset\x12\"\n\x02id\x18\x01 \x01(\x0b\x32\x16.datacatalog.DatasetID\x12\'\n\x08metadata\x18\x02 \x01(\x0b\x32\x15.datacatalog.Metadata\x12\x15\n\rpartitionKeys\x18\x03 \x03(\t\"\'\n\tPartition\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"Y\n\tDatasetID\x12\x0f\n\x07project\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0e\n\x06\x64omain\x18\x03 \x01(\t\x12\x0f\n\x07version\x18\x04 \x01(\t\x12\x0c\n\x04UUID\x18\x05 \x01(\t\"\xdd\x01\n\x08\x41rtifact\x12\n\n\x02id\x18\x01 \x01(\t\x12\'\n\x07\x64\x61taset\x18\x02 \x01(\x0b\x32\x16.datacatalog.DatasetID\x12\'\n\x04\x64\x61ta\x18\x03 \x03(\x0b\x32\x19.datacatalog.ArtifactData\x12\'\n\x08metadata\x18\x04 \x01(\x0b\x32\x15.datacatalog.Metadata\x12*\n\npartitions\x18\x05 \x03(\x0b\x32\x16.datacatalog.Partition\x12\x1e\n\x04tags\x18\x06 \x03(\x0b\x32\x10.datacatalog.Tag\"C\n\x0c\x41rtifactData\x12\x0c\n\x04name\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.flyteidl.core.Literal\"Q\n\x03Tag\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x61rtifact_id\x18\x02 \x01(\t\x12\'\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x16.datacatalog.DatasetID\"m\n\x08Metadata\x12\x32\n\x07key_map\x18\x01 \x03(\x0b\x32!.datacatalog.Metadata.KeyMapEntry\x1a-\n\x0bKeyMapEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"F\n\x10\x46ilterExpression\x12\x32\n\x07\x66ilters\x18\x01 \x03(\x0b\x32!.datacatalog.SinglePropertyFilter\"\x89\x03\n\x14SinglePropertyFilter\x12\x34\n\ntag_filter\x18\x01 \x01(\x0b\x32\x1e.datacatalog.TagPropertyFilterH\x00\x12@\n\x10partition_filter\x18\x02 \x01(\x0b\x32$.datacatalog.PartitionPropertyFilterH\x00\x12>\n\x0f\x61rtifact_filter\x18\x03 \x01(\x0b\x32#.datacatalog.ArtifactPropertyFilterH\x00\x12<\n\x0e\x64\x61taset_filter\x18\x04 \x01(\x0b\x32\".datacatalog.DatasetPropertyFilterH\x00\x12\x46\n\x08operator\x18\n \x01(\x0e\x32\x34.datacatalog.SinglePropertyFilter.ComparisonOperator\" \n\x12\x43omparisonOperator\x12\n\n\x06\x45QUALS\x10\x00\x42\x11\n\x0fproperty_filter\";\n\x16\x41rtifactPropertyFilter\x12\x15\n\x0b\x61rtifact_id\x18\x01 \x01(\tH\x00\x42\n\n\x08property\"3\n\x11TagPropertyFilter\x12\x12\n\x08tag_name\x18\x01 \x01(\tH\x00\x42\n\n\x08property\"S\n\x17PartitionPropertyFilter\x12,\n\x07key_val\x18\x01 \x01(\x0b\x32\x19.datacatalog.KeyValuePairH\x00\x42\n\n\x08property\"*\n\x0cKeyValuePair\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"k\n\x15\x44\x61tasetPropertyFilter\x12\x11\n\x07project\x18\x01 \x01(\tH\x00\x12\x0e\n\x04name\x18\x02 \x01(\tH\x00\x12\x10\n\x06\x64omain\x18\x03 \x01(\tH\x00\x12\x11\n\x07version\x18\x04 \x01(\tH\x00\x42\n\n\x08property\"\xf1\x01\n\x11PaginationOptions\x12\r\n\x05limit\x18\x01 \x01(\r\x12\r\n\x05token\x18\x02 \x01(\t\x12\x37\n\x07sortKey\x18\x03 \x01(\x0e\x32&.datacatalog.PaginationOptions.SortKey\x12;\n\tsortOrder\x18\x04 \x01(\x0e\x32(.datacatalog.PaginationOptions.SortOrder\"*\n\tSortOrder\x12\x0e\n\nDESCENDING\x10\x00\x12\r\n\tASCENDING\x10\x01\"\x1c\n\x07SortKey\x12\x11\n\rCREATION_TIME\x10\x00\x32\xd1\x04\n\x0b\x44\x61taCatalog\x12V\n\rCreateDataset\x12!.datacatalog.CreateDatasetRequest\x1a\".datacatalog.CreateDatasetResponse\x12M\n\nGetDataset\x12\x1e.datacatalog.GetDatasetRequest\x1a\x1f.datacatalog.GetDatasetResponse\x12Y\n\x0e\x43reateArtifact\x12\".datacatalog.CreateArtifactRequest\x1a#.datacatalog.CreateArtifactResponse\x12P\n\x0bGetArtifact\x12\x1f.datacatalog.GetArtifactRequest\x1a .datacatalog.GetArtifactResponse\x12\x41\n\x06\x41\x64\x64Tag\x12\x1a.datacatalog.AddTagRequest\x1a\x1b.datacatalog.AddTagResponse\x12V\n\rListArtifacts\x12!.datacatalog.ListArtifactsRequest\x1a\".datacatalog.ListArtifactsResponse\x12S\n\x0cListDatasets\x12 .datacatalog.ListDatasetsRequest\x1a!.datacatalog.ListDatasetsResponseb\x06proto3')
,
dependencies=[flyteidl_dot_core_dot_literals__pb2.DESCRIPTOR,])
_SINGLEPROPERTYFILTER_COMPARISONOPERATOR = _descriptor.EnumDescriptor(
name='ComparisonOperator',
full_name='datacatalog.SinglePropertyFilter.ComparisonOperator',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='EQUALS', index=0, number=0,
serialized_options=None,
type=None),
],
containing_type=None,
serialized_options=None,
def get_outputs(self):
"""
:rtype: flytekit.models.literals.LiteralMap
"""
client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client
url_blob = client.get_task_execution_data(self.sdk_task_execution.id)
if url_blob.outputs.bytes > 0:
with _common_utils.AutoDeletingTempDir() as t:
tmp_name = _os.path.join(t.name, "outputs.pb")
_data_proxy.Data.get_data(url_blob.outputs.url, tmp_name)
return _literals.LiteralMap.from_flyte_idl(
_common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name)
)
return _literals.LiteralMap({})
def to_flyte_idl(self):
"""
:rtype: flyteidl.core.tasks_pb2.Container
"""
return _core_task.Container(
image=self.image,
command=self.command,
args=self.args,
resources=self.resources.to_flyte_idl(),
env=[_literals_pb2.KeyValuePair(key=k, value=v) for k, v in _six.iteritems(self.env)],
config=[_literals_pb2.KeyValuePair(key=k, value=v) for k, v in _six.iteritems(self.config)]
)