Skip to content

Commit 0ae1a7b

Browse files
committed
DSLField and DSLFragment inherits new DSLSelection method
1 parent b613fef commit 0ae1a7b

File tree

2 files changed

+113
-66
lines changed

2 files changed

+113
-66
lines changed

gql/dsl.py

Lines changed: 106 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def __init__(
245245
self.variable_definitions: DSLVariableDefinitions = DSLVariableDefinitions()
246246

247247
# Concatenate fields without and with alias
248-
all_fields: Tuple["DSLField", ...] = DSLField.get_aliased_fields(
248+
all_fields: Tuple["DSLSelection", ...] = DSLField.get_aliased_fields(
249249
fields, fields_with_alias
250250
)
251251

@@ -265,7 +265,7 @@ def __init__(
265265
)
266266

267267
self.selection_set: SelectionSetNode = SelectionSetNode(
268-
selections=FrozenList(DSLField.get_ast_fields(all_fields))
268+
selections=FrozenList(DSLSelection.get_ast_fields(all_fields))
269269
)
270270

271271

@@ -397,56 +397,35 @@ def __repr__(self) -> str:
397397
return f"<{self.__class__.__name__} {self._type!r}>"
398398

399399

400-
class DSLField:
401-
"""The DSLField represents a GraphQL field for the DSL code.
402-
403-
Instances of this class are generated for you automatically as attributes
404-
of the :class:`DSLType`
400+
class DSLSelection(ABC):
401+
"""DSLSelection is an abstract class which define the
402+
:meth:`select <gql.dsl.DSLSelection.select>` method to select
403+
children fields in the query.
405404
406-
If this field contains children fields, then you need to select which ones
407-
you want in the request using the :meth:`select <gql.dsl.DSLField.select>`
408-
method.
405+
subclasses:
406+
:class:`DSLField`
407+
:class:`DSLFragment`
409408
"""
410409

411410
_type: Union[GraphQLObjectType, GraphQLInterfaceType]
412-
ast_field: FieldNode
413-
field: GraphQLField
414-
415-
def __init__(
416-
self,
417-
name: str,
418-
graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType],
419-
graphql_field: GraphQLField,
420-
):
421-
"""Initialize the DSLField.
422-
423-
.. warning::
424-
Don't instantiate this class yourself.
425-
Use attributes of the :class:`DSLType` instead.
426-
427-
:param name: the name of the field
428-
:param graphql_type: the GraphQL type definition from the schema
429-
:param graphql_field: the GraphQL field definition from the schema
430-
"""
431-
self._type = graphql_type
432-
self.field = graphql_field
433-
self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList())
434-
log.debug(f"Creating {self!r}")
411+
ast_field: Union[FieldNode, InlineFragmentNode]
435412

436413
@staticmethod
437-
def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]:
414+
def get_ast_fields(
415+
fields: Iterable["DSLSelection"],
416+
) -> List[Union[FieldNode, InlineFragmentNode]]:
438417
"""
439418
:meta private:
440419
441420
Equivalent to: :code:`[field.ast_field for field in fields]`
442421
But with a type check for each field in the list.
443422
444423
:raises TypeError: if any of the provided fields are not instances
445-
of the :class:`DSLField` class.
424+
of the :class:`DSLSelection` class.
446425
"""
447426
ast_fields = []
448427
for field in fields:
449-
if isinstance(field, DSLField):
428+
if isinstance(field, DSLSelection):
450429
ast_fields.append(field.ast_field)
451430
else:
452431
raise TypeError(f'Received incompatible field: "{field}".')
@@ -455,8 +434,8 @@ def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]:
455434

456435
@staticmethod
457436
def get_aliased_fields(
458-
fields: Iterable["DSLField"], fields_with_alias: Dict[str, "DSLField"]
459-
) -> Tuple["DSLField", ...]:
437+
fields: Iterable["DSLSelection"], fields_with_alias: Dict[str, "DSLField"]
438+
) -> Tuple["DSLSelection", ...]:
460439
"""
461440
:meta private:
462441
@@ -471,30 +450,32 @@ def get_aliased_fields(
471450
)
472451

473452
def select(
474-
self, *fields: "DSLField", **fields_with_alias: "DSLField"
475-
) -> "DSLField":
453+
self, *fields: "DSLSelection", **fields_with_alias: "DSLField"
454+
) -> "DSLSelection":
476455
r"""Select the new children fields
477456
that we want to receive in the request.
478457
479458
If used multiple times, we will add the new children fields
480459
to the existing children fields.
481460
482461
:param \*fields: new children fields
483-
:type \*fields: DSLField
462+
:type \*fields: DSLSelection (DSLField or DSLFragment)
484463
:param \**fields_with_alias: new children fields with alias as key
485464
:type \**fields_with_alias: DSLField
486465
:return: itself
487466
488467
:raises TypeError: if any of the provided fields are not instances
489-
of the :class:`DSLField` class.
468+
of the :class:`DSLSelection` class.
490469
"""
491470

492471
# Concatenate fields without and with alias
493-
added_fields: Tuple["DSLField", ...] = self.get_aliased_fields(
472+
added_fields: Tuple["DSLSelection", ...] = self.get_aliased_fields(
494473
fields, fields_with_alias
495474
)
496475

497-
added_selections: List[FieldNode] = self.get_ast_fields(added_fields)
476+
added_selections: List[
477+
Union[FieldNode, InlineFragmentNode]
478+
] = self.get_ast_fields(added_fields)
498479

499480
current_selection_set: Optional[SelectionSetNode] = self.ast_field.selection_set
500481

@@ -511,6 +492,58 @@ def select(
511492

512493
return self
513494

495+
@property
496+
def type_name(self):
497+
""":meta private:"""
498+
return self._type.name
499+
500+
def __str__(self) -> str:
501+
return print_ast(self.ast_field)
502+
503+
504+
class DSLField(DSLSelection):
505+
"""The DSLField represents a GraphQL field for the DSL code.
506+
507+
Instances of this class are generated for you automatically as attributes
508+
of the :class:`DSLType`
509+
510+
If this field contains children fields, then you need to select which ones
511+
you want in the request using the :meth:`select <gql.dsl.DSLField.select>`
512+
method.
513+
"""
514+
515+
ast_field: FieldNode
516+
field: GraphQLField
517+
518+
def __init__(
519+
self,
520+
name: str,
521+
graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType],
522+
graphql_field: GraphQLField,
523+
):
524+
"""Initialize the DSLField.
525+
526+
.. warning::
527+
Don't instantiate this class yourself.
528+
Use attributes of the :class:`DSLType` instead.
529+
530+
:param name: the name of the field
531+
:param graphql_type: the GraphQL type definition from the schema
532+
:param graphql_field: the GraphQL field definition from the schema
533+
"""
534+
self._type = graphql_type
535+
self.field = graphql_field
536+
self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList())
537+
log.debug(f"Creating {self!r}")
538+
539+
def select(
540+
self, *fields: "DSLSelection", **fields_with_alias: "DSLField"
541+
) -> "DSLField":
542+
"""Calling :meth:`select <gql.dsl.DSLSelection.select>` method with
543+
corrected typing hints
544+
"""
545+
return cast("DSLField", super().select(*fields, **fields_with_alias))
546+
514547
def __call__(self, **kwargs) -> "DSLField":
515548
return self.args(**kwargs)
516549

@@ -519,7 +552,7 @@ def alias(self, alias: str) -> "DSLField":
519552
520553
.. note::
521554
You can also pass the alias directly at the
522-
:meth:`select <gql.dsl.DSLField.select>` method.
555+
:meth:`select <gql.dsl.DSLSelection.select>` method.
523556
:code:`ds.Query.human.select(my_name=ds.Character.name)` is equivalent to:
524557
:code:`ds.Query.human.select(ds.Character.name.alias("my_name"))`
525558
@@ -579,34 +612,41 @@ def _get_argument(self, name: str) -> GraphQLArgument:
579612

580613
return arg
581614

582-
@property
583-
def type_name(self):
584-
""":meta private:"""
585-
return self._type.name
615+
def __repr__(self) -> str:
616+
return (
617+
f"<{self.__class__.__name__} {self._type.name}"
618+
f"::{self.ast_field.name.value}>"
619+
)
586620

587-
def __str__(self) -> str:
588-
return print_ast(self.ast_field)
589621

590-
def __repr__(self) -> str:
591-
name = self._type.name
592-
try:
593-
name += f"::{self.ast_field.name.value}"
594-
except AttributeError:
595-
pass
596-
return f"<{self.__class__.__name__} {name}>"
622+
class DSLFragment(DSLSelection):
597623

624+
ast_field: InlineFragmentNode
598625

599-
class DSLFragment(DSLField):
600-
def __init__(
601-
self, type_condition: Optional[DSLType] = None,
602-
):
603-
self.ast_field = InlineFragmentNode() # type: ignore
604-
if type_condition:
605-
self.on(type_condition)
626+
def __init__(self):
627+
self.ast_field = InlineFragmentNode()
628+
629+
def select(
630+
self, *fields: "DSLSelection", **fields_with_alias: "DSLField"
631+
) -> "DSLFragment":
632+
"""Calling :meth:`select <gql.dsl.DSLSelection.select>` method with
633+
corrected typing hints
634+
"""
635+
return cast("DSLFragment", super().select(*fields, **fields_with_alias))
606636

607637
def on(self, type_condition: DSLType):
608638
self._type = type_condition._type
609-
self.ast_field.type_condition = NamedTypeNode( # type: ignore
639+
self.ast_field.type_condition = NamedTypeNode(
610640
name=NameNode(value=self._type.name)
611641
)
612642
return self
643+
644+
def __repr__(self) -> str:
645+
type_info = ""
646+
647+
try:
648+
type_info += f" on {self._type.name}"
649+
except AttributeError:
650+
pass
651+
652+
return f"<{self.__class__.__name__}{type_info}>"

tests/starwars/test_dsl.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,13 @@ def test_inline_fragments(ds):
435435
assert query == str(query_dsl)
436436

437437

438+
def test_inline_fragments_repr(ds):
439+
440+
assert repr(DSLFragment()) == "<DSLFragment>"
441+
442+
assert repr(DSLFragment().on(ds.Droid)) == "<DSLFragment on Droid>"
443+
444+
438445
def test_dsl_query_all_fields_should_be_instances_of_DSLField():
439446
with pytest.raises(
440447
TypeError, match="fields must be instances of DSLField. Received type:"

0 commit comments

Comments
 (0)