Skip to content

Commit

Permalink
[Enhance] Update digit_version function. (open-mmlab#402)
Browse files Browse the repository at this point in the history
* Update digit_version

* Add digit_version unit test
  • Loading branch information
mzr1996 authored Aug 12, 2021
1 parent f8f1700 commit c0a4916
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 12 deletions.
53 changes: 42 additions & 11 deletions mmcls/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,49 @@
import warnings

import mmcv
from packaging.version import parse

from .version import __version__


def digit_version(version_str):
digit_version = []
for x in version_str.split('.'):
if x.isdigit():
digit_version.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
digit_version.append(int(patch_version[0]) - 1)
digit_version.append(int(patch_version[1]))
return digit_version
def digit_version(version_str: str, length: int = 4):
"""Convert a version string into a tuple of integers.
This method is usually used for comparing two versions. For pre-release
versions: alpha < beta < rc.
Args:
version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.
Returns:
tuple[int]: The version info in digits (integers).
"""
version = parse(version_str)
assert version.release, f'failed to parse version {version_str}'
release = list(version.release)
release = release[:length]
if len(release) < length:
release = release + [0] * (length - len(release))
if version.is_prerelease:
mapping = {'a': -3, 'b': -2, 'rc': -1}
val = -4
# version.pre can be None
if version.pre:
if version.pre[0] not in mapping:
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
'version checking may go wrong')
else:
val = mapping[version.pre[0]]
release.extend([val, version.pre[-1]])
else:
release.extend([val, 0])

elif version.is_postrelease:
release.extend([1, version.post])
else:
release.extend([0, 0])
return tuple(release)


mmcv_minimum_version = '1.3.8'
Expand All @@ -25,4 +56,4 @@ def digit_version(version_str):
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'

__all__ = ['__version__']
__all__ = ['__version__', 'digit_version']
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmcls
known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,numpy,onnxruntime,pytest,seaborn,torch,torchvision,ts
known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
20 changes: 20 additions & 0 deletions tests/test_utils/test_version_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from mmcls import digit_version


def test_digit_version():
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
assert digit_version('1.0') == digit_version('1.0.0')
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
assert digit_version('1.0.0a') < digit_version('1.0.0b')
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
assert digit_version('1.0.0') < digit_version('1.0.0post')
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)

0 comments on commit c0a4916

Please sign in to comment.