From 4b4e188ba83d28b5dd6ff66479e7448e5b925030 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 19:26:46 -0800 Subject: [PATCH] perf: levenshtein optimization (#3780) optimize compile time. `levenshtein` is a hotspot since it is called a lot during type analysis to construct exceptions (which are then caught as part of the validation routines). this commit delays calling `levenshtein` until the last minute, and also adds a mechanism to `VyperException` so that hints can be constructed lazily in general. on a couple test contracts, compilation time comes down 7%. however, as a portion of the time spent in the frontend, compilation time comes down 20-30%. this will become important as projects become larger (that is, many imports but only some functions are actually present in codegen) and compilation time is dominated by the frontend. --- .../syntax/modules/test_initializers.py | 2 +- tests/functional/syntax/test_for_range.py | 39 +++++++++++++++---- vyper/exceptions.py | 13 ++++++- vyper/semantics/analysis/levenshtein_utils.py | 10 ++++- vyper/semantics/analysis/utils.py | 4 +- vyper/semantics/namespace.py | 4 +- vyper/semantics/types/base.py | 4 +- vyper/semantics/types/user.py | 4 +- vyper/semantics/types/utils.py | 5 +-- 9 files changed, 62 insertions(+), 23 deletions(-) diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index d0965ae61d..66a201a33d 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -1178,4 +1178,4 @@ def test_ownership_decl_errors_not_swallowed(make_input_bundle): input_bundle = make_input_bundle({"lib1.vy": lib1}) with pytest.raises(UndeclaredDefinition) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "'lib2' has not been declared. " + assert e.value._message == "'lib2' has not been declared." diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index a486d11738..94eed58dd4 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -21,6 +21,7 @@ def foo(): """, StructureException, "Invalid syntax for loop iterator", + None, "a[1]", ), ( @@ -32,6 +33,7 @@ def bar(): """, StructureException, "Bound must be at least 1", + None, "0", ), ( @@ -44,6 +46,7 @@ def foo(): """, StateAccessViolation, "Bound must be a literal", + None, "x", ), ( @@ -55,6 +58,7 @@ def foo(): """, StructureException, "Please remove the `bound=` kwarg when using range with constants", + None, "5", ), ( @@ -66,6 +70,7 @@ def foo(): """, StructureException, "Bound must be at least 1", + None, "0", ), ( @@ -78,6 +83,7 @@ def bar(): """, ArgumentException, "Invalid keyword argument 'extra'", + None, "extra=3", ), ( @@ -89,6 +95,7 @@ def bar(): """, StructureException, "End must be greater than start", + None, "0", ), ( @@ -101,6 +108,7 @@ def bar(): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -113,6 +121,7 @@ def bar(): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -125,6 +134,7 @@ def repeat(n: uint256) -> uint256: """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "n * 10", ), ( @@ -137,6 +147,7 @@ def bar(): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x + 1", ), ( @@ -148,6 +159,7 @@ def bar(): """, StructureException, "End must be greater than start", + None, "1", ), ( @@ -160,6 +172,7 @@ def bar(): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -172,6 +185,7 @@ def foo(): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -184,6 +198,7 @@ def repeat(n: uint256) -> uint256: """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "n", ), ( @@ -196,6 +211,7 @@ def foo(x: int128): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -207,6 +223,7 @@ def bar(x: uint256): """, StateAccessViolation, "Value must be a literal integer, unless a bound is specified", + None, "x", ), ( @@ -221,6 +238,7 @@ def foo(): """, TypeMismatch, "Given reference has type int128, expected uint256", + None, "FOO", ), ( @@ -234,6 +252,7 @@ def foo(): """, StructureException, "Bound must be at least 1", + None, "FOO", ), ( @@ -244,7 +263,8 @@ def foo(): pass """, UnknownType, - "No builtin or user-defined type named 'DynArra'. Did you mean 'DynArray'?", + "No builtin or user-defined type named 'DynArra'.", + "Did you mean 'DynArray'?", "DynArra", ), ( @@ -262,7 +282,8 @@ def foo(): pass """, UnknownType, - "No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?", + "No builtin or user-defined type named 'uint9'.", + "Did you mean 'uint96', or maybe 'uint8'?", "uint9", ), ( @@ -278,7 +299,8 @@ def foo(): pass """, UnknownType, - "No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?", + "No builtin or user-defined type named 'uint9'.", + "Did you mean 'uint96', or maybe 'uint8'?", "uint9", ), ] @@ -289,15 +311,18 @@ def foo(): f"{i:02d}: {for_code_regex.search(code).group(1)}" # type: ignore[union-attr] f" raises {type(err).__name__}" ) - for i, (code, err, msg, src) in enumerate(fail_list) + for i, (code, err, msg, hint, src) in enumerate(fail_list) ] -@pytest.mark.parametrize("bad_code,error_type,message,source_code", fail_list, ids=fail_test_names) -def test_range_fail(bad_code, error_type, message, source_code): +@pytest.mark.parametrize( + "bad_code,error_type,message,hint,source_code", fail_list, ids=fail_test_names +) +def test_range_fail(bad_code, error_type, message, hint, source_code): with pytest.raises(error_type) as exc_info: compiler.compile_code(bad_code) - assert message == exc_info.value.message + assert message == exc_info.value._message + assert hint == exc_info.value.hint assert source_code == exc_info.value.args[1].get_original_node().node_source_code diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 53ad6f7bb8..f57cdabe9d 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -79,11 +79,20 @@ def with_annotation(self, *annotations): exc.annotations = annotations return exc + @property + def hint(self): + # some hints are expensive to compute, so we wait until the last + # minute when the formatted message is actually requested to compute + # them. + if callable(self._hint): + return self._hint() + return self._hint + @property def message(self): msg = self._message - if self._hint: - msg += f"\n\n (hint: {self._hint})" + if self.hint: + msg += f"\n\n (hint: {self.hint})" return msg def __str__(self): diff --git a/vyper/semantics/analysis/levenshtein_utils.py b/vyper/semantics/analysis/levenshtein_utils.py index 1d8f87dfbd..fc6e497d43 100644 --- a/vyper/semantics/analysis/levenshtein_utils.py +++ b/vyper/semantics/analysis/levenshtein_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Callable def levenshtein_norm(source: str, target: str) -> float: @@ -73,7 +73,13 @@ def levenshtein(source: str, target: str) -> int: return matrix[len(source)][len(target)] -def get_levenshtein_error_suggestions(key: str, namespace: Dict[str, Any], threshold: float) -> str: +def get_levenshtein_error_suggestions(*args, **kwargs) -> Callable: + return lambda: _get_levenshtein_error_suggestions(*args, **kwargs) + + +def _get_levenshtein_error_suggestions( + key: str, namespace: dict[str, Any], threshold: float +) -> str: """ Generate an error message snippet for the suggested closest values in the provided namespace with the shortest normalized Levenshtein distance from the given key if that distance diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 034cd8c46e..fa4dfcc1d1 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -208,9 +208,9 @@ def _raise_invalid_reference(name, node): if name in self.namespace: _raise_invalid_reference(name, node) - suggestions_str = get_levenshtein_error_suggestions(name, t.members, 0.4) + hint = get_levenshtein_error_suggestions(name, t.members, 0.4) raise UndeclaredDefinition( - f"Storage variable '{name}' has not been declared. {suggestions_str}", node + f"Storage variable '{name}' has not been declared.", node, hint=hint ) from None def types_from_BinOp(self, node): diff --git a/vyper/semantics/namespace.py b/vyper/semantics/namespace.py index 4df2511a29..d59343edfb 100644 --- a/vyper/semantics/namespace.py +++ b/vyper/semantics/namespace.py @@ -45,8 +45,8 @@ def __setitem__(self, attr, obj): def __getitem__(self, key): if key not in self: - suggestions_str = get_levenshtein_error_suggestions(key, self, 0.2) - raise UndeclaredDefinition(f"'{key}' has not been declared. {suggestions_str}") + hint = get_levenshtein_error_suggestions(key, self, 0.2) + raise UndeclaredDefinition(f"'{key}' has not been declared.", hint=hint) return super().__getitem__(key) def __enter__(self): diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index c5e10b52be..37de263319 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -321,8 +321,8 @@ def get_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": if not self.members: raise StructureException(f"{self} instance does not have members", node) - suggestions_str = get_levenshtein_error_suggestions(key, self.members, 0.3) - raise UnknownAttribute(f"{self} has no member '{key}'. {suggestions_str}", node) + hint = get_levenshtein_error_suggestions(key, self.members, 0.3) + raise UnknownAttribute(f"{self} has no member '{key}'.", node, hint=hint) def __repr__(self): return self._id diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 92a455e3d8..0c9b5d70da 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -399,9 +399,9 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": keys = list(self.member_types.keys()) for i, (key, value) in enumerate(zip(node.args[0].keys, node.args[0].values)): if key is None or key.get("id") not in members: - suggestions_str = get_levenshtein_error_suggestions(key.get("id"), members, 1.0) + hint = get_levenshtein_error_suggestions(key.get("id"), members, 1.0) raise UnknownAttribute( - f"Unknown or duplicate struct member. {suggestions_str}", key or value + "Unknown or duplicate struct member.", key or value, hint=hint ) expected_key = keys[i] if key.id != expected_key: diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 96c661021f..0546668900 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -146,10 +146,9 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: raise InvalidType(err_msg, node) if node.id not in namespace: # type: ignore - suggestions_str = get_levenshtein_error_suggestions(node.node_source_code, namespace, 0.3) + hint = get_levenshtein_error_suggestions(node.node_source_code, namespace, 0.3) raise UnknownType( - f"No builtin or user-defined type named '{node.node_source_code}'. {suggestions_str}", - node, + f"No builtin or user-defined type named '{node.node_source_code}'.", node, hint=hint ) from None typ_ = namespace[node.id]