Skip to content

Commit 52c7507

Browse files
authored
Merge pull request #6 from Checho3388/3-allow-multipliers-in-directivesestimator
3 allow multipliers in directivesestimator
2 parents 4d14566 + 6fb43f4 commit 52c7507

19 files changed

+425
-266
lines changed

README.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ query = """
4444
complexity = get_complexity(
4545
query=query,
4646
schema=schema,
47-
estimator=SimpleEstimator(complexity=1, multiplier=10)
47+
estimator=SimpleEstimator(complexity=10)
4848
)
4949
if complexity > 10:
5050
raise Exception("Query is too complex")
@@ -73,7 +73,7 @@ it by another **constant** which is propagated along the depth of the query.
7373
from graphql_complexity import SimpleEstimator
7474

7575

76-
estimator = SimpleEstimator(complexity=2, multiplier=1)
76+
estimator = SimpleEstimator(complexity=2)
7777
```
7878

7979
Given the following GraphQL query:
@@ -150,13 +150,10 @@ from graphql_complexity import ComplexityEstimator
150150

151151

152152
class CustomEstimator(ComplexityEstimator):
153-
def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
153+
def get_field_complexity(self, node, type_info, path) -> int:
154154
if node.name.value == "specificField":
155155
return 100
156156
return 1
157-
158-
def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
159-
return 1
160157
```
161158

162159

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "graphql_complexity"
3-
version = "0.2.0"
3+
version = "0.3.0"
44
description = "A python library that provides complexity calculation helpers for GraphQL"
55
authors = ["Checho3388 <ezequiel.grondona@gmail.com>"]
66
packages = [

src/graphql_complexity/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import dataclasses
2+
3+
4+
@dataclasses.dataclass(frozen=True)
5+
class Config:
6+
count_arg_name: str | None = "first" # ToDo: Improve Unset
7+
count_missing_arg_value: int = 1

src/graphql_complexity/estimators/base.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,5 @@
33

44
class ComplexityEstimator(abc.ABC):
55
@abc.abstractmethod
6-
def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
6+
def get_field_complexity(self, node, type_info, path) -> int:
77
"""Return the complexity of the field."""
8-
9-
@abc.abstractmethod
10-
def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
11-
"""Return the multiplier that will be applied to the children of the given node."""

src/graphql_complexity/estimators/directive.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,5 @@ def collect_from_schema(schema: str, directive_name: str) -> dict[str, int]:
7171
visit(ast, visitor)
7272
return collector
7373

74-
def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
74+
def get_field_complexity(self, node, type_info, path) -> int:
7575
return self.__complexity_map.get(node.name.value, self.__missing_complexity)
76-
77-
def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
78-
# ToDo: Implement this method
79-
return 1

src/graphql_complexity/estimators/simple.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,16 @@
22

33

44
class SimpleEstimator(ComplexityEstimator):
5-
"""Simple complexity estimator that returns a constant complexity and multiplier for all fields.
6-
Constants can be set in the constructor.
5+
"""Simple complexity estimator that returns a constant complexity for all fields.
6+
Constant can be set in the constructor."""
77

8-
Example:
9-
Given the following query:
10-
```qgl
11-
query {
12-
user {
13-
name
14-
email
15-
}
16-
}
17-
```
18-
As the complexity and multiplier are constant, the complexity of the fields will be:
19-
- user: 1 * 1 = 1
20-
- name: 1 * 1 = 1
21-
- email: 1 * 1 = 1
22-
And the total complexity will be 3.
23-
"""
24-
25-
def __init__(self, complexity: int = 1, multiplier: int = 1):
8+
def __init__(self, complexity: int = 1):
269
if complexity < 0:
2710
raise ValueError(
2811
"'complexity' must be a positive integer (greater or equal than 0)"
2912
)
30-
if multiplier < 0:
31-
raise ValueError(
32-
"'multiplier' must be a positive integer (greater or equal than 0)"
33-
)
3413
self.__complexity_constant = complexity
35-
self.__multiplier_constant = multiplier
3614
super().__init__()
3715

38-
def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
16+
def get_field_complexity(self, *_, **__) -> int:
3917
return self.__complexity_constant
40-
41-
def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
42-
return self.__multiplier_constant
Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
from .complexity import (
2-
get_ast_complexity,
3-
get_complexity,
4-
)
1+
from .complexity import get_complexity
52

63
__all__ = [
74
'get_complexity',
8-
'get_ast_complexity'
95
]
Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
1-
from graphql import parse, visit, TypeInfo, TypeInfoVisitor, GraphQLSchema
1+
from graphql import GraphQLSchema, TypeInfo, TypeInfoVisitor, parse, visit
22

3+
from . import nodes
34
from .visitor import ComplexityVisitor
5+
from ..config import Config
46
from ..estimators import ComplexityEstimator
57

68

7-
def get_complexity(query: str, schema: GraphQLSchema, estimator: ComplexityEstimator) -> int:
9+
def get_complexity(query: str, schema: GraphQLSchema, estimator: ComplexityEstimator, config: Config = None) -> int:
810
"""Calculate the complexity of a query using the provided estimator."""
9-
ast = parse(query)
10-
return get_ast_complexity(ast, schema=schema, estimator=estimator)
11+
tree = build_complexity_tree(query, schema, estimator, config)
12+
13+
return tree.evaluate()
1114

1215

13-
def get_ast_complexity(ast, schema: GraphQLSchema, estimator: ComplexityEstimator) -> int:
16+
def build_complexity_tree(
17+
query: str,
18+
schema: GraphQLSchema,
19+
estimator: ComplexityEstimator,
20+
config: Config | None = None,
21+
) -> nodes.ComplexityNode:
1422
"""Calculate the complexity of a query using the provided estimator."""
23+
ast = parse(query)
1524
type_info = TypeInfo(schema)
1625

17-
visitor = ComplexityVisitor(estimator=estimator, type_info=type_info)
26+
visitor = ComplexityVisitor(estimator=estimator, type_info=type_info, config=config)
1827
visit(ast, TypeInfoVisitor(type_info, visitor))
1928

20-
return visitor.evaluate()
29+
return visitor.complexity_tree
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import dataclasses
2+
import logging
3+
from typing import Any
4+
5+
from graphql import (
6+
GraphQLList,
7+
TypeInfo,
8+
get_named_type,
9+
is_introspection_type, FieldNode
10+
)
11+
12+
from graphql_complexity.config import Config
13+
from graphql_complexity.evaluator.utils import get_node_argument_value
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
@dataclasses.dataclass(slots=True, kw_only=True)
19+
class ComplexityNode:
20+
name: str
21+
parent: 'ComplexityNode' = None
22+
children: list['ComplexityNode'] = dataclasses.field(default_factory=list)
23+
24+
def evaluate(self) -> int:
25+
raise NotImplementedError
26+
27+
def describe(self, depth=0) -> str:
28+
"""Return a friendly representation of the node and its children complexity."""
29+
return (
30+
f"{chr(9) * depth}{self.name} ({self.__class__.__name__}) = {self.evaluate()}" +
31+
f"{chr(10) if self.children else ''}" +
32+
'\n'.join(c.describe(depth=depth+1) for c in self.children)
33+
)
34+
35+
def add_child(self, node: 'ComplexityNode') -> None:
36+
"""Add a child to the current node."""
37+
self.children.append(node)
38+
node.parent = self
39+
40+
41+
@dataclasses.dataclass(slots=True, kw_only=True)
42+
class RootNode(ComplexityNode):
43+
def evaluate(self) -> int:
44+
return sum(child.evaluate() for child in self.children)
45+
46+
47+
@dataclasses.dataclass(slots=True, kw_only=True)
48+
class FragmentSpreadNode(ComplexityNode):
49+
fragments_definition: dict
50+
51+
def evaluate(self):
52+
fragment = self.fragments_definition.get(self.name)
53+
if not fragment:
54+
return 0
55+
return fragment.evaluate()
56+
57+
58+
@dataclasses.dataclass(slots=True, kw_only=True)
59+
class Field(ComplexityNode):
60+
complexity: int
61+
62+
def evaluate(self) -> int:
63+
return self.complexity + sum(child.evaluate() for child in self.children)
64+
65+
66+
@dataclasses.dataclass(slots=True, kw_only=True)
67+
class ListField(Field):
68+
count: int = 1
69+
70+
def evaluate(self) -> int:
71+
return self.complexity + self.count * sum(child.evaluate() for child in self.children)
72+
73+
74+
@dataclasses.dataclass(slots=True, kw_only=True)
75+
class SkippedField(ComplexityNode):
76+
wraps: ComplexityNode
77+
78+
@classmethod
79+
def wrap(cls, node: ComplexityNode):
80+
wrapper = cls(
81+
name=node.name,
82+
parent=node.parent,
83+
children=node.children,
84+
wraps=node,
85+
)
86+
node.parent.children.remove(node)
87+
node.parent.add_child(wrapper)
88+
return wrapper
89+
90+
def evaluate(self) -> int:
91+
return 0
92+
93+
94+
@dataclasses.dataclass(slots=True, kw_only=True)
95+
class MetaField(ComplexityNode):
96+
97+
def evaluate(self) -> int:
98+
return 0
99+
100+
101+
def build_node(
102+
node: FieldNode,
103+
type_info: TypeInfo,
104+
complexity: int,
105+
variables: dict[str, Any],
106+
config: Config,
107+
) -> ComplexityNode:
108+
"""Build a complexity node from a field node."""
109+
type_ = type_info.get_type()
110+
unwrapped_type = get_named_type(type_)
111+
if unwrapped_type is not None and is_introspection_type(unwrapped_type):
112+
return MetaField(name=node.name.value)
113+
if isinstance(type_, GraphQLList):
114+
return build_list_node(node, complexity, variables, config)
115+
return Field(
116+
name=node.name.value,
117+
complexity=complexity,
118+
)
119+
120+
121+
def build_list_node(node: FieldNode, complexity: int, variables: dict[str, Any], config: Config) -> ListField:
122+
"""Build a list complexity node from a field node."""
123+
if config.count_arg_name:
124+
try:
125+
count = int(
126+
get_node_argument_value(node=node, arg_name=config.count_arg_name, variables=variables)
127+
)
128+
except ValueError:
129+
logger.debug("Missing or invalid value for argument '%s' in node '%s'", config.count_arg_name, node)
130+
count = config.count_missing_arg_value
131+
else:
132+
count = 1
133+
return ListField(
134+
name=node.name.value,
135+
complexity=complexity,
136+
count=count,
137+
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Any
2+
3+
from graphql import DirectiveNode, FieldNode, VariableNode
4+
5+
6+
def get_node_argument_value(node: FieldNode | DirectiveNode, arg_name: str, variables: dict[str, Any]) -> Any:
7+
"""Returns the value of the argument given by parameter."""
8+
arg = next(
9+
(arg for arg in node.arguments if arg.name.value == arg_name),
10+
None
11+
)
12+
if not arg:
13+
raise ValueError(f"Value for {arg_name!r} not found in {node.name.value!r} arguments")
14+
15+
if isinstance(arg.value, VariableNode):
16+
return variables.get(arg.value.name.value)
17+
18+
return arg.value.value

0 commit comments

Comments
 (0)