Skip to content

Commit

Permalink
apacheGH-37720: [Java][FlightSQL] Implement stateless prepared statem…
Browse files Browse the repository at this point in the history
…ents

Add tests
  • Loading branch information
stevelorddremio committed May 21, 2024
1 parent 9e611c0 commit 9fd5cfa
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

package org.apache.arrow.adapter.jdbc;

import static java.nio.charset.StandardCharsets.UTF_8;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.HashMap;
Expand All @@ -38,6 +43,7 @@ public class JdbcParameterBinder {
private final ColumnBinder[] binders;
private final int[] parameterIndices;
private int nextRowIndex;
private byte[] bindersAsByteArray;

/**
* Create a new parameter binder.
Expand All @@ -51,7 +57,8 @@ private JdbcParameterBinder(
final PreparedStatement statement,
final VectorSchemaRoot root,
final ColumnBinder[] binders,
int[] parameterIndices) {
int[] parameterIndices,
byte[] bindersAsByteArray) {
Preconditions.checkArgument(
binders.length == parameterIndices.length,
"Number of column binders (%s) must equal number of parameter indices (%s)",
Expand All @@ -61,6 +68,7 @@ private JdbcParameterBinder(
this.binders = binders;
this.parameterIndices = parameterIndices;
this.nextRowIndex = 0;
this.bindersAsByteArray = bindersAsByteArray;
}

/**
Expand Down Expand Up @@ -137,7 +145,7 @@ public Builder bind(int parameterIndex, ColumnBinder binder) {
}

/** Build the binder. */
public JdbcParameterBinder build() {
public JdbcParameterBinder build() throws IOException {
ColumnBinder[] binders = new ColumnBinder[bindings.size()];
int[] parameterIndices = new int[bindings.size()];
int index = 0;
Expand All @@ -146,7 +154,20 @@ public JdbcParameterBinder build() {
parameterIndices[index] = entry.getKey();
index++;
}
return new JdbcParameterBinder(statement, root, binders, parameterIndices);

// Convert parameters to byte array
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
try (ObjectOutputStream outObject = new ObjectOutputStream(outStream)) {
outObject.writeObject(bindings.toString().getBytes(UTF_8));
outObject.flush();
}

// return new JdbcParameterBinder(statement, root, binders, parameterIndices, outStream.toByteArray());
return new JdbcParameterBinder(statement, root, binders, parameterIndices, outStream.toByteArray());
}
}

public byte[] getBindersAsByteArray() {
return bindersAsByteArray;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static org.assertj.core.api.Assertions.assertThat;

import java.io.IOException;
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.sql.Date;
Expand Down Expand Up @@ -80,6 +81,7 @@
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.JsonStringHashMap;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand All @@ -98,7 +100,7 @@ void afterEach() {
}

@Test
void bindOrder() throws SQLException {
void bindOrder() throws SQLException, IOException {
final Schema schema =
new Schema(
Arrays.asList(
Expand Down Expand Up @@ -159,7 +161,7 @@ void bindOrder() throws SQLException {
}

@Test
void customBinder() throws SQLException {
void customBinder() throws SQLException, IOException {
final Schema schema =
new Schema(Collections.singletonList(
Field.nullable("ints0", new ArrowType.Int(32, true))));
Expand Down Expand Up @@ -562,7 +564,7 @@ <T, V extends FieldVector> void testSimpleType(ArrowType arrowType, int jdbcType
try (final MockPreparedStatement statement = new MockPreparedStatement();
final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
final JdbcParameterBinder binder =
JdbcParameterBinder.builder(statement, root).bindAll().build();
JdbcParameterBinder.builder(statement, root).bindAll().build();
assertThat(binder.next()).isFalse();

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -605,6 +607,8 @@ <T, V extends FieldVector> void testSimpleType(ArrowType arrowType, int jdbcType
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}

// Non-nullable (since some types have a specialized binder)
Expand Down Expand Up @@ -647,6 +651,8 @@ <T, V extends FieldVector> void testSimpleType(ArrowType arrowType, int jdbcType
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}
}

Expand All @@ -660,11 +666,10 @@ <T, V extends FieldVector> void testListType(ArrowType arrowType, TriConsumer<V,
try (final MockPreparedStatement statement = new MockPreparedStatement();
final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
final JdbcParameterBinder binder =
JdbcParameterBinder.builder(statement, root).bindAll().build();
JdbcParameterBinder.builder(statement, root).bindAll().build();
assertThat(binder.next()).isFalse();

@SuppressWarnings("unchecked")
final V vector = (V) root.getVector(0);
@SuppressWarnings("unchecked") final V vector = (V) root.getVector(0);
final ColumnBinder columnBinder = ColumnBinder.forVector(vector);
assertThat(columnBinder.getJdbcType()).isEqualTo(jdbcType);

Expand Down Expand Up @@ -703,6 +708,8 @@ <T, V extends FieldVector> void testListType(ArrowType arrowType, TriConsumer<V,
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}

// Non-nullable (since some types have a specialized binder)
Expand Down Expand Up @@ -748,6 +755,8 @@ <T, V extends FieldVector> void testListType(ArrowType arrowType, TriConsumer<V,
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}
}

Expand Down Expand Up @@ -807,6 +816,8 @@ <T, V extends FieldVector> void testMapType(ArrowType arrowType, TriConsumer<V,
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1).toString());
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}

// Non-nullable (since some types have a specialized binder)
Expand Down Expand Up @@ -854,6 +865,8 @@ <T, V extends FieldVector> void testMapType(ArrowType arrowType, TriConsumer<V,
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate
while (binder.next()) {
preparedStatement.addBatch();
}
int[] recordCounts = preparedStatement.executeBatch();
final int[] recordCounts = preparedStatement.executeBatch();
recordCount = Arrays.stream(recordCounts).sum();
}

Expand All @@ -897,6 +897,9 @@ public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate
} catch (SQLException e) {
ackStream.onError(CallStatus.INTERNAL.withDescription("Failed to execute update: " + e).toRuntimeException());
return;
} catch (IOException e) {
ackStream.onError(CallStatus.INTERNAL.withDescription("Failed to execute update: " + e).toRuntimeException());
return;
}
ackStream.onCompleted();
};
Expand All @@ -911,11 +914,12 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co
return () -> {
assert statementContext != null;
PreparedStatement preparedStatement = statementContext.getStatement();
JdbcParameterBinder binder = null;

try {
while (flightStream.next()) {
final VectorSchemaRoot root = flightStream.getRoot();
final JdbcParameterBinder binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build();
binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build();
while (binder.next()) {
// Do not execute() - will be done in a getStream call
}
Expand All @@ -927,6 +931,24 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co
.withCause(e)
.toRuntimeException());
return;
} catch (IOException e) {
ackStream.onError(CallStatus.INTERNAL
.withDescription("Failed to bind parameters: " + e.getMessage())
.withCause(e)
.toRuntimeException());
return;
}

if (binder != null && binder.getBindersAsByteArray() != null) {
final byte[] byteArray = binder.getBindersAsByteArray();
final DoPutPreparedStatementResult build =
DoPutPreparedStatementResult.newBuilder()
.setPreparedStatementHandle(ByteString.copyFrom(ByteBuffer.wrap(byteArray))).build();

try (final ArrowBuf buffer = rootAllocator.buffer(build.getSerializedSize())) {
buffer.writeBytes(build.toByteArray());
ackStream.onNext(PutResult.metadata(buffer));
}
}
ackStream.onCompleted();
};
Expand Down

0 comments on commit 9fd5cfa

Please sign in to comment.