diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 3f35ffe0f09ab1..8ceee24979b351 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Dict +import packaging import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info @@ -33,15 +34,17 @@ logger = logging.getLogger(__name__) -try: - pkg = "pytorch_lightning" - min_ver = "1.0.4" - pkg_resources.require(f"{pkg}>={min_ver}") -except pkg_resources.VersionConflict: - logger.warning( - f"{pkg}>={min_ver} is required for a normal functioning of this module, but found {pkg}=={pkg_resources.get_distribution(pkg).version}. Try pip install -r examples/requirements.txt" - ) +def require_min_ver(pkg, min_ver): + got_ver = pkg_resources.get_distribution(pkg).version + if packaging.version.parse(got_ver) < packaging.version.parse(min_ver): + logger.warning( + f"{pkg}>={min_ver} is required for a normal functioning of this module, but found {pkg}=={got_ver}. " + "Try: pip install -r examples/requirements.txt" + ) + + +require_min_ver("pytorch_lightning", "1.0.4") MODEL_MODES = { "base": AutoModel,