diff --git a/src/execution/__tests__/subscribe-test.ts b/src/execution/__tests__/subscribe-test.ts index c8f4aca7f8b..d65b3136cd9 100644 --- a/src/execution/__tests__/subscribe-test.ts +++ b/src/execution/__tests__/subscribe-test.ts @@ -5,6 +5,7 @@ import { expectJSON } from '../../__testUtils__/expectJSON'; import { resolveOnNextTick } from '../../__testUtils__/resolveOnNextTick'; import { isAsyncIterable } from '../../jsutils/isAsyncIterable'; +import { isPromise } from '../../jsutils/isPromise'; import { parse } from '../../language/parser'; @@ -135,9 +136,6 @@ async function expectPromise(promise: Promise) { } return { - toReject() { - expect(caughtError).to.be.an.instanceOf(Error); - }, toRejectWith(message: string) { expect(caughtError).to.be.an.instanceOf(Error); expect(caughtError).to.have.property('message', message); @@ -379,24 +377,22 @@ describe('Subscription Initialization Phase', () => { }); // @ts-expect-error (schema must not be null) - (await expectPromise(subscribe({ schema: null, document }))).toRejectWith( + expect(() => subscribe({ schema: null, document })).to.throw( 'Expected null to be a GraphQL schema.', ); // @ts-expect-error - (await expectPromise(subscribe({ document }))).toRejectWith( + expect(() => subscribe({ document })).to.throw( 'Expected undefined to be a GraphQL schema.', ); // @ts-expect-error (document must not be null) - (await expectPromise(subscribe({ schema, document: null }))).toRejectWith( + expect(() => subscribe({ schema, document: null })).to.throw( 'Must provide document.', ); // @ts-expect-error - (await expectPromise(subscribe({ schema }))).toRejectWith( - 'Must provide document.', - ); + expect(() => subscribe({ schema })).to.throw('Must provide document.'); }); it('resolves to an error if schema does not support subscriptions', async () => { @@ -450,11 +446,17 @@ describe('Subscription Initialization Phase', () => { }); // @ts-expect-error - (await expectPromise(subscribe({ schema, document: {} }))).toReject(); + expect(() => subscribe({ schema, document: {} })).to.throw(); }); it('throws an error if subscribe does not return an iterator', async () => { - (await expectPromise(subscribeWithBadFn(() => 'test'))).toRejectWith( + expect(() => subscribeWithBadFn(() => 'test')).to.throw( + 'Subscription field must return Async Iterable. Received: "test".', + ); + + const result = subscribeWithBadFn(() => Promise.resolve('test')); + assert(isPromise(result)); + (await expectPromise(result)).toRejectWith( 'Subscription field must return Async Iterable. Received: "test".', ); }); @@ -472,12 +474,12 @@ describe('Subscription Initialization Phase', () => { expectJSON( // Returning an error - await subscribeWithBadFn(() => new Error('test error')), + subscribeWithBadFn(() => new Error('test error')), ).toDeepEqual(expectedResult); expectJSON( // Throwing an error - await subscribeWithBadFn(() => { + subscribeWithBadFn(() => { throw new Error('test error'); }), ).toDeepEqual(expectedResult); diff --git a/src/execution/subscribe.ts b/src/execution/subscribe.ts index 7a04480bf44..f86465d3859 100644 --- a/src/execution/subscribe.ts +++ b/src/execution/subscribe.ts @@ -1,7 +1,9 @@ import { inspect } from '../jsutils/inspect'; import { isAsyncIterable } from '../jsutils/isAsyncIterable'; +import { isPromise } from '../jsutils/isPromise'; import type { Maybe } from '../jsutils/Maybe'; import { addPath, pathToArray } from '../jsutils/Path'; +import type { PromiseOrValue } from '../jsutils/PromiseOrValue'; import { GraphQLError } from '../error/GraphQLError'; import { locatedError } from '../error/locatedError'; @@ -47,9 +49,11 @@ import { getArgumentValues } from './values'; * * Accepts either an object with named arguments, or individual arguments. */ -export async function subscribe( +export function subscribe( args: ExecutionArgs, -): Promise | ExecutionResult> { +): PromiseOrValue< + AsyncGenerator | ExecutionResult +> { const { schema, document, @@ -61,7 +65,7 @@ export async function subscribe( subscribeFieldResolver, } = args; - const resultOrStream = await createSourceEventStream( + const resultOrStream = createSourceEventStream( schema, document, rootValue, @@ -71,6 +75,42 @@ export async function subscribe( subscribeFieldResolver, ); + if (isPromise(resultOrStream)) { + return resultOrStream.then((resolvedResultOrStream) => + mapSourceToResponse( + schema, + document, + resolvedResultOrStream, + contextValue, + variableValues, + operationName, + fieldResolver, + ), + ); + } + + return mapSourceToResponse( + schema, + document, + resultOrStream, + contextValue, + variableValues, + operationName, + fieldResolver, + ); +} + +function mapSourceToResponse( + schema: GraphQLSchema, + document: DocumentNode, + resultOrStream: ExecutionResult | AsyncIterable, + contextValue?: unknown, + variableValues?: Maybe<{ readonly [variable: string]: unknown }>, + operationName?: Maybe, + fieldResolver?: Maybe>, +): PromiseOrValue< + AsyncGenerator | ExecutionResult +> { if (!isAsyncIterable(resultOrStream)) { return resultOrStream; } @@ -81,7 +121,7 @@ export async function subscribe( // the GraphQL specification. The `execute` function provides the // "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the // "ExecuteQuery" algorithm, for which `execute` is also used. - const mapSourceToResponse = (payload: unknown) => + return mapAsyncIterator(resultOrStream, (payload: unknown) => execute({ schema, document, @@ -90,10 +130,8 @@ export async function subscribe( variableValues, operationName, fieldResolver, - }); - - // Map every source value to a ExecutionResult value as described above. - return mapAsyncIterator(resultOrStream, mapSourceToResponse); + }), + ); } /** @@ -124,7 +162,7 @@ export async function subscribe( * or otherwise separating these two steps. For more on this, see the * "Supporting Subscriptions at Scale" information in the GraphQL specification. */ -export async function createSourceEventStream( +export function createSourceEventStream( schema: GraphQLSchema, document: DocumentNode, rootValue?: unknown, @@ -132,7 +170,7 @@ export async function createSourceEventStream( variableValues?: Maybe<{ readonly [variable: string]: unknown }>, operationName?: Maybe, subscribeFieldResolver?: Maybe>, -): Promise | ExecutionResult> { +): PromiseOrValue | ExecutionResult> { // If arguments are missing or incorrectly typed, this is an internal // developer mistake which should throw an early error. assertValidExecutionArguments(schema, document, variableValues); @@ -155,17 +193,22 @@ export async function createSourceEventStream( } try { - const eventStream = await executeSubscription(exeContext); - - // Assert field returned an event stream, otherwise yield an error. - if (!isAsyncIterable(eventStream)) { - throw new Error( - 'Subscription field must return Async Iterable. ' + - `Received: ${inspect(eventStream)}.`, - ); + const eventStream = executeSubscription(exeContext); + + if (isPromise(eventStream)) { + return eventStream + .then((resolvedEventStream) => ensureAsyncIterable(resolvedEventStream)) + .then(undefined, (error) => { + // If it GraphQLError, report it as an ExecutionResult, containing only errors and no data. + // Otherwise treat the error as a system-class error and re-throw it. + if (error instanceof GraphQLError) { + return { errors: [error] }; + } + throw error; + }); } - return eventStream; + return ensureAsyncIterable(eventStream); } catch (error) { // If it GraphQLError, report it as an ExecutionResult, containing only errors and no data. // Otherwise treat the error as a system-class error and re-throw it. @@ -176,9 +219,19 @@ export async function createSourceEventStream( } } -async function executeSubscription( - exeContext: ExecutionContext, -): Promise { +function ensureAsyncIterable(eventStream: unknown): AsyncIterable { + // Assert field returned an event stream, otherwise yield an error. + if (!isAsyncIterable(eventStream)) { + throw new Error( + 'Subscription field must return Async Iterable. ' + + `Received: ${inspect(eventStream)}.`, + ); + } + + return eventStream; +} + +function executeSubscription(exeContext: ExecutionContext): unknown { const { schema, fragments, operation, variableValues, rootValue } = exeContext; @@ -233,13 +286,26 @@ async function executeSubscription( // Call the `subscribe()` resolver or the default resolver to produce an // AsyncIterable yielding raw payloads. const resolveFn = fieldDef.subscribe ?? exeContext.subscribeFieldResolver; - const eventStream = await resolveFn(rootValue, args, contextValue, info); - if (eventStream instanceof Error) { - throw eventStream; + const eventStream = resolveFn(rootValue, args, contextValue, info); + + if (isPromise(eventStream)) { + return eventStream + .then((resolvedEventStream) => throwReturnedError(resolvedEventStream)) + .then(undefined, (error) => { + throw locatedError(error, fieldNodes, pathToArray(path)); + }); } - return eventStream; + + return throwReturnedError(eventStream); } catch (error) { throw locatedError(error, fieldNodes, pathToArray(path)); } } + +function throwReturnedError(eventStream: unknown): unknown { + if (eventStream instanceof Error) { + throw eventStream; + } + return eventStream; +}