Skip to content

Commit a40d6b6

Browse files
lidavidmemkornfield
authored andcommitted
ARROW-5978: [FlightRPC] [Java] Properly release buffers in Flight integration client
Fixes a bug where dictionaries weren't properly released in cleaning up a flight stream. Travis build: https://travis-ci.com/lihalite/arrow/builds/119807464 Recreated from #4905 Closes #4913 from lihalite/flight-leak and squashes the following commits: ac8ba8d <David Li> Improve documentation/tests for FlightStream dictionary provider bca02a7 <David Li> Add test case for freeing dictionaries in Flight a096a80 <David Li> Properly release buffers in Flight integration client Authored-by: David Li <li.davidm96@gmail.com> Signed-off-by: Micah Kornfield <emkornfield@gmail.com>
1 parent 8f690e3 commit a40d6b6

File tree

8 files changed

+271
-25
lines changed

8 files changed

+271
-25
lines changed

java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ public ArrowDictionaryBatch asDictionaryBatch() throws IOException {
194194
Preconditions.checkArgument(bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf.");
195195
Preconditions.checkArgument(getMessageType() == HeaderType.DICTIONARY_BATCH);
196196
ArrowBuf underlying = bufs.get(0);
197+
// Retain a reference to keep the batch alive when the message is closed
198+
underlying.getReferenceManager().retain();
197199
return MessageSerializer.deserializeDictionaryBatch(message, underlying);
198200
}
199201

java/flight/src/main/java/org/apache/arrow/flight/DictionaryUtils.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
import java.util.List;
2424
import java.util.Set;
2525
import java.util.function.Consumer;
26+
import java.util.stream.Collectors;
2627

28+
import org.apache.arrow.util.AutoCloseables;
2729
import org.apache.arrow.vector.FieldVector;
2830
import org.apache.arrow.vector.VectorSchemaRoot;
2931
import org.apache.arrow.vector.VectorUnloader;
@@ -74,4 +76,14 @@ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDe
7476
}
7577
return schema;
7678
}
79+
80+
static void closeDictionaries(final Schema schema, final DictionaryProvider provider) throws Exception {
81+
// Close dictionaries
82+
final Set<Long> dictionaryIds = new HashSet<>();
83+
schema.getFields().forEach(field -> DictionaryUtility.toMessageFormat(field, provider, dictionaryIds));
84+
85+
final List<AutoCloseable> dictionaryVectors = dictionaryIds.stream()
86+
.map(id -> (AutoCloseable) provider.lookup(id).getVector()).collect(Collectors.toList());
87+
AutoCloseables.close(dictionaryVectors);
88+
}
7789
}

java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.util.stream.Collectors;
2828

2929
import org.apache.arrow.flight.ArrowMessage.HeaderType;
30+
import org.apache.arrow.flight.grpc.StatusUtils;
3031
import org.apache.arrow.memory.BufferAllocator;
3132
import org.apache.arrow.util.AutoCloseables;
3233
import org.apache.arrow.vector.FieldVector;
@@ -40,7 +41,6 @@
4041
import org.apache.arrow.vector.types.pojo.Schema;
4142
import org.apache.arrow.vector.util.DictionaryUtility;
4243

43-
import com.google.common.base.Throwables;
4444
import com.google.common.collect.ImmutableList;
4545
import com.google.common.collect.Iterables;
4646
import com.google.common.util.concurrent.SettableFuture;
@@ -95,10 +95,42 @@ public Schema getSchema() {
9595
return schema;
9696
}
9797

98+
/**
99+
* Get the provider for dictionaries in this stream.
100+
*
101+
* <p>Does NOT retain a reference to the underlying dictionaries. Dictionaries may be updated as the stream is read.
102+
* This method is intended for stream processing, where the application code will not retain references to values
103+
* after the stream is closed.
104+
*
105+
* @throws IllegalStateException if {@link #takeDictionaryOwnership()} was called
106+
* @see #takeDictionaryOwnership()
107+
*/
98108
public DictionaryProvider getDictionaryProvider() {
109+
if (dictionaries == null) {
110+
throw new IllegalStateException("Dictionary ownership was claimed by the application.");
111+
}
99112
return dictionaries;
100113
}
101114

115+
/**
116+
* Get an owned reference to the dictionaries in this stream. Should be called after finishing reading the stream,
117+
* but before closing.
118+
*
119+
* <p>If called, the client is responsible for closing the dictionaries in this provider. Can only be called once.
120+
*
121+
* @return The dictionary provider for the stream.
122+
* @throws IllegalStateException if called more than once.
123+
*/
124+
public DictionaryProvider takeDictionaryOwnership() {
125+
if (dictionaries == null) {
126+
throw new IllegalStateException("Dictionary ownership was claimed by the application.");
127+
}
128+
// Swap out the provider so it is not closed
129+
final DictionaryProvider provider = dictionaries;
130+
dictionaries = null;
131+
return provider;
132+
}
133+
102134
public FlightDescriptor getDescriptor() {
103135
return descriptor;
104136
}
@@ -117,8 +149,13 @@ public void close() throws Exception {
117149
.map(t -> ((AutoCloseable) t))
118150
.collect(Collectors.toList());
119151

152+
final List<FieldVector> dictionaryVectors =
153+
dictionaries == null ? Collections.emptyList() : dictionaries.getDictionaryIds().stream()
154+
.map(id -> dictionaries.lookup(id).getVector()).collect(Collectors.toList());
155+
120156
// Must check for null since ImmutableList doesn't accept nulls
121157
AutoCloseables.close(Iterables.concat(closeables,
158+
dictionaryVectors,
122159
applicationMetadata != null ? ImmutableList.of(root.get(), applicationMetadata)
123160
: ImmutableList.of(root.get())));
124161
}
@@ -168,6 +205,9 @@ public boolean next() {
168205
} else if (msg.getMessageType() == HeaderType.DICTIONARY_BATCH) {
169206
try (ArrowDictionaryBatch arb = msg.asDictionaryBatch()) {
170207
final long id = arb.getDictionaryId();
208+
if (dictionaries == null) {
209+
throw new IllegalStateException("Dictionary ownership was claimed by the application.");
210+
}
171211
final Dictionary dictionary = dictionaries.lookup(id);
172212
if (dictionary == null) {
173213
throw new IllegalArgumentException("Dictionary not defined in schema: ID " + id);
@@ -195,8 +235,10 @@ public boolean next() {
195235
public VectorSchemaRoot getRoot() {
196236
try {
197237
return root.get();
198-
} catch (InterruptedException | ExecutionException e) {
199-
throw Throwables.propagate(e);
238+
} catch (InterruptedException e) {
239+
throw CallStatus.INTERNAL.withCause(e).toRuntimeException();
240+
} catch (ExecutionException e) {
241+
throw StatusUtils.fromThrowable(e.getCause());
200242
}
201243
}
202244

java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import org.apache.arrow.util.AutoCloseables;
2727

2828
/**
29-
* An Example Flight Server that provides access to the InMemoryStore.
29+
* An Example Flight Server that provides access to the InMemoryStore. Used for integration testing.
3030
*/
3131
public class ExampleFlightServer implements AutoCloseable {
3232

java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.arrow.flight.example;
1919

2020
import java.util.ArrayList;
21+
import java.util.HashSet;
2122
import java.util.List;
23+
import java.util.Set;
2224
import java.util.concurrent.CopyOnWriteArrayList;
2325
import java.util.stream.Collectors;
2426

@@ -31,6 +33,7 @@
3133
import org.apache.arrow.util.Preconditions;
3234
import org.apache.arrow.vector.dictionary.DictionaryProvider;
3335
import org.apache.arrow.vector.types.pojo.Schema;
36+
import org.apache.arrow.vector.util.DictionaryUtility;
3437

3538
import com.google.common.collect.ImmutableList;
3639
import com.google.common.collect.Iterables;
@@ -106,6 +109,13 @@ public FlightInfo getFlightInfo(final Location l) {
106109

107110
@Override
108111
public void close() throws Exception {
109-
AutoCloseables.close(Iterables.concat(streams, ImmutableList.of(allocator)));
112+
// Close dictionaries
113+
final Set<Long> dictionaryIds = new HashSet<>();
114+
schema.getFields().forEach(field -> DictionaryUtility.toMessageFormat(field, dictionaryProvider, dictionaryIds));
115+
116+
final Iterable<AutoCloseable> dictionaries = dictionaryIds.stream()
117+
.map(id -> (AutoCloseable) dictionaryProvider.lookup(id).getVector())::iterator;
118+
119+
AutoCloseables.close(Iterables.concat(streams, ImmutableList.of(allocator), dictionaries));
110120
}
111121
}

java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import org.apache.arrow.vector.VectorUnloader;
4040

4141
/**
42-
* A FlightProducer that hosts an in memory store of Arrow buffers.
42+
* A FlightProducer that hosts an in memory store of Arrow buffers. Used for integration testing.
4343
*/
4444
public class InMemoryStore implements FlightProducer, AutoCloseable {
4545

@@ -80,8 +80,7 @@ public Stream getStream(Ticket t) {
8080
}
8181

8282
@Override
83-
public void listFlights(CallContext context, Criteria criteria,
84-
StreamListener<FlightInfo> listener) {
83+
public void listFlights(CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
8584
try {
8685
for (FlightHolder h : holders.values()) {
8786
listener.onNext(h.getFlightInfo(location));
@@ -93,8 +92,7 @@ public void listFlights(CallContext context, Criteria criteria,
9392
}
9493

9594
@Override
96-
public FlightInfo getFlightInfo(CallContext context,
97-
FlightDescriptor descriptor) {
95+
public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
9896
FlightHolder h = holders.get(descriptor);
9997
if (h == null) {
10098
throw new IllegalStateException("Unknown descriptor.");
@@ -121,6 +119,8 @@ public Runnable acceptPut(CallContext context,
121119
ackStream.onNext(PutResult.metadata(flightStream.getLatestMetadata()));
122120
creator.add(unloader.getRecordBatch());
123121
}
122+
// Closing the stream will release the dictionaries
123+
flightStream.takeDictionaryOwnership();
124124
creator.complete();
125125
success = true;
126126
} finally {

java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.apache.arrow.vector.VectorSchemaRoot;
3838
import org.apache.arrow.vector.VectorUnloader;
3939
import org.apache.arrow.vector.ipc.JsonFileReader;
40+
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
4041
import org.apache.arrow.vector.util.Validator;
4142
import org.apache.commons.cli.CommandLine;
4243
import org.apache.commons.cli.CommandLineParser;
@@ -84,12 +85,19 @@ private void run(String[] args) throws ParseException, IOException {
8485
final String host = cmd.getOptionValue("host", "localhost");
8586
final int port = Integer.parseInt(cmd.getOptionValue("port", "31337"));
8687

87-
final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
8888
final Location defaultLocation = Location.forGrpcInsecure(host, port);
89-
final FlightClient client = FlightClient.builder(allocator, defaultLocation).build();
89+
try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
90+
final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) {
9091

91-
final String inputPath = cmd.getOptionValue("j");
92+
final String inputPath = cmd.getOptionValue("j");
93+
testStream(allocator, defaultLocation, client, inputPath);
94+
} catch (InterruptedException e) {
95+
throw new RuntimeException(e);
96+
}
97+
}
9298

99+
private static void testStream(BufferAllocator allocator, Location server, FlightClient client, String inputPath)
100+
throws IOException {
93101
// 1. Read data from JSON and upload to server.
94102
FlightDescriptor descriptor = FlightDescriptor.path(inputPath);
95103
VectorSchemaRoot jsonRoot;
@@ -121,7 +129,9 @@ public void onNext(PutResult val) {
121129
metadata.writeBytes(rawMetadata);
122130
// Transfers ownership of the buffer, so do not release it ourselves
123131
stream.putNext(metadata);
124-
jsonLoader.load(unloader.getRecordBatch());
132+
try (final ArrowRecordBatch arb = unloader.getRecordBatch()) {
133+
jsonLoader.load(arb);
134+
}
125135
root.clear();
126136
counter++;
127137
}
@@ -141,25 +151,29 @@ public void onNext(PutResult val) {
141151
// 3. Download the data from the server.
142152
List<Location> locations = endpoint.getLocations();
143153
if (locations.size() == 0) {
144-
locations = Collections.singletonList(defaultLocation);
154+
locations = Collections.singletonList(server);
145155
}
146156
for (Location location : locations) {
147157
System.out.println("Verifying location " + location.getUri());
148-
FlightClient readClient = FlightClient.builder(allocator, location).build();
149-
FlightStream stream = readClient.getStream(endpoint.getTicket());
150-
VectorSchemaRoot downloadedRoot;
151-
try (VectorSchemaRoot root = stream.getRoot()) {
152-
downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator);
158+
try (FlightClient readClient = FlightClient.builder(allocator, location).build();
159+
FlightStream stream = readClient.getStream(endpoint.getTicket());
160+
VectorSchemaRoot root = stream.getRoot();
161+
VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator)) {
153162
VectorLoader loader = new VectorLoader(downloadedRoot);
154163
VectorUnloader unloader = new VectorUnloader(root);
155164
while (stream.next()) {
156-
loader.load(unloader.getRecordBatch());
165+
try (final ArrowRecordBatch arb = unloader.getRecordBatch()) {
166+
loader.load(arb);
167+
}
157168
}
158-
}
159169

160-
// 4. Validate the data.
161-
Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot);
170+
// 4. Validate the data.
171+
Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot);
172+
} catch (Exception e) {
173+
throw new RuntimeException(e);
174+
}
162175
}
163176
}
177+
jsonRoot.close();
164178
}
165179
}

0 commit comments

Comments
 (0)