Skip to content

Commit bcd6911

Browse files
qlzh727tensorflower-gardener
authored andcommitted
Make the GPU distribute test to load the cuda lazy, to reduce GPU memory usage, and reduce the flakyness.
PiperOrigin-RevId: 549689113
1 parent 0bfaafd commit bcd6911

File tree

1 file changed

+21
-7
lines changed
  • keras/layers/preprocessing

1 file changed

+21
-7
lines changed

keras/layers/preprocessing/BUILD

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ distribute_py_test(
257257
name = "category_encoding_distribution_test",
258258
srcs = ["category_encoding_distribution_test.py"],
259259
disable_mlir_bridge = False,
260+
env = {
261+
"CUDA_MODULE_LOADING": "LAZY",
262+
},
260263
main = "category_encoding_distribution_test.py",
261264
python_version = "PY3",
262265
shard_count = 4,
@@ -265,7 +268,6 @@ distribute_py_test(
265268
"no_oss", # b/189866692
266269
"noguitar", # b/190034522
267270
"nomultivm", # TODO(b/170502145)
268-
"requires-mem:28g", # spawns multiple processes.
269271
],
270272
tpu_tags = [
271273
"no_oss", # b/155502591
@@ -284,14 +286,16 @@ distribute_py_test(
284286
distribute_py_test(
285287
name = "image_preprocessing_distribution_test",
286288
srcs = ["image_preprocessing_distribution_test.py"],
289+
env = {
290+
"CUDA_MODULE_LOADING": "LAZY",
291+
},
287292
main = "image_preprocessing_distribution_test.py",
288293
python_version = "PY3",
289294
shard_count = 4,
290295
tags = [
291296
"multi_and_single_gpu",
292297
"nomultivm", # TODO(b/170502145)
293298
"notpu", # TODO(b/210148622)
294-
"requires-mem:28g", # spawns multiple processes.
295299
],
296300
tpu_tags = [
297301
"no_oss",
@@ -326,6 +330,9 @@ tf_py_test(
326330
distribute_py_test(
327331
name = "discretization_distribution_test",
328332
srcs = ["discretization_distribution_test.py"],
333+
env = {
334+
"CUDA_MODULE_LOADING": "LAZY",
335+
},
329336
main = "discretization_distribution_test.py",
330337
python_version = "PY3",
331338
shard_count = 4,
@@ -334,7 +341,6 @@ distribute_py_test(
334341
"no_oss", # TODO(b/189956080)
335342
"noguitar", # b/190034522
336343
"nomultivm", # TODO(b/170502145)
337-
"requires-mem:28g", # spawns multiple processes.
338344
],
339345
deps = [
340346
":discretization",
@@ -366,13 +372,15 @@ distribute_py_test(
366372
name = "hashing_distribution_test",
367373
srcs = ["hashing_distribution_test.py"],
368374
disable_mlir_bridge = False,
375+
env = {
376+
"CUDA_MODULE_LOADING": "LAZY",
377+
},
369378
main = "hashing_distribution_test.py",
370379
python_version = "PY3",
371380
shard_count = 4,
372381
tags = [
373382
"multi_and_single_gpu",
374383
"nomultivm", # TODO(b/170502145)
375-
"requires-mem:28g", # spawns multiple processes.
376384
],
377385
deps = [
378386
":hashing",
@@ -420,13 +428,15 @@ distribute_py_test(
420428
name = "index_lookup_distribution_test",
421429
srcs = ["index_lookup_distribution_test.py"],
422430
disable_mlir_bridge = False,
431+
env = {
432+
"CUDA_MODULE_LOADING": "LAZY",
433+
},
423434
main = "index_lookup_distribution_test.py",
424435
python_version = "PY3",
425436
shard_count = 4,
426437
tags = [
427438
"multi_and_single_gpu",
428439
"nomultivm", # TODO(b/170502145)
429-
"requires-mem:28g", # spawns multiple processes.
430440
],
431441
tpu_tags = ["no_oss"],
432442
deps = [
@@ -496,13 +506,15 @@ tf_py_test(
496506
distribute_py_test(
497507
name = "normalization_distribution_test",
498508
srcs = ["normalization_distribution_test.py"],
509+
env = {
510+
"CUDA_MODULE_LOADING": "LAZY",
511+
},
499512
main = "normalization_distribution_test.py",
500513
python_version = "PY3",
501514
shard_count = 8,
502515
tags = [
503516
"no_oss",
504517
"nomultivm", # TODO(b/170502145)
505-
"requires-mem:28g", # spawns multiple processes.
506518
],
507519
deps = [
508520
":normalization",
@@ -534,13 +546,15 @@ distribute_py_test(
534546
name = "text_vectorization_distribution_test",
535547
srcs = ["text_vectorization_distribution_test.py"],
536548
disable_mlir_bridge = False,
549+
env = {
550+
"CUDA_MODULE_LOADING": "LAZY",
551+
},
537552
main = "text_vectorization_distribution_test.py",
538553
python_version = "PY3",
539554
shard_count = 8,
540555
tags = [
541556
"multi_and_single_gpu",
542557
"nomultivm", # TODO(b/170502145)
543-
"requires-mem:28g", # spawns multiple processes.
544558
],
545559
tpu_tags = [
546560
"no_oss", # b/155502591

0 commit comments

Comments
 (0)