Skip to content

Commit

Permalink
Implement preferred_import_style
Browse files Browse the repository at this point in the history
This is a configuration option to select the import style that rope will
use when adding new imports.

Co-authored-by: Nicolas Zermati <nicoolas25@gmail.com>
  • Loading branch information
lieryan and nicoolas25 committed Apr 4, 2024
1 parent b6bd7cb commit 9ce1d9d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 1 deletion.
7 changes: 6 additions & 1 deletion rope/refactor/importutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import rope.base.codeanalyze
import rope.base.evaluate
from rope.base import libutils
from rope.base.prefs import get_preferred_import_style
from rope.base.prefs import ImportStyle
from rope.base.change import ChangeContents, ChangeSet
from rope.refactor import occurrences, rename
from rope.refactor.importutils import actions, module_imports
Expand Down Expand Up @@ -299,20 +301,23 @@ def get_module_imports(project, pymodule):


def add_import(project, pymodule, module_name, name=None):
preferred_import_style = get_preferred_import_style(project.prefs)
imports = get_module_imports(project, pymodule)
candidates = []
names = []
selected_import = None
# from mod import name
if name is not None:
from_import = FromImport(module_name, 0, [(name, None)])
if preferred_import_style == ImportStyle.from_global:
selected_import = from_import
names.append(name)
candidates.append(from_import)
# from pkg import mod
if "." in module_name:
pkg, mod = module_name.rsplit(".", 1)
from_import = FromImport(pkg, 0, [(mod, None)])
if project.prefs.get("prefer_module_from_imports"):
if preferred_import_style == ImportStyle.from_module:
selected_import = from_import
candidates.append(from_import)
if name:
Expand Down
69 changes: 69 additions & 0 deletions ropetest/refactor/movetest.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,75 @@ def a_function():
self.mod3.read(),
)

def test_adding_imports_preferred_import_style_is_normal_import(self) -> None:
self.project.prefs.imports.preferred_import_style = "normal-import"
self.origin_module.write(dedent("""\
class AClass(object):
pass
def a_function():
pass
"""))
self.mod3.write(dedent("""\
import origin_module
a_var = origin_module.AClass()
origin_module.a_function()"""))
# Move to destination_module_in_pkg which is in a different package
self._move(self.origin_module, self.origin_module.read().index("AClass") + 1, self.destination_module_in_pkg)
self.assertEqual(
dedent("""\
import origin_module
import pkg.destination_module_in_pkg
a_var = pkg.destination_module_in_pkg.AClass()
origin_module.a_function()"""),
self.mod3.read(),
)

def test_adding_imports_preferred_import_style_is_from_module(self) -> None:
self.project.prefs.imports.preferred_import_style = "from-module"
self.origin_module.write(dedent("""\
class AClass(object):
pass
def a_function():
pass
"""))
self.mod3.write(dedent("""\
import origin_module
a_var = origin_module.AClass()
origin_module.a_function()"""))
# Move to destination_module_in_pkg which is in a different package
self._move(self.origin_module, self.origin_module.read().index("AClass") + 1, self.destination_module_in_pkg)
self.assertEqual(
dedent("""\
import origin_module
from pkg import destination_module_in_pkg
a_var = destination_module_in_pkg.AClass()
origin_module.a_function()"""),
self.mod3.read(),
)

def test_adding_imports_preferred_import_style_is_from_global(self) -> None:
self.project.prefs.imports.preferred_import_style = "from-global"
self.origin_module.write(dedent("""\
class AClass(object):
pass
def a_function():
pass
"""))
self.mod3.write(dedent("""\
import origin_module
a_var = origin_module.AClass()
origin_module.a_function()"""))
# Move to destination_module_in_pkg which is in a different package
self._move(self.origin_module, self.origin_module.read().index("AClass") + 1, self.destination_module_in_pkg)
self.assertEqual(
dedent("""\
import origin_module
from pkg.destination_module_in_pkg import AClass
a_var = AClass()
origin_module.a_function()"""),
self.mod3.read(),
)

def test_adding_imports_noprefer_from_module(self) -> None:
self.project.prefs["prefer_module_from_imports"] = False
self.origin_module.write(dedent("""\
Expand Down

0 comments on commit 9ce1d9d

Please sign in to comment.