Skip to content

Commit 5786519

Browse files
ydshiehnovice03
authored andcommitted
Fix Wav2Vec2 CI OOM (huggingface#24190)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 0131e27 commit 5786519

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

tests/models/wav2vec2/test_modeling_tf_wav2vec2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import copy
20+
import gc
2021
import glob
2122
import inspect
2223
import math
@@ -709,6 +710,11 @@ def test_compute_mask_indices_overlap(self):
709710
@require_tf
710711
@slow
711712
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
713+
def tearDown(self):
714+
super().tearDown()
715+
# clean-up as much as possible GPU memory occupied by PyTorch
716+
gc.collect()
717+
712718
def _load_datasamples(self, num_samples):
713719
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
714720
# automatic decoding with librispeech

tests/models/wav2vec2/test_modeling_wav2vec2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
""" Testing suite for the PyTorch Wav2Vec2 model. """
1616

17+
import gc
1718
import math
1819
import multiprocessing
1920
import os
@@ -1374,6 +1375,12 @@ def test_sample_negatives_with_mask(self):
13741375
@require_soundfile
13751376
@slow
13761377
class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
1378+
def tearDown(self):
1379+
super().tearDown()
1380+
# clean-up as much as possible GPU memory occupied by PyTorch
1381+
gc.collect()
1382+
torch.cuda.empty_cache()
1383+
13771384
def _load_datasamples(self, num_samples):
13781385
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
13791386
# automatic decoding with librispeech

0 commit comments

Comments
 (0)