diff --git a/ChangeLog b/ChangeLog index 73cc908ebc..2d9266ee85 100644 --- a/ChangeLog +++ b/ChangeLog @@ -20,6 +20,7 @@ What's New in astroid 3.3.6? ============================ Release date: TBA +* Fix precedence of `path` arg in `modpath_from_file_with_callback` to be higher than `sys.path` What's New in astroid 3.3.5? diff --git a/astroid/modutils.py b/astroid/modutils.py index bf84b3b933..957be61cbe 100644 --- a/astroid/modutils.py +++ b/astroid/modutils.py @@ -278,7 +278,7 @@ def modpath_from_file_with_callback( filename = os.path.expanduser(_path_from_filename(filename)) paths_to_check = sys.path.copy() if path: - paths_to_check += path + paths_to_check = path + paths_to_check for pathname in itertools.chain( paths_to_check, map(_cache_normalize_path, paths_to_check) ): diff --git a/astroid/util.py b/astroid/util.py index 510b81cc13..3ddbc09040 100644 --- a/astroid/util.py +++ b/astroid/util.py @@ -5,7 +5,10 @@ from __future__ import annotations +import contextlib +import sys import warnings +from collections.abc import Iterator, Sequence from typing import TYPE_CHECKING, Any, Final, Literal from astroid.exceptions import InferenceError @@ -157,3 +160,26 @@ def safe_infer( return None # there is some kind of ambiguity except StopIteration: return value + + +def _augment_sys_path(additional_paths: Sequence[str]) -> list[str]: + original = list(sys.path) + changes = [] + seen = set() + for additional_path in additional_paths: + if additional_path not in seen: + changes.append(additional_path) + seen.add(additional_path) + + sys.path[:] = changes + sys.path + return original + + +@contextlib.contextmanager +def augmented_sys_path(additional_paths: Sequence[str]) -> Iterator[None]: + """Augment 'sys.path' by adding entries from additional_paths.""" + original = _augment_sys_path(additional_paths) + try: + yield + finally: + sys.path[:] = original diff --git a/tests/test_modutils.py b/tests/test_modutils.py index 85452b0f77..6b815d986c 100644 --- a/tests/test_modutils.py +++ b/tests/test_modutils.py @@ -22,6 +22,7 @@ from astroid import modutils from astroid.const import PY310_PLUS from astroid.interpreter._import import spec +from astroid.util import augmented_sys_path from . import resources @@ -175,6 +176,37 @@ def test_import_symlink_with_source_outside_of_path(self) -> None: finally: os.remove(linked_file_name) + def test_modpath_from_file_path_order(self) -> None: + """Test for ordering of paths. + The test does the following: + 1. Add a tmp directory to beginning of sys.path via augmented_sys_path + 2. Create a module file in sub directory of tmp directory + 3. If the sub directory is passed as additional directory, module name + should be relative to the subdirectory since additional directory has + higher precedence.""" + with tempfile.TemporaryDirectory() as tmp_dir: + with augmented_sys_path([tmp_dir]): + mod_name = "module" + sub_dirname = "subdir" + sub_dir = tmp_dir + "/" + sub_dirname + os.mkdir(sub_dir) + module_file = f"{sub_dir}/{mod_name}.py" + + with open(module_file, "w+", encoding="utf-8"): + pass + + # Without additional directory, return relative to tmp_dir + self.assertEqual( + modutils.modpath_from_file(module_file), [sub_dirname, mod_name] + ) + + # With sub directory as additional directory, return relative to + # sub directory + self.assertEqual( + modutils.modpath_from_file(f"{sub_dir}/{mod_name}.py", [sub_dir]), + [mod_name], + ) + def test_import_symlink_both_outside_of_path(self) -> None: with tempfile.NamedTemporaryFile() as tmpfile: linked_file_name = os.path.join(tempfile.gettempdir(), "symlinked_file.py")