Skip to content

Commit

Permalink
Refactor AST handling of espresso visualizers
Browse files Browse the repository at this point in the history
Fixes a regression introduced by 1cc931f
  • Loading branch information
jngrad committed Jan 8, 2020
1 parent 5c299fb commit baa5f9d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 13 deletions.
36 changes: 24 additions & 12 deletions testsuite/scripts/importlib_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,31 +596,37 @@ def __init__(self):
"visualization_mayavi"}
self.namespace_visualizers = {
"espressomd." + x for x in self.visualizers}
self.visu_aliases = []
self.visu_linenos = []
self.visu_items = {}

def register_import(self, lineno, from_str, module_str, alias):
if lineno not in self.visu_items:
self.visu_items[lineno] = []
if from_str:
line = "from {} import {}".format(from_str, module_str)
else:
line = "import {}".format(module_str)
if alias:
line += " as {}".format(alias)
self.visu_items[lineno].append(line)

def visit_Import(self, node):
# get visualizer alias
for child in node.names:
if child.name in self.namespace_visualizers:
self.visu_aliases.append(child.asname or child.name)
self.visu_linenos.append(node.lineno)
self.register_import(node.lineno, None, child.name, child.asname)

def visit_ImportFrom(self, node):
if node.module in self.namespace_visualizers:
for child in node.names:
if child.name == "*":
raise ValueError("cannot use MagicMock() on a wildcard "
"import at line {}".format(node.lineno))
elif child.name in {"openGLLive", "mayaviLive"}:
self.visu_aliases.append(child.asname or child.name)
self.visu_linenos.append(node.lineno)
self.register_import(node.lineno, node.module, child.name, child.asname)
# get visualizer alias
if node.module == "espressomd":
for child in node.names:
if child.name in self.visualizers:
self.visu_aliases.append(child.asname or child.name)
self.visu_linenos.append(node.lineno)
self.register_import(node.lineno, node.module, child.name, child.asname)


def mock_es_visualization(code):
Expand Down Expand Up @@ -661,10 +667,16 @@ def check_for_deferred_ImportError(line, alias):
visitor = GetEspressomdVisualizerImports()
visitor.visit(ast.parse(protect_ipython_magics(code)))
lines = code.split("\n")
for alias, lineno in zip(visitor.visu_aliases, visitor.visu_linenos):
for lineno, imports in visitor.visu_items.items():
line = lines[lineno - 1]
checks = check_for_deferred_ImportError(line, alias)
lines[lineno - 1] = r_es_vis_mock.format(line, checks, alias)
indentation = line[:len(line) - len(line.lstrip())]
lines[lineno - 1] = ""
for import_str in imports:
alias = import_str.split()[-1]
checks = check_for_deferred_ImportError(import_str, alias)
import_str_new = "\n".join(indentation + x for x in
r_es_vis_mock.format(import_str, checks, alias).split("\n"))
lines[lineno - 1] += import_str_new

return "\n".join(lines)

Expand Down
31 changes: 30 additions & 1 deletion testsuite/scripts/test_importlib_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,29 @@ def test_mock_es_visualization(self):
"""
self.assertEqual(iw.mock_es_visualization(statement), expected[1:])

statement = "import espressomd.visualization, espressomd.visualization as test"
expected = """
try:
import espressomd.visualization
if hasattr(espressomd.visualization.mayaviLive, 'deferred_ImportError') or \\
hasattr(espressomd.visualization.openGLLive, 'deferred_ImportError'):
raise ImportError()
except ImportError:
from unittest.mock import MagicMock
import espressomd
espressomd.visualization = MagicMock()
try:
import espressomd.visualization as test
if hasattr(test.mayaviLive, 'deferred_ImportError') or \\
hasattr(test.openGLLive, 'deferred_ImportError'):
raise ImportError()
except ImportError:
from unittest.mock import MagicMock
import espressomd
test = MagicMock()
"""
self.assertEqual(iw.mock_es_visualization(statement), expected[1:])

statement = "from espressomd import visualization"
expected = """
try:
Expand Down Expand Up @@ -217,7 +240,13 @@ def test_mock_es_visualization(self):
statement = "from espressomd.visualization_mayavi import a as b, mayaviLive"
expected = """
try:
from espressomd.visualization_mayavi import a as b, mayaviLive
from espressomd.visualization_mayavi import a as b
except ImportError:
from unittest.mock import MagicMock
import espressomd
b = MagicMock()
try:
from espressomd.visualization_mayavi import mayaviLive
except ImportError:
from unittest.mock import MagicMock
import espressomd
Expand Down

0 comments on commit baa5f9d

Please sign in to comment.