7
7
8
8
from graphql_relay import to_global_id
9
9
from graphql_relay .connection .connectiontypes import Edge
10
- from graphene import relay , Argument , Boolean , Int , String , Field , List , NonNull , Dynamic
11
- from graphene .relay .connection import PageInfo
10
+ from graphene import Argument , Boolean , Int , String , Field , List , NonNull , Dynamic
11
+ from graphene .relay import Connection
12
+ from graphene .relay .connection import PageInfo , ConnectionField
13
+
14
+
15
+ from .registry import get_global_registry
12
16
13
17
14
18
__author__ = 'ekampf'
@@ -49,7 +53,7 @@ def connection_from_ndb_query(query, args=None, connection_type=None, edge_type=
49
53
so pagination will only work if the array is static.
50
54
'''
51
55
args = args or {}
52
- connection_type = connection_type or relay . Connection
56
+ connection_type = connection_type or Connection
53
57
edge_type = edge_type or Edge
54
58
pageinfo_type = pageinfo_type or PageInfo
55
59
@@ -91,7 +95,7 @@ def connection_from_ndb_query(query, args=None, connection_type=None, edge_type=
91
95
)
92
96
93
97
94
- class NdbConnectionField (relay . ConnectionField ):
98
+ class NdbConnectionField (ConnectionField ):
95
99
def __init__ (self , type , transform_edges = None , * args , ** kwargs ):
96
100
super (NdbConnectionField , self ).__init__ (
97
101
type ,
@@ -104,13 +108,23 @@ def __init__(self, type, transform_edges=None, *args, **kwargs):
104
108
105
109
self .transform_edges = transform_edges
106
110
111
+ @property
112
+ def type (self ):
113
+ from .types import NdbObjectType
114
+ _type = super (ConnectionField , self ).type
115
+ assert issubclass (_type , NdbObjectType ), (
116
+ "NdbConnectionField only accepts NdbObjectType types"
117
+ )
118
+ assert _type ._meta .connection , "The type {} doesn't have a connection" .format (_type .__name__ )
119
+ return _type ._meta .connection
120
+
107
121
@property
108
122
def model (self ):
109
123
return self .type ._meta .node ._meta .model
110
124
111
125
@staticmethod
112
- def connection_resolver (resolver , connection , model , transform_edges , root , args , context , info ):
113
- ndb_query = resolver (root , args , context , info )
126
+ def connection_resolver (resolver , connection , model , transform_edges , root , info , ** args ):
127
+ ndb_query = resolver (root , info , ** args )
114
128
if ndb_query is None :
115
129
ndb_query = model .query ()
116
130
@@ -121,26 +135,29 @@ def connection_resolver(resolver, connection, model, transform_edges, root, args
121
135
edge_type = connection .Edge ,
122
136
pageinfo_type = PageInfo ,
123
137
transform_edges = transform_edges ,
124
- context = context
138
+ context = info . context
125
139
)
126
140
127
141
def get_resolver (self , parent_resolver ):
128
- return partial (self .connection_resolver , parent_resolver , self .type , self .model , self .transform_edges )
142
+ return partial (
143
+ self .connection_resolver , parent_resolver , self .type , self .model , self .transform_edges
144
+ )
129
145
130
146
131
147
class DynamicNdbKeyStringField (Dynamic ):
132
- def __init__ (self , ndb_key_prop , * args , ** kwargs ):
148
+ def __init__ (self , ndb_key_prop , registry = None , * args , ** kwargs ):
133
149
kind = ndb_key_prop ._kind
150
+ if not registry :
151
+ registry = get_global_registry ()
134
152
135
153
def get_type ():
136
- from .types import NdbObjectTypeMeta
137
154
kind_name = kind if isinstance (kind , six .string_types ) else kind .__name__
138
155
139
- if not NdbObjectTypeMeta .REGISTRY .get (kind_name ):
156
+ _type = registry .get_type_for_model_name (kind_name )
157
+ if not _type :
140
158
return None
141
159
142
- global_type_name = NdbObjectTypeMeta .REGISTRY [kind_name ].__name__
143
- return NdbKeyStringField (ndb_key_prop , global_type_name )
160
+ return NdbKeyStringField (ndb_key_prop , _type ._meta .name )
144
161
145
162
super (DynamicNdbKeyStringField , self ).__init__ (
146
163
get_type ,
@@ -149,13 +166,15 @@ def get_type():
149
166
150
167
151
168
class DynamicNdbKeyReferenceField (Dynamic ):
152
- def __init__ (self , ndb_key_prop , * args , ** kwargs ):
169
+ def __init__ (self , ndb_key_prop , registry = None , * args , ** kwargs ):
153
170
kind = ndb_key_prop ._kind
171
+ if not registry :
172
+ registry = get_global_registry ()
154
173
155
174
def get_type ():
156
- from .types import NdbObjectTypeMeta
157
175
kind_name = kind if isinstance (kind , six .string_types ) else kind .__name__
158
- _type = NdbObjectTypeMeta .REGISTRY .get (kind_name )
176
+
177
+ _type = registry .get_type_for_model_name (kind_name )
159
178
if not _type :
160
179
return None
161
180
@@ -187,8 +206,8 @@ def __init__(self, ndb_key_prop, graphql_type_name, *args, **kwargs):
187
206
188
207
super (NdbKeyStringField , self ).__init__ (_type , * args , ** kwargs )
189
208
190
- def resolve_key_to_string (self , entity , args , context , info ):
191
- is_global_id = not args . get ( ' ndb' , False )
209
+ def resolve_key_to_string (self , entity , info , ndb = False ):
210
+ is_global_id = not ndb
192
211
key_value = self .__ndb_key_prop ._get_user_value (entity )
193
212
if not key_value :
194
213
return None
@@ -218,7 +237,7 @@ def __init__(self, ndb_key_prop, graphql_type, *args, **kwargs):
218
237
219
238
super (NdbKeyReferenceField , self ).__init__ (_type , * args , ** kwargs )
220
239
221
- def resolve_key_reference (self , entity , args , context , info ):
240
+ def resolve_key_reference (self , entity , info ):
222
241
key_value = self .__ndb_key_prop ._get_user_value (entity )
223
242
if not key_value :
224
243
return None
0 commit comments