diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinder.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinder.java index 2dfc0658cb8d1..5fd02d08636b3 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinder.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinder.java @@ -17,6 +17,11 @@ package org.apache.arrow.adapter.jdbc; +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; import java.sql.PreparedStatement; import java.sql.SQLException; import java.util.HashMap; @@ -38,6 +43,7 @@ public class JdbcParameterBinder { private final ColumnBinder[] binders; private final int[] parameterIndices; private int nextRowIndex; + private byte[] bindersAsByteArray; /** * Create a new parameter binder. @@ -51,7 +57,8 @@ private JdbcParameterBinder( final PreparedStatement statement, final VectorSchemaRoot root, final ColumnBinder[] binders, - int[] parameterIndices) { + int[] parameterIndices, + byte[] bindersAsByteArray) { Preconditions.checkArgument( binders.length == parameterIndices.length, "Number of column binders (%s) must equal number of parameter indices (%s)", @@ -61,6 +68,7 @@ private JdbcParameterBinder( this.binders = binders; this.parameterIndices = parameterIndices; this.nextRowIndex = 0; + this.bindersAsByteArray = bindersAsByteArray; } /** @@ -137,7 +145,7 @@ public Builder bind(int parameterIndex, ColumnBinder binder) { } /** Build the binder. */ - public JdbcParameterBinder build() { + public JdbcParameterBinder build() throws IOException { ColumnBinder[] binders = new ColumnBinder[bindings.size()]; int[] parameterIndices = new int[bindings.size()]; int index = 0; @@ -146,7 +154,20 @@ public JdbcParameterBinder build() { parameterIndices[index] = entry.getKey(); index++; } - return new JdbcParameterBinder(statement, root, binders, parameterIndices); + + // Convert parameters to byte array + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + try (ObjectOutputStream outObject = new ObjectOutputStream(outStream)) { + outObject.writeObject(bindings.toString().getBytes(UTF_8)); + outObject.flush(); + } + + // return new JdbcParameterBinder(statement, root, binders, parameterIndices, outStream.toByteArray()); + return new JdbcParameterBinder(statement, root, binders, parameterIndices, outStream.toByteArray()); } } + + public byte[] getBindersAsByteArray() { + return bindersAsByteArray; + } } diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinderTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinderTest.java index 15b9ab0386159..fc64891cf91e9 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinderTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinderTest.java @@ -19,6 +19,7 @@ import static org.assertj.core.api.Assertions.assertThat; +import java.io.IOException; import java.math.BigDecimal; import java.nio.charset.StandardCharsets; import java.sql.Date; @@ -80,6 +81,7 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.JsonStringHashMap; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -98,7 +100,7 @@ void afterEach() { } @Test - void bindOrder() throws SQLException { + void bindOrder() throws SQLException, IOException { final Schema schema = new Schema( Arrays.asList( @@ -159,7 +161,7 @@ void bindOrder() throws SQLException { } @Test - void customBinder() throws SQLException { + void customBinder() throws SQLException, IOException { final Schema schema = new Schema(Collections.singletonList( Field.nullable("ints0", new ArrowType.Int(32, true)))); @@ -562,7 +564,7 @@ void testSimpleType(ArrowType arrowType, int jdbcType try (final MockPreparedStatement statement = new MockPreparedStatement(); final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { final JdbcParameterBinder binder = - JdbcParameterBinder.builder(statement, root).bindAll().build(); + JdbcParameterBinder.builder(statement, root).bindAll().build(); assertThat(binder.next()).isFalse(); @SuppressWarnings("unchecked") @@ -605,6 +607,8 @@ void testSimpleType(ArrowType arrowType, int jdbcType assertThat(binder.next()).isTrue(); assertThat(statement.getParamValue(1)).isEqualTo(values.get(1)); assertThat(binder.next()).isFalse(); + } catch (IOException e) { + Assertions.fail("Unexpected binding error."); } // Non-nullable (since some types have a specialized binder) @@ -647,6 +651,8 @@ void testSimpleType(ArrowType arrowType, int jdbcType assertThat(binder.next()).isTrue(); assertThat(statement.getParamValue(1)).isEqualTo(values.get(1)); assertThat(binder.next()).isFalse(); + } catch (IOException e) { + Assertions.fail("Unexpected binding error."); } } @@ -660,11 +666,10 @@ void testListType(ArrowType arrowType, TriConsumer void testListType(ArrowType arrowType, TriConsumer void testListType(ArrowType arrowType, TriConsumer void testMapType(ArrowType arrowType, TriConsumer void testMapType(ArrowType arrowType, TriConsumer { assert statementContext != null; PreparedStatement preparedStatement = statementContext.getStatement(); + JdbcParameterBinder binder = null; try { while (flightStream.next()) { final VectorSchemaRoot root = flightStream.getRoot(); - final JdbcParameterBinder binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build(); + binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build(); while (binder.next()) { // Do not execute() - will be done in a getStream call } @@ -927,6 +931,24 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co .withCause(e) .toRuntimeException()); return; + } catch (IOException e) { + ackStream.onError(CallStatus.INTERNAL + .withDescription("Failed to bind parameters: " + e.getMessage()) + .withCause(e) + .toRuntimeException()); + return; + } + + if (binder != null && binder.getBindersAsByteArray() != null) { + final byte[] byteArray = binder.getBindersAsByteArray(); + final DoPutPreparedStatementResult build = + DoPutPreparedStatementResult.newBuilder() + .setPreparedStatementHandle(ByteString.copyFrom(ByteBuffer.wrap(byteArray))).build(); + + try (final ArrowBuf buffer = rootAllocator.buffer(build.getSerializedSize())) { + buffer.writeBytes(build.toByteArray()); + ackStream.onNext(PutResult.metadata(buffer)); + } } ackStream.onCompleted(); };