diff --git a/src/execution/__tests__/subscribe-test.ts b/src/execution/__tests__/subscribe-test.ts index d943ef4006..2e7077ff1a 100644 --- a/src/execution/__tests__/subscribe-test.ts +++ b/src/execution/__tests__/subscribe-test.ts @@ -5,6 +5,8 @@ import { expectJSON } from '../../__testUtils__/expectJSON'; import { resolveOnNextTick } from '../../__testUtils__/resolveOnNextTick'; import { isAsyncIterable } from '../../jsutils/isAsyncIterable'; +import { isPromise } from '../../jsutils/isPromise'; +import type { PromiseOrValue } from '../../jsutils/PromiseOrValue'; import { parse } from '../../language/parser'; @@ -135,9 +137,6 @@ async function expectPromise(promise: Promise) { } return { - toReject() { - expect(caughtError).to.be.an.instanceOf(Error); - }, toRejectWith(message: string) { expect(caughtError).to.be.an.instanceOf(Error); expect(caughtError).to.have.property('message', message); @@ -152,9 +151,9 @@ const DummyQueryType = new GraphQLObjectType({ }, }); -async function subscribeWithBadFn( +function subscribeWithBadFn( subscribeFn: () => unknown, -): Promise { +): PromiseOrValue { const schema = new GraphQLSchema({ query: DummyQueryType, subscription: new GraphQLObjectType({ @@ -165,13 +164,28 @@ async function subscribeWithBadFn( }), }); const document = parse('subscription { foo }'); - const result = await subscribe({ schema, document }); - assert(!isAsyncIterable(result)); - expectJSON(await createSourceEventStream(schema, document)).toDeepEqual( - result, - ); - return result; + const subscribeResult = subscribe({ schema, document }); + const streamResult = createSourceEventStream(schema, document); + + if (isPromise(subscribeResult)) { + assert(isPromise(streamResult)); + return Promise.all([subscribeResult, streamResult]).then((resolved) => + expectEquivalentStreamErrors(resolved[0], resolved[1]), + ); + } + + assert(!isPromise(streamResult)); + return expectEquivalentStreamErrors(subscribeResult, streamResult); +} + +function expectEquivalentStreamErrors( + subscribeResult: ExecutionResult | AsyncGenerator, + createSourceEventStreamResult: ExecutionResult | AsyncIterable, +): ExecutionResult { + assert(!isAsyncIterable(subscribeResult)); + expectJSON(createSourceEventStreamResult).toDeepEqual(subscribeResult); + return subscribeResult; } /* eslint-disable @typescript-eslint/require-await */ @@ -379,24 +393,22 @@ describe('Subscription Initialization Phase', () => { }); // @ts-expect-error (schema must not be null) - (await expectPromise(subscribe({ schema: null, document }))).toRejectWith( + expect(() => subscribe({ schema: null, document })).to.throw( 'Expected null to be a GraphQL schema.', ); // @ts-expect-error - (await expectPromise(subscribe({ document }))).toRejectWith( + expect(() => subscribe({ document })).to.throw( 'Expected undefined to be a GraphQL schema.', ); // @ts-expect-error (document must not be null) - (await expectPromise(subscribe({ schema, document: null }))).toRejectWith( + expect(() => subscribe({ schema, document: null })).to.throw( 'Must provide document.', ); // @ts-expect-error - (await expectPromise(subscribe({ schema }))).toRejectWith( - 'Must provide document.', - ); + expect(() => subscribe({ schema })).to.throw('Must provide document.'); }); it('resolves to an error if schema does not support subscriptions', async () => { @@ -450,11 +462,11 @@ describe('Subscription Initialization Phase', () => { }); // @ts-expect-error - (await expectPromise(subscribe({ schema, document: {} }))).toReject(); + expect(() => subscribe({ schema, document: {} })).to.throw(); }); it('throws an error if subscribe does not return an iterator', async () => { - expectJSON(await subscribeWithBadFn(() => 'test')).toDeepEqual({ + const expectedResult = { errors: [ { message: @@ -463,7 +475,13 @@ describe('Subscription Initialization Phase', () => { path: ['foo'], }, ], - }); + }; + + expectJSON(subscribeWithBadFn(() => 'test')).toDeepEqual(expectedResult); + + const result = subscribeWithBadFn(() => Promise.resolve('test')); + assert(isPromise(result)); + expectJSON(await result).toDeepEqual(expectedResult); }); it('resolves to an error for subscription resolver errors', async () => { @@ -479,12 +497,12 @@ describe('Subscription Initialization Phase', () => { expectJSON( // Returning an error - await subscribeWithBadFn(() => new Error('test error')), + subscribeWithBadFn(() => new Error('test error')), ).toDeepEqual(expectedResult); expectJSON( // Throwing an error - await subscribeWithBadFn(() => { + subscribeWithBadFn(() => { throw new Error('test error'); }), ).toDeepEqual(expectedResult); diff --git a/src/execution/subscribe.ts b/src/execution/subscribe.ts index e54949830c..526b66459e 100644 --- a/src/execution/subscribe.ts +++ b/src/execution/subscribe.ts @@ -1,12 +1,15 @@ import { inspect } from '../jsutils/inspect'; import { isAsyncIterable } from '../jsutils/isAsyncIterable'; +import { isPromise } from '../jsutils/isPromise'; import type { Maybe } from '../jsutils/Maybe'; +import type { Path } from '../jsutils/Path'; import { addPath, pathToArray } from '../jsutils/Path'; +import type { PromiseOrValue } from '../jsutils/PromiseOrValue'; import { GraphQLError } from '../error/GraphQLError'; import { locatedError } from '../error/locatedError'; -import type { DocumentNode } from '../language/ast'; +import type { DocumentNode, FieldNode } from '../language/ast'; import type { GraphQLFieldResolver } from '../type/definition'; import type { GraphQLSchema } from '../type/schema'; @@ -47,9 +50,11 @@ import { getArgumentValues } from './values'; * * Accepts either an object with named arguments, or individual arguments. */ -export async function subscribe( +export function subscribe( args: ExecutionArgs, -): Promise | ExecutionResult> { +): PromiseOrValue< + AsyncGenerator | ExecutionResult +> { const { schema, document, @@ -61,7 +66,7 @@ export async function subscribe( subscribeFieldResolver, } = args; - const resultOrStream = await createSourceEventStream( + const resultOrStream = createSourceEventStream( schema, document, rootValue, @@ -71,6 +76,42 @@ export async function subscribe( subscribeFieldResolver, ); + if (isPromise(resultOrStream)) { + return resultOrStream.then((resolvedResultOrStream) => + mapSourceToResponse( + schema, + document, + resolvedResultOrStream, + contextValue, + variableValues, + operationName, + fieldResolver, + ), + ); + } + + return mapSourceToResponse( + schema, + document, + resultOrStream, + contextValue, + variableValues, + operationName, + fieldResolver, + ); +} + +function mapSourceToResponse( + schema: GraphQLSchema, + document: DocumentNode, + resultOrStream: ExecutionResult | AsyncIterable, + contextValue?: unknown, + variableValues?: Maybe<{ readonly [variable: string]: unknown }>, + operationName?: Maybe, + fieldResolver?: Maybe>, +): PromiseOrValue< + AsyncGenerator | ExecutionResult +> { if (!isAsyncIterable(resultOrStream)) { return resultOrStream; } @@ -81,7 +122,7 @@ export async function subscribe( // the GraphQL specification. The `execute` function provides the // "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the // "ExecuteQuery" algorithm, for which `execute` is also used. - const mapSourceToResponse = (payload: unknown) => + return mapAsyncIterator(resultOrStream, (payload: unknown) => execute({ schema, document, @@ -90,10 +131,8 @@ export async function subscribe( variableValues, operationName, fieldResolver, - }); - - // Map every source value to a ExecutionResult value as described above. - return mapAsyncIterator(resultOrStream, mapSourceToResponse); + }), + ); } /** @@ -124,7 +163,7 @@ export async function subscribe( * or otherwise separating these two steps. For more on this, see the * "Supporting Subscriptions at Scale" information in the GraphQL specification. */ -export async function createSourceEventStream( +export function createSourceEventStream( schema: GraphQLSchema, document: DocumentNode, rootValue?: unknown, @@ -132,7 +171,7 @@ export async function createSourceEventStream( variableValues?: Maybe<{ readonly [variable: string]: unknown }>, operationName?: Maybe, subscribeFieldResolver?: Maybe>, -): Promise | ExecutionResult> { +): PromiseOrValue | ExecutionResult> { // If arguments are missing or incorrectly typed, this is an internal // developer mistake which should throw an early error. assertValidExecutionArguments(schema, document, variableValues); @@ -155,7 +194,10 @@ export async function createSourceEventStream( } try { - const eventStream = await executeSubscription(exeContext); + const eventStream = executeSubscription(exeContext); + if (isPromise(eventStream)) { + return eventStream.then(undefined, (error) => ({ errors: [error] })); + } return eventStream; } catch (error) { @@ -163,9 +205,9 @@ export async function createSourceEventStream( } } -async function executeSubscription( +function executeSubscription( exeContext: ExecutionContext, -): Promise> { +): PromiseOrValue | ExecutionResult> { const { schema, fragments, operation, variableValues, rootValue } = exeContext; @@ -220,22 +262,44 @@ async function executeSubscription( // Call the `subscribe()` resolver or the default resolver to produce an // AsyncIterable yielding raw payloads. const resolveFn = fieldDef.subscribe ?? exeContext.subscribeFieldResolver; - const eventStream = await resolveFn(rootValue, args, contextValue, info); - - if (eventStream instanceof Error) { - throw eventStream; - } + const eventStream = resolveFn(rootValue, args, contextValue, info); - // Assert field returned an event stream, otherwise yield an error. - if (!isAsyncIterable(eventStream)) { - throw new GraphQLError( - 'Subscription field must return Async Iterable. ' + - `Received: ${inspect(eventStream)}.`, + if (isPromise(eventStream)) { + return eventStream.then( + (resolvedEventStream) => + ensureAsyncIterable(resolvedEventStream, fieldNodes, path), + (error) => { + throw locatedError(error, fieldNodes, pathToArray(path)); + }, ); } - return eventStream; + return ensureAsyncIterable(eventStream, fieldNodes, path); } catch (error) { throw locatedError(error, fieldNodes, pathToArray(path)); } } + +function ensureAsyncIterable( + eventStream: unknown, + fieldNodes: ReadonlyArray, + path: Path, +): AsyncIterable { + if (eventStream instanceof Error) { + throw locatedError(eventStream, fieldNodes, pathToArray(path)); + } + + // Assert field returned an event stream, otherwise yield an error. + if (!isAsyncIterable(eventStream)) { + throw locatedError( + new GraphQLError( + 'Subscription field must return Async Iterable. ' + + `Received: ${inspect(eventStream)}.`, + ), + fieldNodes, + pathToArray(path), + ); + } + + return eventStream; +}