|
| 1 | +name: Generate tfdv stats |
| 2 | +inputs: |
| 3 | +- {name: input_data, type: String} |
| 4 | +- {name: output_path, type: String} |
| 5 | +- {name: job_name, type: String} |
| 6 | +- {name: use_dataflow, type: String} |
| 7 | +- {name: project_id, type: String} |
| 8 | +- {name: region, type: String} |
| 9 | +- {name: gcs_temp_location, type: String} |
| 10 | +- {name: gcs_staging_location, type: String} |
| 11 | +- {name: whl_location, type: String, default: '', optional: true} |
| 12 | +- {name: requirements_file, type: String, default: requirements.txt, optional: true} |
| 13 | +outputs: |
| 14 | +- {name: stats_path, type: String} |
| 15 | +implementation: |
| 16 | + container: |
| 17 | + image: gcr.io/google-samples/tfdv-tests:v1 |
| 18 | + command: |
| 19 | + - sh |
| 20 | + - -ec |
| 21 | + - | |
| 22 | + program_path=$(mktemp) |
| 23 | + printf "%s" "$0" > "$program_path" |
| 24 | + python3 -u "$program_path" "$@" |
| 25 | + - | |
| 26 | + def generate_tfdv_stats(input_data, output_path, job_name, use_dataflow, |
| 27 | + project_id, region, gcs_temp_location, gcs_staging_location, |
| 28 | + whl_location = '', requirements_file = 'requirements.txt' |
| 29 | + ): |
| 30 | +
|
| 31 | + import logging |
| 32 | + import time |
| 33 | +
|
| 34 | + import tensorflow_data_validation as tfdv |
| 35 | + import tensorflow_data_validation.statistics.stats_impl |
| 36 | + from apache_beam.options.pipeline_options import PipelineOptions, GoogleCloudOptions, StandardOptions, SetupOptions |
| 37 | +
|
| 38 | + # pip download tensorflow_data_validation --no-deps --platform manylinux2010_x86_64 --only-binary=:all: |
| 39 | + # CHANGE this if your download resulted in a different filename. |
| 40 | +
|
| 41 | + logging.getLogger().setLevel(logging.INFO) |
| 42 | + logging.info("output path: %s", output_path) |
| 43 | + logging.info("Building pipeline options") |
| 44 | + # Create and set your PipelineOptions. |
| 45 | + options = PipelineOptions() |
| 46 | +
|
| 47 | + if use_dataflow == 'true': |
| 48 | + logging.info("using Dataflow") |
| 49 | + if not whl_location: |
| 50 | + logging.warning('tfdv whl file required with dataflow runner.') |
| 51 | + exit(1) |
| 52 | + # For Cloud execution, set the Cloud Platform project, job_name, |
| 53 | + # staging location, temp_location and specify DataflowRunner. |
| 54 | + google_cloud_options = options.view_as(GoogleCloudOptions) |
| 55 | + google_cloud_options.project = project_id |
| 56 | + google_cloud_options.job_name = '{}-{}'.format(job_name, str(int(time.time()))) |
| 57 | + google_cloud_options.staging_location = gcs_staging_location |
| 58 | + google_cloud_options.temp_location = gcs_temp_location |
| 59 | + google_cloud_options.region = region |
| 60 | + options.view_as(StandardOptions).runner = 'DataflowRunner' |
| 61 | +
|
| 62 | + setup_options = options.view_as(SetupOptions) |
| 63 | + # PATH_TO_WHL_FILE should point to the downloaded tfdv wheel file. |
| 64 | + setup_options.extra_packages = [whl_location] |
| 65 | + setup_options.requirements_file = 'requirements.txt' |
| 66 | +
|
| 67 | + tfdv.generate_statistics_from_csv( |
| 68 | + data_location=input_data, output_path=output_path, |
| 69 | + pipeline_options=options) |
| 70 | +
|
| 71 | + return (output_path, ) |
| 72 | +
|
| 73 | + def _serialize_str(str_value: str) -> str: |
| 74 | + if not isinstance(str_value, str): |
| 75 | + raise TypeError('Value "{}" has type "{}" instead of str.'.format(str(str_value), str(type(str_value)))) |
| 76 | + return str_value |
| 77 | +
|
| 78 | + import argparse |
| 79 | + _parser = argparse.ArgumentParser(prog='Generate tfdv stats', description='') |
| 80 | + _parser.add_argument("--input-data", dest="input_data", type=str, required=True, default=argparse.SUPPRESS) |
| 81 | + _parser.add_argument("--output-path", dest="output_path", type=str, required=True, default=argparse.SUPPRESS) |
| 82 | + _parser.add_argument("--job-name", dest="job_name", type=str, required=True, default=argparse.SUPPRESS) |
| 83 | + _parser.add_argument("--use-dataflow", dest="use_dataflow", type=str, required=True, default=argparse.SUPPRESS) |
| 84 | + _parser.add_argument("--project-id", dest="project_id", type=str, required=True, default=argparse.SUPPRESS) |
| 85 | + _parser.add_argument("--region", dest="region", type=str, required=True, default=argparse.SUPPRESS) |
| 86 | + _parser.add_argument("--gcs-temp-location", dest="gcs_temp_location", type=str, required=True, default=argparse.SUPPRESS) |
| 87 | + _parser.add_argument("--gcs-staging-location", dest="gcs_staging_location", type=str, required=True, default=argparse.SUPPRESS) |
| 88 | + _parser.add_argument("--whl-location", dest="whl_location", type=str, required=False, default=argparse.SUPPRESS) |
| 89 | + _parser.add_argument("--requirements-file", dest="requirements_file", type=str, required=False, default=argparse.SUPPRESS) |
| 90 | + _parser.add_argument("----output-paths", dest="_output_paths", type=str, nargs=1) |
| 91 | + _parsed_args = vars(_parser.parse_args()) |
| 92 | + _output_files = _parsed_args.pop("_output_paths", []) |
| 93 | +
|
| 94 | + _outputs = generate_tfdv_stats(**_parsed_args) |
| 95 | +
|
| 96 | + _output_serializers = [ |
| 97 | + _serialize_str, |
| 98 | +
|
| 99 | + ] |
| 100 | +
|
| 101 | + import os |
| 102 | + for idx, output_file in enumerate(_output_files): |
| 103 | + try: |
| 104 | + os.makedirs(os.path.dirname(output_file)) |
| 105 | + except OSError: |
| 106 | + pass |
| 107 | + with open(output_file, 'w') as f: |
| 108 | + f.write(_output_serializers[idx](_outputs[idx])) |
| 109 | + args: |
| 110 | + - --input-data |
| 111 | + - {inputValue: input_data} |
| 112 | + - --output-path |
| 113 | + - {inputValue: output_path} |
| 114 | + - --job-name |
| 115 | + - {inputValue: job_name} |
| 116 | + - --use-dataflow |
| 117 | + - {inputValue: use_dataflow} |
| 118 | + - --project-id |
| 119 | + - {inputValue: project_id} |
| 120 | + - --region |
| 121 | + - {inputValue: region} |
| 122 | + - --gcs-temp-location |
| 123 | + - {inputValue: gcs_temp_location} |
| 124 | + - --gcs-staging-location |
| 125 | + - {inputValue: gcs_staging_location} |
| 126 | + - if: |
| 127 | + cond: {isPresent: whl_location} |
| 128 | + then: |
| 129 | + - --whl-location |
| 130 | + - {inputValue: whl_location} |
| 131 | + - if: |
| 132 | + cond: {isPresent: requirements_file} |
| 133 | + then: |
| 134 | + - --requirements-file |
| 135 | + - {inputValue: requirements_file} |
| 136 | + - '----output-paths' |
| 137 | + - {outputPath: stats_path} |
0 commit comments