@@ -197,6 +197,7 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self):
197197 native_serialization_platforms = None ,
198198 xla_flags_per_platform = None ,
199199 expected_platforms = ['tpu' ],
200+ persist_xla_flags = False ,
200201 ),
201202 dict (
202203 testcase_name = 'no_native_serialization_platforms_with_xla_flags' ,
@@ -205,12 +206,14 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self):
205206 'tpu' : [f'--{ k } ={ v } ' for k , v in XLA_FLAGS_DICT .items ()]
206207 },
207208 expected_platforms = ['tpu' ],
209+ persist_xla_flags = True ,
208210 ),
209211 dict (
210212 testcase_name = 'with_native_serialization_platforms_no_xla_flags' ,
211213 native_serialization_platforms = ['cpu' , 'tpu' , 'cuda' ],
212214 xla_flags_per_platform = None ,
213215 expected_platforms = ['cpu' , 'tpu' , 'cuda' ],
216+ persist_xla_flags = False ,
214217 ),
215218 dict (
216219 testcase_name = 'with_native_serialization_platforms_with_xla_flags' ,
@@ -219,25 +222,28 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self):
219222 'tpu' : [f'--{ k } ={ v } ' for k , v in XLA_FLAGS_DICT .items ()]
220223 },
221224 expected_platforms = ['cpu' , 'tpu' , 'cuda' ],
225+ persist_xla_flags = True ,
222226 ),
223227 )
224228 def test_generate_xla_compile_options_flags_and_platforms (
225229 self ,
226230 native_serialization_platforms ,
227231 xla_flags_per_platform ,
228232 expected_platforms ,
233+ persist_xla_flags ,
229234 ):
230235 compile_options_map = compile_options_util .generate_xla_compile_options (
231236 native_serialization_platforms = native_serialization_platforms ,
232237 xla_flags_per_platform = xla_flags_per_platform ,
238+ persist_xla_flags = persist_xla_flags ,
233239 )
234240 self .assertLen (compile_options_map .map , len (expected_platforms ))
235241
236242 for platform in expected_platforms :
237243 self .assertIn (platform , compile_options_map .map )
238244 compile_options = compile_options_map .map [platform ]
239245
240- if platform != 'tpu' :
246+ if platform != 'tpu' or not persist_xla_flags :
241247 self .assertEmpty (
242248 compile_options .executable_build_options .comp_envs .environments
243249 )
@@ -307,49 +313,17 @@ def test_generate_xla_compile_options_xla_flags_platform_not_in_native_serializa
307313 },
308314 )
309315
310- @parameterized .named_parameters (
311- dict (testcase_name = 'strip_xla_flags_true' , strip_xla_flags = True ),
312- dict (testcase_name = 'strip_xla_flags_false' , strip_xla_flags = False ),
313- )
314- def test_generate_xla_compile_options_strip_xla_flags (self , strip_xla_flags ):
315- xla_flags_per_platform = {
316- 'tpu' : [f'--{ k } ={ v } ' for k , v in XLA_FLAGS_DICT .items ()]
317- }
318- compile_options_map = compile_options_util .generate_xla_compile_options (
319- native_serialization_platforms = ['cpu' , 'tpu' , 'cuda' ],
320- xla_flags_per_platform = xla_flags_per_platform ,
321- strip_xla_flags = strip_xla_flags ,
322- )
323- self .assertLen (compile_options_map .map , 3 )
324- for platform in ['cpu' , 'tpu' , 'cuda' ]:
325- self .assertIn (platform , compile_options_map .map )
326- compile_options = compile_options_map .map [platform ]
327-
328- if strip_xla_flags or platform != 'tpu' :
329- self .assertEmpty (
330- compile_options .executable_build_options .comp_envs .environments
331- )
332- else :
333- # For TPU platform when not stripping, it should have xla flags.
334- self .assertLen (
335- compile_options .executable_build_options .comp_envs .environments , 1
336- )
337- actual_env_proto = tpu_comp_env_pb2 .TpuCompilationEnvironment ()
338- compile_options .executable_build_options .comp_envs .environments [
339- 0
340- ].Unpack (actual_env_proto )
341-
342- expected_env_overrides = EXPECTED_ENV
343- expected_env_proto = tpu_comp_env_pb2 .TpuCompilationEnvironment ()
344- expected_env_proto .ParseFromString (
345- tpu_comp_env .create_default_tpu_comp_env ()
346- )
347- expected_env_proto .MergeFrom (expected_env_overrides )
348-
349- self .assertEqual (
350- text_format .MessageToString (actual_env_proto ),
351- text_format .MessageToString (expected_env_proto ),
352- )
316+ def test_generate_xla_compile_options_xla_flags_no_persist_raise_error (self ):
317+ with self .assertRaisesWithLiteralMatch (
318+ ValueError , 'persist_xla_flags must be True if xla_flags are provided.'
319+ ):
320+ compile_options_util .generate_xla_compile_options (
321+ native_serialization_platforms = ['tpu' ],
322+ xla_flags_per_platform = {
323+ 'tpu' : [f'--{ k } ={ v } ' for k , v in XLA_FLAGS_DICT .items ()]
324+ },
325+ persist_xla_flags = False ,
326+ )
353327
354328 @parameterized .named_parameters (
355329 dict (
0 commit comments