Skip to content

Commit b3d4058

Browse files
committed
defer/stream: fix flattenAsyncIterable concurrency
Fixes the bug demonstrated in #3709 (which has already been incorporated into the defer-stream branch). This fix is extracted from #3703, which also updates the typing and API around execute. This particular change doesn't affect the API, other than making the `subscribe` return type more honest, as its returned generator could yield AsyncExecutionResult before this change as well. (The reason the previous version built is because every ExecutionResult is actually an AsyncExecutionResult; fixing that fact is part of what #3703 does.)
1 parent feb203a commit b3d4058

File tree

3 files changed

+151
-98
lines changed

3 files changed

+151
-98
lines changed

src/execution/__tests__/flattenAsyncIterable-test.ts

Lines changed: 44 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,61 +4,20 @@ import { describe, it } from 'mocha';
44
import { flattenAsyncIterable } from '../flattenAsyncIterable';
55

66
describe('flattenAsyncIterable', () => {
7-
it('does not modify an already flat async generator', async () => {
8-
async function* source() {
9-
yield await Promise.resolve(1);
10-
yield await Promise.resolve(2);
11-
yield await Promise.resolve(3);
12-
}
13-
14-
const result = flattenAsyncIterable(source());
15-
16-
expect(await result.next()).to.deep.equal({ value: 1, done: false });
17-
expect(await result.next()).to.deep.equal({ value: 2, done: false });
18-
expect(await result.next()).to.deep.equal({ value: 3, done: false });
19-
expect(await result.next()).to.deep.equal({
20-
value: undefined,
21-
done: true,
22-
});
23-
});
24-
25-
it('does not modify an already flat async iterator', async () => {
26-
const items = [1, 2, 3];
27-
28-
const iterator: any = {
29-
[Symbol.asyncIterator]() {
30-
return this;
31-
},
32-
next() {
33-
return Promise.resolve({
34-
done: items.length === 0,
35-
value: items.shift(),
36-
});
37-
},
38-
};
39-
40-
const result = flattenAsyncIterable(iterator);
41-
42-
expect(await result.next()).to.deep.equal({ value: 1, done: false });
43-
expect(await result.next()).to.deep.equal({ value: 2, done: false });
44-
expect(await result.next()).to.deep.equal({ value: 3, done: false });
45-
expect(await result.next()).to.deep.equal({
46-
value: undefined,
47-
done: true,
48-
});
49-
});
50-
517
it('flatten nested async generators', async () => {
528
async function* source() {
53-
yield await Promise.resolve(1);
54-
yield await Promise.resolve(2);
9+
yield await Promise.resolve(
10+
(async function* nested(): AsyncGenerator<number, void, void> {
11+
yield await Promise.resolve(1.1);
12+
yield await Promise.resolve(1.2);
13+
})(),
14+
);
5515
yield await Promise.resolve(
5616
(async function* nested(): AsyncGenerator<number, void, void> {
5717
yield await Promise.resolve(2.1);
5818
yield await Promise.resolve(2.2);
5919
})(),
6020
);
61-
yield await Promise.resolve(3);
6221
}
6322

6423
const doubles = flattenAsyncIterable(source());
@@ -67,13 +26,17 @@ describe('flattenAsyncIterable', () => {
6726
for await (const x of doubles) {
6827
result.push(x);
6928
}
70-
expect(result).to.deep.equal([1, 2, 2.1, 2.2, 3]);
29+
expect(result).to.deep.equal([1.1, 1.2, 2.1, 2.2]);
7130
});
7231

7332
it('allows returning early from a nested async generator', async () => {
7433
async function* source() {
75-
yield await Promise.resolve(1);
76-
yield await Promise.resolve(2);
34+
yield await Promise.resolve(
35+
(async function* nested(): AsyncGenerator<number, void, void> {
36+
yield await Promise.resolve(1.1);
37+
yield await Promise.resolve(1.2);
38+
})(),
39+
);
7740
yield await Promise.resolve(
7841
(async function* nested(): AsyncGenerator<number, void, void> {
7942
yield await Promise.resolve(2.1); /* c8 ignore start */
@@ -82,14 +45,19 @@ describe('flattenAsyncIterable', () => {
8245
})(),
8346
);
8447
// Not reachable, early return
85-
yield await Promise.resolve(3);
48+
yield await Promise.resolve(
49+
(async function* nested(): AsyncGenerator<number, void, void> {
50+
yield await Promise.resolve(3.1);
51+
yield await Promise.resolve(3.2);
52+
})(),
53+
);
8654
}
8755
/* c8 ignore stop */
8856

8957
const doubles = flattenAsyncIterable(source());
9058

91-
expect(await doubles.next()).to.deep.equal({ value: 1, done: false });
92-
expect(await doubles.next()).to.deep.equal({ value: 2, done: false });
59+
expect(await doubles.next()).to.deep.equal({ value: 1.1, done: false });
60+
expect(await doubles.next()).to.deep.equal({ value: 1.2, done: false });
9361
expect(await doubles.next()).to.deep.equal({ value: 2.1, done: false });
9462

9563
// Early return
@@ -111,8 +79,12 @@ describe('flattenAsyncIterable', () => {
11179

11280
it('allows throwing errors from a nested async generator', async () => {
11381
async function* source() {
114-
yield await Promise.resolve(1);
115-
yield await Promise.resolve(2);
82+
yield await Promise.resolve(
83+
(async function* nested(): AsyncGenerator<number, void, void> {
84+
yield await Promise.resolve(1.1);
85+
yield await Promise.resolve(1.2);
86+
})(),
87+
);
11688
yield await Promise.resolve(
11789
(async function* nested(): AsyncGenerator<number, void, void> {
11890
yield await Promise.resolve(2.1); /* c8 ignore start */
@@ -121,14 +93,19 @@ describe('flattenAsyncIterable', () => {
12193
})(),
12294
);
12395
// Not reachable, early return
124-
yield await Promise.resolve(3);
96+
yield await Promise.resolve(
97+
(async function* nested(): AsyncGenerator<number, void, void> {
98+
yield await Promise.resolve(3.1);
99+
yield await Promise.resolve(3.2);
100+
})(),
101+
);
125102
}
126103
/* c8 ignore stop */
127104

128105
const doubles = flattenAsyncIterable(source());
129106

130-
expect(await doubles.next()).to.deep.equal({ value: 1, done: false });
131-
expect(await doubles.next()).to.deep.equal({ value: 2, done: false });
107+
expect(await doubles.next()).to.deep.equal({ value: 1.1, done: false });
108+
expect(await doubles.next()).to.deep.equal({ value: 1.2, done: false });
132109
expect(await doubles.next()).to.deep.equal({ value: 2.1, done: false });
133110

134111
// Throw error
@@ -142,16 +119,20 @@ describe('flattenAsyncIterable', () => {
142119
}
143120
expect(caughtError).to.equal('ouch');
144121
});
145-
/* c8 ignore start */
146-
it.skip('completely yields sub-iterables even when next() called in parallel', async () => {
122+
it('completely yields sub-iterables even when next() called in parallel', async () => {
147123
async function* source() {
148124
yield await Promise.resolve(
149125
(async function* nested(): AsyncGenerator<number, void, void> {
150126
yield await Promise.resolve(1.1);
151127
yield await Promise.resolve(1.2);
152128
})(),
153129
);
154-
yield await Promise.resolve(2);
130+
yield await Promise.resolve(
131+
(async function* nested(): AsyncGenerator<number, void, void> {
132+
yield await Promise.resolve(2.1);
133+
yield await Promise.resolve(2.2);
134+
})(),
135+
);
155136
}
156137

157138
const result = flattenAsyncIterable(source());
@@ -160,11 +141,11 @@ describe('flattenAsyncIterable', () => {
160141
const promise2 = result.next();
161142
expect(await promise1).to.deep.equal({ value: 1.1, done: false });
162143
expect(await promise2).to.deep.equal({ value: 1.2, done: false });
163-
expect(await result.next()).to.deep.equal({ value: 2, done: false });
144+
expect(await result.next()).to.deep.equal({ value: 2.1, done: false });
145+
expect(await result.next()).to.deep.equal({ value: 2.2, done: false });
164146
expect(await result.next()).to.deep.equal({
165147
value: undefined,
166148
done: true,
167149
});
168150
});
169-
/* c8 ignore stop */
170151
});

src/execution/execute.ts

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,7 +1362,8 @@ export const defaultFieldResolver: GraphQLFieldResolver<unknown, unknown> =
13621362
export function subscribe(
13631363
args: ExecutionArgs,
13641364
): PromiseOrValue<
1365-
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
1365+
| AsyncGenerator<ExecutionResult | AsyncExecutionResult, void, void>
1366+
| ExecutionResult
13661367
> {
13671368
// If a valid execution context cannot be created due to incorrect arguments,
13681369
// a "Response" with only errors is returned.
@@ -1384,11 +1385,24 @@ export function subscribe(
13841385
return mapSourceToResponse(exeContext, resultOrStream);
13851386
}
13861387

1388+
async function* ensureAsyncIterable(
1389+
someExecutionResult:
1390+
| ExecutionResult
1391+
| AsyncGenerator<AsyncExecutionResult, void, void>,
1392+
): AsyncGenerator<ExecutionResult | AsyncExecutionResult, void, void> {
1393+
if (isAsyncIterable(someExecutionResult)) {
1394+
yield* someExecutionResult;
1395+
} else {
1396+
yield someExecutionResult;
1397+
}
1398+
}
1399+
13871400
function mapSourceToResponse(
13881401
exeContext: ExecutionContext,
13891402
resultOrStream: ExecutionResult | AsyncIterable<unknown>,
13901403
): PromiseOrValue<
1391-
AsyncGenerator<ExecutionResult, void, void> | ExecutionResult
1404+
| AsyncGenerator<ExecutionResult | AsyncExecutionResult, void, void>
1405+
| ExecutionResult
13921406
> {
13931407
if (!isAsyncIterable(resultOrStream)) {
13941408
return resultOrStream;
@@ -1400,9 +1414,11 @@ function mapSourceToResponse(
14001414
// the GraphQL specification. The `execute` function provides the
14011415
// "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the
14021416
// "ExecuteQuery" algorithm, for which `execute` is also used.
1403-
return flattenAsyncIterable<ExecutionResult, AsyncExecutionResult>(
1404-
mapAsyncIterable(resultOrStream, (payload: unknown) =>
1405-
executeImpl(buildPerEventExecutionContext(exeContext, payload)),
1417+
return flattenAsyncIterable(
1418+
mapAsyncIterable(resultOrStream, async (payload: unknown) =>
1419+
ensureAsyncIterable(
1420+
await executeImpl(buildPerEventExecutionContext(exeContext, payload)),
1421+
),
14061422
),
14071423
);
14081424
}

src/execution/flattenAsyncIterable.ts

Lines changed: 86 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,105 @@
1-
import { isAsyncIterable } from '../jsutils/isAsyncIterable';
2-
31
type AsyncIterableOrGenerator<T> =
42
| AsyncGenerator<T, void, void>
53
| AsyncIterable<T>;
64

75
/**
8-
* Given an AsyncIterable that could potentially yield other async iterators,
9-
* flatten all yielded results into a single AsyncIterable
6+
* Given an AsyncIterable of AsyncIterables, flatten all yielded results into a
7+
* single AsyncIterable.
108
*/
11-
export function flattenAsyncIterable<T, AT>(
12-
iterable: AsyncIterableOrGenerator<T | AsyncIterableOrGenerator<AT>>,
13-
): AsyncGenerator<T | AT, void, void> {
14-
const iteratorMethod = iterable[Symbol.asyncIterator];
15-
const iterator: any = iteratorMethod.call(iterable);
16-
let iteratorStack: Array<AsyncIterator<T>> = [iterator];
9+
export function flattenAsyncIterable<T>(
10+
iterable: AsyncIterableOrGenerator<AsyncIterableOrGenerator<T>>,
11+
): AsyncGenerator<T, void, void> {
12+
// You might think this whole function could be replaced with
13+
//
14+
// async function* flattenAsyncIterable(iterable) {
15+
// for await (const subIterator of iterable) {
16+
// yield* subIterator;
17+
// }
18+
// }
19+
//
20+
// but calling `.return()` on the iterator it returns won't interrupt the `for await`.
21+
22+
const topIterator = iterable[Symbol.asyncIterator]();
23+
let currentNestedIterator: AsyncIterator<T> | undefined;
24+
let waitForCurrentNestedIterator: Promise<void> | undefined;
25+
let done = false;
1726

18-
async function next(): Promise<IteratorResult<T | AT, void>> {
19-
const currentIterator = iteratorStack[0];
20-
if (!currentIterator) {
27+
async function next(): Promise<IteratorResult<T, void>> {
28+
if (done) {
2129
return { value: undefined, done: true };
2230
}
23-
const result = await currentIterator.next();
24-
if (result.done) {
25-
iteratorStack.shift();
26-
return next();
27-
} else if (isAsyncIterable(result.value)) {
28-
const childIterator = result.value[
29-
Symbol.asyncIterator
30-
]() as AsyncIterator<T>;
31-
iteratorStack.unshift(childIterator);
32-
return next();
31+
32+
try {
33+
if (!currentNestedIterator) {
34+
// Somebody else is getting it already.
35+
if (waitForCurrentNestedIterator) {
36+
await waitForCurrentNestedIterator;
37+
return await next();
38+
}
39+
// Nobody else is getting it. We should!
40+
let resolve: () => void;
41+
waitForCurrentNestedIterator = new Promise<void>((r) => {
42+
resolve = r;
43+
});
44+
const topIteratorResult = await topIterator.next();
45+
if (topIteratorResult.done) {
46+
// Given that done only ever transitions from false to true,
47+
// require-atomic-updates is being unnecessarily cautious.
48+
// eslint-disable-next-line require-atomic-updates
49+
done = true;
50+
return await next();
51+
}
52+
// eslint is making a reasonable point here, but we've explicitly protected
53+
// ourself from the race condition by ensuring that only the single call
54+
// that assigns to waitForCurrentNestedIterator is allowed to assign to
55+
// currentNestedIterator or waitForCurrentNestedIterator.
56+
// eslint-disable-next-line require-atomic-updates
57+
currentNestedIterator = topIteratorResult.value[Symbol.asyncIterator]();
58+
// eslint-disable-next-line require-atomic-updates
59+
waitForCurrentNestedIterator = undefined;
60+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
61+
resolve!();
62+
return await next();
63+
}
64+
65+
const rememberCurrentNestedIterator = currentNestedIterator;
66+
const nestedIteratorResult = await currentNestedIterator.next();
67+
if (!nestedIteratorResult.done) {
68+
return nestedIteratorResult;
69+
}
70+
71+
// The nested iterator is done. If it's still the current one, make it not
72+
// current. (If it's not the current one, somebody else has made us move on.)
73+
if (currentNestedIterator === rememberCurrentNestedIterator) {
74+
currentNestedIterator = undefined;
75+
}
76+
return await next();
77+
} catch (err) {
78+
done = true;
79+
throw err;
3380
}
34-
return result;
3581
}
3682
return {
3783
next,
38-
return() {
39-
iteratorStack = [];
40-
return iterator.return();
84+
async return() {
85+
done = true;
86+
await Promise.all([
87+
currentNestedIterator?.return?.(),
88+
topIterator.return?.(),
89+
]);
90+
return { value: undefined, done: true };
4191
},
42-
throw(error?: unknown): Promise<IteratorResult<T | AT>> {
43-
iteratorStack = [];
44-
return iterator.throw(error);
92+
async throw(error?: unknown): Promise<IteratorResult<T>> {
93+
done = true;
94+
await Promise.all([
95+
currentNestedIterator?.throw?.(error),
96+
topIterator.throw?.(error),
97+
]);
98+
/* c8 ignore next */
99+
throw error;
45100
},
46101
[Symbol.asyncIterator]() {
102+
/* c8 ignore next */
47103
return this;
48104
},
49105
};

0 commit comments

Comments
 (0)