11import { inspect } from '../jsutils/inspect' ;
22import { isAsyncIterable } from '../jsutils/isAsyncIterable' ;
3+ import { isPromise } from '../jsutils/isPromise' ;
34import type { Maybe } from '../jsutils/Maybe' ;
5+ import type { Path } from '../jsutils/Path' ;
46import { addPath , pathToArray } from '../jsutils/Path' ;
7+ import type { PromiseOrValue } from '../jsutils/PromiseOrValue' ;
58
69import { GraphQLError } from '../error/GraphQLError' ;
710import { locatedError } from '../error/locatedError' ;
811
9- import type { DocumentNode } from '../language/ast' ;
12+ import type { DocumentNode , FieldNode } from '../language/ast' ;
1013
1114import type { GraphQLFieldResolver } from '../type/definition' ;
1215import type { GraphQLSchema } from '../type/schema' ;
@@ -47,9 +50,11 @@ import { getArgumentValues } from './values';
4750 *
4851 * Accepts either an object with named arguments, or individual arguments.
4952 */
50- export async function subscribe (
53+ export function subscribe (
5154 args : ExecutionArgs ,
52- ) : Promise < AsyncGenerator < ExecutionResult , void , void > | ExecutionResult > {
55+ ) : PromiseOrValue <
56+ AsyncGenerator < ExecutionResult , void , void > | ExecutionResult
57+ > {
5358 const {
5459 schema,
5560 document,
@@ -61,7 +66,7 @@ export async function subscribe(
6166 subscribeFieldResolver,
6267 } = args ;
6368
64- const resultOrStream = await createSourceEventStream (
69+ const resultOrStream = createSourceEventStream (
6570 schema ,
6671 document ,
6772 rootValue ,
@@ -71,6 +76,42 @@ export async function subscribe(
7176 subscribeFieldResolver ,
7277 ) ;
7378
79+ if ( isPromise ( resultOrStream ) ) {
80+ return resultOrStream . then ( ( resolvedResultOrStream ) =>
81+ mapSourceToResponse (
82+ schema ,
83+ document ,
84+ resolvedResultOrStream ,
85+ contextValue ,
86+ variableValues ,
87+ operationName ,
88+ fieldResolver ,
89+ ) ,
90+ ) ;
91+ }
92+
93+ return mapSourceToResponse (
94+ schema ,
95+ document ,
96+ resultOrStream ,
97+ contextValue ,
98+ variableValues ,
99+ operationName ,
100+ fieldResolver ,
101+ ) ;
102+ }
103+
104+ function mapSourceToResponse (
105+ schema : GraphQLSchema ,
106+ document : DocumentNode ,
107+ resultOrStream : ExecutionResult | AsyncIterable < unknown > ,
108+ contextValue ?: unknown ,
109+ variableValues ?: Maybe < { readonly [ variable : string ] : unknown } > ,
110+ operationName ?: Maybe < string > ,
111+ fieldResolver ?: Maybe < GraphQLFieldResolver < any , any > > ,
112+ ) : PromiseOrValue <
113+ AsyncGenerator < ExecutionResult , void , void > | ExecutionResult
114+ > {
74115 if ( ! isAsyncIterable ( resultOrStream ) ) {
75116 return resultOrStream ;
76117 }
@@ -81,7 +122,7 @@ export async function subscribe(
81122 // the GraphQL specification. The `execute` function provides the
82123 // "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
83124 // "ExecuteQuery" algorithm, for which `execute` is also used.
84- const mapSourceToResponse = ( payload : unknown ) =>
125+ return mapAsyncIterator ( resultOrStream , ( payload : unknown ) =>
85126 execute ( {
86127 schema,
87128 document,
@@ -90,10 +131,8 @@ export async function subscribe(
90131 variableValues,
91132 operationName,
92133 fieldResolver,
93- } ) ;
94-
95- // Map every source value to a ExecutionResult value as described above.
96- return mapAsyncIterator ( resultOrStream , mapSourceToResponse ) ;
134+ } ) ,
135+ ) ;
97136}
98137
99138/**
@@ -124,15 +163,15 @@ export async function subscribe(
124163 * or otherwise separating these two steps. For more on this, see the
125164 * "Supporting Subscriptions at Scale" information in the GraphQL specification.
126165 */
127- export async function createSourceEventStream (
166+ export function createSourceEventStream (
128167 schema : GraphQLSchema ,
129168 document : DocumentNode ,
130169 rootValue ?: unknown ,
131170 contextValue ?: unknown ,
132171 variableValues ?: Maybe < { readonly [ variable : string ] : unknown } > ,
133172 operationName ?: Maybe < string > ,
134173 subscribeFieldResolver ?: Maybe < GraphQLFieldResolver < any , any > > ,
135- ) : Promise < AsyncIterable < unknown > | ExecutionResult > {
174+ ) : PromiseOrValue < AsyncIterable < unknown > | ExecutionResult > {
136175 // If arguments are missing or incorrectly typed, this is an internal
137176 // developer mistake which should throw an early error.
138177 assertValidExecutionArguments ( schema , document , variableValues ) ;
@@ -155,17 +194,20 @@ export async function createSourceEventStream(
155194 }
156195
157196 try {
158- const eventStream = await executeSubscription ( exeContext ) ;
197+ const eventStream = executeSubscription ( exeContext ) ;
198+ if ( isPromise ( eventStream ) ) {
199+ return eventStream . then ( undefined , ( error ) => ( { errors : [ error ] } ) ) ;
200+ }
159201
160202 return eventStream ;
161203 } catch ( error ) {
162204 return { errors : [ error ] } ;
163205 }
164206}
165207
166- async function executeSubscription (
208+ function executeSubscription (
167209 exeContext : ExecutionContext ,
168- ) : Promise < AsyncIterable < unknown > > {
210+ ) : PromiseOrValue < AsyncIterable < unknown > | ExecutionResult > {
169211 const { schema, fragments, operation, variableValues, rootValue } =
170212 exeContext ;
171213
@@ -220,22 +262,44 @@ async function executeSubscription(
220262 // Call the `subscribe()` resolver or the default resolver to produce an
221263 // AsyncIterable yielding raw payloads.
222264 const resolveFn = fieldDef . subscribe ?? exeContext . subscribeFieldResolver ;
223- const eventStream = await resolveFn ( rootValue , args , contextValue , info ) ;
224-
225- if ( eventStream instanceof Error ) {
226- throw eventStream ;
227- }
265+ const eventStream = resolveFn ( rootValue , args , contextValue , info ) ;
228266
229- // Assert field returned an event stream, otherwise yield an error.
230- if ( ! isAsyncIterable ( eventStream ) ) {
231- throw new GraphQLError (
232- 'Subscription field must return Async Iterable. ' +
233- `Received: ${ inspect ( eventStream ) } .` ,
267+ if ( isPromise ( eventStream ) ) {
268+ return eventStream . then (
269+ ( resolvedEventStream ) =>
270+ ensureAsyncIterable ( resolvedEventStream , fieldNodes , path ) ,
271+ ( error ) => {
272+ throw locatedError ( error , fieldNodes , pathToArray ( path ) ) ;
273+ } ,
234274 ) ;
235275 }
236276
237- return eventStream ;
277+ return ensureAsyncIterable ( eventStream , fieldNodes , path ) ;
238278 } catch ( error ) {
239279 throw locatedError ( error , fieldNodes , pathToArray ( path ) ) ;
240280 }
241281}
282+
283+ function ensureAsyncIterable (
284+ eventStream : unknown ,
285+ fieldNodes : ReadonlyArray < FieldNode > ,
286+ path : Path ,
287+ ) : AsyncIterable < unknown > {
288+ if ( eventStream instanceof Error ) {
289+ throw locatedError ( eventStream , fieldNodes , pathToArray ( path ) ) ;
290+ }
291+
292+ // Assert field returned an event stream, otherwise yield an error.
293+ if ( ! isAsyncIterable ( eventStream ) ) {
294+ throw locatedError (
295+ new GraphQLError (
296+ 'Subscription field must return Async Iterable. ' +
297+ `Received: ${ inspect ( eventStream ) } .` ,
298+ ) ,
299+ fieldNodes ,
300+ pathToArray ( path ) ,
301+ ) ;
302+ }
303+
304+ return eventStream ;
305+ }
0 commit comments