Skip to content

Commit

Permalink
Improve processing of Jupyter notebook cells
Browse files Browse the repository at this point in the history
Rewrite most of the regular expressions as AST node visitors to
improve the reliability of the importlib_wrapper. Better support
parsing of Jupyter cells containing matplotlib import statements,
which can lead to plots not being displayed in the HTML output.
  • Loading branch information
jngrad committed Jan 7, 2020
1 parent 7f73967 commit 1cc931f
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -608,15 +608,8 @@
"source": [
"import matplotlib.pyplot as plt\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"%matplotlib notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib notebook\n",
"\n",
"trajectory_file = 'posVsTime.dat'\n",
"trajectory = np.loadtxt(trajectory_file)[:,1:4]\n",
"# optional: trajectory smoothing with a running average\n",
Expand Down
63 changes: 45 additions & 18 deletions doc/tutorials/html_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from nbconvert.preprocessors import ExecutePreprocessor
import re
import os
import ast
import sys
import uuid
import argparse
sys.path.append('@CMAKE_SOURCE_DIR@/testsuite/scripts')
from importlib_wrapper import substitute_variable_values, mock_es_visualization
import importlib_wrapper as iw

parser = argparse.ArgumentParser(description='Process IPython notebooks.')
parser.add_argument('--input', type=str,
Expand Down Expand Up @@ -71,30 +72,56 @@ def set_code_cells(nb, new_cells):
code = re.sub('^(#\n)+', '', code.replace(m.group(0), ''), re.M)
# strip first component in relative paths
code = re.sub('(?<=[\'\"])\.\./', './', code)
# if matplotlib is used in this script, split cell to keep the import
# statement separate and avoid a know bug in the Jupyter backend which
# causes the plot object to be represented as a string instead of a
# canvas when created in the cell where matplotlib is imported
# (https://github.com/jupyter/notebook/issues/3523)
if 'import matplotlib' in code:
cells_code = re.split('^((?:|.*\n)import matplotlib.*?)\n', code,
maxsplit=1, flags=re.DOTALL)[1:]
else:
cells_code = [code]
# create new cells
cell_md = nbformat.v4.new_markdown_cell(source='Solution from ' + filepath)
filename = os.path.relpath(filepath)
if len(filename) > len(filepath):
filename = filepath
cell_md = nbformat.v4.new_markdown_cell(source='Solution from ' + filename)
nb['cells'].append(cell_md)
for cell_code in cells_code:
cell_code = nbformat.v4.new_code_cell(source=cell_code.strip())
nb['cells'].append(cell_code)
cell_code = nbformat.v4.new_code_cell(source=code.strip())
nb['cells'].append(cell_code)


# disable plot interactivity
for i in range(len(nb['cells'])):
cell = nb['cells'][i]
if cell['cell_type'] == 'code' and 'matplotlib' in cell['source']:
cell['source'] = re.sub('^%matplotlib +notebook', '%matplotlib inline',
cell['source'], flags=re.M)


# if matplotlib is used in this script, split cell to keep the import
# statement separate and avoid a know bug in the Jupyter backend which
# causes the plot object to be represented as a string instead of a
# canvas when created in the cell where matplotlib is imported for the
# first time (https://github.com/jupyter/notebook/issues/3523)
for i in range(len(nb['cells'])):
cell = nb['cells'][i]
if cell['cell_type'] == 'code' and 'matplotlib' in cell['source']:
code = iw.protect_ipython_magics(cell['source'])
# split cells after matplotlib imports
mapping = iw.delimit_statements(code)
tree = ast.parse(code)
visitor = iw.GetMatplotlibImports()
visitor.visit(tree)
if visitor.matplotlib_first:
code = iw.deprotect_ipython_magics(code)
lines = code.split('\n')
lineno_end = mapping[visitor.matplotlib_first]
split_code = '\n'.join(lines[lineno_end:]).lstrip('\n')
new_cell = nbformat.v4.new_code_cell(source=split_code)
nb['cells'].insert(i + 1, new_cell)
lines = lines[:lineno_end]
nb['cells'][i]['source'] = '\n'.join(lines).rstrip('\n')
break

# substitute global variables and disable OpenGL/Mayavi GUI
cell_separator = '\n##{}\n'.format(uuid.uuid4().hex)
src = cell_separator.join(get_code_cells(nb))
parameters = dict(x.split('=', 1) for x in new_values)
src = substitute_variable_values(src, strings_as_is=True, keep_original=False,
**parameters)
src_no_gui = mock_es_visualization(src)
src = iw.substitute_variable_values(src, strings_as_is=True,
keep_original=False, **parameters)
src_no_gui = iw.mock_es_visualization(src)

# update notebook with new code
set_code_cells(nb, src_no_gui.split(cell_separator))
Expand Down
188 changes: 135 additions & 53 deletions testsuite/scripts/importlib_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

import re
import os
import io
import sys
import ast
import tokenize
import unittest
import importlib
import espressomd
Expand Down Expand Up @@ -207,24 +210,86 @@ def set_random_seeds(code):
return code


def delimit_statements(code):
"""
For every Python statement, map the line number where it starts to the
line number where it ends.
"""
statements = []
statement_start = None
for tok in tokenize.tokenize(io.BytesIO(code.encode("utf-8")).readline):
if tok.exact_type == tokenize.ENDMARKER:
break
elif tok.exact_type == tokenize.ENCODING:
pass
elif tok.start == tok.end:
pass
elif tok.exact_type == tokenize.NEWLINE or tok.exact_type == tokenize.NL and prev_tok.exact_type == tokenize.COMMENT:
statements.append((statement_start, tok.start[0]))
statement_start = None
elif tok.exact_type == tokenize.NL:
pass
elif statement_start is None:
statement_start = tok.start[0]
prev_tok = tok
return dict(statements)


def protect_ipython_magics(code):
return re.sub("^(%+)(?=[a-z])", "#_IPYTHON_MAGIC_\g<1>", code, flags=re.M)


def deprotect_ipython_magics(code):
return re.sub("^#_IPYTHON_MAGIC_(%+)(?=[a-z])", "\g<1>", code, flags=re.M)


class GetMatplotlibImports(ast.NodeVisitor):
"""
Find all line numbers where ``matplotlib`` is imported.
"""

def __init__(self):
self.matplotlib_first = None
self.matplotlib_aliases = []
self.pyplot_aliases = []

def visit_Import(self, node):
# get line number of the first matplotlib import
for child in node.names:
if child.name.split(".")[0] == "matplotlib":
self.matplotlib_first = self.matplotlib_first or node.lineno
# get matplotlib aliases
for child in node.names:
if child.name == "matplotlib":
self.matplotlib_aliases.append(child.asname or child.name)
# get pyplot aliases
for child in node.names:
if child.name == "matplotlib.pyplot":
self.pyplot_aliases.append(child.asname or child.name)
elif child.name == "matplotlib":
name = (child.asname or "matplotlib") + ".pyplot"
self.pyplot_aliases.append(name)

def visit_ImportFrom(self, node):
# get line number of the first matplotlib import
if node.module.split(".")[0] == "matplotlib":
self.matplotlib_first = self.matplotlib_first or node.lineno
# get pyplot aliases
for child in node.names:
if node.module == "matplotlib" and child.name == "pyplot":
self.pyplot_aliases.append(child.asname or child.name)


def disable_matplotlib_gui(code):
"""
Use the matplotlib Agg backend (no GUI).
"""
# find under which name matplotlib was imported
re_mpl_aliases = [
re.compile(r"^[\t\ ]*import[\t\ ]+(matplotlib)[\t\ ]*$", re.M),
re.compile(r"^[\t\ ]*import[\t\ ]+matplotlib[\t\ ]+as[\t\ ]+([^\s;]+)",
re.M)]
aliases_mpl = set(x for re_m in re_mpl_aliases for x in re_m.findall(code))
# find under which name pyplot was imported
re_plt_aliases = [re.compile(pat, flags=re.M) for pat in [
r"^[\t\ ]*import[\t\ ]+(matplotlib\.pyplot)[\t\ ]*(?=$|#)",
r"^[\t\ ]*import[\t\ ]+matplotlib\.pyplot[\t\ ]+as[\t\ ]+([^\s;]+)[\t\ ]*(?=$|#)",
r"^[\t\ ]*from[\t\ ]+matplotlib[\t\ ]+import[\t\ ]+(pyplot)[\t\ ]*(?=$|#)",
r"^[\t\ ]*from[\t\ ]+matplotlib[\t\ ]+import[\t\ ]+pyplot"
r"[\t\ ]+as[\t\ ]+([^\s;]+)[\t\ ]*(?=$|#)"]]
aliases_plt = set(x for re_p in re_plt_aliases for x in re_p.findall(code))
visitor = GetMatplotlibImports()
visitor.visit(ast.parse(protect_ipython_magics(code)))
# find under which names matplotlib is accessible
aliases_mpl = set(visitor.matplotlib_aliases)
# find under which names pyplot is accessible
aliases_plt = set(visitor.pyplot_aliases)
# remove any custom backend
for alias in aliases_mpl:
code = re.sub(r"^[\t\ ]*" + alias + r"\.use\(([\"']+).+?\1[\t\ ]*\)",
Expand All @@ -242,24 +307,56 @@ def disable_matplotlib_gui(code):
return code


class GetEspressomdVisualizerImports(ast.NodeVisitor):
"""
Find line numbers and aliases of imported ESPResSo visualizers.
"""

def __init__(self):
self.visualizers = {
"visualization",
"visualization_opengl",
"visualization_mayavi"}
self.namespace_visualizers = {
"espressomd." + x for x in self.visualizers}
self.visu_aliases = []
self.visu_linenos = []

def visit_Import(self, node):
# get visualizer alias
for child in node.names:
if child.name in self.namespace_visualizers:
self.visu_aliases.append(child.asname or child.name)
self.visu_linenos.append(node.lineno)

def visit_ImportFrom(self, node):
if node.module in self.namespace_visualizers:
for child in node.names:
if child.name == "*":
raise ValueError("cannot use MagicMock() on a wildcard "
"import at line {}".format(node.lineno))
elif child.name in {"openGLLive", "mayaviLive"}:
self.visu_aliases.append(child.asname or child.name)
self.visu_linenos.append(node.lineno)
# get visualizer alias
if node.module == "espressomd":
for child in node.names:
if child.name in self.visualizers:
self.visu_aliases.append(child.asname or child.name)
self.visu_linenos.append(node.lineno)


def mock_es_visualization(code):
"""
Replace `import espressomd.visualization_<backend>` by a `MagicMock()` when
the visualization module is not installed, by catching the `ImportError()`
exception. Please note that `espressomd.visualization` is deferring the
exception, thus requiring additional checks. Import aliases are supported,
however please don't use `from espressomd.visualization import *` because
it hides the namespace of classes to be mocked.
Replace ``import espressomd.visualization_<backend>`` by a ``MagicMock()``
when the visualization module is unavailable, by catching the
``ImportError()`` exception. Please note that ``espressomd.visualization``
is deferring the exception, thus requiring additional checks.
Import aliases are supported, however please don't use
``from espressomd.visualization import *`` because it hides the namespace
of classes to be mocked.
"""
# consider all legal import statements in Python3
# (the ordering follows regex precedence rules)
re_es_vis_import = re.compile(r"""
^from\ espressomd\ import\ (?:visualization(?:_opengl|_mayavi)?)\ as\ (\S+)
|^from\ espressomd\ import\ (visualization(?:_opengl|_mayavi)?)
|^from\ espressomd\.visualization(?:_opengl|_mayavi)?\ import\ ([^\n]+)
|^import\ espressomd\.visualization(?:_opengl|_mayavi)?\ as\ (\S+)
|^import\ (espressomd\.visualization(?:_opengl|_mayavi)?)
""".replace(r"\ ", r"[\t\ ]+"), re.VERBOSE | re.M)
# replacement template
r_es_vis_mock = r"""
try:
Expand All @@ -269,12 +366,6 @@ def mock_es_visualization(code):
import espressomd
{2} = MagicMock()
""".lstrip()
# cannot handle "from espressomd.visualization import *"
re_es_vis_import_namespace = re.compile(
r"^from\ espressomd\.visualization(?:_opengl|_mayavi)?\ import\ \*"
.replace(r"\ ", r"[\t\ ]+"), re.M)
m = re_es_vis_import_namespace.search(code)
assert m is None, "cannot use MagicMock() at line '" + m.group(0) + "'"

def check_for_deferred_ImportError(line, alias):
if "_opengl" not in line and "_mayavi" not in line:
Expand All @@ -290,24 +381,15 @@ def check_for_deferred_ImportError(line, alias):
else:
return ""

def substitution_es_vis_import(m):
aliases = [x for x in m.groups() if x is not None][0].split(',')
guards = []
for alias in aliases:
line = m.group(0)
if len(aliases) >= 2 and 'from espressomd.visualization' in line:
line = line.split('import')[0] + 'import ' + alias.strip()
if ' as ' in alias:
alias = alias.split(' as ')[1]
alias = alias.strip()
checks = check_for_deferred_ImportError(line, alias)
s = r_es_vis_mock.format(line, checks, alias)
guards.append(s)
return '\n'.join(guards)

# handle deferred ImportError
code = re_es_vis_import.sub(substitution_es_vis_import, code)
return code
visitor = GetEspressomdVisualizerImports()
visitor.visit(ast.parse(protect_ipython_magics(code)))
lines = code.split("\n")
for alias, lineno in zip(visitor.visu_aliases, visitor.visu_linenos):
line = lines[lineno - 1]
checks = check_for_deferred_ImportError(line, alias)
lines[lineno - 1] = r_es_vis_mock.format(line, checks, alias)

return "\n".join(lines)


def skip_future_imports_dependency(filepath):
Expand Down
Loading

0 comments on commit 1cc931f

Please sign in to comment.