Skip to content

defer/stream: fix flattenAsyncIterable concurrency #3710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 44 additions & 63 deletions src/execution/__tests__/flattenAsyncIterable-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,20 @@ import { describe, it } from 'mocha';
import { flattenAsyncIterable } from '../flattenAsyncIterable';

describe('flattenAsyncIterable', () => {
it('does not modify an already flat async generator', async () => {
async function* source() {
yield await Promise.resolve(1);
yield await Promise.resolve(2);
yield await Promise.resolve(3);
}

const result = flattenAsyncIterable(source());

expect(await result.next()).to.deep.equal({ value: 1, done: false });
expect(await result.next()).to.deep.equal({ value: 2, done: false });
expect(await result.next()).to.deep.equal({ value: 3, done: false });
expect(await result.next()).to.deep.equal({
value: undefined,
done: true,
});
});

it('does not modify an already flat async iterator', async () => {
const items = [1, 2, 3];

const iterator: any = {
[Symbol.asyncIterator]() {
return this;
},
next() {
return Promise.resolve({
done: items.length === 0,
value: items.shift(),
});
},
};

const result = flattenAsyncIterable(iterator);

expect(await result.next()).to.deep.equal({ value: 1, done: false });
expect(await result.next()).to.deep.equal({ value: 2, done: false });
expect(await result.next()).to.deep.equal({ value: 3, done: false });
expect(await result.next()).to.deep.equal({
value: undefined,
done: true,
});
});

it('flatten nested async generators', async () => {
async function* source() {
yield await Promise.resolve(1);
yield await Promise.resolve(2);
yield await Promise.resolve(
(async function* nested(): AsyncGenerator<number, void, void> {
yield await Promise.resolve(1.1);
yield await Promise.resolve(1.2);
})(),
);
yield await Promise.resolve(
(async function* nested(): AsyncGenerator<number, void, void> {
yield await Promise.resolve(2.1);
yield await Promise.resolve(2.2);
})(),
);
yield await Promise.resolve(3);
}

const doubles = flattenAsyncIterable(source());
Expand All @@ -67,13 +26,17 @@ describe('flattenAsyncIterable', () => {
for await (const x of doubles) {
result.push(x);
}
expect(result).to.deep.equal([1, 2, 2.1, 2.2, 3]);
expect(result).to.deep.equal([1.1, 1.2, 2.1, 2.2]);
});

it('allows returning early from a nested async generator', async () => {
async function* source() {
yield await Promise.resolve(1);
yield await Promise.resolve(2);
yield await Promise.resolve(
(async function* nested(): AsyncGenerator<number, void, void> {
yield await Promise.resolve(1.1);
yield await Promise.resolve(1.2);
})(),
);
yield await Promise.resolve(
(async function* nested(): AsyncGenerator<number, void, void> {
yield await Promise.resolve(2.1); /* c8 ignore start */
Expand All @@ -82,14 +45,19 @@ describe('flattenAsyncIterable', () => {
})(),
);
// Not reachable, early return
yield await Promise.resolve(3);
yield await Promise.resolve(
(async function* nested(): AsyncGenerator<number, void, void> {
yield await Promise.resolve(3.1);
yield await Promise.resolve(3.2);
})(),
);
}
/* c8 ignore stop */

const doubles = flattenAsyncIterable(source());

expect(await doubles.next()).to.deep.equal({ value: 1, done: false });
expect(await doubles.next()).to.deep.equal({ value: 2, done: false });
expect(await doubles.next()).to.deep.equal({ value: 1.1, done: false });
expect(await doubles.next()).to.deep.equal({ value: 1.2, done: false });
expect(await doubles.next()).to.deep.equal({ value: 2.1, done: false });

// Early return
Expand All @@ -111,8 +79,12 @@ describe('flattenAsyncIterable', () => {

it('allows throwing errors from a nested async generator', async () => {
async function* source() {
yield await Promise.resolve(1);
yield await Promise.resolve(2);
yield await Promise.resolve(
(async function* nested(): AsyncGenerator<number, void, void> {
yield await Promise.resolve(1.1);
yield await Promise.resolve(1.2);
})(),
);
yield await Promise.resolve(
(async function* nested(): AsyncGenerator<number, void, void> {
yield await Promise.resolve(2.1); /* c8 ignore start */
Expand All @@ -121,14 +93,19 @@ describe('flattenAsyncIterable', () => {
})(),
);
// Not reachable, early return
yield await Promise.resolve(3);
yield await Promise.resolve(
(async function* nested(): AsyncGenerator<number, void, void> {
yield await Promise.resolve(3.1);
yield await Promise.resolve(3.2);
})(),
);
}
/* c8 ignore stop */

const doubles = flattenAsyncIterable(source());

expect(await doubles.next()).to.deep.equal({ value: 1, done: false });
expect(await doubles.next()).to.deep.equal({ value: 2, done: false });
expect(await doubles.next()).to.deep.equal({ value: 1.1, done: false });
expect(await doubles.next()).to.deep.equal({ value: 1.2, done: false });
expect(await doubles.next()).to.deep.equal({ value: 2.1, done: false });

// Throw error
Expand All @@ -142,16 +119,20 @@ describe('flattenAsyncIterable', () => {
}
expect(caughtError).to.equal('ouch');
});
/* c8 ignore start */
it.skip('completely yields sub-iterables even when next() called in parallel', async () => {
it('completely yields sub-iterables even when next() called in parallel', async () => {
async function* source() {
yield await Promise.resolve(
(async function* nested(): AsyncGenerator<number, void, void> {
yield await Promise.resolve(1.1);
yield await Promise.resolve(1.2);
})(),
);
yield await Promise.resolve(2);
yield await Promise.resolve(
(async function* nested(): AsyncGenerator<number, void, void> {
yield await Promise.resolve(2.1);
yield await Promise.resolve(2.2);
})(),
);
}

const result = flattenAsyncIterable(source());
Expand All @@ -160,11 +141,11 @@ describe('flattenAsyncIterable', () => {
const promise2 = result.next();
expect(await promise1).to.deep.equal({ value: 1.1, done: false });
expect(await promise2).to.deep.equal({ value: 1.2, done: false });
expect(await result.next()).to.deep.equal({ value: 2, done: false });
expect(await result.next()).to.deep.equal({ value: 2.1, done: false });
expect(await result.next()).to.deep.equal({ value: 2.2, done: false });
expect(await result.next()).to.deep.equal({
value: undefined,
done: true,
});
});
/* c8 ignore stop */
});
26 changes: 21 additions & 5 deletions src/execution/execute.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,8 @@ export const defaultFieldResolver: GraphQLFieldResolver<unknown, unknown> =
export function subscribe(
args: ExecutionArgs,
): PromiseOrValue<
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
| AsyncGenerator<ExecutionResult | AsyncExecutionResult, void, void>
| ExecutionResult
> {
// If a valid execution context cannot be created due to incorrect arguments,
// a "Response" with only errors is returned.
Expand All @@ -1384,11 +1385,24 @@ export function subscribe(
return mapSourceToResponse(exeContext, resultOrStream);
}

async function* ensureAsyncIterable(
someExecutionResult:
| ExecutionResult
| AsyncGenerator<AsyncExecutionResult, void, void>,
): AsyncGenerator<ExecutionResult | AsyncExecutionResult, void, void> {
if (isAsyncIterable(someExecutionResult)) {
yield* someExecutionResult;
} else {
yield someExecutionResult;
}
}

function mapSourceToResponse(
exeContext: ExecutionContext,
resultOrStream: ExecutionResult | AsyncIterable<unknown>,
): PromiseOrValue<
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
| AsyncGenerator<ExecutionResult | AsyncExecutionResult, void, void>
| ExecutionResult
> {
if (!isAsyncIterable(resultOrStream)) {
return resultOrStream;
Expand All @@ -1400,9 +1414,11 @@ function mapSourceToResponse(
// the GraphQL specification. The `execute` function provides the
// "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
// "ExecuteQuery" algorithm, for which `execute` is also used.
return flattenAsyncIterable<ExecutionResult, AsyncExecutionResult>(
mapAsyncIterable(resultOrStream, (payload: unknown) =>
executeImpl(buildPerEventExecutionContext(exeContext, payload)),
return flattenAsyncIterable(
mapAsyncIterable(resultOrStream, async (payload: unknown) =>
ensureAsyncIterable(
await executeImpl(buildPerEventExecutionContext(exeContext, payload)),
),
),
);
}
Expand Down
116 changes: 86 additions & 30 deletions src/execution/flattenAsyncIterable.ts
Original file line number Diff line number Diff line change
@@ -1,49 +1,105 @@
import { isAsyncIterable } from '../jsutils/isAsyncIterable';

type AsyncIterableOrGenerator<T> =
| AsyncGenerator<T, void, void>
| AsyncIterable<T>;

/**
* Given an AsyncIterable that could potentially yield other async iterators,
* flatten all yielded results into a single AsyncIterable
* Given an AsyncIterable of AsyncIterables, flatten all yielded results into a
* single AsyncIterable.
*/
export function flattenAsyncIterable<T, AT>(
iterable: AsyncIterableOrGenerator<T | AsyncIterableOrGenerator<AT>>,
): AsyncGenerator<T | AT, void, void> {
const iteratorMethod = iterable[Symbol.asyncIterator];
const iterator: any = iteratorMethod.call(iterable);
let iteratorStack: Array<AsyncIterator<T>> = [iterator];
export function flattenAsyncIterable<T>(
iterable: AsyncIterableOrGenerator<AsyncIterableOrGenerator<T>>,
): AsyncGenerator<T, void, void> {
// You might think this whole function could be replaced with
//
// async function* flattenAsyncIterable(iterable) {
// for await (const subIterator of iterable) {
// yield* subIterator;
// }
// }
//
// but calling `.return()` on the iterator it returns won't interrupt the `for await`.

const topIterator = iterable[Symbol.asyncIterator]();
let currentNestedIterator: AsyncIterator<T> | undefined;
let waitForCurrentNestedIterator: Promise<void> | undefined;
let done = false;

async function next(): Promise<IteratorResult<T | AT, void>> {
const currentIterator = iteratorStack[0];
if (!currentIterator) {
async function next(): Promise<IteratorResult<T, void>> {
if (done) {
return { value: undefined, done: true };
}
const result = await currentIterator.next();
if (result.done) {
iteratorStack.shift();
return next();
} else if (isAsyncIterable(result.value)) {
const childIterator = result.value[
Symbol.asyncIterator
]() as AsyncIterator<T>;
iteratorStack.unshift(childIterator);
return next();

try {
if (!currentNestedIterator) {
// Somebody else is getting it already.
if (waitForCurrentNestedIterator) {
await waitForCurrentNestedIterator;
return await next();
}
// Nobody else is getting it. We should!
let resolve: () => void;
waitForCurrentNestedIterator = new Promise<void>((r) => {
resolve = r;
});
const topIteratorResult = await topIterator.next();
if (topIteratorResult.done) {
// Given that done only ever transitions from false to true,
// require-atomic-updates is being unnecessarily cautious.
// eslint-disable-next-line require-atomic-updates
done = true;
return await next();
}
// eslint is making a reasonable point here, but we've explicitly protected
// ourself from the race condition by ensuring that only the single call
// that assigns to waitForCurrentNestedIterator is allowed to assign to
// currentNestedIterator or waitForCurrentNestedIterator.
// eslint-disable-next-line require-atomic-updates
currentNestedIterator = topIteratorResult.value[Symbol.asyncIterator]();
// eslint-disable-next-line require-atomic-updates
waitForCurrentNestedIterator = undefined;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
resolve!();
return await next();
}

const rememberCurrentNestedIterator = currentNestedIterator;
const nestedIteratorResult = await currentNestedIterator.next();
if (!nestedIteratorResult.done) {
return nestedIteratorResult;
}

// The nested iterator is done. If it's still the current one, make it not
// current. (If it's not the current one, somebody else has made us move on.)
if (currentNestedIterator === rememberCurrentNestedIterator) {
currentNestedIterator = undefined;
}
return await next();
} catch (err) {
done = true;
throw err;
}
return result;
}
return {
next,
return() {
iteratorStack = [];
return iterator.return();
async return() {
done = true;
await Promise.all([
currentNestedIterator?.return?.(),
topIterator.return?.(),
]);
return { value: undefined, done: true };
},
throw(error?: unknown): Promise<IteratorResult<T | AT>> {
iteratorStack = [];
return iterator.throw(error);
async throw(error?: unknown): Promise<IteratorResult<T>> {
done = true;
await Promise.all([
currentNestedIterator?.throw?.(error),
topIterator.throw?.(error),
]);
/* c8 ignore next */
throw error;
},
[Symbol.asyncIterator]() {
/* c8 ignore next */
return this;
},
};
Expand Down