Skip to content

Commit a096a80

Browse files
committed
Properly release buffers in Flight integration client
1 parent 7478fac commit a096a80

File tree

6 files changed

+71
-21
lines changed

6 files changed

+71
-21
lines changed

java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ public ArrowDictionaryBatch asDictionaryBatch() throws IOException {
194194
Preconditions.checkArgument(bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf.");
195195
Preconditions.checkArgument(getMessageType() == HeaderType.DICTIONARY_BATCH);
196196
ArrowBuf underlying = bufs.get(0);
197+
// Retain a reference to keep the batch alive when the message is closed
198+
underlying.getReferenceManager().retain();
197199
return MessageSerializer.deserializeDictionaryBatch(message, underlying);
198200
}
199201

java/flight/src/main/java/org/apache/arrow/flight/FlightStream.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,30 @@ public Schema getSchema() {
9595
return schema;
9696
}
9797

98+
/**
99+
* Get the provider for dictionaries in this stream.
100+
*
101+
* <p>Does NOT retain a reference to the underlying dictionaries. Dictionaries may be updated as the stream is read.
102+
*
103+
* @see #retainDictionaries()
104+
*/
98105
public DictionaryProvider getDictionaryProvider() {
99106
return dictionaries;
100107
}
101108

109+
/**
110+
* Retain a reference to each of the dictionaries in the stream. Should be called after finishing reading the stream,
111+
* but before closing.
112+
*
113+
* @return The current dictionaries in the stream.
114+
*/
115+
public DictionaryProvider retainDictionaries() {
116+
// Swap out the provider so it is not closed
117+
final DictionaryProvider provider = dictionaries;
118+
dictionaries = new DictionaryProvider.MapDictionaryProvider();
119+
return provider;
120+
}
121+
102122
public FlightDescriptor getDescriptor() {
103123
return descriptor;
104124
}
@@ -117,8 +137,12 @@ public void close() throws Exception {
117137
.map(t -> ((AutoCloseable) t))
118138
.collect(Collectors.toList());
119139

140+
final List<FieldVector> dictionaryVectors = dictionaries.getDictionaryIds().stream()
141+
.map(id -> dictionaries.lookup(id).getVector()).collect(Collectors.toList());
142+
120143
// Must check for null since ImmutableList doesn't accept nulls
121144
AutoCloseables.close(Iterables.concat(closeables,
145+
dictionaryVectors,
122146
applicationMetadata != null ? ImmutableList.of(root.get(), applicationMetadata)
123147
: ImmutableList.of(root.get())));
124148
}

java/flight/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import org.apache.arrow.util.AutoCloseables;
2727

2828
/**
29-
* An Example Flight Server that provides access to the InMemoryStore.
29+
* An Example Flight Server that provides access to the InMemoryStore. Used for integration testing.
3030
*/
3131
public class ExampleFlightServer implements AutoCloseable {
3232

java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.arrow.flight.example;
1919

2020
import java.util.ArrayList;
21+
import java.util.HashSet;
2122
import java.util.List;
23+
import java.util.Set;
2224
import java.util.concurrent.CopyOnWriteArrayList;
2325
import java.util.stream.Collectors;
2426

@@ -31,6 +33,7 @@
3133
import org.apache.arrow.util.Preconditions;
3234
import org.apache.arrow.vector.dictionary.DictionaryProvider;
3335
import org.apache.arrow.vector.types.pojo.Schema;
36+
import org.apache.arrow.vector.util.DictionaryUtility;
3437

3538
import com.google.common.collect.ImmutableList;
3639
import com.google.common.collect.Iterables;
@@ -106,6 +109,13 @@ public FlightInfo getFlightInfo(final Location l) {
106109

107110
@Override
108111
public void close() throws Exception {
109-
AutoCloseables.close(Iterables.concat(streams, ImmutableList.of(allocator)));
112+
// Close dictionaries
113+
final Set<Long> dictionaryIds = new HashSet<>();
114+
schema.getFields().forEach(field -> DictionaryUtility.toMessageFormat(field, dictionaryProvider, dictionaryIds));
115+
116+
final Iterable<AutoCloseable> dictionaries = dictionaryIds.stream()
117+
.map(id -> (AutoCloseable) dictionaryProvider.lookup(id).getVector())::iterator;
118+
119+
AutoCloseables.close(Iterables.concat(streams, ImmutableList.of(allocator), dictionaries));
110120
}
111121
}

java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import org.apache.arrow.vector.VectorUnloader;
4040

4141
/**
42-
* A FlightProducer that hosts an in memory store of Arrow buffers.
42+
* A FlightProducer that hosts an in memory store of Arrow buffers. Used for integration testing.
4343
*/
4444
public class InMemoryStore implements FlightProducer, AutoCloseable {
4545

@@ -80,8 +80,7 @@ public Stream getStream(Ticket t) {
8080
}
8181

8282
@Override
83-
public void listFlights(CallContext context, Criteria criteria,
84-
StreamListener<FlightInfo> listener) {
83+
public void listFlights(CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
8584
try {
8685
for (FlightHolder h : holders.values()) {
8786
listener.onNext(h.getFlightInfo(location));
@@ -93,8 +92,7 @@ public void listFlights(CallContext context, Criteria criteria,
9392
}
9493

9594
@Override
96-
public FlightInfo getFlightInfo(CallContext context,
97-
FlightDescriptor descriptor) {
95+
public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
9896
FlightHolder h = holders.get(descriptor);
9997
if (h == null) {
10098
throw new IllegalStateException("Unknown descriptor.");
@@ -121,6 +119,8 @@ public Runnable acceptPut(CallContext context,
121119
ackStream.onNext(PutResult.metadata(flightStream.getLatestMetadata()));
122120
creator.add(unloader.getRecordBatch());
123121
}
122+
// Closing the stream will release the dictionaries
123+
flightStream.retainDictionaries();
124124
creator.complete();
125125
success = true;
126126
} finally {

java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.apache.arrow.vector.VectorSchemaRoot;
3838
import org.apache.arrow.vector.VectorUnloader;
3939
import org.apache.arrow.vector.ipc.JsonFileReader;
40+
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
4041
import org.apache.arrow.vector.util.Validator;
4142
import org.apache.commons.cli.CommandLine;
4243
import org.apache.commons.cli.CommandLineParser;
@@ -84,12 +85,19 @@ private void run(String[] args) throws ParseException, IOException {
8485
final String host = cmd.getOptionValue("host", "localhost");
8586
final int port = Integer.parseInt(cmd.getOptionValue("port", "31337"));
8687

87-
final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
8888
final Location defaultLocation = Location.forGrpcInsecure(host, port);
89-
final FlightClient client = FlightClient.builder(allocator, defaultLocation).build();
89+
try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
90+
final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) {
9091

91-
final String inputPath = cmd.getOptionValue("j");
92+
final String inputPath = cmd.getOptionValue("j");
93+
testStream(allocator, defaultLocation, client, inputPath);
94+
} catch (InterruptedException e) {
95+
throw new RuntimeException(e);
96+
}
97+
}
9298

99+
private static void testStream(BufferAllocator allocator, Location server, FlightClient client, String inputPath)
100+
throws IOException {
93101
// 1. Read data from JSON and upload to server.
94102
FlightDescriptor descriptor = FlightDescriptor.path(inputPath);
95103
VectorSchemaRoot jsonRoot;
@@ -121,7 +129,9 @@ public void onNext(PutResult val) {
121129
metadata.writeBytes(rawMetadata);
122130
// Transfers ownership of the buffer, so do not release it ourselves
123131
stream.putNext(metadata);
124-
jsonLoader.load(unloader.getRecordBatch());
132+
try (final ArrowRecordBatch arb = unloader.getRecordBatch()) {
133+
jsonLoader.load(arb);
134+
}
125135
root.clear();
126136
counter++;
127137
}
@@ -141,25 +151,29 @@ public void onNext(PutResult val) {
141151
// 3. Download the data from the server.
142152
List<Location> locations = endpoint.getLocations();
143153
if (locations.size() == 0) {
144-
locations = Collections.singletonList(defaultLocation);
154+
locations = Collections.singletonList(server);
145155
}
146156
for (Location location : locations) {
147157
System.out.println("Verifying location " + location.getUri());
148-
FlightClient readClient = FlightClient.builder(allocator, location).build();
149-
FlightStream stream = readClient.getStream(endpoint.getTicket());
150-
VectorSchemaRoot downloadedRoot;
151-
try (VectorSchemaRoot root = stream.getRoot()) {
152-
downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator);
158+
try (FlightClient readClient = FlightClient.builder(allocator, location).build();
159+
FlightStream stream = readClient.getStream(endpoint.getTicket());
160+
VectorSchemaRoot root = stream.getRoot();
161+
VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator)) {
153162
VectorLoader loader = new VectorLoader(downloadedRoot);
154163
VectorUnloader unloader = new VectorUnloader(root);
155164
while (stream.next()) {
156-
loader.load(unloader.getRecordBatch());
165+
try (final ArrowRecordBatch arb = unloader.getRecordBatch()) {
166+
loader.load(arb);
167+
}
157168
}
158-
}
159169

160-
// 4. Validate the data.
161-
Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot);
170+
// 4. Validate the data.
171+
Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot);
172+
} catch (Exception e) {
173+
throw new RuntimeException(e);
174+
}
162175
}
163176
}
177+
jsonRoot.close();
164178
}
165179
}

0 commit comments

Comments
 (0)