Skip to content

Commit

Permalink
[Function optimization] add unittest for downloading with file-lock (#…
Browse files Browse the repository at this point in the history
…4972)

* add unittest for file lock

* add file
  • Loading branch information
wj-Mcat authored Mar 24, 2023
1 parent c436823 commit f54e80d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
5 changes: 3 additions & 2 deletions paddlenlp/utils/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):


def get_path_from_url_with_filelock(
url: str, root_dir: str, md5sum: Optional[str] = None, check_exist: bool = True
url: str, root_dir: str, md5sum: Optional[str] = None, check_exist: bool = True, timeout: float = -1
) -> str:
"""construct `get_path_from_url` for `model_utils` to enable downloading multiprocess-safe
Expand All @@ -148,6 +148,7 @@ def get_path_from_url_with_filelock(
root_dir (str): the local download path
md5sum (str, optional): md5sum string for file. Defaults to None.
check_exist (bool, optional): whether check the file is exist. Defaults to True.
timeout (int, optional): the timeout for downloading. Defaults to -1.
Returns:
str: the path of downloaded file
Expand All @@ -163,7 +164,7 @@ def get_path_from_url_with_filelock(

os.makedirs(os.path.dirname(lock_file_path), exist_ok=True)

with FileLock(lock_file_path):
with FileLock(lock_file_path, timeout=timeout):
result = get_path_from_url(url=url, root_dir=root_dir, md5sum=md5sum, check_exist=check_exist)
return result

Expand Down
60 changes: 60 additions & 0 deletions tests/utils/test_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import hashlib
import os
import unittest
from tempfile import TemporaryDirectory

from paddlenlp.utils.downloader import get_path_from_url_with_filelock


class LockFileTest(unittest.TestCase):
test_url = (
"https://bj.bcebos.com/paddlenlp/models/transformers/roformerv2/roformer_v2_chinese_char_small/vocab.txt"
)

def test_downloading_with_exist_file(self):

with TemporaryDirectory() as tempdir:
lock_file_name = hashlib.md5((self.test_url + tempdir).encode("utf-8")).hexdigest()
lock_file_path = os.path.join(tempdir, ".lock", lock_file_name)
os.makedirs(os.path.dirname(lock_file_path), exist_ok=True)

# create lock file
with open(lock_file_path, "w", encoding="utf-8") as f:
f.write("temp test")

# downloading with exist lock file
config_file = get_path_from_url_with_filelock(self.test_url, root_dir=tempdir)
self.assertIsNotNone(config_file)

def test_downloading_with_opened_exist_file(self):

with TemporaryDirectory() as tempdir:
lock_file_name = hashlib.md5((self.test_url + tempdir).encode("utf-8")).hexdigest()
lock_file_path = os.path.join(tempdir, ".lock", lock_file_name)
os.makedirs(os.path.dirname(lock_file_path), exist_ok=True)

# create lock file
with open(lock_file_path, "w", encoding="utf-8") as f:
f.write("temp test")

# downloading with opened lock file
open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC
_ = os.open(lock_file_path, open_mode)
config_file = get_path_from_url_with_filelock(self.test_url, root_dir=tempdir)
self.assertIsNotNone(config_file)

0 comments on commit f54e80d

Please sign in to comment.