Skip to content

Commit

Permalink
Add experiments flag to converter (tensorflow#3748)
Browse files Browse the repository at this point in the history
FEATURE
INTERNAL
* added experiments flag and g3 only graph transform

* fix build error

* addressed comments

* add more comments
  • Loading branch information
pyu10055 authored Aug 7, 2020
1 parent dc97f29 commit 178a54b
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 22 deletions.
8 changes: 8 additions & 0 deletions tfjs-converter/python/requirements-exp.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
h5py>=2.8.0
numpy>=1.16.4,<1.19.0
six>=1.12.0
tf-nightly-cpu>=2.4.0.dev20200806,<3
tensorflow-hub==0.7.0
PyInquirer==1.0.3
pylint==1.9.4; python_version < '3.0'
pylint==2.5.0; python_version > '3.0'
27 changes: 26 additions & 1 deletion tfjs-converter/python/run-python-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@

# A script that runs all Python unit tests in tfjs-layers.

function print_usage() {
echo "Usage:"
echo " run-python-tests.sh <requirments_file>"
echo
}

set -e

SCRIPTS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
Expand All @@ -28,8 +34,26 @@ TMP_VENV_DIR="$(mktemp -u).venv"
virtualenv -p "python" "${TMP_VENV_DIR}"
source "${TMP_VENV_DIR}/bin/activate"

pip install -r "${SCRIPTS_DIR}/requirements-dev.txt"
# There is one argument (requirements_file), please update this constant when
# you adding more arguments.
ARGS_COUNT=1

# Default requirements file name.
REQ_FILE="${SCRIPTS_DIR}/requirements-dev.txt"

# Show the usage message if there are too many arguments.
if [[ $# > ARGS_COUNT ]]; then
print_usage
exit 1
fi

# Use the user specified requirements file name.
if [[ $# == 1 ]]; then
REQ_FILE=$1
fi
pip install -r "${REQ_FILE}"

# Run pylint for tensorflowjs directory
cd "${SCRIPTS_DIR}"
pylint --rcfile=.pylintrc tensorflowjs

Expand All @@ -45,5 +69,6 @@ echo
echo "All tests passed."
echo

# Clean up
deactivate
rm -rf "${TMP_VENV_DIR}"
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ py_library(
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_hub_installed",
"//tensorflowjs:expect_tensorflow_installed",
"//tensorflowjs:expect_graph_transforms_installed",
"//tensorflowjs:resource_loader",
"//tensorflowjs:version",
"//tensorflowjs:write_weights",
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
STRIP_DEBUG_OPS = 'strip_debug_ops'
WEIGHT_SHARD_SIZE_BYTES = 'weight_shard_size_bytes'
CONTROL_FLOW_V2 = 'control_flow_v2'
EXPERIMENTS = 'experiments'

def get_converted_by():
"""Get the convertedBy string for storage in model artifacts."""
Expand Down
44 changes: 33 additions & 11 deletions tfjs-converter/python/tensorflowjs/converters/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
skip_op_check=False,
strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4,
control_flow_v2=False):
control_flow_v2=False,
experiments=False):
"""
Convert a keras HDF5-format model to tfjs GraphModel artifacts.
Expand All @@ -120,6 +121,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
strip_debug_ops: Bool whether to allow unsupported debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
control_flow_v2: Bool whether to enable control flow v2 ops.
experiments: Bool enable experimental features.
"""

if not os.path.exists(h5_path):
Expand All @@ -143,7 +146,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
control_flow_v2=control_flow_v2)
control_flow_v2=control_flow_v2,
experiments=experiments)

# Clean up the temporary SavedModel directory.
shutil.rmtree(temp_savedmodel_dir)
Expand Down Expand Up @@ -331,7 +335,9 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
quantization_dtype_map=None,
skip_op_check=False,
strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4):
weight_shard_size_bytes=1024 * 1024 * 4,
control_flow_v2=False,
experiments=False):
"""Converts a TensorFlow.js Layers Model to TensorFlow.js Graph Model.
This conversion often benefits speed of inference, due to the graph
Expand All @@ -348,7 +354,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
strip_debug_ops: Bool whether to allow unsupported debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
control_flow_v2: Bool whether to enable control flow v2 ops.
experiments: Bool enable experimental features.
Raises:
ValueError, if `config_json_path` is not a path to a valid JSON
file, or if h5_path points to an existing directory.
Expand Down Expand Up @@ -382,7 +389,9 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
quantization_dtype_map=quantization_dtype_map,
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
weight_shard_size_bytes=weight_shard_size_bytes,
control_flow_v2=control_flow_v2,
experiments=experiments)

# Clean up temporary HDF5 file.
os.remove(temp_h5_path)
Expand Down Expand Up @@ -575,9 +584,16 @@ def get_arg_parser():
'"tf_frozen_model".')
parser.add_argument(
'--%s' % common.CONTROL_FLOW_V2,
type=str,
type=bool,
default=False,
help='Enable control flow v2 ops, this would improve inference '
'performance on models with branches or loops.')
parser.add_argument(
'--%s' % common.EXPERIMENTS,
type=bool,
default=False,
help='Enable experimental features, you should only enable this flag '
'when using Python3 and TensorFlow nightly build.')
return parser

def convert(arguments):
Expand Down Expand Up @@ -660,7 +676,8 @@ def convert(arguments):
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
control_flow_v2=args.control_flow_v2)
control_flow_v2=args.control_flow_v2,
experiments=args.experiments)
elif (input_format == common.KERAS_SAVED_MODEL and
output_format == common.TFJS_LAYERS_MODEL):
dispatch_keras_saved_model_to_tensorflowjs_conversion(
Expand All @@ -678,7 +695,8 @@ def convert(arguments):
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
control_flow_v2=args.control_flow_v2)
control_flow_v2=args.control_flow_v2,
experiments=args.experiments)
elif (input_format == common.TF_HUB_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_hub_module(
Expand All @@ -689,7 +707,8 @@ def convert(arguments):
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
control_flow_v2=args.control_flow_v2)
control_flow_v2=args.control_flow_v2,
experiments=args.experiments)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.KERAS_MODEL):
dispatch_tensorflowjs_to_keras_h5_conversion(args.input_path,
Expand All @@ -711,15 +730,18 @@ def convert(arguments):
quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
weight_shard_size_bytes=weight_shard_size_bytes,
control_flow_v2=args.control_flow_v2,
experiments=args.experiments)
elif (input_format == common.TF_FROZEN_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_frozen_model(
args.input_path, args.output_node_names, args.output_path,
quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
weight_shard_size_bytes=weight_shard_size_bytes,
experiments=args.experiments)
else:
raise ValueError(
'Unsupported input_format - output_format pair: %s - %s' %
Expand Down
Loading

0 comments on commit 178a54b

Please sign in to comment.