Skip to content

Commit

Permalink
[Sample] Lint and clean up parameterized TFX sample (#2594)
Browse files Browse the repository at this point in the history
* lint

* add commas
  • Loading branch information
Jiaxiao Zheng authored and k8s-ci-robot committed Nov 12, 2019
1 parent d775470 commit fa6859d
Showing 1 changed file with 46 additions and 30 deletions.
76 changes: 46 additions & 30 deletions samples/contrib/parameterized_tfx_oss/parameterized_tfx_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import tensorflow as tf

from typing import Text

Expand All @@ -30,33 +28,35 @@
from tfx.components.statistics_gen.component import StatisticsGen
from tfx.components.trainer.component import Trainer
from tfx.components.transform.component import Transform
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.kubeflow import kubeflow_dag_runner
from tfx.proto import evaluator_pb2
from tfx.utils.dsl_utils import csv_input
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.extensions.google_cloud_ai_platform.trainer import executor as ai_platform_trainer_executor
from ml_metadata.proto import metadata_store_pb2
from tfx.orchestration.kubeflow.proto import kubeflow_pb2

# Define pipeline params used for pipeline execution.
# Path to the module file, should be a GCS path.
_taxi_module_file_param = dsl.PipelineParam(
name='module-file',
value='gs://ml-pipeline-playground/tfx_taxi_simple/modules/taxi_utils.py')
value='gs://ml-pipeline-playground/tfx_taxi_simple/modules/taxi_utils.py'
)

# Path to the CSV data file, under which their should be a data.csv file.
_data_root_param = dsl.PipelineParam(
name='data-root',
value='gs://ml-pipeline-playground/tfx_taxi_simple/data')
name='data-root', value='gs://ml-pipeline-playground/tfx_taxi_simple/data'
)

# Path of pipeline root, should be a GCS path.
pipeline_root = os.path.join('gs://your-bucket', 'tfx_taxi_simple', kfp.dsl.RUN_ID_PLACEHOLDER)
pipeline_root = os.path.join(
'gs://your-bucket', 'tfx_taxi_simple', kfp.dsl.RUN_ID_PLACEHOLDER
)

def _create_test_pipeline(pipeline_root: Text, csv_input_location: Text,
taxi_module_file: Text, enable_cache: bool):

def _create_test_pipeline(
pipeline_root: Text, csv_input_location: Text, taxi_module_file: Text,
enable_cache: bool
):
"""Creates a simple Kubeflow-based Chicago Taxi TFX pipeline.
Args:
Expand All @@ -74,29 +74,38 @@ def _create_test_pipeline(pipeline_root: Text, csv_input_location: Text,
example_gen = CsvExampleGen(input_base=examples)
statistics_gen = StatisticsGen(input_data=example_gen.outputs.examples)
infer_schema = SchemaGen(
stats=statistics_gen.outputs.output, infer_feature_shape=False)
stats=statistics_gen.outputs.output, infer_feature_shape=False,
)
validate_stats = ExampleValidator(
stats=statistics_gen.outputs.output, schema=infer_schema.outputs.output)
stats=statistics_gen.outputs.output, schema=infer_schema.outputs.output,
)
transform = Transform(
input_data=example_gen.outputs.examples,
schema=infer_schema.outputs.output,
module_file=taxi_module_file)
module_file=taxi_module_file,
)
trainer = Trainer(
module_file=taxi_module_file,
transformed_examples=transform.outputs.transformed_examples,
schema=infer_schema.outputs.output,
transform_output=transform.outputs.transform_output,
train_args=trainer_pb2.TrainArgs(num_steps=10),
eval_args=trainer_pb2.EvalArgs(num_steps=5))
eval_args=trainer_pb2.EvalArgs(num_steps=5),
)
model_analyzer = Evaluator(
examples=example_gen.outputs.examples,
model_exports=trainer.outputs.output,
feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[
evaluator_pb2.SingleSlicingSpec(
column_for_slicing=['trip_start_hour'])
]))
feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(
specs=[
evaluator_pb2.SingleSlicingSpec(
column_for_slicing=['trip_start_hour']
)
]
),
)
model_validator = ModelValidator(
examples=example_gen.outputs.examples, model=trainer.outputs.output)
examples=example_gen.outputs.examples, model=trainer.outputs.output
)

# Hack: ensuring push_destination can be correctly parameterized and interpreted.
# pipeline root will be specified as a dsl.PipelineParam with the name
Expand All @@ -108,14 +117,18 @@ def _create_test_pipeline(pipeline_root: Text, csv_input_location: Text,
model_blessing=model_validator.outputs.blessing,
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=os.path.join(str(_pipeline_root_param), 'model_serving'))))
base_directory=os.path.
join(str(_pipeline_root_param), 'model_serving')
)
),
)

return pipeline.Pipeline(
pipeline_name='parameterized_tfx_oss',
pipeline_root=pipeline_root,
components=[
example_gen, statistics_gen, infer_schema, validate_stats, transform,
trainer, model_analyzer, model_validator, pusher
example_gen, statistics_gen, infer_schema, validate_stats, transform,
trainer, model_analyzer, model_validator, pusher
],
enable_cache=enable_cache,
)
Expand All @@ -125,13 +138,16 @@ def _create_test_pipeline(pipeline_root: Text, csv_input_location: Text,

enable_cache = True
pipeline = _create_test_pipeline(
pipeline_root,
str(_data_root_param),
str(_taxi_module_file_param),
enable_cache=enable_cache)
pipeline_root,
str(_data_root_param),
str(_taxi_module_file_param),
enable_cache=enable_cache,
)
config = kubeflow_dag_runner.KubeflowDagRunnerConfig(
kubeflow_metadata_config=kubeflow_dag_runner.get_default_kubeflow_metadata_config(),
tfx_image='tensorflow/tfx:0.16.0.dev20191101')
kubeflow_metadata_config=kubeflow_dag_runner.
get_default_kubeflow_metadata_config(),
tfx_image='tensorflow/tfx:0.16.0.dev20191101',
)
kfp_runner = kubeflow_dag_runner.KubeflowDagRunner(config=config)
# Make sure kfp_runner recognizes those parameters.
kfp_runner._params.extend([_data_root_param, _taxi_module_file_param])
Expand Down

0 comments on commit fa6859d

Please sign in to comment.