diff --git a/src/reactive_python_engine.py b/src/reactive_python_engine.py index 326d9f7..d1980e3 100644 --- a/src/reactive_python_engine.py +++ b/src/reactive_python_engine.py @@ -52,6 +52,12 @@ class FunctionTempScope(): parent: "TempScope" variables: TempScopeVariables = dataclasses.field(default_factory=TempScopeVariables) + @dataclasses.dataclass + class WithTempScope(): + node: ast.AST + parent: "TempScope" + variables: TempScopeVariables = dataclasses.field(default_factory=TempScopeVariables) + @dataclasses.dataclass class ClassTempScope(): node: ast.AST @@ -95,6 +101,13 @@ class FunctionScope(): variables: Variables = dataclasses.field(default_factory=Variables) children: List["Scope"] = dataclasses.field(default_factory=list) + @dataclasses.dataclass + class WithScope(): + parent: "Scope" + with_node: ast.With + variables: Variables = dataclasses.field(default_factory=Variables) + children: List["Scope"] = dataclasses.field(default_factory=list) + @dataclasses.dataclass class ClassScope(): parent: "Scope" @@ -214,6 +227,14 @@ def visit_Lambda(self, func_node): ArgumentsVisitor(self, subscope).visit(func_node.args) visit_all(subscope, func_node.body) + def visit_With(self, func_node): + self.all_tempscope_data[func_node] = AstNodeData('', self.current_tempscope, False, True, False) + _temp_scope = WithTempScope(node=func_node, parent=self.current_tempscope) + subscope = TempScopeVisitor(_temp_scope, self.all_tempscope_data, self.class_binds_near) + visit_all(self, getattr(func_node, 'type_comment', None)) + visit_all(WithItemsVisitor(self, subscope), func_node.items) + visit_all(subscope, func_node.body) + def _visit_comprehension(self, targets, comprehensions, typ): del typ current_scope = self @@ -253,8 +274,6 @@ def visit_Nonlocal(self, nonlocal_node): for name in nonlocal_node.names: self.current_tempscope.variables.nonlocal_variables.add(name) - - class ArgumentsVisitor(ast.NodeVisitor): """ Util visitor to handle args only """ def __init__(self, expr_scope, arg_scope): @@ -271,7 +290,19 @@ def visit_arguments(self, node): def generic_visit(self, node): self.expr_scope.visit(node) + class WithItemsVisitor(ast.NodeVisitor): + """ Util visitor to WITH-scope args only. Let's see if I get this.. """ + def __init__(self, parent_scope, sub_scope): + self.parent_scope = parent_scope + self.sub_scope = sub_scope + + def visit_withitem(self, node): + if node.optional_vars: + self.sub_scope.visit(node.optional_vars) + self.parent_scope.visit(node.context_expr) + def generic_visit(self, node): + self.parent_scope.visit(node) class ScopeVisitor(ast.NodeVisitor): def __init__(self, all_tempscope_data): @@ -338,6 +369,14 @@ def visit_Lambda(self, node): _add_child(scope, self.node_to_corresponding_scope[node]) super().generic_visit(node) + def visit_With(self, node): + scope, is_input, is_output = self._get_scope(node) + if node not in self.node_to_corresponding_scope: + self.node_to_corresponding_scope[node] = WithScope(with_node=node, parent=scope) + _add_child(scope, self.node_to_corresponding_scope[node]) + super().generic_visit(node) + + def visit_DictComp(self, comp_node): targets=[comp_node.key, comp_node.value] comprehensions=comp_node.generators @@ -415,18 +454,31 @@ def find_tempscope(temp_scope, name, is_assignment, global_acceptable=True): elif type(temp_scope) is FunctionTempScope: if name in temp_scope.variables.global_variables: return get_global_scope(temp_scope) - if name in temp_scope.variables.nonlocal_variables: + elif name in temp_scope.variables.nonlocal_variables: return find_tempscope(get_parent_scope(temp_scope), name, is_assignment, global_acceptable=False) - if name in temp_scope.variables.assigned_variables: + elif name in temp_scope.variables.assigned_variables: return temp_scope - return find_tempscope(get_parent_scope(temp_scope), name, is_assignment, global_acceptable) + else: + return find_tempscope(get_parent_scope(temp_scope), name, is_assignment, global_acceptable) elif type(temp_scope) is ClassTempScope: if temp_scope.class_binds_near: # anything can be in a class scope return temp_scope + elif is_assignment: + return temp_scope + else: + return find_tempscope(temp_scope.parent, name, is_assignment, global_acceptable) + elif type(temp_scope) is WithTempScope: if is_assignment: + return find_tempscope(get_parent_scope(temp_scope), name, is_assignment, global_acceptable) + elif name in temp_scope.variables.global_variables: + return get_global_scope(temp_scope) + elif name in temp_scope.variables.nonlocal_variables: + return find_tempscope(get_parent_scope(temp_scope), name, is_assignment, global_acceptable=False) + elif name in temp_scope.variables.assigned_variables: return temp_scope - return find_tempscope(temp_scope.parent, name, is_assignment, global_acceptable) + else: + return find_tempscope(get_parent_scope(temp_scope), name, is_assignment, global_acceptable) else: raise RuntimeError("Unknown scope type") @@ -435,6 +487,8 @@ def get_global_scope(temp_scope): return temp_scope elif type(temp_scope) is FunctionTempScope: return get_global_scope(get_parent_scope(temp_scope)) + elif type(temp_scope) is WithTempScope: + return get_global_scope(get_parent_scope(temp_scope)) elif type(temp_scope) is ClassTempScope: return find_tempscope(get_parent_scope(temp_scope), temp_scope, is_assignment=False) # TODO: is_assignment is MADE UP ??? @@ -446,14 +500,10 @@ def get_parent_scope(intermediate_scope): def get_name(node): - if type(node) is ast.FunctionDef: - return node.name - elif type(node) is ast.AsyncFunctionDef: + if type(node) in [ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef]: return node.name elif type(node) is ast.Name: return node.id - elif type(node) is ast.ClassDef: - return node.name elif type(node) is ast.alias: return node.asname if node.asname is not None else node.name else: diff --git a/src/test/reactive_python_engine_tests.py b/src/test/reactive_python_engine_tests.py index aaef978..eede679 100644 --- a/src/test/reactive_python_engine_tests.py +++ b/src/test/reactive_python_engine_tests.py @@ -1,3 +1,52 @@ + + + + + + + + +from src.reactive_python_engine import reactive_python_dag_builder_utils__ + +# import "time": +import time +import ast + + +draw_dag = reactive_python_dag_builder_utils__.draw_dag +update_staleness_info_in_new_dag = reactive_python_dag_builder_utils__.update_staleness_info_in_new_dag +get_input_variables_for = reactive_python_dag_builder_utils__.get_input_variables_for +get_output_variables_for = reactive_python_dag_builder_utils__.get_output_variables_for +annotate = reactive_python_dag_builder_utils__.annotate + + + +code = """ +with a+b as f: + c = f + 4 + c = c + d + k = f + e +""" +tree = ast.parse(code).body[0] +inputs, errors = get_input_variables_for(annotate(tree)); inputs +assert inputs == {'a', 'b', 'd', 'e'} +outputs, errors = get_output_variables_for(annotate(tree)); outputs +assert outputs == {'c', 'k', 'f'} + + +d, e = 5, 6 +with open('package.json', 'r') as f: + # Get the first line: + k = f.readline() + c = d+e + +k +f + + + + + sample_python_code = """ import pandas as pd @@ -147,19 +196,6 @@ def test_dag_builder(): # %autoreload 2 -from reactive_python_engine import reactive_python_dag_builder_utils__ - -# import "time": -import time -import ast - - -draw_dag = reactive_python_dag_builder_utils__.draw_dag -update_staleness_info_in_new_dag = reactive_python_dag_builder_utils__.update_staleness_info_in_new_dag -get_input_variables_for = reactive_python_dag_builder_utils__.get_input_variables_for -get_output_variables_for = reactive_python_dag_builder_utils__.get_output_variables_for -annotate = reactive_python_dag_builder_utils__.annotate - code = """ # % [ @@ -187,6 +223,7 @@ def test_dag_builder(): current_time = time.time() + # test code = """ x = True, False, None @@ -989,7 +1026,24 @@ def f(): outputs, errors = get_output_variables_for(annotate(tree)); outputs assert outputs == {'a', 'b'} +code = """ +a, b = (1, c[d]) +""" +tree = ast.parse(code).body[0] +inputs, errors = get_input_variables_for(annotate(tree)); inputs +assert inputs == {'c', 'd'} +outputs, errors = get_output_variables_for(annotate(tree)); outputs +assert outputs == {'a', 'b'} + +code = """ +k.fields = (1, c[d]) +""" +tree = ast.parse(code).body[0] +inputs, errors = get_input_variables_for(annotate(tree)); inputs +assert inputs == {'c', 'd', 'k'} +outputs, errors = get_output_variables_for(annotate(tree)); outputs +assert outputs == {'k'} ######################## TODO/ Wrong stuff: