Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _create_pipeline():
pipeline_name = _PIPELINE_NAME
test_output_dir = 'gs://{}/test_output'.format(test_utils.BUCKET_NAME)
pipeline_root = os.path.join(test_output_dir, pipeline_name)
components = test_utils.create_e2e_components(pipeline_root,
test_utils.DATA_ROOT,
test_utils.TAXI_MODULE_FILE)
return tfx_pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
components=components,
log_root='/var/tmp/tfx/logs',
additional_pipeline_args={
'WORKFLOW_ID': pipeline_name,
},
def _create_pipeline():
pipeline_name = _PIPELINE_NAME
test_output_dir = 'gs://{}/test_output'.format(test_utils.BUCKET_NAME)
pipeline_root = os.path.join(test_output_dir, pipeline_name)
components = test_utils.create_e2e_components(pipeline_root,
test_utils.DATA_ROOT,
test_utils.TAXI_MODULE_FILE)
return tfx_pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
components=components[:4],
log_root='/var/tmp/tfx/logs',
additional_pipeline_args={
'WORKFLOW_ID': pipeline_name,
},
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
metadata_path: Text) -> pipeline.Pipeline:
"""Implements the chicago taxi pipeline with TFX."""
examples = external_input(data_root)
# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input=examples)
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
# Generates schema based on statistics files.
infer_schema = SchemaGen(statistics=statistics_gen.outputs['statistics'])
return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[example_gen, statistics_gen, infer_schema],
enable_cache=True,
metadata_connection_config=metadata.sqlite_metadata_connection_config(
metadata_path),
additional_pipeline_args={},
)
def _create_pipeline():
"""Implements the chicago taxi pipeline with TFX."""
examples = csv_input(_data_root)
# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input=examples)
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
# Generates schema based on statistics files.
infer_schema = SchemaGen(statistics=statistics_gen.outputs['statistics'])
return pipeline.Pipeline(
pipeline_name='chicago_taxi_simple',
pipeline_root=_pipeline_root,
components=[
example_gen, statistics_gen, infer_schema
],
enable_cache=True,
metadata_db_root=_metadata_db_root,
)
]))
# Performs quality validation of a candidate model (compared to a baseline).
model_validator = ModelValidator(
examples=example_gen.outputs.examples, model=trainer.outputs.output)
# Checks whether the model passed the validation steps and pushes the model
# to a file destination if check passed.
pusher = Pusher(
model_export=trainer.outputs.output,
model_blessing=model_validator.outputs.blessing,
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=_serving_model_dir)))
return pipeline.Pipeline(
pipeline_name='taxi',
pipeline_root=_pipeline_root,
components=[
example_gen, statistics_gen, infer_schema, validate_stats, transform,
trainer, model_analyzer, model_validator, pusher
],
enable_cache=True,
metadata_db_root=_metadata_db_root,
additional_pipeline_args={'logger_args': logger_overrides},
)
]))
# Performs quality validation of a candidate model (compared to a baseline).
model_validator = ModelValidator(
examples=example_gen.outputs['examples'], model=trainer.outputs['output'])
# Checks whether the model passed the validation steps and pushes the model
# to a file destination if check passed.
pusher = Pusher(
model_export=trainer.outputs['output'],
model_blessing=model_validator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=serving_model_dir)))
return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[
example_gen, statistics_gen, infer_schema, validate_stats, transform,
trainer, model_analyzer, model_validator, pusher
],
enable_cache=True,
metadata_connection_config=metadata_connection_config)
]))
# Performs quality validation of a candidate model (compared to a baseline).
model_validator = ModelValidator(
examples=example_gen.outputs['examples'], model=trainer.outputs['model'])
# Checks whether the model passed the validation steps and pushes the model
# to Google Cloud AI Platform if check passed.
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=model_validator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=serving_model_dir)))
return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[
example_gen, statistics_gen, infer_schema, validate_stats, transform,
trainer, model_analyzer, model_validator, pusher
],
# TODO(b/141578059): The multi-processing API might change.
beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers],
)
]))
# Performs quality validation of a candidate model (compared to a baseline).
model_validator = ModelValidator(
examples=example_gen.outputs['examples'], model=trainer.outputs['model'])
# Checks whether the model passed the validation steps and pushes the model
# to a file destination if check passed.
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=model_validator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=serving_model_dir)))
return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[
example_gen, statistics_gen, user_schema_importer, infer_schema,
validate_stats, transform, trainer, model_analyzer, model_validator,
pusher
],
enable_cache=True,
metadata_connection_config=metadata.sqlite_metadata_connection_config(
metadata_path),
# TODO(b/141578059): The multi-processing API might change.
beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers])
]))
# Performs quality validation of a candidate model (compared to a baseline).
model_validator = ModelValidator(
examples=example_gen.outputs['examples'], model=trainer.outputs['model'])
# Checks whether the model passed the validation steps and pushes the model
# to Google Cloud AI Platform if check passed.
pusher = Pusher(
custom_executor_spec=executor_spec.ExecutorClassSpec(
ai_platform_pusher_executor.Executor),
model=trainer.outputs['model'],
model_blessing=model_validator.outputs['blessing'],
custom_config={'ai_platform_serving_args': ai_platform_serving_args})
return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[
example_gen, statistics_gen, infer_schema, validate_stats, transform,
trainer, model_analyzer, model_validator, pusher
],
additional_pipeline_args={
'beam_pipeline_args': beam_pipeline_args,
},