Skip to content

Commit 7dbdc6c

Browse files
authored
tfdv components + pipeline + other minor cleanup (#83)
update container image used, minor changes to keep consistent with the notebook version, tfdv component recompile
1 parent b34a283 commit 7dbdc6c

File tree

8 files changed

+432
-48
lines changed

8 files changed

+432
-48
lines changed

ml/kubeflow-pipelines/keras_tuner/components/kubeflow-resources/bikesw_training/bw_hptune_standalone.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def create_model(hp):
4848
linear_feature_columns=sparse.values(),
4949
dnn_feature_columns=real.values(),
5050
num_hidden_layers=hp.Int('num_hidden_layers', 2, 5),
51-
dnn_hidden_units1=hp.Int('hidden_size', 32, 256, step=32),
51+
dnn_hidden_units1=hp.Int('hidden_size', 16, 256, step=32),
5252
learning_rate=hp.Choice('learning_rate',
53-
values=[1e-1, 1e-2, 1e-3, 1e-4])
53+
values=[5e-1, 1e-1, 1e-2, 1e-3, 1e-4])
5454
)
5555

5656
model.summary()

ml/kubeflow-pipelines/keras_tuner/components/tfdv/Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
FROM gcr.io/deeplearning-platform-release/tf2-cpu.2-3:latest
1616

1717
ADD requirements.txt /
18-
ADD tfdv.py /
18+
# ADD tfdv.py /
19+
RUN pip install -U tensorflow-data-validation
1920
RUN pip download tensorflow_data_validation --no-deps --platform manylinux2010_x86_64 --only-binary=:all:
2021
RUN pip install -U "apache-beam[gcp]"
Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020 Google Inc. All Rights Reserved.
1+
# Copyright 2021 Google Inc. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,25 +12,32 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import NamedTuple
1516

16-
def generate_tfdv_stats(input_data: str, output_path: str, job_name: str, use_dataflow: bool,
17+
18+
def generate_tfdv_stats(input_data: str, output_path: str, job_name: str, use_dataflow: str,
1719
project_id: str, region:str, gcs_temp_location: str, gcs_staging_location: str,
18-
whl_location: str = '', requirements_file: str = 'requirements.txt'):
19-
import tensorflow_data_validation as tfdv
20+
whl_location: str = '', requirements_file: str = 'requirements.txt'
21+
) -> NamedTuple('Outputs', [('stats_path', str)]):
22+
2023
import logging
2124
import time
2225

23-
import tensorflow_data_validation.statistics.stats_impl
2426
import tensorflow_data_validation as tfdv
27+
import tensorflow_data_validation.statistics.stats_impl
2528
from apache_beam.options.pipeline_options import PipelineOptions, GoogleCloudOptions, StandardOptions, SetupOptions
2629

2730
# pip download tensorflow_data_validation --no-deps --platform manylinux2010_x86_64 --only-binary=:all:
2831
# CHANGE this if your download resulted in a different filename.
2932

33+
logging.getLogger().setLevel(logging.INFO)
34+
logging.info("output path: %s", output_path)
35+
logging.info("Building pipeline options")
3036
# Create and set your PipelineOptions.
3137
options = PipelineOptions()
3238

33-
if use_dataflow:
39+
if use_dataflow == 'true':
40+
logging.info("using Dataflow")
3441
if not whl_location:
3542
logging.warning('tfdv whl file required with dataflow runner.')
3643
exit(1)
@@ -53,46 +60,11 @@ def generate_tfdv_stats(input_data: str, output_path: str, job_name: str, use_da
5360
data_location=input_data, output_path=output_path,
5461
pipeline_options=options)
5562

56-
57-
def main():
58-
59-
logging.getLogger().setLevel(logging.INFO)
60-
parser = argparse.ArgumentParser(description='TVDV')
61-
62-
parser.add_argument(
63-
'--project_id', default='aju-vtests2')
64-
parser.add_argument(
65-
'--region', default='us-central1')
66-
parser.add_argument(
67-
'--job_name', required=True)
68-
parser.add_argument(
69-
'--gcs-staging-location', required=True)
70-
parser.add_argument(
71-
'--gcs-temp-location', required=True)
72-
parser.add_argument(
73-
'--output-path', required=True)
74-
parser.add_argument(
75-
'--data-path', required=True)
76-
# TFDV whl required for Dataflow runner. Download whl file with this command:
77-
# pip download tensorflow_data_validation --no-deps --platform manylinux2010_x86_64 --only-binary=:all:
78-
parser.add_argument('--whl-location')
79-
parser.add_argument('--requirements_file', default='requirements.txt')
80-
parser.add_argument('--use-dataflow', default=False, help='Run on Dataflow', action='store_true')
81-
parser.add_argument('--local', dest='use-dataflow', help='Run locally', action='store_false')
82-
args = parser.parse_args()
83-
84-
use_dataflow = False
85-
if args.use_dataflow:
86-
use_dataflow = True
87-
88-
generate_tfdv_stats(args.data_path, args.output_path, args.job_name, use_dataflow,
89-
args.project_id, args.region, args.gcs_temp_location, args.gcs_staging_location,
90-
args.whl_location, args.requirements_file)
91-
63+
return (output_path, )
9264

9365

9466
if __name__ == '__main__':
9567
import kfp
9668
kfp.components.func_to_container_op(generate_tfdv_stats,
97-
output_component_file='../tfdv_component.yaml', base_image='gcr.io/aju-vtests2/tfdv-tests:v6')
98-
# main()
69+
output_component_file='../tfdv_component.yaml',
70+
base_image='gcr.io/google-samples/tfdv-tests:v1')
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2021 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import NamedTuple
16+
17+
18+
def tfdv_detect_drift(
19+
stats_older_path: str, stats_new_path: str
20+
) -> NamedTuple('Outputs', [('drift', str)]):
21+
22+
import logging
23+
import time
24+
25+
import tensorflow_data_validation as tfdv
26+
import tensorflow_data_validation.statistics.stats_impl
27+
28+
logging.getLogger().setLevel(logging.INFO)
29+
logging.info('stats_older_path: %s', stats_older_path)
30+
logging.info('stats_new_path: %s', stats_new_path)
31+
32+
if stats_older_path == 'none':
33+
return ('true', )
34+
35+
stats1 = tfdv.load_statistics(stats_older_path)
36+
stats2 = tfdv.load_statistics(stats_new_path)
37+
38+
schema1 = tfdv.infer_schema(statistics=stats1)
39+
tfdv.get_feature(schema1, 'duration').drift_comparator.jensen_shannon_divergence.threshold = 0.01
40+
drift_anomalies = tfdv.validate_statistics(
41+
statistics=stats2, schema=schema1, previous_statistics=stats1)
42+
logging.info('drift analysis results: %s', drift_anomalies.drift_skew_info)
43+
44+
from google.protobuf.json_format import MessageToDict
45+
d = MessageToDict(drift_anomalies)
46+
val = d['driftSkewInfo'][0]['driftMeasurements'][0]['value']
47+
thresh = d['driftSkewInfo'][0]['driftMeasurements'][0]['threshold']
48+
logging.info('value %s and threshold %s', val, thresh)
49+
res = 'true'
50+
if val < thresh:
51+
res = 'false'
52+
logging.info('train decision: %s', res)
53+
return (res, )
54+
55+
56+
if __name__ == '__main__':
57+
import kfp
58+
kfp.components.func_to_container_op(tfdv_detect_drift,
59+
output_component_file='../tfdv_drift_component.yaml',
60+
base_image='gcr.io/google-samples/tfdv-tests:v1')
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
name: Generate tfdv stats
2+
inputs:
3+
- {name: input_data, type: String}
4+
- {name: output_path, type: String}
5+
- {name: job_name, type: String}
6+
- {name: use_dataflow, type: String}
7+
- {name: project_id, type: String}
8+
- {name: region, type: String}
9+
- {name: gcs_temp_location, type: String}
10+
- {name: gcs_staging_location, type: String}
11+
- {name: whl_location, type: String, default: '', optional: true}
12+
- {name: requirements_file, type: String, default: requirements.txt, optional: true}
13+
outputs:
14+
- {name: stats_path, type: String}
15+
implementation:
16+
container:
17+
image: gcr.io/google-samples/tfdv-tests:v1
18+
command:
19+
- sh
20+
- -ec
21+
- |
22+
program_path=$(mktemp)
23+
printf "%s" "$0" > "$program_path"
24+
python3 -u "$program_path" "$@"
25+
- |
26+
def generate_tfdv_stats(input_data, output_path, job_name, use_dataflow,
27+
project_id, region, gcs_temp_location, gcs_staging_location,
28+
whl_location = '', requirements_file = 'requirements.txt'
29+
):
30+
31+
import logging
32+
import time
33+
34+
import tensorflow_data_validation as tfdv
35+
import tensorflow_data_validation.statistics.stats_impl
36+
from apache_beam.options.pipeline_options import PipelineOptions, GoogleCloudOptions, StandardOptions, SetupOptions
37+
38+
# pip download tensorflow_data_validation --no-deps --platform manylinux2010_x86_64 --only-binary=:all:
39+
# CHANGE this if your download resulted in a different filename.
40+
41+
logging.getLogger().setLevel(logging.INFO)
42+
logging.info("output path: %s", output_path)
43+
logging.info("Building pipeline options")
44+
# Create and set your PipelineOptions.
45+
options = PipelineOptions()
46+
47+
if use_dataflow == 'true':
48+
logging.info("using Dataflow")
49+
if not whl_location:
50+
logging.warning('tfdv whl file required with dataflow runner.')
51+
exit(1)
52+
# For Cloud execution, set the Cloud Platform project, job_name,
53+
# staging location, temp_location and specify DataflowRunner.
54+
google_cloud_options = options.view_as(GoogleCloudOptions)
55+
google_cloud_options.project = project_id
56+
google_cloud_options.job_name = '{}-{}'.format(job_name, str(int(time.time())))
57+
google_cloud_options.staging_location = gcs_staging_location
58+
google_cloud_options.temp_location = gcs_temp_location
59+
google_cloud_options.region = region
60+
options.view_as(StandardOptions).runner = 'DataflowRunner'
61+
62+
setup_options = options.view_as(SetupOptions)
63+
# PATH_TO_WHL_FILE should point to the downloaded tfdv wheel file.
64+
setup_options.extra_packages = [whl_location]
65+
setup_options.requirements_file = 'requirements.txt'
66+
67+
tfdv.generate_statistics_from_csv(
68+
data_location=input_data, output_path=output_path,
69+
pipeline_options=options)
70+
71+
return (output_path, )
72+
73+
def _serialize_str(str_value: str) -> str:
74+
if not isinstance(str_value, str):
75+
raise TypeError('Value "{}" has type "{}" instead of str.'.format(str(str_value), str(type(str_value))))
76+
return str_value
77+
78+
import argparse
79+
_parser = argparse.ArgumentParser(prog='Generate tfdv stats', description='')
80+
_parser.add_argument("--input-data", dest="input_data", type=str, required=True, default=argparse.SUPPRESS)
81+
_parser.add_argument("--output-path", dest="output_path", type=str, required=True, default=argparse.SUPPRESS)
82+
_parser.add_argument("--job-name", dest="job_name", type=str, required=True, default=argparse.SUPPRESS)
83+
_parser.add_argument("--use-dataflow", dest="use_dataflow", type=str, required=True, default=argparse.SUPPRESS)
84+
_parser.add_argument("--project-id", dest="project_id", type=str, required=True, default=argparse.SUPPRESS)
85+
_parser.add_argument("--region", dest="region", type=str, required=True, default=argparse.SUPPRESS)
86+
_parser.add_argument("--gcs-temp-location", dest="gcs_temp_location", type=str, required=True, default=argparse.SUPPRESS)
87+
_parser.add_argument("--gcs-staging-location", dest="gcs_staging_location", type=str, required=True, default=argparse.SUPPRESS)
88+
_parser.add_argument("--whl-location", dest="whl_location", type=str, required=False, default=argparse.SUPPRESS)
89+
_parser.add_argument("--requirements-file", dest="requirements_file", type=str, required=False, default=argparse.SUPPRESS)
90+
_parser.add_argument("----output-paths", dest="_output_paths", type=str, nargs=1)
91+
_parsed_args = vars(_parser.parse_args())
92+
_output_files = _parsed_args.pop("_output_paths", [])
93+
94+
_outputs = generate_tfdv_stats(**_parsed_args)
95+
96+
_output_serializers = [
97+
_serialize_str,
98+
99+
]
100+
101+
import os
102+
for idx, output_file in enumerate(_output_files):
103+
try:
104+
os.makedirs(os.path.dirname(output_file))
105+
except OSError:
106+
pass
107+
with open(output_file, 'w') as f:
108+
f.write(_output_serializers[idx](_outputs[idx]))
109+
args:
110+
- --input-data
111+
- {inputValue: input_data}
112+
- --output-path
113+
- {inputValue: output_path}
114+
- --job-name
115+
- {inputValue: job_name}
116+
- --use-dataflow
117+
- {inputValue: use_dataflow}
118+
- --project-id
119+
- {inputValue: project_id}
120+
- --region
121+
- {inputValue: region}
122+
- --gcs-temp-location
123+
- {inputValue: gcs_temp_location}
124+
- --gcs-staging-location
125+
- {inputValue: gcs_staging_location}
126+
- if:
127+
cond: {isPresent: whl_location}
128+
then:
129+
- --whl-location
130+
- {inputValue: whl_location}
131+
- if:
132+
cond: {isPresent: requirements_file}
133+
then:
134+
- --requirements-file
135+
- {inputValue: requirements_file}
136+
- '----output-paths'
137+
- {outputPath: stats_path}

0 commit comments

Comments
 (0)