Skip to content

Commit 0d80418

Browse files
committed
Alternative fix to dead lock challenge (graphql-java#1255)
* Added @OverRide as part of errorprone code health check * Revert "Added @OverRide as part of errorprone code health check" This reverts commit 38dfab1 * Brads attempt at graphql-java#1234 * Missed the test
1 parent a46d462 commit 0d80418

File tree

2 files changed

+248
-11
lines changed

2 files changed

+248
-11
lines changed

src/main/java/graphql/execution/instrumentation/dataloader/FieldLevelTrackingApproach.java

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ public String toString() {
9494
'}';
9595
}
9696

97-
public void dispatchIfNotDispatchedBefore(int level, Runnable dispatch) {
97+
public boolean dispatchIfNotDispatchedBefore(int level) {
9898
if (dispatchedLevels.contains(level)) {
9999
Assert.assertShouldNeverHappen("level " + level + " already dispatched");
100-
return;
100+
return false;
101101
}
102102
dispatchedLevels.add(level);
103-
dispatch.run();
103+
return true;
104104
}
105105

106106
public void clearAndMarkCurrentLevelAsReady(int level) {
@@ -151,17 +151,25 @@ public void onCompleted(ExecutionResult result, Throwable t) {
151151

152152
@Override
153153
public void onFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList) {
154+
boolean dispatchNeeded;
154155
synchronized (callStack) {
155-
handleOnFieldValuesInfo(fieldValueInfoList, callStack, curLevel);
156+
dispatchNeeded = handleOnFieldValuesInfo(fieldValueInfoList, callStack, curLevel);
157+
}
158+
if (dispatchNeeded) {
159+
dispatch();
156160
}
157161
}
158162

159163
@Override
160164
public void onDeferredField(List<Field> field) {
165+
boolean dispatchNeeded;
161166
// fake fetch count for this field
162167
synchronized (callStack) {
163168
callStack.increaseFetchCount(curLevel);
164-
dispatchIfNeeded(callStack, curLevel);
169+
dispatchNeeded = dispatchIfNeeded(callStack, curLevel);
170+
}
171+
if (dispatchNeeded) {
172+
dispatch();
165173
}
166174
}
167175
};
@@ -170,7 +178,7 @@ public void onDeferredField(List<Field> field) {
170178
//
171179
// thread safety : called with synchronised(callStack)
172180
//
173-
private void handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, CallStack callStack, int curLevel) {
181+
private boolean handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, CallStack callStack, int curLevel) {
174182
callStack.increaseHappenedOnFieldValueCalls(curLevel);
175183
int expectedStrategyCalls = 0;
176184
for (FieldValueInfo fieldValueInfo : fieldValueInfoList) {
@@ -181,7 +189,7 @@ private void handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, Ca
181189
}
182190
}
183191
callStack.increaseExpectedStrategyCalls(curLevel + 1, expectedStrategyCalls);
184-
dispatchIfNeeded(callStack, curLevel + 1);
192+
return dispatchIfNeeded(callStack, curLevel + 1);
185193
}
186194

187195
private int getCountForList(FieldValueInfo fieldValueInfo) {
@@ -215,8 +223,12 @@ public void onCompleted(ExecutionResult result, Throwable t) {
215223

216224
@Override
217225
public void onFieldValueInfo(FieldValueInfo fieldValueInfo) {
226+
boolean dispatchNeeded;
218227
synchronized (callStack) {
219-
handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), callStack, level);
228+
dispatchNeeded = handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), callStack, level);
229+
}
230+
if (dispatchNeeded) {
231+
dispatch();
220232
}
221233
}
222234
};
@@ -230,10 +242,15 @@ public InstrumentationContext<Object> beginFieldFetch(InstrumentationFieldFetchP
230242

231243
@Override
232244
public void onDispatched(CompletableFuture result) {
245+
boolean dispatchNeeded;
233246
synchronized (callStack) {
234247
callStack.increaseFetchCount(level);
235-
dispatchIfNeeded(callStack, level);
248+
dispatchNeeded = dispatchIfNeeded(callStack, level);
236249
}
250+
if (dispatchNeeded) {
251+
dispatch();
252+
}
253+
237254
}
238255

239256
@Override
@@ -246,10 +263,11 @@ public void onCompleted(Object result, Throwable t) {
246263
//
247264
// thread safety : called with synchronised(callStack)
248265
//
249-
private void dispatchIfNeeded(CallStack callStack, int level) {
266+
private boolean dispatchIfNeeded(CallStack callStack, int level) {
250267
if (levelReady(callStack, level)) {
251-
callStack.dispatchIfNotDispatchedBefore(level, this::dispatch);
268+
return callStack.dispatchIfNotDispatchedBefore(level);
252269
}
270+
return false;
253271
}
254272

255273
//
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
package graphql.execution.instrumentation.dataloader
2+
3+
import graphql.ExecutionInput
4+
import graphql.ExecutionResult
5+
import graphql.GraphQL
6+
import graphql.TestUtil
7+
import graphql.execution.Async
8+
import graphql.schema.DataFetcher
9+
import graphql.schema.DataFetchingEnvironment
10+
import graphql.schema.idl.RuntimeWiring
11+
import org.apache.commons.lang3.concurrent.BasicThreadFactory
12+
import org.dataloader.BatchLoader
13+
import org.dataloader.DataLoader
14+
import org.dataloader.DataLoaderOptions
15+
import org.dataloader.DataLoaderRegistry
16+
import spock.lang.Specification
17+
18+
import java.util.concurrent.CompletableFuture
19+
import java.util.concurrent.CompletionStage
20+
import java.util.concurrent.SynchronousQueue
21+
import java.util.concurrent.ThreadFactory
22+
import java.util.concurrent.ThreadPoolExecutor
23+
import java.util.concurrent.TimeUnit
24+
25+
import static graphql.schema.idl.TypeRuntimeWiring.newTypeWiring
26+
27+
class DataLoaderHangingTest extends Specification {
28+
29+
public static final int NUM_OF_REPS = 50
30+
31+
def "deadlock attempt"() {
32+
setup:
33+
def sdl = """
34+
type Album {
35+
id: ID!
36+
title: String!
37+
artist: Artist
38+
songs(
39+
limit: Int,
40+
nextToken: String
41+
): ModelSongConnection
42+
}
43+
44+
type Artist {
45+
id: ID!
46+
name: String!
47+
albums(
48+
limit: Int,
49+
nextToken: String
50+
): ModelAlbumConnection
51+
songs(
52+
limit: Int,
53+
nextToken: String
54+
): ModelSongConnection
55+
}
56+
57+
type ModelAlbumConnection {
58+
items: [Album]
59+
nextToken: String
60+
}
61+
62+
type ModelArtistConnection {
63+
items: [Artist]
64+
nextToken: String
65+
}
66+
67+
type ModelSongConnection {
68+
items: [Song]
69+
nextToken: String
70+
}
71+
72+
type Query {
73+
listArtists(limit: Int, nextToken: String): ModelArtistConnection
74+
}
75+
76+
type Song {
77+
id: ID!
78+
title: String!
79+
artist: Artist
80+
album: Album
81+
}
82+
"""
83+
84+
ThreadFactory threadFactory = new BasicThreadFactory.Builder()
85+
.namingPattern("resolver-chain-thread-%d").build()
86+
def executor = new ThreadPoolExecutor(15, 15, 0L,
87+
TimeUnit.MILLISECONDS, new SynchronousQueue<>(), threadFactory,
88+
new ThreadPoolExecutor.CallerRunsPolicy())
89+
90+
def dataLoaderAlbums = new DataLoader<Object, Object>(new BatchLoader<DataFetchingEnvironment, List<Object>>() {
91+
@Override
92+
CompletionStage<List<List<Object>>> load(List<DataFetchingEnvironment> keys) {
93+
return CompletableFuture.supplyAsync({
94+
def limit = keys.first().getArgument("limit") as Integer
95+
return keys.collect({ k ->
96+
def albums = []
97+
for (int i = 1; i <= limit; i++) {
98+
albums.add(['id': "artist-$k.source.id-$i", 'title': "album-$i"])
99+
}
100+
def albumsConnection = ['nextToken': 'album-next', 'items': albums]
101+
return albumsConnection
102+
})
103+
}, executor)
104+
}
105+
}, DataLoaderOptions.newOptions().setMaxBatchSize(5))
106+
107+
def dataLoaderSongs = new DataLoader<Object, Object>(new BatchLoader<DataFetchingEnvironment, List<Object>>() {
108+
@Override
109+
CompletionStage<List<List<Object>>> load(List<DataFetchingEnvironment> keys) {
110+
return CompletableFuture.supplyAsync({
111+
def limit = keys.first().getArgument("limit") as Integer
112+
return keys.collect({ k ->
113+
def songs = []
114+
for (int i = 1; i <= limit; i++) {
115+
songs.add(['id': "album-$k.source.id-$i", 'title': "song-$i"])
116+
}
117+
def songsConnection = ['nextToken': 'song-next', 'items': songs]
118+
return songsConnection
119+
})
120+
}, executor)
121+
}
122+
}, DataLoaderOptions.newOptions().setMaxBatchSize(5))
123+
124+
def dataLoaderRegistry = new DataLoaderRegistry()
125+
dataLoaderRegistry.register("artist.albums", dataLoaderAlbums)
126+
dataLoaderRegistry.register("album.songs", dataLoaderSongs)
127+
128+
129+
def albumsDf = new MyForwardingDataFetcher(dataLoaderAlbums)
130+
def songsDf = new MyForwardingDataFetcher(dataLoaderSongs)
131+
132+
def dataFetcherArtists = new DataFetcher() {
133+
@Override
134+
Object get(DataFetchingEnvironment environment) {
135+
def limit = environment.getArgument("limit") as Integer
136+
def artists = []
137+
for (int i = 1; i <= limit; i++) {
138+
artists.add(['id': "artist-$i", 'name': "artist-$i"])
139+
}
140+
return ['nextToken': 'artist-next', 'items': artists]
141+
}
142+
}
143+
144+
def wiring = RuntimeWiring.newRuntimeWiring()
145+
.type(newTypeWiring("Query")
146+
.dataFetcher("listArtists", dataFetcherArtists))
147+
.type(newTypeWiring("Artist")
148+
.dataFetcher("albums", albumsDf))
149+
.type(newTypeWiring("Album")
150+
.dataFetcher("songs", songsDf))
151+
.build()
152+
153+
def schema = TestUtil.schema(sdl, wiring)
154+
155+
when:
156+
def graphql = GraphQL.newGraphQL(schema)
157+
.instrumentation(new DataLoaderDispatcherInstrumentation(dataLoaderRegistry))
158+
.build()
159+
160+
then: "execution shouldn't hang"
161+
List<CompletableFuture<ExecutionResult>> futures = []
162+
for (int i = 0; i < NUM_OF_REPS; i++) {
163+
def result = graphql.executeAsync(ExecutionInput.newExecutionInput()
164+
.query("""
165+
query getArtistsWithData {
166+
listArtists(limit: 1) {
167+
items {
168+
name
169+
albums(limit: 200) {
170+
items {
171+
title
172+
# Uncommenting the following causes query to timeout
173+
songs(limit: 5) {
174+
nextToken
175+
items {
176+
title
177+
}
178+
}
179+
}
180+
}
181+
}
182+
}
183+
}
184+
""")
185+
.build())
186+
result.whenComplete({ res, error ->
187+
if (error) {
188+
throw error
189+
}
190+
assert res.errors.empty
191+
})
192+
// add all futures
193+
futures.add(result)
194+
}
195+
// wait for each future to complete and grab the results
196+
Async.each(futures)
197+
.whenComplete({ results, error ->
198+
if (error) {
199+
throw error
200+
}
201+
results.each { assert it.errors.empty }
202+
})
203+
.join()
204+
}
205+
206+
static class MyForwardingDataFetcher implements DataFetcher<CompletableFuture<Object>> {
207+
208+
private final DataLoader dataLoader
209+
210+
public MyForwardingDataFetcher(DataLoader dataLoader) {
211+
this.dataLoader = dataLoader
212+
}
213+
214+
@Override
215+
CompletableFuture<Object> get(DataFetchingEnvironment environment) {
216+
return dataLoader.load(environment)
217+
}
218+
}
219+
}

0 commit comments

Comments
 (0)