Skip to content

Commit c07abee

Browse files
authored
remove turning off grappler layout optimizer. (tensorflow#7384)
1 parent 8304e64 commit c07abee

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

official/transformer/v2/transformer_main.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,15 @@ def __init__(self, flags_obj):
127127
# We should have a better way in the tf.keras.mixed_precision API of doing
128128
# this.
129129
policy = tf.keras.mixed_precision.experimental.Policy(
130-
'infer_float32_vars')
130+
"infer_float32_vars")
131131
tf.keras.mixed_precision.experimental.set_policy(policy)
132132

133133
def train(self):
134134
"""Trains the model."""
135135
params, flags_obj, is_train = self.params, self.flags_obj, True
136136
# Sets config options.
137137
keras_utils.set_session_config(
138-
enable_xla=flags_obj.enable_xla,
139-
enable_grappler_layout_optimizer=
140-
flags_obj.enable_grappler_layout_optimizer)
138+
enable_xla=flags_obj.enable_xla)
141139

142140
_ensure_dir(flags_obj.model_dir)
143141
if self.distribution_strategy:
@@ -154,7 +152,7 @@ def train(self):
154152

155153
train_ds = data_pipeline.train_input_fn(params)
156154
map_data_fn = data_pipeline.map_data_for_transformer_fn
157-
train_ds = train_ds.map(map_data_fn,
155+
train_ds = train_ds.map(map_data_fn,
158156
num_parallel_calls=params["num_parallel_calls"])
159157

160158
callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)

0 commit comments

Comments
 (0)