-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
ARG BASE_IMAGE_TAG=1.12.0-py3 | ||
FROM tensorflow/tensorflow:$BASE_IMAGE_TAG | ||
RUN python3 -m pip install keras | ||
COPY ./src /pipelines/component/src | ||
ENTRYPOINT python3 /pipelines/component/src/train.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
## Train classification model using Keras ## | ||
|
||
Usage: | ||
|
||
''' | ||
#Load the component | ||
train_op = comp.load_component(url='https://raw.githubusercontent.com/Ark-kun/pipelines/Added-sample-component/components/sample/keras/train_classifier/component.yaml') | ||
|
||
#Use the component as part of the pipeline | ||
def pipeline(): | ||
train_task = train_op( | ||
training_set_features_path=os.path.join(testdata_root, 'training_set_features.tsv'), | ||
training_set_labels_path=os.path.join(testdata_root, 'training_set_labels.tsv'), | ||
output_model_uri=os.path.join(temp_dir_name, 'outputs/output_model/data'), | ||
model_config=Path(testdata_root).joinpath('model_config.json').read_text(), | ||
number_of_classes=2, | ||
number_of_epochs=10, | ||
batch_size=32, | ||
) | ||
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#!/bin/bash -e | ||
# Copyright 2018 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
image_name=gcr.io/ml-pipeline/sample/keras/train_classifier:latest | ||
image_tag=latest | ||
full_image_name=${image_name}:${image_tag} | ||
base_image_tag=1.12.0-py3 | ||
|
||
cd "$(dirname "$0")" | ||
|
||
docker build --build-arg BASE_IMAGE_TAG=$base_image_tag -t "$full_image_name" . | ||
docker push "$full_image_name" | ||
|
||
#Output the strict image name (which contains the sha256 image digest) | ||
#This name can be used by the subsequent steps to refer to the exact image that was built even if another image with the same name was pushed. | ||
image_name_with_digest=$(docker inspect --format="{{index .RepoDigests 0}}" "$IMAGE_NAME") | ||
strict_image_name_output_file=./versions/image_digests_for_tags/$image_tag | ||
mkdir -p "$(dirname "$strict_image_name_output_file")" | ||
echo $image_name_with_digest | tee "$strict_image_name_output_file" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
name: Keras - Train classifier | ||
description: Trains classifier using Keras sequential model | ||
inputs: | ||
- {name: Training set features path, type: {GcsUri: TSV}, description: 'Local or GCS path to the training set features table.'} | ||
- {name: Training set labels path, type: {GcsUri: TSV}, description: 'Local or GCS path to the training set labels (each label is a class index from 0 to num-classes - 1).'} | ||
- {name: Output model URI, type: {GcsUri: Keras model}, description: 'Local or GCS path specifying where to save the trained model. The model (topology + weights + optimizer state) is saved in HDF5 format and can be loaded back by calling keras.models.load_model'} #Remove GcsUri and move to outputs once artifact passing support is checked in. | ||
- {name: Model config, type: {GcsUri: Keras model config json}, description: 'JSON string containing the serialized model structure. Can be obtained by calling model.to_json() on a Keras model.'} | ||
- {name: Number of classes, type: Integer, description: 'Number of classifier classes.'} | ||
- {name: Number of epochs, type: Integer, default: '100', description: 'Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided.'} | ||
- {name: Batch size, type: Integer, default: '32', description: 'Number of samples per gradient update.'} | ||
outputs: | ||
- {name: Output model URI, type: {GcsUri: Keras model}, description: 'GCS path where the trained model has been saved. The model (topology + weights + optimizer state) is saved in HDF5 format and can be loaded back by calling keras.models.load_model'} #Remove GcsUri and make it a proper output once artifact passing support is checked in. | ||
implementation: | ||
container: | ||
image: gcr.io/ml-pipeline/sample/keras/train_classifier | ||
command: [python3, /pipelines/component/src/train.py] | ||
args: [ | ||
--training-set-features-path, {inputValue: Training set features path}, | ||
--training-set-labels-path, {inputValue: Training set labels path}, | ||
--output-model-path, {inputValue: Output model URI}, | ||
--model-config-json, {inputValue: Model config}, | ||
--num-classes, {inputValue: Number of classes}, | ||
--num-epochs, {inputValue: Number of epochs}, | ||
--batch-size, {inputValue: Batch size}, | ||
|
||
--output-model-path-file, {outputPath: Output model URI}, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/bin/bash -e | ||
# Copyright 2018 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
cd $(dirname $0) | ||
python3 -m unittest discover --verbose --start-dir tests --top-level-directory=.. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright 2018 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import json | ||
import os | ||
from pathlib import Path | ||
|
||
import keras | ||
import numpy as np | ||
|
||
parser = argparse.ArgumentParser(description='Train classifier model using Keras') | ||
|
||
parser.add_argument('--training-set-features-path', type=str, help='Local or GCS path to the training set features table.') | ||
parser.add_argument('--training-set-labels-path', type=str, help='Local or GCS path to the training set labels (each label is a class index from 0 to num-classes - 1).') | ||
parser.add_argument('--output-model-path', type=str, help='Local or GCS path specifying where to save the trained model. The model (topology + weights + optimizer state) is saved in HDF5 format and can be loaded back by calling keras.models.load_model') | ||
parser.add_argument('--model-config-json', type=str, help='JSON string containing the serialized model structure. Can be obtained by calling model.to_json() on a Keras model.') | ||
parser.add_argument('--num-classes', type=int, help='Number of classifier classes.') | ||
parser.add_argument('--num-epochs', type=int, default=100, help='Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided.') | ||
parser.add_argument('--batch-size', type=int, default=32, help='Number of samples per gradient update.') | ||
|
||
parser.add_argument('--output-model-path-file', type=str, help='Path to a local file containing the output model URI. Needed for data passing until the artifact support is checked in.') #TODO: Remove after the team agrees to let me check in artifact support. | ||
args = parser.parse_args() | ||
|
||
# The data, split between train and test sets: | ||
#(x_train, y_train), (x_test, y_test) = cifar10.load_data() | ||
x_train = np.loadtxt(args.training_set_features_path) | ||
y_train = np.loadtxt(args.training_set_labels_path) | ||
print('x_train shape:', x_train.shape) | ||
print(x_train.shape[0], 'train samples') | ||
|
||
# Convert class vectors to binary class matrices. | ||
y_train = keras.utils.to_categorical(y_train, args.num_classes) | ||
|
||
model = keras.models.model_from_json(args.model_config_json) | ||
|
||
model.add(keras.layers.Activation('softmax')) | ||
|
||
# initiate RMSprop optimizer | ||
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6) | ||
|
||
# Let's train the model using RMSprop | ||
model.compile(loss='categorical_crossentropy', | ||
optimizer=opt, | ||
metrics=['accuracy']) | ||
|
||
x_train = x_train.astype('float32') | ||
x_train /= 255 | ||
|
||
model.fit( | ||
x_train, | ||
y_train, | ||
batch_size=args.batch_size, | ||
epochs=args.num_epochs, | ||
shuffle=True | ||
) | ||
|
||
# Save model and weights | ||
if not args.output_model_path.startswith('gs://'): | ||
save_dir = os.path.dirname(args.output_model_path) | ||
if not os.path.isdir(save_dir): | ||
os.makedirs(save_dir) | ||
|
||
model.save(args.output_model_path) | ||
print('Saved trained model at %s ' % args.output_model_path) | ||
|
||
Path(args.output_model_path_file).parent.mkdir(parents=True, exist_ok=True) | ||
Path(args.output_model_path_file).write_text(args.output_model_path) |
64 changes: 64 additions & 0 deletions
64
components/sample/keras/train_classifier/tests/test_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright 2018 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import subprocess | ||
import tempfile | ||
import unittest | ||
from contextlib import contextmanager | ||
from pathlib import Path | ||
|
||
import kfp.components as comp | ||
|
||
@contextmanager | ||
def components_local_output_dir_context(output_dir: str): | ||
old_dir = comp._components._outputs_dir | ||
try: | ||
comp._components._outputs_dir = output_dir | ||
yield output_dir | ||
finally: | ||
comp._components._outputs_dir = old_dir | ||
|
||
class KerasTrainClassifierTestCase(unittest.TestCase): | ||
def test_handle_training_xor(self): | ||
tests_root = os.path.abspath(os.path.dirname(__file__)) | ||
component_root = os.path.abspath(os.path.join(tests_root, '..')) | ||
testdata_root = os.path.abspath(os.path.join(tests_root, 'testdata')) | ||
|
||
train_op = comp.load_component(os.path.join(component_root, 'component.yaml')) | ||
|
||
with tempfile.TemporaryDirectory() as temp_dir_name: | ||
with components_local_output_dir_context(temp_dir_name): | ||
train_task = train_op( | ||
training_set_features_path=os.path.join(testdata_root, 'training_set_features.tsv'), | ||
training_set_labels_path=os.path.join(testdata_root, 'training_set_labels.tsv'), | ||
output_model_uri=os.path.join(temp_dir_name, 'outputs/output_model/data'), | ||
model_config=Path(testdata_root).joinpath('model_config.json').read_text(), | ||
number_of_classes=2, | ||
number_of_epochs=10, | ||
batch_size=32, | ||
) | ||
|
||
full_command = train_task.command + train_task.arguments | ||
full_command[0] = 'python' | ||
full_command[1] = os.path.join(component_root, 'src', 'train.py') | ||
|
||
process = subprocess.run(full_command) | ||
|
||
(output_model_uri_file, ) = (train_task.file_outputs['output-model-uri'], ) | ||
output_model_uri = Path(output_model_uri_file).read_text() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
66 changes: 66 additions & 0 deletions
66
components/sample/keras/train_classifier/tests/testdata/model_config.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
{ | ||
"class_name": "Sequential", | ||
"config": { | ||
"name": "sequential_1", | ||
"layers": [ | ||
{ | ||
"class_name": "Dense", | ||
"config": { | ||
"name": "dense_1", | ||
"trainable": true, | ||
"units": 2, | ||
"activation": "linear", | ||
"use_bias": true, | ||
"kernel_initializer": { | ||
"class_name": "VarianceScaling", | ||
"config": { | ||
"scale": 1.0, | ||
"mode": "fan_avg", | ||
"distribution": "uniform", | ||
"seed": null | ||
} | ||
}, | ||
"bias_initializer": { | ||
"class_name": "Zeros", | ||
"config": {} | ||
}, | ||
"kernel_regularizer": null, | ||
"bias_regularizer": null, | ||
"activity_regularizer": null, | ||
"kernel_constraint": null, | ||
"bias_constraint": null | ||
} | ||
}, | ||
{ | ||
"class_name": "Dense", | ||
"config": { | ||
"name": "dense_2", | ||
"trainable": true, | ||
"units": 2, | ||
"activation": "linear", | ||
"use_bias": true, | ||
"kernel_initializer": { | ||
"class_name": "VarianceScaling", | ||
"config": { | ||
"scale": 1.0, | ||
"mode": "fan_avg", | ||
"distribution": "uniform", | ||
"seed": null | ||
} | ||
}, | ||
"bias_initializer": { | ||
"class_name": "Zeros", | ||
"config": {} | ||
}, | ||
"kernel_regularizer": null, | ||
"bias_regularizer": null, | ||
"activity_regularizer": null, | ||
"kernel_constraint": null, | ||
"bias_constraint": null | ||
} | ||
} | ||
] | ||
}, | ||
"keras_version": "2.2.4", | ||
"backend": "tensorflow" | ||
} |
4 changes: 4 additions & 0 deletions
4
components/sample/keras/train_classifier/tests/testdata/training_set_features.tsv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
0 0 | ||
0 1 | ||
1 0 | ||
1 1 |
4 changes: 4 additions & 0 deletions
4
components/sample/keras/train_classifier/tests/testdata/training_set_labels.tsv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
0 | ||
1 | ||
1 | ||
0 |