Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Components - TFX #2671

Merged
merged 26 commits into from
Dec 5, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1de4368
Added CsvExampleGen component
Ark-kun Oct 21, 2019
1751f99
Switched to using some processing code from the component class
Ark-kun Oct 31, 2019
0ccd7c3
Renamed output_examples to example_artifacts for consistency with the…
Ark-kun Oct 31, 2019
ace062e
Fixed the docstring a bit
Ark-kun Oct 31, 2019
8e30a62
Added StatisticsGen
Ark-kun Oct 31, 2019
b8fd5a7
Added SchemaGen
Ark-kun Oct 31, 2019
35ab2e0
Fixed the input_dict construction
Ark-kun Nov 1, 2019
fcef473
Use None defaults
Ark-kun Nov 1, 2019
8a1d1e5
Switched to TFX container image
Ark-kun Nov 1, 2019
d6e6b52
Updated component definitions
Ark-kun Nov 1, 2019
fa7374c
Fixed StatisticsGen and SchemaGen
Ark-kun Nov 2, 2019
a9e784e
Printing component instance in CsvExampleGen
Ark-kun Nov 2, 2019
3a1159a
Moved components to directories
Ark-kun Nov 2, 2019
5645997
Updated the sample TFX pipeline
Ark-kun Nov 2, 2019
eb8f281
Renamed ExamplesPath to Examples for data passing components
Ark-kun Nov 7, 2019
f84d7c9
Corrected output_component_file paths
Ark-kun Nov 7, 2019
1cc4a0f
Added the Transform component
Ark-kun Nov 7, 2019
7cc3350
Added the Trainer component
Ark-kun Nov 7, 2019
9f5fe9c
Added the BigQueryExampleGen component
Ark-kun Nov 7, 2019
91ec94d
Added the ImportExampleGen component
Ark-kun Nov 7, 2019
4892bbf
Added the Evaluator component
Ark-kun Nov 7, 2019
bda6978
Added the ExampleValidator component
Ark-kun Nov 7, 2019
093e2d2
Updated the sample
Ark-kun Nov 26, 2019
034cd32
Upgraded to TFX 0.15.0
Ark-kun Nov 28, 2019
8fec5f1
Upgraded the sample to 0.15.0
Ark-kun Nov 28, 2019
bbf11e3
Silence Flake8 for annotations
Ark-kun Nov 28, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added the Trainer component
  • Loading branch information
Ark-kun committed Nov 8, 2019
commit 7cc33500eb9e4b37a6b1f24d4e845cef91703d89
171 changes: 171 additions & 0 deletions components/tfx/Trainer/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from kfp.components import InputPath, OutputPath


def Trainer(
examples_path: InputPath('Examples'),
transform_output_path: InputPath('TransformGraph'), # ? = None
#transform_graph_path: InputPath('TransformGraph'),
schema_path: InputPath('Schema'),

output_path: OutputPath('Model'),

module_file: str = None,
trainer_fn: str = None,
train_args: 'JsonObject: tfx.proto.trainer_pb2.TrainArgs' = None,
eval_args: 'JsonObject: tfx.proto.trainer_pb2.EvalArgs' = None,
#custom_config: dict = None,
#custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
):
"""
A TFX component to train a TensorFlow model.

The Trainer component is used to train and eval a model using given inputs and
a user-supplied estimator. This component includes a custom driver to
optionally grab previous model to warm start from.

## Providing an estimator
The TFX executor will use the estimator provided in the `module_file` file
to train the model. The Trainer executor will look specifically for the
`trainer_fn()` function within that file. Before training, the executor will
call that function expecting the following returned as a dictionary:

- estimator: The
[estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)
to be used by TensorFlow to train the model.
- train_spec: The
[configuration](https://www.tensorflow.org/api_docs/python/tf/estimator/TrainSpec)
to be used by the "train" part of the TensorFlow `train_and_evaluate()`
call.
- eval_spec: The
[configuration](https://www.tensorflow.org/api_docs/python/tf/estimator/EvalSpec)
to be used by the "eval" part of the TensorFlow `train_and_evaluate()` call.
- eval_input_receiver_fn: The
[configuration](https://www.tensorflow.org/tfx/model_analysis/get_started#modify_an_existing_model)
to be used
by the [ModelValidator](https://www.tensorflow.org/tfx/guide/modelval)
component when validating the model.

An example of `trainer_fn()` can be found in the [user-supplied
code]((https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_utils.py))
of the TFX Chicago Taxi pipeline example.


Args:
examples: A Channel of 'ExamplesPath' type, serving as the source of
examples that are used in training (required). May be raw or
transformed.
transform_output: An optional Channel of 'TransformPath' type, serving as
the input transform graph if present.
#transform_graph: Forwards compatibility alias for the 'transform_output'
# argument.
schema: A Channel of 'SchemaPath' type, serving as the schema of training
and eval data.
module_file: A path to python module file containing UDF model definition.
The module_file must implement a function named `trainer_fn` at its
top level. The function must have the following signature.

def trainer_fn(tf.contrib.training.HParams,
tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
...

where the returned Dict has the following key-values.
'estimator': an instance of tf.estimator.Estimator
'train_spec': an instance of tf.estimator.TrainSpec
'eval_spec': an instance of tf.estimator.EvalSpec
'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver

Exactly one of 'module_file' or 'trainer_fn' must be supplied.
trainer_fn: A python path to UDF model definition function. See
'module_file' for the required signature of the UDF.
Exactly one of 'module_file' or 'trainer_fn' must be supplied.
train_args: A trainer_pb2.TrainArgs instance, containing args used for
training. Current only num_steps is available.
eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval.
Current only num_steps is available.
#custom_config: A dict which contains the training job parameters to be
# passed to Google Cloud ML Engine. For the full set of parameters
# supported by Google Cloud ML Engine, refer to
# https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
#custom_executor_spec: Optional custom executor spec.
Returns:
output: Optional 'ModelExportPath' channel for result of exported models.
Raises:
ValueError:
- When both or neither of 'module_file' and 'trainer_fn' is supplied.
- When both or neither of 'examples' and 'transformed_examples'
is supplied.
- When 'transformed_examples' is supplied but 'transform_output'
is not supplied.
"""
from tfx.components.trainer.component import Trainer
component_class = Trainer
input_channels_with_splits = {'examples'}
output_channels_with_splits = {}


import json
import os
from google.protobuf import json_format, message
from tfx.types import Artifact, channel_utils

arguments = locals().copy()

component_class_args = {}

for name, execution_parameter in component_class.SPEC_CLASS.PARAMETERS.items():
argument_value_obj = argument_value = arguments.get(name, None)
if argument_value is None:
continue
parameter_type = execution_parameter.type
if isinstance(parameter_type, type) and issubclass(parameter_type, message.Message): # execution_parameter.type can also be a tuple
argument_value_obj = parameter_type()
json_format.Parse(argument_value, argument_value_obj)
component_class_args[name] = argument_value_obj

for name, channel_parameter in component_class.SPEC_CLASS.INPUTS.items():
artifact_path = arguments[name + '_path']
artifacts = []
if name in input_channels_with_splits:
# Recovering splits
splits = sorted(os.listdir(artifact_path))
for split in splits:
artifact = Artifact(type_name=channel_parameter.type_name)
artifact.split = split
artifact.uri = os.path.join(artifact_path, split) + '/'
artifacts.append(artifact)
else:
artifact = Artifact(type_name=channel_parameter.type_name)
artifact.uri = artifact_path + '/' # ?
artifacts.append(artifact)
component_class_args[name] = channel_utils.as_channel(artifacts)

component_class_instance = component_class(**component_class_args)

input_dict = {name: channel.get() for name, channel in component_class_instance.inputs.get_all().items()}
output_dict = {name: channel.get() for name, channel in component_class_instance.outputs.get_all().items()}
exec_properties = component_class_instance.exec_properties

# Generating paths for output artifacts
for name, artifacts in output_dict.items():
base_artifact_path = arguments[name + '_path']
for artifact in artifacts:
artifact.uri = os.path.join(base_artifact_path, artifact.split) # Default split is ''

print('component instance: ' + str(component_class_instance))

#executor = component_class.EXECUTOR_SPEC.executor_class() # Same
executor = component_class_instance.executor_spec.executor_class()
executor.Do(
input_dict=input_dict,
output_dict=output_dict,
exec_properties=exec_properties,
)


if __name__ == '__main__':
import kfp
kfp.components.func_to_container_op(
Trainer,
base_image='tensorflow/tfx:0.15.0rc0',
output_component_file='component.yaml'
)
Loading