From c3269380bae288332e51e3bad40734a152b7d4f6 Mon Sep 17 00:00:00 2001 From: Jonathan Hogg Date: Mon, 23 Sep 2024 13:31:29 +0100 Subject: [PATCH] Allow simplifier to import modules The simplifier can now parse an imported module (if the filename is static), simplify that module's AST and directly import static names from it. This allows imported inlineable functions to be inlined. Threw away the old module caching behaviour as I never trusted it anyway and it *should* be the case that module names can *always* be statically imported as modules aren't allowed to reference any dynamic names or (during import) state. --- src/flitter/cache.py | 44 +++++++++++++++++++++----- src/flitter/engine/control.py | 6 ++-- src/flitter/language/tree.pyx | 59 +++++++++++++++++++++++++++++------ src/flitter/language/vm.pyx | 5 +-- src/flitter/model.pxd | 2 +- src/flitter/model.pyx | 7 +++-- tests/test_simplifier.py | 55 +++++++++++++++++++++++++++----- tests/test_vm.py | 10 ++---- 8 files changed, 144 insertions(+), 44 deletions(-) diff --git a/src/flitter/cache.py b/src/flitter/cache.py index 90575b3e..69527623 100644 --- a/src/flitter/cache.py +++ b/src/flitter/cache.py @@ -100,11 +100,42 @@ def read_bytes(self): self._cache[key] = data return data + def read_flitter_top(self): + key = 'flitter_top' + current_top = self._cache.get(key, False) + if self.check_unmodified() and current_top is not False: + return current_top + top = None + if self._mtime is None: + if current_top is False: + logger.warning("Program not found: {}", self._path) + elif current_top is not None: + logger.error("Program disappeared: {}", self._path) + else: + from .language.parser import parse, ParseError + try: + parse_time = -system_clock() + source = self._path.read_text(encoding='utf8') + top = parse(source) + parse_time += system_clock() + logger.debug("Parsed '{}' in {:.1f}ms", self._path, parse_time*1000) + except ParseError as exc: + logger.error("Error parsing {} at line {} column {}:\n{}", + self._path.name, exc.line, exc.column, exc.context) + except Exception as exc: + logger.opt(exception=exc).error("Error reading program: {}", self._path) + self._cache[key] = top + return top + def read_flitter_program(self, static=None, dynamic=None, simplify=True): - key = 'flitter', tuple(sorted(static.items())) if static else None, tuple(dynamic) if dynamic else (), simplify + key = 'flitter_program', tuple(sorted(static.items())) if static else None, tuple(dynamic) if dynamic else (), simplify current_program = self._cache.get(key, False) if self.check_unmodified() and current_program is not False: - return current_program + for path in current_program.top.dependencies if current_program is not None else (): + if not path.check_unmodified(): + break + else: + return current_program if self._mtime is None: if current_program is False: logger.warning("Program not found: {}", self._path) @@ -123,18 +154,17 @@ def read_flitter_program(self, static=None, dynamic=None, simplify=True): now = system_clock() parse_time += now simplify_time = -now - top = initial_top.simplify(static=static, dynamic=dynamic) if simplify else initial_top + top, context = initial_top.simplify(static=static, dynamic=dynamic, path=self, return_context=True) if simplify else (initial_top, None) now = system_clock() simplify_time += now compile_time = -now - program = top.compile(initial_lnames=tuple(dynamic) if dynamic else ()) + program = top.compile(initial_lnames=tuple(dynamic) if dynamic else (), initial_errors=context.errors if simplify else None) program.set_top(top) program.set_path(self) program.use_simplifier(simplify) compile_time += system_clock() - logger.debug("Read program: {}", self._path) - logger.debug("Compiled to {} instructions in {:.1f}/{:.1f}/{:.1f}ms", - len(program), parse_time*1000, simplify_time*1000, compile_time*1000) + logger.debug("Compiled '{}' to {} instructions in {:.1f}/{:.1f}/{:.1f}ms", + self._path, len(program), parse_time*1000, simplify_time*1000, compile_time*1000) except ParseError as exc: if current_program is None: logger.error("Error parsing {} at line {} column {}:\n{}", diff --git a/src/flitter/engine/control.py b/src/flitter/engine/control.py index ca4ee14a..e5d5dd0e 100644 --- a/src/flitter/engine/control.py +++ b/src/flitter/engine/control.py @@ -61,7 +61,6 @@ def __init__(self, target_fps=60, screen=0, fullscreen=False, vsync=False, state self.current_page = None self.current_path = None self._references = {} - self._modules = {} def load_page(self, filename): page_number = len(self.pages) @@ -190,8 +189,7 @@ async def run(self): logger.log(level, "Loaded page {}: {}", self.current_page, self.current_path) run_program = current_program = program self.handle_pragmas(program.pragmas, frame_time) - errors = set() - logs = set() + errors = logs = None self.state_generation0 ^= self.state_generation1 self.state_generation1 = self.state_generation2 self.state_generation2 = set() @@ -203,7 +201,7 @@ async def run(self): if run_program is not None: context = Context(names={key: Vector.coerce(value) for key, value in dynamic.items()}, - state=self.state, references=self._references, modules=self._modules) + state=self.state, references=self._references) run_program.run(context, record_stats=self.vm_stats) else: context = Context() diff --git a/src/flitter/language/tree.pyx b/src/flitter/language/tree.pyx index 8273b470..00b2b1fe 100644 --- a/src/flitter/language/tree.pyx +++ b/src/flitter/language/tree.pyx @@ -12,6 +12,7 @@ from loguru import logger from libc.stdint cimport int64_t from .. import name_patch +from ..cache import SharedCache from ..model cimport Vector, Node, Context, StateDict, null_, true_, false_, minusone_ from .vm cimport Program, Instruction, InstructionInt, InstructionVector, OpCode, static_builtins, dynamic_builtins @@ -46,8 +47,10 @@ cdef bint sequence_pack(list expressions): cdef class Expression: - def compile(self, tuple initial_lnames=(), bint log_errors=True): + def compile(self, tuple initial_lnames=(), set initial_errors=None, bint log_errors=True): cdef Program program = Program.__new__(Program, initial_lnames) + if initial_errors: + program.compiler_errors.update(initial_errors) self._compile(program, list(initial_lnames)) program.link() if log_errors: @@ -55,7 +58,7 @@ cdef class Expression: logger.warning("Compiler error: {}", error) return program - def simplify(self, StateDict state=None, dict static=None, dynamic=None, bint return_context=False): + def simplify(self, StateDict state=None, dict static=None, dynamic=None, Context parent=None, path=None, bint return_context=False): cdef dict context_vars = {} cdef str key if static is not None: @@ -65,6 +68,8 @@ cdef class Expression: for key in dynamic: context_vars[key] = None cdef Context context = Context(state=state, names=context_vars) + context.path = path + context.parent = parent cdef Expression expr = self try: expr = expr._simplify(context) @@ -74,7 +79,7 @@ cdef class Expression: return expr, context else: for error in context.errors: - logger.warning("Simplification error: {}", error) + logger.warning("Simplifier error: {}", error) return expr cdef void _compile(self, Program program, list lnames): @@ -87,10 +92,12 @@ cdef class Expression: cdef class Top(Expression): cdef readonly tuple pragmas cdef readonly Expression body + cdef readonly set dependencies - def __init__(self, tuple pragmas, Expression body): + def __init__(self, tuple pragmas, Expression body, set dependencies=None): self.pragmas = pragmas self.body = body + self.dependencies = dependencies if dependencies is not None else set() cdef void _compile(self, Program program, list lnames): cdef Binding binding @@ -108,12 +115,12 @@ cdef class Top(Expression): cdef Expression _simplify(self, Context context): cdef Expression body = self.body._simplify(context) - if body is self.body: + if body is self.body and context.dependencies == self.dependencies: return self - return Top(self.pragmas, body) + return Top(self.pragmas, body, self.dependencies ^ context.dependencies) def __repr__(self): - return f'Top({self.pragmas!r}, {self.body!r})' + return f'Top({self.pragmas!r}, {self.body!r}, {self.dependencies!r})' cdef class Export(Expression): @@ -143,6 +150,8 @@ cdef class Export(Expression): cdef dict static_exports = dict(self.static_exports) if self.static_exports else {} cdef bint touched = False for name, value in context.names.items(): + if value is not None: + context.exports[name] = value if isinstance(value, Vector) and (name not in static_exports or value != static_exports[name]): static_exports[name] = value touched = True @@ -178,14 +187,44 @@ cdef class Import(Expression): cdef Expression _simplify(self, Context context): cdef str name cdef Expression filename = self.filename._simplify(context) + cdef Top top + cdef Context import_context + cdef dict import_static_names = None + if isinstance(filename, Literal) and context.path is not None: + name = (filename).value.as_string() + path = SharedCache.get_with_root(name, context.path) + import_context = context.parent + while import_context is not None: + if import_context.path is path: + context.errors.add(f"Circular import of '{name}'") + import_static_names = {name: null_ for name in self.names} + break + import_context = import_context.parent + else: + if (top := path.read_flitter_top()) is not None: + top, import_context = top.simplify(path=path, parent=context, return_context=True) + import_static_names = import_context.exports + context.errors.update(import_context.errors) + cdef dict let_names = {} cdef dict saved = dict(context.names) + cdef list remaining = [] for name in self.names: - context.names[name] = None - expr = self.expr._simplify(context) + if import_static_names is not None and name in import_static_names: + let_names[name] = import_static_names[name] + context.dependencies.add(path) + else: + context.names[name] = None + remaining.append(name) + cdef Expression expr = self.expr + if let_names: + expr = Let(tuple(PolyBinding((name,), value if isinstance(value, Function) else Literal(value)) for name, value in let_names.items()), expr) + expr = expr._simplify(context) context.names = saved + if not remaining: + return expr if filename is self.filename and expr is self.expr: return self - return Import(self.names, filename, expr) + return Import(tuple(remaining), filename, expr) def __repr__(self): return f'Import({self.names!r}, {self.filename!r}, {self.expr!r})' diff --git a/src/flitter/language/vm.pyx b/src/flitter/language/vm.pyx index dfe627bc..c2040383 100644 --- a/src/flitter/language/vm.pyx +++ b/src/flitter/language/vm.pyx @@ -618,13 +618,11 @@ cdef inline dict import_module(Context context, str filename, bint record_stats, PySet_Add(context.errors, f"Circular import of '{filename}'") return None import_context = import_context.parent - if program in context.modules: - return context.modules[program] + context.errors.update(program.compiler_errors) cdef VectorStack stack=context.stack, lnames=context.lnames cdef int64_t stack_top=stack.top, lnames_top=lnames.top import_context = Context.__new__(Context) import_context.parent = context - import_context.modules = context.modules import_context.errors = context.errors import_context.logs = context.logs import_context.state = StateDict() @@ -642,7 +640,6 @@ cdef inline dict import_module(Context context, str filename, bint record_stats, drop(stack, 1) assert stack.top == stack_top, "Bad stack" assert lnames.top == lnames_top, "Bad lnames" - context.modules[program] = import_context.exports return import_context.exports diff --git a/src/flitter/model.pxd b/src/flitter/model.pxd index 07f6d478..1b5625e6 100644 --- a/src/flitter/model.pxd +++ b/src/flitter/model.pxd @@ -204,10 +204,10 @@ cdef class Context: cdef readonly Node root cdef readonly object path cdef readonly Context parent - cdef readonly dict modules cdef readonly dict exports cdef readonly set errors cdef readonly set logs cdef readonly dict references cdef readonly object stack cdef readonly object lnames + cdef readonly set dependencies diff --git a/src/flitter/model.pyx b/src/flitter/model.pyx index 86802a31..b32ba7c7 100644 --- a/src/flitter/model.pyx +++ b/src/flitter/model.pyx @@ -2545,17 +2545,18 @@ cdef class DummyStateDict(StateDict): cdef class Context: def __init__(self, dict names=None, StateDict state=None, Node root=None, - object path=None, Context parent=None, dict references=None, dict modules=None, - dict exports=None, set errors=None, set logs=None, stack=None, lnames=None): + object path=None, Context parent=None, dict references=None, + dict exports=None, set errors=None, set logs=None, stack=None, lnames=None, + set dependencies=None): self.names = names if names is not None else {} self.state = state self.root = root if root is not None else Node('root') self.path = path self.parent = parent - self.modules = modules if modules is not None else {} self.exports = exports if exports is not None else {} self.errors = errors if errors is not None else set() self.logs = logs if logs is not None else set() self.references = references self.stack = stack self.lnames = lnames + self.dependencies = dependencies if dependencies is not None else set() diff --git a/tests/test_simplifier.py b/tests/test_simplifier.py index cbb5b6a6..bd5fda3e 100644 --- a/tests/test_simplifier.py +++ b/tests/test_simplifier.py @@ -2,6 +2,8 @@ Flitter language simplifier unit tests """ +from pathlib import Path +import tempfile import unittest import unittest.mock @@ -20,18 +22,18 @@ class SimplifierTestCase(unittest.TestCase): def assertSimplifiesTo(self, x, y, state=None, dynamic=None, static=None, with_errors=None): - xx, context = x.simplify(state=state, dynamic=dynamic, static=static, return_context=True) + xx, context = x.simplify(state=state, dynamic=dynamic, static=static, path='test.fl', return_context=True) xxx = xx.simplify(state=state, dynamic=dynamic, static=static) - self.assertIs(xxx, xx, msg="Simplification not complete") + self.assertIs(xxx, xx, msg="Simplification not complete in one step") self.assertEqual(repr(xx), repr(y)) self.assertEqual(context.errors, set() if with_errors is None else with_errors) if with_errors: with unittest.mock.patch('flitter.language.tree.logger') as mock_logger: - x.simplify(state=state, dynamic=dynamic, static=static) + x.simplify(state=state, dynamic=dynamic, static=static, path='test.fl') errors = set() for name, args, kwargs in mock_logger.warning.mock_calls: self.assertEqual(len(args), 2) - self.assertEqual(args[0], "Simplification error: {}") + self.assertEqual(args[0], "Simplifier error: {}") errors.add(args[1]) self.assertEqual(errors, with_errors) if static is not None: @@ -1089,10 +1091,47 @@ def test_true_condition_2_of_3(self): class TestImport(SimplifierTestCase): - def test_recursive(self): - """Imports are left alone except for the sub-expression being simplified""" - self.assertSimplifiesTo(Import(('x', 'y'), Name('m'), Name('z')), Import(('x', 'y'), Literal('module.fl'), Literal(5)), - static={'m': 'module.fl', 'z': 5}) + def setUp(self): + self.test_module = Path(tempfile.mktemp('.fl')) + self.test_module.write_text(""" +let x=5 y=10 +""", encoding='utf8') + self.circular_module_a = Path(tempfile.mktemp('.fl')) + self.circular_module_b = Path(tempfile.mktemp('.fl')) + module_a = f""" +import y from {str(self.circular_module_b)!r} +let x=5 +""" + module_b = f""" +import x from {str(self.circular_module_a)!r} +let y=10+(x or 3) +""" + self.circular_module_a.write_text(module_a, encoding='utf8') + self.circular_module_b.write_text(module_b, encoding='utf8') + + def tearDown(self): + if self.test_module.exists(): + self.test_module.unlink() + if self.circular_module_a.exists(): + self.circular_module_a.unlink() + if self.circular_module_b.exists(): + self.circular_module_b.unlink() + + def test_dynamic(self): + """Import module name and sub-expression are simplified""" + self.assertSimplifiesTo(Import(('x', 'y'), Name('m'), Name('z')), Import(('x', 'y'), Name("m'"), Literal(5)), + static={'m': Name("m'"), 'z': 5}, dynamic={"m'"}) + + def test_static(self): + with unittest.mock.patch('flitter.cache.logger'): + self.assertSimplifiesTo(Import(('x', 'y'), Literal(str(self.test_module)), Add(Name('x'), Name('y'))), Literal(15)) + + def test_recursive_import(self): + with unittest.mock.patch('flitter.cache.logger'): + self.assertSimplifiesTo(Import(('x', 'y'), Literal(str(self.circular_module_a)), Add(Name('x'), Name('y'))), Literal(18), + with_errors={f"Circular import of '{self.circular_module_a}'"}) + self.assertSimplifiesTo(Import(('x', 'y'), Literal(str(self.circular_module_b)), Add(Name('x'), Name('y'))), Literal(20), + with_errors={f"Circular import of '{self.circular_module_b}'"}) class TestFunction(SimplifierTestCase): diff --git a/tests/test_vm.py b/tests/test_vm.py index b6431635..2035fe51 100644 --- a/tests/test_vm.py +++ b/tests/test_vm.py @@ -941,11 +941,11 @@ def setUp(self): self.circular_module_b = Path(tempfile.mktemp('.fl')) module_a = f""" import y from {str(self.circular_module_b)!r} -let x=5 +let x=5 + y """ module_b = f""" import x from {str(self.circular_module_a)!r} -let y=10 +let y=10 + x """ self.circular_module_a.write_text(module_a, encoding='utf8') self.circular_module_b.write_text(module_b, encoding='utf8') @@ -964,7 +964,6 @@ def test_import_one_name(self): self.program.local_load(0) stack = self.program.execute(self.context) self.assertEqual(stack, [5]) - self.assertEqual(len(self.context.modules), 1) def test_import_two_names(self): self.program.literal(str(self.test_module)) @@ -973,7 +972,6 @@ def test_import_two_names(self): self.program.local_load(0) stack = self.program.execute(self.context) self.assertEqual(stack, [5, 10]) - self.assertEqual(len(self.context.modules), 1) def test_import_twice(self): self.program.literal(str(self.test_module)) @@ -984,7 +982,6 @@ def test_import_twice(self): self.program.local_load(0) stack = self.program.execute(self.context) self.assertEqual(stack, [5, 10]) - self.assertEqual(len(self.context.modules), 1) def test_import_bad_name(self): self.program.literal(str(self.test_module)) @@ -992,7 +989,6 @@ def test_import_bad_name(self): self.program.local_load(0) stack = self.program.execute(self.context) self.assertEqual(stack, [null]) - self.assertEqual(len(self.context.modules), 1) self.assertEqual(self.context.errors, {f"Unable to import 'z' from '{self.test_module}'"}) def test_import_missing_module(self): @@ -1009,7 +1005,7 @@ def test_circular_import(self): self.program.local_load(1) self.program.local_load(0) stack = self.program.execute(self.context) - self.assertEqual(stack, [5, 10]) + self.assertEqual(stack, [null, null]) self.assertEqual(self.context.errors, {f"Circular import of '{self.circular_module_a}'"})