Skip to content

Commit cd52f64

Browse files
reedwmtensorflower-gardener
authored andcommitted
Merge the two LossScale classes.
Before, we had the following loss scale base classes, which did the exact same thing: * tf.keras.mixed_precision.experimental.LossScale, which only worked for the keras OptimizerV2 * An unexposed LossScale in tensorflow/python/training/experimental, which only worked for the V1 Optimizer This change removes the Keras LossScale and merges it into the training LossScale, which now works for both Optimizers. The training LossScale is exposed as tf.train.experimental.LossScale. I moved over some functionality, comments, and style conventions from the Keras LossScale to the training LossScale. Because the LossScale class can not rely on Keras, the Keras OptimizerV2 now calls backend.track_variable on the LossScale variables instead. Note: I intend to cherrypick this into TF 1.14. PiperOrigin-RevId: 248213961
1 parent 9c9afc5 commit cd52f64

19 files changed

+138
-738
lines changed

tensorflow/python/keras/mixed_precision/experimental/BUILD

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ py_library(
9696
srcs = ["loss_scale_optimizer.py"],
9797
srcs_version = "PY2AND3",
9898
deps = [
99-
":loss_scale",
99+
"//tensorflow/python:loss_scale",
100100
"//tensorflow/python/keras/optimizer_v2",
101101
"@absl_py//absl/testing:parameterized",
102102
],
@@ -116,31 +116,6 @@ cuda_py_test(
116116
],
117117
)
118118

119-
py_library(
120-
name = "loss_scale",
121-
srcs = ["loss_scale.py"],
122-
srcs_version = "PY2AND3",
123-
deps = [
124-
"//tensorflow/python:framework",
125-
"@absl_py//absl/testing:parameterized",
126-
],
127-
)
128-
129-
py_test(
130-
name = "loss_scale_test",
131-
size = "medium",
132-
srcs = ["loss_scale_test.py"],
133-
python_version = "PY2",
134-
deps = [
135-
":loss_scale",
136-
"//tensorflow/python:client_testlib",
137-
"//tensorflow/python/distribute:mirrored_strategy",
138-
"//tensorflow/python/distribute:one_device_strategy",
139-
"//tensorflow/python/keras",
140-
"@absl_py//absl/testing:parameterized",
141-
],
142-
)
143-
144119
py_library(
145120
name = "test_util",
146121
srcs = ["test_util.py"],

tensorflow/python/keras/mixed_precision/experimental/keras_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from tensorflow.python.keras import regularizers
3737
from tensorflow.python.keras.engine import base_layer
3838
from tensorflow.python.keras.layers import core
39-
from tensorflow.python.keras.mixed_precision.experimental import loss_scale as loss_scale_module
4039
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
4140
from tensorflow.python.keras.mixed_precision.experimental import policy
4241
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
@@ -45,6 +44,7 @@
4544
from tensorflow.python.ops import math_ops
4645
from tensorflow.python.ops import variables
4746
from tensorflow.python.platform import test
47+
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
4848
from tensorflow.python.training.tracking import util as trackable_utils
4949
from tensorflow.python.util import nest
5050

0 commit comments

Comments
 (0)