Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sample] Lint and clean up parameterized TFX sample #2594

Merged
merged 2 commits into from
Nov 12, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions samples/contrib/parameterized_tfx_oss/parameterized_tfx_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import tensorflow as tf

from typing import Text

Expand All @@ -30,33 +28,35 @@
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer.component import Trainer
from tfx.components.transform.component import Transform
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.kubeflow import kubeflow_dag_runner
from tfx.proto import evaluator_pb2
from tfx.utils.dsl_utils import csv_input
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.extensions.google_cloud_ai_platform.trainer import executor as ai_platform_trainer_executor
from ml_metadata.proto import metadata_store_pb2
from tfx.orchestration.kubeflow.proto import kubeflow_pb2

# Define pipeline params used for pipeline execution.
# Path to the module file, should be a GCS path.
_taxi_module_file_param = dsl.PipelineParam(
name='module-file',
value='gs://ml-pipeline-playground/tfx_taxi_simple/modules/taxi_utils.py')
value='gs://ml-pipeline-playground/tfx_taxi_simple/modules/taxi_utils.py'
)

# Path to the CSV data file, under which their should be a data.csv file.
_data_root_param = dsl.PipelineParam(
name='data-root',
value='gs://ml-pipeline-playground/tfx_taxi_simple/data')
name='data-root', value='gs://ml-pipeline-playground/tfx_taxi_simple/data'
)

# Path of pipeline root, should be a GCS path.
pipeline_root = os.path.join('gs://your-bucket', 'tfx_taxi_simple', kfp.dsl.RUN_ID_PLACEHOLDER)
pipeline_root = os.path.join(
'gs://your-bucket', 'tfx_taxi_simple', kfp.dsl.RUN_ID_PLACEHOLDER
)

def _create_test_pipeline(pipeline_root: Text, csv_input_location: Text,
taxi_module_file: Text, enable_cache: bool):

def _create_test_pipeline(
pipeline_root: Text, csv_input_location: Text, taxi_module_file: Text,
enable_cache: bool
):
"""Creates a simple Kubeflow-based Chicago Taxi TFX pipeline.

Args:
Expand All @@ -74,29 +74,38 @@ def _create_test_pipeline(pipeline_root: Text, csv_input_location: Text,
example_gen = CsvExampleGen(input_base=examples)
statistics_gen = StatisticsGen(input_data=example_gen.outputs.examples)
infer_schema = SchemaGen(
stats=statistics_gen.outputs.output, infer_feature_shape=False)
stats=statistics_gen.outputs.output, infer_feature_shape=False
)
validate_stats = ExampleValidator(
stats=statistics_gen.outputs.output, schema=infer_schema.outputs.output)
stats=statistics_gen.outputs.output, schema=infer_schema.outputs.output
)
transform = Transform(
input_data=example_gen.outputs.examples,
schema=infer_schema.outputs.output,
module_file=taxi_module_file)
module_file=taxi_module_file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I usually add comma to the last line so that more lines can be added without changing existing lines.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Good idea.

)
trainer = Trainer(
module_file=taxi_module_file,
transformed_examples=transform.outputs.transformed_examples,
schema=infer_schema.outputs.output,
transform_output=transform.outputs.transform_output,
train_args=trainer_pb2.TrainArgs(num_steps=10),
eval_args=trainer_pb2.EvalArgs(num_steps=5))
eval_args=trainer_pb2.EvalArgs(num_steps=5)
)
model_analyzer = Evaluator(
examples=example_gen.outputs.examples,
model_exports=trainer.outputs.output,
feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[
evaluator_pb2.SingleSlicingSpec(
column_for_slicing=['trip_start_hour'])
]))
feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(
specs=[
evaluator_pb2.SingleSlicingSpec(
column_for_slicing=['trip_start_hour']
)
]
)
)
model_validator = ModelValidator(
examples=example_gen.outputs.examples, model=trainer.outputs.output)
examples=example_gen.outputs.examples, model=trainer.outputs.output
)

# Hack: ensuring push_destination can be correctly parameterized and interpreted.
# pipeline root will be specified as a dsl.PipelineParam with the name
Expand All @@ -108,14 +117,18 @@ def _create_test_pipeline(pipeline_root: Text, csv_input_location: Text,
model_blessing=model_validator.outputs.blessing,
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=os.path.join(str(_pipeline_root_param), 'model_serving'))))
base_directory=os.path.
join(str(_pipeline_root_param), 'model_serving')
)
)
)

return pipeline.Pipeline(
pipeline_name='parameterized_tfx_oss',
pipeline_root=pipeline_root,
components=[
example_gen, statistics_gen, infer_schema, validate_stats, transform,
trainer, model_analyzer, model_validator, pusher
example_gen, statistics_gen, infer_schema, validate_stats, transform,
trainer, model_analyzer, model_validator, pusher
],
enable_cache=enable_cache,
)
Expand All @@ -125,13 +138,16 @@ def _create_test_pipeline(pipeline_root: Text, csv_input_location: Text,

enable_cache = True
pipeline = _create_test_pipeline(
pipeline_root,
str(_data_root_param),
str(_taxi_module_file_param),
enable_cache=enable_cache)
pipeline_root,
str(_data_root_param),
str(_taxi_module_file_param),
enable_cache=enable_cache
)
config = kubeflow_dag_runner.KubeflowDagRunnerConfig(
kubeflow_metadata_config=kubeflow_dag_runner.get_default_kubeflow_metadata_config(),
tfx_image='tensorflow/tfx:0.16.0.dev20191101')
kubeflow_metadata_config=kubeflow_dag_runner.
get_default_kubeflow_metadata_config(),
tfx_image='tensorflow/tfx:0.16.0.dev20191101'
)
kfp_runner = kubeflow_dag_runner.KubeflowDagRunner(config=config)
# Make sure kfp_runner recognizes those parameters.
kfp_runner._params.extend([_data_root_param, _taxi_module_file_param])
Expand Down