Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement parallel jdbc driver #23932

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ public class ClientOptions
@Option(names = "--decimal-data-size", description = "Show data size and rate in base 10 rather than base 2")
public boolean decimalDataSize;

@Option(names = "--prefetch-buffer-size", paramLabel = "<prefetch-buffer-size>", defaultValue = "64000", description = "Experimental spooled protocol prefetch buffer size, default: + " + DEFAULT_VALUE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make it a DataSize, you'll need to register DataSize converter similarly to a Duration one

public String prefetchBufferSize;

public enum OutputFormat
{
AUTO,
Expand Down Expand Up @@ -346,6 +349,7 @@ public ClientSession toClientSession(TrinoUri uri)
.toClientSessionBuilder()
.source(uri.getSource().orElse(SOURCE_DEFAULT))
.encoding(encoding)
.prefetchBufferSize(prefetchBufferSize)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ public void testDefaults()
assertThat(session.getServer().toString()).isEqualTo("http://localhost:8080");
assertThat(session.getSource()).isEqualTo("trino-cli");
assertThat(session.getTimeZone()).isEqualTo(ZoneId.systemDefault());
assertThat(session.getPrefetchBufferSize()).isEqualTo("64000000");
}

@Test
public void testPrefetchBufferSize()
{
Console console = createConsole("--prefetch-buffer-size=32000");
ClientSession session = console.clientOptions.toClientSession(console.clientOptions.getTrinoUri());
assertThat(session.getPrefetchBufferSize()).isEqualTo("32000");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -57,6 +58,9 @@ public class ClientSession
private final Duration clientRequestTimeout;
private final boolean compressionDisabled;
private Optional<String> encoding;
private final String prefetchBufferSize;
private final ExecutorService decoderExecutorService;
private final ExecutorService segmentLoaderExecutorService;

public static Builder builder()
{
Expand Down Expand Up @@ -97,7 +101,10 @@ private ClientSession(
String transactionId,
Duration clientRequestTimeout,
boolean compressionDisabled,
Optional<String> encoding)
Optional<String> encoding,
String prefetchBufferSize,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataSize

ExecutorService decoderExecutorService,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be added to ClientSession

ExecutorService segmentLoaderExecutorService)
{
this.server = requireNonNull(server, "server is null");
this.user = requireNonNull(user, "user is null");
Expand All @@ -121,6 +128,9 @@ private ClientSession(
this.clientRequestTimeout = clientRequestTimeout;
this.compressionDisabled = compressionDisabled;
this.encoding = requireNonNull(encoding, "encoding is null");
this.prefetchBufferSize = requireNonNull(prefetchBufferSize, "prefetchBufferSize is null");
this.decoderExecutorService = requireNonNull(decoderExecutorService, "decoderExecutorService is null");
this.segmentLoaderExecutorService = requireNonNull(segmentLoaderExecutorService, "segmentLoaderExecutorService is null");

for (String clientTag : clientTags) {
checkArgument(!clientTag.contains(","), "client tag cannot contain ','");
Expand Down Expand Up @@ -269,6 +279,21 @@ public Optional<String> getEncoding()
return encoding;
}

public String getPrefetchBufferSize()
{
return prefetchBufferSize;
}

public ExecutorService getDecoderExecutorService()
{
return decoderExecutorService;
}

public ExecutorService getSegmentLoaderExecutorService()
{
return segmentLoaderExecutorService;
}

@Override
public String toString()
{
Expand All @@ -292,6 +317,7 @@ public String toString()
.add("clientRequestTimeout", clientRequestTimeout)
.add("compressionDisabled", compressionDisabled)
.add("encoding", encoding)
.add("prefetchBufferSize", prefetchBufferSize)
.omitNullValues()
.toString();
}
Expand Down Expand Up @@ -320,6 +346,9 @@ public static final class Builder
private Duration clientRequestTimeout;
private boolean compressionDisabled;
private Optional<String> encoding = Optional.empty();
private String prefetchBufferSize;
private ExecutorService decoderExecutorService;
private ExecutorService segmentLoaderExecutorService;

private Builder() {}

Expand Down Expand Up @@ -482,6 +511,24 @@ public Builder encoding(Optional<String> encoding)
return this;
}

public Builder prefetchBufferSize(String prefetchBufferSize)
{
this.prefetchBufferSize = prefetchBufferSize;
return this;
}

public Builder decoderExecutorService(ExecutorService decoderExecutorService)
{
this.decoderExecutorService = decoderExecutorService;
return this;
}

public Builder segmentLoaderExecutorService(ExecutorService segmentLoaderExecutorService)
{
this.segmentLoaderExecutorService = segmentLoaderExecutorService;
return this;
}

public ClientSession build()
{
return new ClientSession(
Expand All @@ -506,7 +553,10 @@ public ClientSession build()
transactionId,
clientRequestTimeout,
compressionDisabled,
encoding);
encoding,
prefetchBufferSize,
decoderExecutorService,
segmentLoaderExecutorService);
}
}
}
163 changes: 134 additions & 29 deletions client/trino-client/src/main/java/io/trino/client/ResultRowsDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
*/
package io.trino.client;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.trino.client.spooling.DataAttributes;
import io.trino.client.spooling.DeferredIterable;
import io.trino.client.spooling.EncodedQueryData;
import io.trino.client.spooling.InlineSegment;
import io.trino.client.spooling.Segment;
Expand All @@ -24,19 +27,32 @@
import org.gaul.modernizer_maven_annotations.SuppressModernizer;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.filter;
import static com.google.common.collect.Iterables.transform;
import static com.google.common.collect.Streams.forEachPair;
import static io.trino.client.ResultRows.NULL_ROWS;
import static io.trino.client.ResultRows.fromIterableRows;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newSingleThreadExecutor;

/**
* Class responsible for decoding any QueryData type.
Expand All @@ -45,16 +61,39 @@ public class ResultRowsDecoder
implements AutoCloseable
{
private final SegmentLoader loader;
private final long prefetchBufferSize;
private final ExecutorService decoderExecutorService;
private final ExecutorService segmentLoaderExecutorService;
private QueryDataDecoder decoder;
private final ReentrantLock lock = new ReentrantLock();
private final Condition sizeCondition = lock.newCondition();
private final Deque<Thread> waitingThreads = new ArrayDeque<>();
private long currentSizeInBytes;

@VisibleForTesting
public ResultRowsDecoder()
{
this(new OkHttpSegmentLoader());
this(new OkHttpSegmentLoader(),
"64000000",
newCachedThreadPool(new ThreadFactoryBuilder().setNameFormat("Decoder-%s").setDaemon(true).build()),
newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("Segment loader worker-%s").setDaemon(true).build()));
}

@VisibleForTesting
public ResultRowsDecoder(SegmentLoader loader)
{
this(loader,
"64000000",
newCachedThreadPool(new ThreadFactoryBuilder().setNameFormat("Decoder-%s").setDaemon(true).build()),
newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("Segment loader worker-%s").setDaemon(true).build()));
}

public ResultRowsDecoder(SegmentLoader loader, String prefetchBufferSize, ExecutorService decoderExecutorService, ExecutorService segmentLoaderExecutorService)
{
this.loader = requireNonNull(loader, "loader is null");
this.prefetchBufferSize = Long.parseLong(requireNonNull(prefetchBufferSize, "prefetchBufferSize is null"));
this.decoderExecutorService = requireNonNull(decoderExecutorService, "decoder is null");
this.segmentLoaderExecutorService = requireNonNull(segmentLoaderExecutorService, "segmentLoaderExecutor is null");
}

private void setEncoding(List<Column> columns, String encoding)
Expand Down Expand Up @@ -106,38 +145,104 @@ public ResultRows toRows(List<Column> columns, QueryData data)
if (data instanceof EncodedQueryData) {
EncodedQueryData encodedData = (EncodedQueryData) data;
setEncoding(columns, encodedData.getEncoding());
return concat(transform(encodedData.getSegments(), this::segmentToRows));
}
List<Segment> segments = encodedData.getSegments();

throw new UnsupportedOperationException("Unsupported data type: " + data.getClass().getName());
}
List<Future<? extends InputStream>> futures = segments.stream().map(segment -> {
int segmentSize = segment.getSegmentSize();

private ResultRows segmentToRows(Segment segment)
{
if (segment instanceof InlineSegment) {
InlineSegment inlineSegment = (InlineSegment) segment;
try {
return decoder.decode(new ByteArrayInputStream(inlineSegment.getData()), inlineSegment.getMetadata());
}
catch (IOException e) {
throw new UncheckedIOException(e);
}
}
if (segment instanceof InlineSegment) {
InlineSegment inlineSegment = (InlineSegment) segment;
return completedFuture(new ByteArrayInputStream(inlineSegment.getData()));
}

if (segment instanceof SpooledSegment) {
SpooledSegment spooledSegment = (SpooledSegment) segment;
if (segment instanceof SpooledSegment) {
SpooledSegment spooledSegment = (SpooledSegment) segment;
// download segments in parallel
return segmentLoaderExecutorService.submit(() -> {
lock.lock();
try {
boolean mustWait = currentSizeInBytes + segmentSize > prefetchBufferSize;
if (mustWait) {
waitingThreads.addLast(Thread.currentThread());
}
while (mustWait && waitingThreads.peekFirst() != Thread.currentThread()) {
// block if prefetch buffer is full
sizeCondition.await();
// now unblock the first thread that came in
}
if (mustWait) {
waitingThreads.removeFirst();
}
currentSizeInBytes += segmentSize;
}
catch (InterruptedException e) {
waitingThreads.remove(Thread.currentThread());
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
finally {
lock.unlock();
}
// download whole segment
InputStream segmentInputStream = loader.load(spooledSegment);
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
byte[] off = new byte[1024];
int bytesRead;
while ((bytesRead = segmentInputStream.read(off, 0, off.length)) != -1) {
buffer.write(off, 0, bytesRead);
}
buffer.flush();
return new ByteArrayInputStream(buffer.toByteArray());
});
}

try {
// The returned rows are lazy which means that decoder is responsible for closing input stream
InputStream stream = loader.load(spooledSegment);
return decoder.decode(stream, spooledSegment.getMetadata());
}
catch (IOException e) {
throw new RuntimeException(e);
}
throw new UnsupportedOperationException("Unsupported segment type: " + segment.getClass().getName());
}).collect(toImmutableList());

List<ResultRows> resultRows = new ArrayList<>();
forEachPair(futures.stream(), segments.stream(), (future, segment) -> {
resultRows.add(ResultRows.fromIterableRows(new DeferredIterable(
decoderExecutorService.submit(() -> {
try {
// block decode if segment is not yet downloaded
InputStream input = future.get();
return decoder.decode(input, segment.getMetadata());
}
catch (IOException e) {
throw new RuntimeException(e);
}
catch (ExecutionException e) {
throw new RuntimeException(e);
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
}),
segment.getRowsCount(),
new Callable<Void>() {
@Override
public Void call()
throws Exception
{
// the data has been read, so we can free up the buffer
lock.lock();
try {
int segmentSize = segment.getSegmentSize();
currentSizeInBytes -= segmentSize;
sizeCondition.signalAll();
}
finally {
lock.unlock();
}
return null;
}
})));
});

return concat(resultRows);
}

throw new UnsupportedOperationException("Unsupported segment type: " + segment.getClass().getName());
throw new UnsupportedOperationException("Unsupported data type: " + data.getClass().getName());
}

public Optional<String> getEncoding()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ public StatementClientV1(Call.Factory httpCallFactory, Call.Factory segmentHttpC
.collect(toImmutableSet())));
this.compressionDisabled = session.isCompressionDisabled();

this.resultRowsDecoder = new ResultRowsDecoder(new OkHttpSegmentLoader(requireNonNull(segmentHttpCallFactory, "segmentHttpCallFactory is null")));
this.resultRowsDecoder = new ResultRowsDecoder(
new OkHttpSegmentLoader(requireNonNull(segmentHttpCallFactory, "segmentHttpCallFactory is null")),
session.getPrefetchBufferSize(),
session.getDecoderExecutorService(),
session.getSegmentLoaderExecutorService());

Request request = buildQueryRequest(session, query, session.getEncoding());
// Pass empty as materializedJsonSizeLimit to always materialize the first response
Expand Down
Loading
Loading