Skip to content

Commit 5356499

Browse files
committed
refactor: Big refactor to support include and skip directives
1 parent 57d08c3 commit 5356499

16 files changed

+622
-200
lines changed

src/graphql_complexity/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .estimators import (
2+
ComplexityEstimator,
3+
DirectivesEstimator,
4+
SimpleEstimator
5+
)
6+
from .visitor import ComplexityVisitor
7+
8+
__all__ = [
9+
"ComplexityVisitor",
10+
"SimpleEstimator",
11+
"ComplexityEstimator",
12+
"DirectivesEstimator",
13+
]

src/graphql_complexity/estimators.py

Lines changed: 0 additions & 120 deletions
This file was deleted.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .base import ComplexityEstimator
2+
from .directive import DirectivesEstimator
3+
from .simple import SimpleEstimator
4+
5+
__all__ = ["ComplexityEstimator", "SimpleEstimator", "DirectivesEstimator"]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import abc
2+
3+
4+
class ComplexityEstimator(abc.ABC):
5+
@abc.abstractmethod
6+
def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
7+
"""Return the complexity of the field."""
8+
9+
@abc.abstractmethod
10+
def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
11+
"""Return the multiplier that will be applied to the children of the given node."""
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from graphql import Visitor, parse, visit
2+
3+
from graphql_complexity.estimators.base import ComplexityEstimator
4+
5+
DIRECTIVE_ESTIMATOR_FIELD_COMPLEXITY_NAME = "value"
6+
DEFAULT_COMPLEXITY_DIRECTIVE_NAME = "complexity"
7+
DEFAULT_COMPLEXITY_VALUE = 1
8+
9+
10+
class DirectivesVisitor(Visitor):
11+
def __init__(self, directive_name: str, collector: dict[str, int]):
12+
self._collector = collector
13+
self.directive_name = directive_name
14+
super().__init__()
15+
16+
def enter_field_definition(self, node, key, parent, path, ancestors):
17+
for directive in node.directives:
18+
if directive.name.value == self.directive_name:
19+
complexity = next(
20+
arg.value.value
21+
for arg in directive.arguments
22+
if arg.name.value == DIRECTIVE_ESTIMATOR_FIELD_COMPLEXITY_NAME
23+
)
24+
self._collector[node.name.value] = int(complexity)
25+
break
26+
27+
28+
class DirectivesEstimator(ComplexityEstimator):
29+
"""Complexity estimator that uses directives to get the complexity of the fields.
30+
The complexity is calculated by visiting the schema and using the complexity
31+
directive to get the complexity of each field.
32+
33+
Example:
34+
Given the following schema:
35+
```qgl
36+
directive @complexity(
37+
value: Int!
38+
) on FIELD_DEFINITION
39+
40+
type Query {
41+
oneField: String @complexity(value: 5)
42+
otherField: String @complexity(value: 1)
43+
}
44+
```
45+
The complexity of the fields will be:
46+
- oneField: 5
47+
- otherField: 1
48+
And the total complexity will be 6.
49+
"""
50+
51+
def __init__(
52+
self,
53+
schema: str,
54+
directive_name: str = DEFAULT_COMPLEXITY_DIRECTIVE_NAME,
55+
missing_complexity: int = DEFAULT_COMPLEXITY_VALUE,
56+
):
57+
self.__directive_name = directive_name
58+
self.__missing_complexity = int(missing_complexity)
59+
self.__complexity_map = self.collect_from_schema(
60+
schema=schema, directive_name=directive_name
61+
)
62+
super().__init__()
63+
64+
@staticmethod
65+
def collect_from_schema(schema: str, directive_name: str) -> dict[str, int]:
66+
collector = {}
67+
ast = parse(schema)
68+
visitor = DirectivesVisitor(collector=collector, directive_name=directive_name)
69+
visit(ast, visitor)
70+
return collector
71+
72+
def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
73+
return self.__complexity_map.get(node.name.value, self.__missing_complexity)
74+
75+
def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
76+
# ToDo: Implement this method
77+
return 1
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from graphql_complexity.estimators.base import ComplexityEstimator
2+
3+
4+
class SimpleEstimator(ComplexityEstimator):
5+
"""Simple complexity estimator that returns a constant complexity and multiplier for all fields.
6+
Constants can be set in the constructor.
7+
8+
Example:
9+
Given the following query:
10+
```qgl
11+
query {
12+
user {
13+
name
14+
email
15+
}
16+
}
17+
```
18+
As the complexity and multiplier are constant, the complexity of the fields will be:
19+
- user: 1 * 1 = 1
20+
- name: 1 * 1 = 1
21+
- email: 1 * 1 = 1
22+
And the total complexity will be 3.
23+
"""
24+
25+
def __init__(self, complexity: int = 1, multiplier: int = 1):
26+
if complexity < 0:
27+
raise ValueError(
28+
"'complexity' must be a positive integer (greater or equal than 0)"
29+
)
30+
if multiplier < 0:
31+
raise ValueError(
32+
"'multiplier' must be a positive integer (greater or equal than 0)"
33+
)
34+
self.__complexity_constant = complexity
35+
self.__multiplier_constant = multiplier
36+
super().__init__()
37+
38+
def get_field_complexity(self, node, key, parent, path, ancestors) -> int:
39+
return self.__complexity_constant
40+
41+
def get_field_multiplier(self, node, key, parent, path, ancestors) -> int:
42+
return self.__multiplier_constant

src/graphql_complexity/extensions.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

src/graphql_complexity/extensions/__init__.py

Whitespace-only changes.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import Type
2+
3+
from graphql import GraphQLError, visit
4+
from strawberry.extensions import SchemaExtension
5+
6+
from graphql_complexity import (
7+
ComplexityEstimator,
8+
ComplexityVisitor,
9+
SimpleEstimator
10+
)
11+
12+
13+
def build_complexity_extension(
14+
estimator: ComplexityEstimator | None = None,
15+
max_complexity: int | None = None,
16+
) -> Type[SchemaExtension]:
17+
estimator = estimator or SimpleEstimator(1, 1)
18+
19+
class ComplexityExtension(SchemaExtension):
20+
visitor = None
21+
estimated_complexity: int = None
22+
23+
def on_validate(
24+
self,
25+
):
26+
self.visitor = ComplexityVisitor(estimator=estimator)
27+
visit(self.execution_context.graphql_document, self.visitor)
28+
29+
self.estimated_complexity = self.visitor.evaluate()
30+
31+
if max_complexity and self.estimated_complexity > max_complexity:
32+
error = GraphQLError(
33+
f"Query is too complex. Max complexity is {max_complexity}, estimated "
34+
f"complexity is {self.estimated_complexity}"
35+
)
36+
self.execution_context.errors = [error]
37+
38+
def get_results(self):
39+
return {"complexity": {"value": self.estimated_complexity}}
40+
41+
return ComplexityExtension

src/graphql_complexity/services.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from graphql import parse, visit
2+
3+
from . import ComplexityEstimator, ComplexityVisitor
4+
5+
6+
def get_complexity(query: str, estimator: ComplexityEstimator) -> int:
7+
ast = parse(query)
8+
visitor = ComplexityVisitor(estimator=estimator)
9+
visit(ast, visitor)
10+
return visitor.evaluate()

0 commit comments

Comments
 (0)