Skip to content

Commit

Permalink
Allow simplifier to import modules
Browse files Browse the repository at this point in the history
The simplifier can now parse an imported module (if the filename is 
static), simplify that module's AST and directly import static names 
from it. This allows imported inlineable functions to be inlined. Threw 
away the old module caching behaviour as I never trusted it anyway and 
it *should* be the case that module names can *always* be statically 
imported as modules aren't allowed to reference any dynamic names or 
(during import) state.
  • Loading branch information
jonathanhogg committed Sep 23, 2024
1 parent 63d537d commit c326938
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 44 deletions.
44 changes: 37 additions & 7 deletions src/flitter/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,42 @@ def read_bytes(self):
self._cache[key] = data
return data

def read_flitter_top(self):
key = 'flitter_top'
current_top = self._cache.get(key, False)
if self.check_unmodified() and current_top is not False:
return current_top
top = None
if self._mtime is None:
if current_top is False:
logger.warning("Program not found: {}", self._path)
elif current_top is not None:
logger.error("Program disappeared: {}", self._path)
else:
from .language.parser import parse, ParseError
try:
parse_time = -system_clock()
source = self._path.read_text(encoding='utf8')
top = parse(source)
parse_time += system_clock()
logger.debug("Parsed '{}' in {:.1f}ms", self._path, parse_time*1000)
except ParseError as exc:
logger.error("Error parsing {} at line {} column {}:\n{}",
self._path.name, exc.line, exc.column, exc.context)
except Exception as exc:
logger.opt(exception=exc).error("Error reading program: {}", self._path)
self._cache[key] = top
return top

def read_flitter_program(self, static=None, dynamic=None, simplify=True):
key = 'flitter', tuple(sorted(static.items())) if static else None, tuple(dynamic) if dynamic else (), simplify
key = 'flitter_program', tuple(sorted(static.items())) if static else None, tuple(dynamic) if dynamic else (), simplify
current_program = self._cache.get(key, False)
if self.check_unmodified() and current_program is not False:
return current_program
for path in current_program.top.dependencies if current_program is not None else ():
if not path.check_unmodified():
break
else:
return current_program
if self._mtime is None:
if current_program is False:
logger.warning("Program not found: {}", self._path)
Expand All @@ -123,18 +154,17 @@ def read_flitter_program(self, static=None, dynamic=None, simplify=True):
now = system_clock()
parse_time += now
simplify_time = -now
top = initial_top.simplify(static=static, dynamic=dynamic) if simplify else initial_top
top, context = initial_top.simplify(static=static, dynamic=dynamic, path=self, return_context=True) if simplify else (initial_top, None)
now = system_clock()
simplify_time += now
compile_time = -now
program = top.compile(initial_lnames=tuple(dynamic) if dynamic else ())
program = top.compile(initial_lnames=tuple(dynamic) if dynamic else (), initial_errors=context.errors if simplify else None)
program.set_top(top)
program.set_path(self)
program.use_simplifier(simplify)
compile_time += system_clock()
logger.debug("Read program: {}", self._path)
logger.debug("Compiled to {} instructions in {:.1f}/{:.1f}/{:.1f}ms",
len(program), parse_time*1000, simplify_time*1000, compile_time*1000)
logger.debug("Compiled '{}' to {} instructions in {:.1f}/{:.1f}/{:.1f}ms",
self._path, len(program), parse_time*1000, simplify_time*1000, compile_time*1000)
except ParseError as exc:
if current_program is None:
logger.error("Error parsing {} at line {} column {}:\n{}",
Expand Down
6 changes: 2 additions & 4 deletions src/flitter/engine/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(self, target_fps=60, screen=0, fullscreen=False, vsync=False, state
self.current_page = None
self.current_path = None
self._references = {}
self._modules = {}

def load_page(self, filename):
page_number = len(self.pages)
Expand Down Expand Up @@ -190,8 +189,7 @@ async def run(self):
logger.log(level, "Loaded page {}: {}", self.current_page, self.current_path)
run_program = current_program = program
self.handle_pragmas(program.pragmas, frame_time)
errors = set()
logs = set()
errors = logs = None
self.state_generation0 ^= self.state_generation1
self.state_generation1 = self.state_generation2
self.state_generation2 = set()
Expand All @@ -203,7 +201,7 @@ async def run(self):

if run_program is not None:
context = Context(names={key: Vector.coerce(value) for key, value in dynamic.items()},
state=self.state, references=self._references, modules=self._modules)
state=self.state, references=self._references)
run_program.run(context, record_stats=self.vm_stats)
else:
context = Context()
Expand Down
59 changes: 49 additions & 10 deletions src/flitter/language/tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from loguru import logger
from libc.stdint cimport int64_t

from .. import name_patch
from ..cache import SharedCache
from ..model cimport Vector, Node, Context, StateDict, null_, true_, false_, minusone_
from .vm cimport Program, Instruction, InstructionInt, InstructionVector, OpCode, static_builtins, dynamic_builtins

Expand Down Expand Up @@ -46,16 +47,18 @@ cdef bint sequence_pack(list expressions):


cdef class Expression:
def compile(self, tuple initial_lnames=(), bint log_errors=True):
def compile(self, tuple initial_lnames=(), set initial_errors=None, bint log_errors=True):
cdef Program program = Program.__new__(Program, initial_lnames)
if initial_errors:
program.compiler_errors.update(initial_errors)
self._compile(program, list(initial_lnames))
program.link()
if log_errors:
for error in program.compiler_errors:
logger.warning("Compiler error: {}", error)
return program

def simplify(self, StateDict state=None, dict static=None, dynamic=None, bint return_context=False):
def simplify(self, StateDict state=None, dict static=None, dynamic=None, Context parent=None, path=None, bint return_context=False):
cdef dict context_vars = {}
cdef str key
if static is not None:
Expand All @@ -65,6 +68,8 @@ cdef class Expression:
for key in dynamic:
context_vars[key] = None
cdef Context context = Context(state=state, names=context_vars)
context.path = path
context.parent = parent
cdef Expression expr = self
try:
expr = expr._simplify(context)
Expand All @@ -74,7 +79,7 @@ cdef class Expression:
return expr, context
else:
for error in context.errors:
logger.warning("Simplification error: {}", error)
logger.warning("Simplifier error: {}", error)
return expr

cdef void _compile(self, Program program, list lnames):
Expand All @@ -87,10 +92,12 @@ cdef class Expression:
cdef class Top(Expression):
cdef readonly tuple pragmas
cdef readonly Expression body
cdef readonly set dependencies

def __init__(self, tuple pragmas, Expression body):
def __init__(self, tuple pragmas, Expression body, set dependencies=None):
self.pragmas = pragmas
self.body = body
self.dependencies = dependencies if dependencies is not None else set()

cdef void _compile(self, Program program, list lnames):
cdef Binding binding
Expand All @@ -108,12 +115,12 @@ cdef class Top(Expression):

cdef Expression _simplify(self, Context context):
cdef Expression body = self.body._simplify(context)
if body is self.body:
if body is self.body and context.dependencies == self.dependencies:
return self
return Top(self.pragmas, body)
return Top(self.pragmas, body, self.dependencies ^ context.dependencies)

def __repr__(self):
return f'Top({self.pragmas!r}, {self.body!r})'
return f'Top({self.pragmas!r}, {self.body!r}, {self.dependencies!r})'


cdef class Export(Expression):
Expand Down Expand Up @@ -143,6 +150,8 @@ cdef class Export(Expression):
cdef dict static_exports = dict(self.static_exports) if self.static_exports else {}
cdef bint touched = False
for name, value in context.names.items():
if value is not None:
context.exports[name] = value
if isinstance(value, Vector) and (name not in static_exports or value != static_exports[name]):
static_exports[name] = value
touched = True
Expand Down Expand Up @@ -178,14 +187,44 @@ cdef class Import(Expression):
cdef Expression _simplify(self, Context context):
cdef str name
cdef Expression filename = self.filename._simplify(context)
cdef Top top
cdef Context import_context
cdef dict import_static_names = None
if isinstance(filename, Literal) and context.path is not None:
name = (<Literal>filename).value.as_string()
path = SharedCache.get_with_root(name, context.path)
import_context = context.parent
while import_context is not None:
if import_context.path is path:
context.errors.add(f"Circular import of '{name}'")
import_static_names = {name: null_ for name in self.names}
break
import_context = import_context.parent
else:
if (top := path.read_flitter_top()) is not None:
top, import_context = top.simplify(path=path, parent=context, return_context=True)
import_static_names = import_context.exports
context.errors.update(import_context.errors)
cdef dict let_names = {}
cdef dict saved = dict(context.names)
cdef list remaining = []
for name in self.names:
context.names[name] = None
expr = self.expr._simplify(context)
if import_static_names is not None and name in import_static_names:
let_names[name] = import_static_names[name]
context.dependencies.add(path)
else:
context.names[name] = None
remaining.append(name)
cdef Expression expr = self.expr
if let_names:
expr = Let(tuple(PolyBinding((name,), value if isinstance(value, Function) else Literal(value)) for name, value in let_names.items()), expr)
expr = expr._simplify(context)
context.names = saved
if not remaining:
return expr
if filename is self.filename and expr is self.expr:
return self
return Import(self.names, filename, expr)
return Import(tuple(remaining), filename, expr)

def __repr__(self):
return f'Import({self.names!r}, {self.filename!r}, {self.expr!r})'
Expand Down
5 changes: 1 addition & 4 deletions src/flitter/language/vm.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -618,13 +618,11 @@ cdef inline dict import_module(Context context, str filename, bint record_stats,
PySet_Add(context.errors, f"Circular import of '{filename}'")
return None
import_context = import_context.parent
if program in context.modules:
return context.modules[program]
context.errors.update(program.compiler_errors)
cdef VectorStack stack=context.stack, lnames=context.lnames
cdef int64_t stack_top=stack.top, lnames_top=lnames.top
import_context = Context.__new__(Context)
import_context.parent = context
import_context.modules = context.modules
import_context.errors = context.errors
import_context.logs = context.logs
import_context.state = StateDict()
Expand All @@ -642,7 +640,6 @@ cdef inline dict import_module(Context context, str filename, bint record_stats,
drop(stack, 1)
assert stack.top == stack_top, "Bad stack"
assert lnames.top == lnames_top, "Bad lnames"
context.modules[program] = import_context.exports
return import_context.exports


Expand Down
2 changes: 1 addition & 1 deletion src/flitter/model.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ cdef class Context:
cdef readonly Node root
cdef readonly object path
cdef readonly Context parent
cdef readonly dict modules
cdef readonly dict exports
cdef readonly set errors
cdef readonly set logs
cdef readonly dict references
cdef readonly object stack
cdef readonly object lnames
cdef readonly set dependencies
7 changes: 4 additions & 3 deletions src/flitter/model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2545,17 +2545,18 @@ cdef class DummyStateDict(StateDict):

cdef class Context:
def __init__(self, dict names=None, StateDict state=None, Node root=None,
object path=None, Context parent=None, dict references=None, dict modules=None,
dict exports=None, set errors=None, set logs=None, stack=None, lnames=None):
object path=None, Context parent=None, dict references=None,
dict exports=None, set errors=None, set logs=None, stack=None, lnames=None,
set dependencies=None):
self.names = names if names is not None else {}
self.state = state
self.root = root if root is not None else Node('root')
self.path = path
self.parent = parent
self.modules = modules if modules is not None else {}
self.exports = exports if exports is not None else {}
self.errors = errors if errors is not None else set()
self.logs = logs if logs is not None else set()
self.references = references
self.stack = stack
self.lnames = lnames
self.dependencies = dependencies if dependencies is not None else set()
55 changes: 47 additions & 8 deletions tests/test_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Flitter language simplifier unit tests
"""

from pathlib import Path
import tempfile
import unittest
import unittest.mock

Expand All @@ -20,18 +22,18 @@

class SimplifierTestCase(unittest.TestCase):
def assertSimplifiesTo(self, x, y, state=None, dynamic=None, static=None, with_errors=None):
xx, context = x.simplify(state=state, dynamic=dynamic, static=static, return_context=True)
xx, context = x.simplify(state=state, dynamic=dynamic, static=static, path='test.fl', return_context=True)
xxx = xx.simplify(state=state, dynamic=dynamic, static=static)
self.assertIs(xxx, xx, msg="Simplification not complete")
self.assertIs(xxx, xx, msg="Simplification not complete in one step")
self.assertEqual(repr(xx), repr(y))
self.assertEqual(context.errors, set() if with_errors is None else with_errors)
if with_errors:
with unittest.mock.patch('flitter.language.tree.logger') as mock_logger:
x.simplify(state=state, dynamic=dynamic, static=static)
x.simplify(state=state, dynamic=dynamic, static=static, path='test.fl')
errors = set()
for name, args, kwargs in mock_logger.warning.mock_calls:
self.assertEqual(len(args), 2)
self.assertEqual(args[0], "Simplification error: {}")
self.assertEqual(args[0], "Simplifier error: {}")
errors.add(args[1])
self.assertEqual(errors, with_errors)
if static is not None:
Expand Down Expand Up @@ -1089,10 +1091,47 @@ def test_true_condition_2_of_3(self):


class TestImport(SimplifierTestCase):
def test_recursive(self):
"""Imports are left alone except for the sub-expression being simplified"""
self.assertSimplifiesTo(Import(('x', 'y'), Name('m'), Name('z')), Import(('x', 'y'), Literal('module.fl'), Literal(5)),
static={'m': 'module.fl', 'z': 5})
def setUp(self):
self.test_module = Path(tempfile.mktemp('.fl'))
self.test_module.write_text("""
let x=5 y=10
""", encoding='utf8')
self.circular_module_a = Path(tempfile.mktemp('.fl'))
self.circular_module_b = Path(tempfile.mktemp('.fl'))
module_a = f"""
import y from {str(self.circular_module_b)!r}
let x=5
"""
module_b = f"""
import x from {str(self.circular_module_a)!r}
let y=10+(x or 3)
"""
self.circular_module_a.write_text(module_a, encoding='utf8')
self.circular_module_b.write_text(module_b, encoding='utf8')

def tearDown(self):
if self.test_module.exists():
self.test_module.unlink()
if self.circular_module_a.exists():
self.circular_module_a.unlink()
if self.circular_module_b.exists():
self.circular_module_b.unlink()

def test_dynamic(self):
"""Import module name and sub-expression are simplified"""
self.assertSimplifiesTo(Import(('x', 'y'), Name('m'), Name('z')), Import(('x', 'y'), Name("m'"), Literal(5)),
static={'m': Name("m'"), 'z': 5}, dynamic={"m'"})

def test_static(self):
with unittest.mock.patch('flitter.cache.logger'):
self.assertSimplifiesTo(Import(('x', 'y'), Literal(str(self.test_module)), Add(Name('x'), Name('y'))), Literal(15))

def test_recursive_import(self):
with unittest.mock.patch('flitter.cache.logger'):
self.assertSimplifiesTo(Import(('x', 'y'), Literal(str(self.circular_module_a)), Add(Name('x'), Name('y'))), Literal(18),
with_errors={f"Circular import of '{self.circular_module_a}'"})
self.assertSimplifiesTo(Import(('x', 'y'), Literal(str(self.circular_module_b)), Add(Name('x'), Name('y'))), Literal(20),
with_errors={f"Circular import of '{self.circular_module_b}'"})


class TestFunction(SimplifierTestCase):
Expand Down
Loading

0 comments on commit c326938

Please sign in to comment.