diff --git a/testsuite/scripts/importlib_wrapper.py b/testsuite/scripts/importlib_wrapper.py index 5a5229c3b8e..9f924a0e555 100644 --- a/testsuite/scripts/importlib_wrapper.py +++ b/testsuite/scripts/importlib_wrapper.py @@ -596,15 +596,24 @@ def __init__(self): "visualization_mayavi"} self.namespace_visualizers = { "espressomd." + x for x in self.visualizers} - self.visu_aliases = [] - self.visu_linenos = [] + self.visu_items = {} + + def register_import(self, lineno, from_str, module_str, alias): + if lineno not in self.visu_items: + self.visu_items[lineno] = [] + if from_str: + line = "from {} import {}".format(from_str, module_str) + else: + line = "import {}".format(module_str) + if alias: + line += " as {}".format(alias) + self.visu_items[lineno].append(line) def visit_Import(self, node): # get visualizer alias for child in node.names: if child.name in self.namespace_visualizers: - self.visu_aliases.append(child.asname or child.name) - self.visu_linenos.append(node.lineno) + self.register_import(node.lineno, None, child.name, child.asname) def visit_ImportFrom(self, node): if node.module in self.namespace_visualizers: @@ -612,15 +621,12 @@ def visit_ImportFrom(self, node): if child.name == "*": raise ValueError("cannot use MagicMock() on a wildcard " "import at line {}".format(node.lineno)) - elif child.name in {"openGLLive", "mayaviLive"}: - self.visu_aliases.append(child.asname or child.name) - self.visu_linenos.append(node.lineno) + self.register_import(node.lineno, node.module, child.name, child.asname) # get visualizer alias if node.module == "espressomd": for child in node.names: if child.name in self.visualizers: - self.visu_aliases.append(child.asname or child.name) - self.visu_linenos.append(node.lineno) + self.register_import(node.lineno, node.module, child.name, child.asname) def mock_es_visualization(code): @@ -661,10 +667,16 @@ def check_for_deferred_ImportError(line, alias): visitor = GetEspressomdVisualizerImports() visitor.visit(ast.parse(protect_ipython_magics(code))) lines = code.split("\n") - for alias, lineno in zip(visitor.visu_aliases, visitor.visu_linenos): + for lineno, imports in visitor.visu_items.items(): line = lines[lineno - 1] - checks = check_for_deferred_ImportError(line, alias) - lines[lineno - 1] = r_es_vis_mock.format(line, checks, alias) + indentation = line[:len(line) - len(line.lstrip())] + lines[lineno - 1] = "" + for import_str in imports: + alias = import_str.split()[-1] + checks = check_for_deferred_ImportError(import_str, alias) + import_str_new = "\n".join(indentation + x for x in + r_es_vis_mock.format(import_str, checks, alias).split("\n")) + lines[lineno - 1] += import_str_new return "\n".join(lines) diff --git a/testsuite/scripts/test_importlib_wrapper.py b/testsuite/scripts/test_importlib_wrapper.py index 7934eea0501..be97c16c2e2 100644 --- a/testsuite/scripts/test_importlib_wrapper.py +++ b/testsuite/scripts/test_importlib_wrapper.py @@ -142,6 +142,29 @@ def test_mock_es_visualization(self): """ self.assertEqual(iw.mock_es_visualization(statement), expected[1:]) + statement = "import espressomd.visualization, espressomd.visualization as test" + expected = """ +try: + import espressomd.visualization + if hasattr(espressomd.visualization.mayaviLive, 'deferred_ImportError') or \\ + hasattr(espressomd.visualization.openGLLive, 'deferred_ImportError'): + raise ImportError() +except ImportError: + from unittest.mock import MagicMock + import espressomd + espressomd.visualization = MagicMock() +try: + import espressomd.visualization as test + if hasattr(test.mayaviLive, 'deferred_ImportError') or \\ + hasattr(test.openGLLive, 'deferred_ImportError'): + raise ImportError() +except ImportError: + from unittest.mock import MagicMock + import espressomd + test = MagicMock() +""" + self.assertEqual(iw.mock_es_visualization(statement), expected[1:]) + statement = "from espressomd import visualization" expected = """ try: @@ -217,7 +240,13 @@ def test_mock_es_visualization(self): statement = "from espressomd.visualization_mayavi import a as b, mayaviLive" expected = """ try: - from espressomd.visualization_mayavi import a as b, mayaviLive + from espressomd.visualization_mayavi import a as b +except ImportError: + from unittest.mock import MagicMock + import espressomd + b = MagicMock() +try: + from espressomd.visualization_mayavi import mayaviLive except ImportError: from unittest.mock import MagicMock import espressomd