Skip to content

Commit e139b8e

Browse files
committed
refactor: Split complexity into building the tree and then the evaluation to allow testing the describe function
1 parent a4a5b30 commit e139b8e

File tree

4 files changed

+123
-30
lines changed

4 files changed

+123
-30
lines changed
Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,29 @@
11
from graphql import GraphQLSchema, TypeInfo, TypeInfoVisitor, parse, visit
22

3+
from . import nodes
34
from .visitor import ComplexityVisitor
45
from ..config import Config
56
from ..estimators import ComplexityEstimator
67

78

89
def get_complexity(query: str, schema: GraphQLSchema, estimator: ComplexityEstimator, config: Config = None) -> int:
10+
"""Calculate the complexity of a query using the provided estimator."""
11+
tree = build_complexity_tree(query, schema, estimator, config)
12+
13+
return tree.evaluate()
14+
15+
16+
def build_complexity_tree(
17+
query: str,
18+
schema: GraphQLSchema,
19+
estimator: ComplexityEstimator,
20+
config: Config | None = None,
21+
) -> nodes.ComplexityNode:
922
"""Calculate the complexity of a query using the provided estimator."""
1023
ast = parse(query)
1124
type_info = TypeInfo(schema)
1225

1326
visitor = ComplexityVisitor(estimator=estimator, type_info=type_info, config=config)
1427
visit(ast, TypeInfoVisitor(type_info, visitor))
1528

16-
return visitor.evaluate()
29+
return visitor.complexity_tree

src/graphql_complexity/evaluator/nodes.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,57 +21,54 @@ class ComplexityNode:
2121
parent: 'ComplexityNode' = None
2222
children: list['ComplexityNode'] = dataclasses.field(default_factory=list)
2323

24-
def evaluate(self, *args, **kwargs) -> int:
24+
def evaluate(self) -> int:
2525
raise NotImplementedError
2626

27-
def describe(self, depth=0):
27+
def describe(self, depth=0) -> str:
28+
"""Return a friendly representation of the node and its children complexity."""
2829
return (
2930
f"{chr(9) * depth}{self.name} ({self.__class__.__name__}) = {self.evaluate()}" +
3031
f"{chr(10) if self.children else ''}" +
3132
'\n'.join(c.describe(depth=depth+1) for c in self.children)
3233
)
3334

34-
def add_child(self, node: 'ComplexityNode'):
35+
def add_child(self, node: 'ComplexityNode') -> None:
36+
"""Add a child to the current node."""
3537
self.children.append(node)
3638
node.parent = self
3739

3840

3941
@dataclasses.dataclass(slots=True, kw_only=True)
4042
class RootNode(ComplexityNode):
41-
def evaluate(self, *args, **kwargs) -> int:
42-
return sum(
43-
child.evaluate(*args, **kwargs) for child in self.children
44-
)
43+
def evaluate(self) -> int:
44+
return sum(child.evaluate() for child in self.children)
4545

4646

4747
@dataclasses.dataclass(slots=True, kw_only=True)
48-
class FragmentNode(ComplexityNode):
48+
class FragmentSpreadNode(ComplexityNode):
49+
fragments_definition: dict
4950

50-
def evaluate(self, *, fragments_definition: dict[str, "ComplexityNode"]):
51-
fragment = fragments_definition.get(self.name)
51+
def evaluate(self):
52+
fragment = self.fragments_definition.get(self.name)
5253
if not fragment:
5354
return 0
54-
return fragment.evaluate(fragments_definition=fragments_definition)
55+
return fragment.evaluate()
5556

5657

5758
@dataclasses.dataclass(slots=True, kw_only=True)
5859
class Field(ComplexityNode):
5960
complexity: int
6061

61-
def evaluate(self, *args, **kwargs) -> int:
62-
return self.complexity + sum(
63-
child.evaluate(*args, **kwargs) for child in self.children
64-
)
62+
def evaluate(self) -> int:
63+
return self.complexity + sum(child.evaluate() for child in self.children)
6564

6665

6766
@dataclasses.dataclass(slots=True, kw_only=True)
6867
class ListField(Field):
6968
count: int = 1
7069

71-
def evaluate(self, *args, **kwargs) -> int:
72-
return self.complexity + self.count * sum(
73-
child.evaluate(*args, **kwargs) for child in self.children
74-
)
70+
def evaluate(self) -> int:
71+
return self.complexity + self.count * sum(child.evaluate() for child in self.children)
7572

7673

7774
@dataclasses.dataclass(slots=True, kw_only=True)
@@ -90,14 +87,14 @@ def wrap(cls, node: ComplexityNode):
9087
node.parent.add_child(wrapper)
9188
return wrapper
9289

93-
def evaluate(self, *args, **kwargs) -> int:
90+
def evaluate(self) -> int:
9491
return 0
9592

9693

9794
@dataclasses.dataclass(slots=True, kw_only=True)
9895
class MetaField(ComplexityNode):
9996

100-
def evaluate(self, *args, **kwargs) -> int:
97+
def evaluate(self) -> int:
10198
return 0
10299

103100

src/graphql_complexity/evaluator/visitor.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,17 @@ def __init__(
4040
self.variables = variables or {}
4141
self.type_info = type_info
4242
self.fragments: dict[str, nodes.ComplexityNode] = {}
43-
self.current_node = nodes.RootNode(name="root")
43+
self.root = nodes.RootNode(name="root")
44+
self.current_node = self.root
4445
self._previous_current_node = None
45-
self._ignore_until_leave = None
4646
super().__init__()
4747

48-
def evaluate(self) -> int:
49-
"""Evaluate the complexity of the operations after visiting the document."""
50-
return self.current_node.evaluate(
51-
fragments_definition=self.fragments
52-
)
48+
@property
49+
def complexity_tree(self) -> nodes.ComplexityNode:
50+
"""Return the complexity tree after visiting the document.
51+
The tree is represented by a RootNode with the complexity of the operations
52+
represented as Node children. Each node is evaluated returning the complexity."""
53+
return self.root
5354

5455
def enter_variable_definition(self, node, key, parent, path, ancestors):
5556
input_variable = self.variables.get(node.variable.name.value)
@@ -85,7 +86,12 @@ def leave_fragment_definition(self, node, *args, **kwargs):
8586

8687
def enter_fragment_spread(self, node, *args, **kwargs):
8788
"""Add a lazy fragment to the current complexity list."""
88-
self.current_node.add_child(nodes.FragmentNode(name=node.name.value))
89+
self.current_node.add_child(
90+
nodes.FragmentSpreadNode(
91+
name=node.name.value,
92+
fragments_definition=self.fragments
93+
)
94+
)
8995

9096

9197
def should_include_field(node: DirectiveNode, variables: dict[str, Any]) -> bool:

tests/test_tree.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import pytest
2+
from graphql import build_schema
3+
4+
from graphql_complexity import SimpleEstimator
5+
from graphql_complexity.evaluator.complexity import build_complexity_tree
6+
from graphql_complexity.evaluator.nodes import ComplexityNode
7+
from tests import ut_utils
8+
9+
10+
def _build_complexity_tree(query: str, estimator=None):
11+
estimator = estimator or SimpleEstimator(1)
12+
schema = build_schema(ut_utils.schema)
13+
return build_complexity_tree(query, schema, estimator)
14+
15+
16+
def test_tree_describes_simple_query():
17+
query = """
18+
query {
19+
version
20+
}
21+
"""
22+
23+
tree = _build_complexity_tree(query)
24+
25+
assert tree.describe() == """root (RootNode) = 1
26+
\tversion (Field) = 1"""
27+
28+
29+
def test_tree_describes_skipped_fields():
30+
query = """query Foo ($shouldSkip: Boolean = false) {
31+
version @include(if: $shouldSkip)
32+
}"""
33+
34+
tree = _build_complexity_tree(query)
35+
36+
assert tree.describe() == """root (RootNode) = 0
37+
\tversion (SkippedField) = 0"""
38+
39+
40+
def test_tree_describes_lists():
41+
query = """query {
42+
droid {
43+
friends {
44+
name
45+
}
46+
}
47+
}"""
48+
49+
tree = _build_complexity_tree(query)
50+
51+
assert tree.describe() == """root (RootNode) = 3
52+
\tdroid (Field) = 3
53+
\t\tfriends (ListField) = 2
54+
\t\t\tname (Field) = 1"""
55+
56+
57+
def test_describes_fragment_spread():
58+
query = """query {
59+
...fragmentName
60+
}
61+
fragment fragmentName on Droid {
62+
id
63+
}"""
64+
65+
tree = _build_complexity_tree(query)
66+
67+
assert tree.describe() == """root (RootNode) = 1
68+
\tfragmentName (FragmentSpreadNode) = 1"""
69+
70+
71+
def test_complexity_node_can_not_be_evaluated():
72+
node = ComplexityNode(
73+
name="root"
74+
)
75+
76+
with pytest.raises(NotImplementedError):
77+
node.evaluate()

0 commit comments

Comments
 (0)