Skip to content

Commit e6b811f

Browse files
authored
[testing] replace hardcoded paths to allow running tests from anywhere (huggingface#6523)
* [testing] replace hardcoded paths to allow running tests from anywhere * fix the merge conflict
1 parent 9d1b4db commit e6b811f

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

src/transformers/testing_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import os
23
import re
34
import shutil
@@ -144,6 +145,15 @@ def require_torch_and_cuda(test_case):
144145
return test_case
145146

146147

148+
def get_tests_dir():
149+
"""
150+
returns the full path to the `tests` dir, so that the tests can be invoked from anywhere
151+
"""
152+
# this function caller's __file__
153+
caller__file__ = inspect.stack()[1][1]
154+
return os.path.abspath(os.path.dirname(caller__file__))
155+
156+
147157
#
148158
# Helper functions for dealing with testing text outputs
149159
# The original code came from:

tests/test_tokenization_fast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
TransfoXLTokenizer,
1616
is_torch_available,
1717
)
18-
from transformers.testing_utils import require_torch
18+
from transformers.testing_utils import get_tests_dir, require_torch
1919
from transformers.tokenization_distilbert import DistilBertTokenizerFast
2020
from transformers.tokenization_openai import OpenAIGPTTokenizerFast
2121
from transformers.tokenization_roberta import RobertaTokenizerFast
@@ -42,7 +42,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
4242
TOKENIZERS_CLASSES = frozenset([])
4343

4444
def setUp(self) -> None:
45-
with open("tests/fixtures/sample_text.txt", encoding="utf-8") as f_data:
45+
with open(f"{get_tests_dir()}/fixtures/sample_text.txt", encoding="utf-8") as f_data:
4646
self._data = f_data.read().replace("\n\n", "\n").strip()
4747

4848
def test_all_tokenizers(self):

tests/test_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from transformers import AutoTokenizer, TrainingArguments, is_torch_available
7-
from transformers.testing_utils import require_torch
7+
from transformers.testing_utils import get_tests_dir, require_torch
88

99

1010
if is_torch_available():
@@ -20,7 +20,7 @@
2020
)
2121

2222

23-
PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
23+
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
2424

2525

2626
class RegressionDataset:
@@ -262,7 +262,7 @@ def test_trainer_eval_mrpc(self):
262262
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
263263
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
264264
data_args = GlueDataTrainingArguments(
265-
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
265+
task_name="mrpc", data_dir=f"{get_tests_dir()}/fixtures/tests_samples/MRPC", overwrite_cache=True
266266
)
267267
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
268268

0 commit comments

Comments
 (0)