Skip to content

Commit

Permalink
[ScopeProvider][optimization] use batch set union in infer_accesses
Browse files Browse the repository at this point in the history
  • Loading branch information
jimmylai committed Mar 3, 2020
1 parent 08a89f1 commit 24878a1
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 37 deletions.
38 changes: 26 additions & 12 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
MutableMapping,
Optional,
Set,
Tuple,
Type,
Union,
)
Expand Down Expand Up @@ -78,6 +77,9 @@ def referents(self) -> Collection["BaseAssignment"]:
def record_assignment(self, assignment: "BaseAssignment") -> None:
self.__assignments.add(assignment)

def record_assignments(self, assignments: Set["BaseAssignment"]) -> None:
self.__assignments |= assignments


class BaseAssignment(abc.ABC):
"""Abstract base class of :class:`Assignment` and :class:`BuitinAssignment`."""
Expand All @@ -97,6 +99,9 @@ def __init__(self, name: str, scope: "Scope") -> None:
def record_access(self, access: Access) -> None:
self.__accesses.add(access)

def record_accesses(self, accesses: Set[Access]) -> None:
self.__accesses |= accesses

@property
def references(self) -> Collection[Access]:
"""Return all accesses of the assignment."""
Expand Down Expand Up @@ -308,7 +313,7 @@ def record_assignment(self, name: str, node: cst.CSTNode) -> None:
def record_access(self, name: str, access: Access) -> None:
self._accesses[name].add(access)

def _getitem_from_self_or_parent(self, name: str) -> Tuple[BaseAssignment, ...]:
def _getitem_from_self_or_parent(self, name: str) -> Set[BaseAssignment]:
"""Overridden by ClassScope to hide it's assignments from child scopes."""
return self[name]

Expand All @@ -321,7 +326,7 @@ def __contains__(self, name: str) -> bool:
return len(self[name]) > 0

@abc.abstractmethod
def __getitem__(self, name: str) -> Tuple[BaseAssignment, ...]:
def __getitem__(self, name: str) -> Set[BaseAssignment]:
"""
Get assignments given a name str by ``scope[name]``.
Expand Down Expand Up @@ -356,7 +361,7 @@ def __getitem__(self, name: str) -> Tuple[BaseAssignment, ...]:
As a result, instead of returning a single declaration,
we're forced to return a collection of all of the assignments we think could have
defined a given name by the time a piece of code is executed.
For the above example, value would resolve to a tuple of both assignments.
For the above example, value would resolve to a set of both assignments.
"""
...

Expand Down Expand Up @@ -446,13 +451,13 @@ def __init__(self) -> None:
self.globals: Scope = self # must be defined before Scope.__init__ is called
super().__init__(parent=self)

def __getitem__(self, name: str) -> Tuple[BaseAssignment, ...]:
def __getitem__(self, name: str) -> Set[BaseAssignment]:
if hasattr(builtins, name):
if not any(
isinstance(i, BuiltinAssignment) for i in self._assignments[name]
):
self._assignments[name].add(BuiltinAssignment(name, self))
return tuple(self._assignments[name])
return self._assignments[name]

def record_global_overwrite(self, name: str) -> None:
pass
Expand Down Expand Up @@ -490,11 +495,11 @@ def record_assignment(self, name: str, node: cst.CSTNode) -> None:
else:
super().record_assignment(name, node)

def __getitem__(self, name: str) -> Tuple[BaseAssignment, ...]:
def __getitem__(self, name: str) -> Set[BaseAssignment]:
if name in self._scope_overwrites:
return self._scope_overwrites[name]._getitem_from_self_or_parent(name)
if name in self._assignments:
return tuple(self._assignments[name])
return self._assignments[name]
else:
return self.parent._getitem_from_self_or_parent(name)

Expand Down Expand Up @@ -531,7 +536,7 @@ def inner_fn():
"""
self.parent._record_assignment_as_parent(name, node)

def _getitem_from_self_or_parent(self, name: str) -> Tuple[BaseAssignment, ...]:
def _getitem_from_self_or_parent(self, name: str) -> Set[BaseAssignment]:
"""
Class variables are only accessible using ClassName.attribute, cls.attribute, or
self.attribute in child scopes. They cannot be accessed with their bare names.
Expand Down Expand Up @@ -755,10 +760,19 @@ def _visit_comp_alike(
return False

def infer_accesses(self) -> None:
# Aggregate access with the same name and batch add with set union as an optimization.
# In worst case, all accesses (m) and assignments (n) refer to the same name,
# the time complexity is O(m x n), this optimizes it as O(m + n).
scope_name_accesses = defaultdict(set)
for access in self.__deferred_accesses:
for assignment in access.scope[access.node.value]:
assignment.record_access(access)
access.record_assignment(assignment)
name = access.node.value
scope_name_accesses[(access.scope, name)].add(access)
access.record_assignments(access.scope[name])

for (scope, name), accesses in scope_name_accesses.items():
for assignment in scope[name]:
assignment.record_accesses(accesses)

self.__deferred_accesses = []

def on_leave(self, original_node: cst.CSTNode) -> None:
Expand Down
49 changes: 24 additions & 25 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_not_in_scope(self) -> None:
"""
)
global_scope = scopes[m]
self.assertEqual(global_scope["not_in_scope"], ())
self.assertEqual(global_scope["not_in_scope"], set())

def test_accesses(self) -> None:
m, scopes = get_scope_metadata_provider(
Expand All @@ -64,7 +64,7 @@ def fn_def():
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
global_foo_assignments = scope_of_module["foo"]
global_foo_assignments = list(scope_of_module["foo"])
self.assertEqual(len(global_foo_assignments), 1)
foo_assignment = global_foo_assignments[0]
self.assertEqual(len(foo_assignment.references), 2)
Expand All @@ -91,7 +91,7 @@ def fn_def():
self.assertIsInstance(scope_of_func_statement, FunctionScope)
func_foo_assignments = scope_of_func_statement["foo"]
self.assertEqual(len(func_foo_assignments), 1)
foo_assignment = func_foo_assignments[0]
foo_assignment = list(func_foo_assignments)[0]
self.assertEqual(len(foo_assignment.references), 1)
fn3_call_arg = ensure_type(
ensure_type(
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_import(self) -> None:
len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope."
)

assignment = cast(Assignment, scope_of_module[in_scope][0])
assignment = cast(Assignment, list(scope_of_module[in_scope])[0])
self.assertEqual(
assignment.name,
in_scope,
Expand All @@ -171,7 +171,7 @@ def test_imoprt_from(self) -> None:
self.assertEqual(
len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope."
)
import_assignment = cast(Assignment, scope_of_module[in_scope][0])
import_assignment = cast(Assignment, list(scope_of_module[in_scope])[0])
self.assertEqual(
import_assignment.name,
in_scope,
Expand Down Expand Up @@ -226,7 +226,7 @@ def f():
)
scope_of_module = scopes[m]
self.assertIsInstance(scope_of_module, GlobalScope)
cls_assignments = scope_of_module["Cls"]
cls_assignments = list(scope_of_module["Cls"])
self.assertEqual(len(cls_assignments), 1)
cls_assignment = cast(Assignment, cls_assignments[0])
cls_def = ensure_type(m.body[1], cst.ClassDef)
Expand Down Expand Up @@ -444,14 +444,14 @@ def f():
scope_of_outer_f = scopes[outer_f.body.body[0]]
self.assertIsInstance(scope_of_outer_f, FunctionScope)
self.assertTrue("f" in scope_of_outer_f)
out_f_assignment = scope_of_module["f"][0]
out_f_assignment = list(scope_of_module["f"])[0]
self.assertEqual(cast(Assignment, out_f_assignment).node, outer_f)

inner_f = ensure_type(outer_f.body.body[0], cst.FunctionDef)
scope_of_inner_f = scopes[inner_f.body.body[0]]
self.assertIsInstance(scope_of_inner_f, FunctionScope)
self.assertTrue("f" in scope_of_inner_f)
inner_f_assignment = scope_of_outer_f["f"][0]
inner_f_assignment = list(scope_of_outer_f["f"])[0]
self.assertEqual(cast(Assignment, inner_f_assignment).node, inner_f)

def test_func_param_scope(self) -> None:
Expand Down Expand Up @@ -505,11 +505,11 @@ def f(x: T=1, *vararg, y: T=2, z, **kwarg) -> RET:
self.assertTrue("kwarg" not in scope_of_module)
self.assertTrue("kwarg" in scope_of_f)

self.assertEqual(cast(Assignment, scope_of_f["x"][0]).node, x)
self.assertEqual(cast(Assignment, scope_of_f["vararg"][0]).node, vararg)
self.assertEqual(cast(Assignment, scope_of_f["y"][0]).node, y)
self.assertEqual(cast(Assignment, scope_of_f["z"][0]).node, z)
self.assertEqual(cast(Assignment, scope_of_f["kwarg"][0]).node, kwarg)
self.assertEqual(cast(Assignment, list(scope_of_f["x"])[0]).node, x)
self.assertEqual(cast(Assignment, list(scope_of_f["vararg"])[0]).node, vararg)
self.assertEqual(cast(Assignment, list(scope_of_f["y"])[0]).node, y)
self.assertEqual(cast(Assignment, list(scope_of_f["z"])[0]).node, z)
self.assertEqual(cast(Assignment, list(scope_of_f["kwarg"])[0]).node, kwarg)

def test_lambda_param_scope(self) -> None:
m, scopes = get_scope_metadata_provider(
Expand Down Expand Up @@ -556,11 +556,11 @@ def test_lambda_param_scope(self) -> None:
self.assertTrue("kwarg" not in scope_of_module)
self.assertTrue("kwarg" in scope_of_f)

self.assertEqual(cast(Assignment, scope_of_f["x"][0]).node, x)
self.assertEqual(cast(Assignment, scope_of_f["vararg"][0]).node, vararg)
self.assertEqual(cast(Assignment, scope_of_f["y"][0]).node, y)
self.assertEqual(cast(Assignment, scope_of_f["z"][0]).node, z)
self.assertEqual(cast(Assignment, scope_of_f["kwarg"][0]).node, kwarg)
self.assertEqual(cast(Assignment, list(scope_of_f["x"])[0]).node, x)
self.assertEqual(cast(Assignment, list(scope_of_f["vararg"])[0]).node, vararg)
self.assertEqual(cast(Assignment, list(scope_of_f["y"])[0]).node, y)
self.assertEqual(cast(Assignment, list(scope_of_f["z"])[0]).node, z)
self.assertEqual(cast(Assignment, list(scope_of_f["kwarg"])[0]).node, kwarg)

def test_except_handler(self) -> None:
"""
Expand All @@ -581,7 +581,7 @@ def test_except_handler(self) -> None:
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertTrue("ex" in scope_of_module)
self.assertEqual(
cast(Assignment, scope_of_module["ex"][0]).node,
cast(Assignment, list(scope_of_module["ex"])[0]).node,
ensure_type(
ensure_type(m.body[0], cst.Try).handlers[0].name, cst.AsName
).name,
Expand All @@ -598,7 +598,7 @@ def test_with_asname(self) -> None:
self.assertIsInstance(scope_of_module, GlobalScope)
self.assertTrue("f" in scope_of_module)
self.assertEqual(
cast(Assignment, scope_of_module["f"][0]).node,
cast(Assignment, list(scope_of_module["f"])[0]).node,
ensure_type(
ensure_type(m.body[0], cst.With).items[0].asname, cst.AsName
).name,
Expand Down Expand Up @@ -806,11 +806,10 @@ def g():
scope_of_module = scopes[a_outer_assign]
a_outer_assignments = scope_of_module.assignments[a_outer_access]
self.assertEqual(len(a_outer_assignments), 1)
a_outer_assignment = list(a_outer_assignments)[0]
self.assertEqual(cast(Assignment, a_outer_assignment).node, a_outer_assign)
self.assertEqual(
cast(Assignment, list(a_outer_assignments)[0]).node, a_outer_assign
)
self.assertEqual(
{i.node for i in list(a_outer_assignments)[0].references}, {a_outer_access}
{i.node for i in a_outer_assignment.references}, {a_outer_access}
)

a_outer_assesses = scope_of_module.accesses[a_outer_assign]
Expand Down Expand Up @@ -986,7 +985,7 @@ def test_del_context_names(self) -> None:
{i.node for i in a_assign.references},
{ensure_type(del_a_b.target, cst.Attribute).value},
)
self.assertEqual(scope["b"], ())
self.assertEqual(scope["b"], set())

def test_keyword_arg_in_call(self) -> None:
m, scopes = get_scope_metadata_provider("call(arg=val)")
Expand Down

0 comments on commit 24878a1

Please sign in to comment.