From a9d6f66234b5dd2859a0dc116ef3e38a52d0f81d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xueguang=20Ma=20=E9=A9=AC=E9=9B=AA=E5=85=89?= Date: Sun, 2 May 2021 23:09:31 -0400 Subject: [PATCH] Fix unittests tests (#535) * fix tokenize tests * remove temp_dir --- pyserini/tokenize_json_collection.py | 2 +- tests/test_load_qrels.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pyserini/tokenize_json_collection.py b/pyserini/tokenize_json_collection.py index ee30890ae..849381f85 100644 --- a/pyserini/tokenize_json_collection.py +++ b/pyserini/tokenize_json_collection.py @@ -41,7 +41,7 @@ def main(args): else: tokenizer = T5Tokenizer.from_pretrained('castorini/doc2query-t5-base-msmarco') if (os.path.isdir(args.input)): - for i, inf in enumerate(os.listdir(args.input)): + for i, inf in enumerate(sorted(os.listdir(args.input))): if not os.path.isdir(args.output): os.mkdir(args.output) outf = os.path.join(args.output, 'docs{:02d}.json'.format(i)) diff --git a/tests/test_load_qrels.py b/tests/test_load_qrels.py index 1fad4fcce..aa45722c7 100644 --- a/tests/test_load_qrels.py +++ b/tests/test_load_qrels.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os +import shutil import unittest from pyserini import search @@ -26,6 +27,9 @@ def read_file_lines(path): class TestGetQrels(unittest.TestCase): + def setUp(self): + os.environ['PYSERINI_CACHE'] = 'temp_dir' + def test_robust04(self): qrels_path = search.get_qrels_file('robust04') lines = read_file_lines(qrels_path) @@ -247,6 +251,10 @@ def test_trec2019_bl(self): self.assertEqual(mid_line, "853 0 2444d88d62539b0b88dc919909cb9701 2") self.assertEqual(last_line, "885 0 fde80cb0-b4f0-11e2-bbf2-a6f9e9d79e19 0") + def tearDown(self): + if os.path.exists('temp_dir'): + shutil.rmtree('temp_dir') + if __name__ == '__main__': unittest.main()