Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import java.io.IOException;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.text.MessageFormat;
import java.time.Duration;
Expand Down Expand Up @@ -391,8 +392,7 @@ enum RunConnectionState {
*/
private RunConnectionState runConnection(boolean ssl) {
RunConnectionState result = RunConnectionState.TERMINATED;
try (ConnectionMetadata connectionMetadata =
new ConnectionMetadata(this.socket.getInputStream(), this.socket.getOutputStream())) {
try (ConnectionMetadata connectionMetadata = new ConnectionMetadata(this.socket)) {
this.connectionMetadata = connectionMetadata;

try {
Expand Down Expand Up @@ -862,7 +862,7 @@ public void setWellKnownClient(WellKnownClient wellKnownClient) {
* opportunity to determine the client that is connected based on the SQL string that is being
* executed.
*/
public void maybeDetermineWellKnownClient(Statement statement) {
public void maybeDetermineWellKnownClient(Statement statement) throws IOException {
if (!this.hasDeterminedClientUsingQuery) {
if (this.wellKnownClient == WellKnownClient.UNSPECIFIED
&& getServer().getOptions().shouldAutoDetectClient()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,42 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.concurrent.TimeUnit;

@InternalApi
public class ConnectionMetadata implements AutoCloseable {
private static final int SOCKET_BUFFER_SIZE = 1 << 16;

private final ByteBuffer buffer = ByteBuffer.allocateDirect(SOCKET_BUFFER_SIZE);
private final SocketChannel socketChannel;
private final DataInputStream inputStream;
private final DataOutputStream outputStream;
private boolean markedForRestart;

public ConnectionMetadata(InputStream rawInputStream, OutputStream rawOutputStream) {
this.socketChannel = null;
this.inputStream = new DataInputStream(rawInputStream);
this.outputStream = new DataOutputStream(rawOutputStream);
}

/**
* Creates a {@link DataInputStream} and a {@link DataOutputStream} from the given raw streams and
* pushes these as the current streams to use for communication for a connection.
*/
public ConnectionMetadata(InputStream rawInputStream, OutputStream rawOutputStream) {
public ConnectionMetadata(Socket socket) throws IOException {
this.socketChannel = socket.getChannel();
this.inputStream =
new DataInputStream(
new BufferedInputStream(
Preconditions.checkNotNull(rawInputStream), SOCKET_BUFFER_SIZE));
Preconditions.checkNotNull(socket.getInputStream()), SOCKET_BUFFER_SIZE));
this.outputStream =
new DataOutputStream(
new BufferedOutputStream(
Preconditions.checkNotNull(rawOutputStream), SOCKET_BUFFER_SIZE));
Preconditions.checkNotNull(socket.getOutputStream()), SOCKET_BUFFER_SIZE));

}

public void markForRestart() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
import com.google.cloud.spanner.pgadapter.statements.DeclareStatement.ParsedDeclareStatement.Builder;
import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName;
import com.google.cloud.spanner.pgadapter.wireprotocol.BindMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.ManuallyCreatedToken;
import com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage.ManuallyCreatedToken;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.Futures;
import java.util.concurrent.Future;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata;
import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName;
import com.google.cloud.spanner.pgadapter.wireprotocol.BindMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.ManuallyCreatedToken;
import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.PreparedType;
import com.google.cloud.spanner.pgadapter.wireprotocol.DescribeMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.ExecuteMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage.ManuallyCreatedToken;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.Futures;
import java.nio.charset.StandardCharsets;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata;
import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.ManuallyCreatedToken;
import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.PreparedType;
import com.google.cloud.spanner.pgadapter.wireprotocol.DescribeMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.ExecuteMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage.ManuallyCreatedToken;

public class FetchStatement extends AbstractFetchOrMoveStatement {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata;
import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.ManuallyCreatedToken;
import com.google.cloud.spanner.pgadapter.wireprotocol.ExecuteMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage.ManuallyCreatedToken;

/**
* MOVE is the same as FETCH, except it just skips the results instead of actually sending the rows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
import com.google.cloud.spanner.pgadapter.statements.BackendConnection.NoResult;
import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TableOrIndexName;
import com.google.cloud.spanner.pgadapter.statements.SimpleParser.TypeDefinition;
import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.ManuallyCreatedToken;
import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.PreparedType;
import com.google.cloud.spanner.pgadapter.wireprotocol.DescribeMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage.ManuallyCreatedToken;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
import com.google.cloud.spanner.pgadapter.statements.BackendConnection.ConnectionState;
import com.google.cloud.spanner.pgadapter.utils.ClientAutoDetector.WellKnownClient;
import com.google.cloud.spanner.pgadapter.wireprotocol.BindMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.ControlMessage.ManuallyCreatedToken;
import com.google.cloud.spanner.pgadapter.wireprotocol.DescribeMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.ExecuteMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.FlushMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.ParseMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.SyncMessage;
import com.google.cloud.spanner.pgadapter.wireprotocol.WireMessage.ManuallyCreatedToken;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ public abstract class AbstractQueryProtocolMessage extends ControlMessage {
}

AbstractQueryProtocolMessage(
ConnectionHandler connection, int length, ManuallyCreatedToken manuallyCreatedToken) {
ConnectionHandler connection, int length, ManuallyCreatedToken manuallyCreatedToken)
throws IOException {
super(connection, length, manuallyCreatedToken);
this.handler = connection.getExtendedQueryProtocolHandler();
this.queryMode = QueryMode.SIMPLE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement;
import com.google.cloud.spanner.pgadapter.wireoutput.BindCompleteResponse;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.Arrays;

Expand Down Expand Up @@ -58,7 +59,8 @@ public BindMessage(ConnectionHandler connection) throws Exception {
}

/** Constructor for Bind messages that are constructed to execute a Query message. */
public BindMessage(ConnectionHandler connection, ManuallyCreatedToken manuallyCreatedToken) {
public BindMessage(ConnectionHandler connection, ManuallyCreatedToken manuallyCreatedToken)
throws IOException {
this(connection, "", "", new byte[0][], manuallyCreatedToken);
}

Expand All @@ -68,7 +70,8 @@ public BindMessage(
String statementName,
String portalName,
byte[][] parameters,
ManuallyCreatedToken manuallyCreatedToken) {
ManuallyCreatedToken manuallyCreatedToken)
throws IOException {
super(connection, 4, manuallyCreatedToken);
this.portalName = portalName;
this.statementName = statementName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.cloud.spanner.pgadapter.wireoutput.ParameterStatusResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.ZoneId;
import java.util.ArrayList;
Expand All @@ -36,8 +37,8 @@
*/
@InternalApi
public abstract class BootstrapMessage extends WireMessage {
public BootstrapMessage(ConnectionHandler connection, int length) {
super(connection, length);
public BootstrapMessage(ConnectionHandler connection, int length) throws IOException {
super(connection, length, /* ManuallyCreatedToken= */ null);
}

/**
Expand Down Expand Up @@ -173,4 +174,9 @@ public static void sendStartupMessage(
ReadyResponse.sendIdleResponse(output);
output.flush();
}

@Override
protected int getHeaderLength() {
return 8;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,17 @@ public abstract class ControlMessage extends WireMessage {
/** Maximum number of invalid messages in a row allowed before we terminate the connection. */
static final int MAX_INVALID_MESSAGE_COUNT = 50;

/**
* Token that is used to mark {@link ControlMessage}s that are manually created to execute a
* {@link QueryMessage}.
*/
public enum ManuallyCreatedToken {
MANUALLY_CREATED_TOKEN
}

private final ManuallyCreatedToken manuallyCreatedToken;

public ControlMessage(ConnectionHandler connection) throws IOException {
super(connection, connection.getConnectionMetadata().getInputStream().readInt());
this.manuallyCreatedToken = null;
super(
connection,
connection.getConnectionMetadata().getInputStream().readInt(),
/* manuallyCreatedToken= */ null);
}

/** Constructor for manually created Control messages. */
protected ControlMessage(ConnectionHandler connection, int length, ManuallyCreatedToken token) {
super(connection, length);
this.manuallyCreatedToken = token;
}

public boolean isExtendedProtocol() {
return manuallyCreatedToken == null;
protected ControlMessage(ConnectionHandler connection, int length, ManuallyCreatedToken token)
throws IOException {
super(connection, length, token);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.cloud.spanner.pgadapter.wireoutput.ParameterDescriptionResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.RowDescriptionResponse;
import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
Expand Down Expand Up @@ -53,7 +54,8 @@ public DescribeMessage(ConnectionHandler connection) throws Exception {
}

/** Constructor for manually created Describe messages from the simple query protocol. */
public DescribeMessage(ConnectionHandler connection, ManuallyCreatedToken manuallyCreatedToken) {
public DescribeMessage(ConnectionHandler connection, ManuallyCreatedToken manuallyCreatedToken)
throws IOException {
this(connection, PreparedType.Portal, "", manuallyCreatedToken);
}

Expand All @@ -62,7 +64,8 @@ public DescribeMessage(
ConnectionHandler connection,
PreparedType type,
String name,
ManuallyCreatedToken manuallyCreatedToken) {
ManuallyCreatedToken manuallyCreatedToken)
throws IOException {
super(connection, 4, manuallyCreatedToken);
this.type = type;
this.name = name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.cloud.spanner.pgadapter.statements.BackendConnection;
import com.google.cloud.spanner.pgadapter.statements.CopyStatement;
import com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement;
import java.io.IOException;
import java.text.MessageFormat;

/** Executes a portal. */
Expand All @@ -40,7 +41,8 @@ public ExecuteMessage(ConnectionHandler connection) throws Exception {
}

/** Constructor for execute messages that are generated by the simple query protocol. */
public ExecuteMessage(ConnectionHandler connection, ManuallyCreatedToken manuallyCreatedToken) {
public ExecuteMessage(ConnectionHandler connection, ManuallyCreatedToken manuallyCreatedToken)
throws IOException {
this(connection, "", 0, null, true, manuallyCreatedToken);
}

Expand All @@ -50,7 +52,8 @@ public ExecuteMessage(
int maxRows,
String commandTag,
boolean cleanupAfterExecute,
ManuallyCreatedToken manuallyCreatedToken) {
ManuallyCreatedToken manuallyCreatedToken)
throws IOException {
super(connection, 8, manuallyCreatedToken);
this.name = name;
this.maxRows = maxRows;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import com.google.api.core.InternalApi;
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import java.io.IOException;
import java.text.MessageFormat;

/**
Expand All @@ -31,7 +32,8 @@ public FlushMessage(ConnectionHandler connection) throws Exception {
super(connection);
}

public FlushMessage(ConnectionHandler connection, ManuallyCreatedToken manuallyCreatedToken) {
public FlushMessage(ConnectionHandler connection, ManuallyCreatedToken manuallyCreatedToken)
throws IOException {
super(connection, 4, manuallyCreatedToken);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class GSSENCRequestMessage extends BootstrapMessage {

private final ThreadLocal<Boolean> executedOnce = ThreadLocal.withInitial(() -> false);

public GSSENCRequestMessage(ConnectionHandler connection) {
public GSSENCRequestMessage(ConnectionHandler connection) throws IOException {
super(connection, MESSAGE_LENGTH);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import com.google.cloud.spanner.pgadapter.statements.VacuumStatement;
import com.google.cloud.spanner.pgadapter.wireoutput.ParseCompleteResponse;
import com.google.common.base.Strings;
import java.io.IOException;
import java.text.MessageFormat;

/** Creates a prepared statement. */
Expand Down Expand Up @@ -96,7 +97,8 @@ public ParseMessage(ConnectionHandler connection) throws Exception {
* Constructor for manually created Parse messages that originate from the simple query protocol.
*/
public ParseMessage(
ConnectionHandler connection, ParsedStatement parsedStatement, Statement originalStatement) {
ConnectionHandler connection, ParsedStatement parsedStatement, Statement originalStatement)
throws IOException {
this(connection, "", new int[0], parsedStatement, originalStatement);
}

Expand All @@ -106,7 +108,8 @@ public ParseMessage(
String name,
int[] parameterDataTypes,
ParsedStatement parsedStatement,
Statement originalStatement) {
Statement originalStatement)
throws IOException {
super(
connection,
5 + parsedStatement.getSqlWithoutComments().length(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class SSLMessage extends BootstrapMessage {

private final ThreadLocal<Boolean> executedOnce = ThreadLocal.withInitial(() -> false);

public SSLMessage(ConnectionHandler connection) {
public SSLMessage(ConnectionHandler connection) throws IOException {
super(connection, MESSAGE_LENGTH);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,6 @@ public void nextHandler() throws Exception {
}
}

@Override
protected int getHeaderLength() {
return 8;
}

@Override
protected String getMessageName() {
return "Start-Up";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import com.google.api.core.InternalApi;
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import java.io.IOException;
import java.text.MessageFormat;

/**
Expand All @@ -31,7 +32,8 @@ public SyncMessage(ConnectionHandler connection) throws Exception {
super(connection);
}

public SyncMessage(ConnectionHandler connection, ManuallyCreatedToken manuallyCreatedToken) {
public SyncMessage(ConnectionHandler connection, ManuallyCreatedToken manuallyCreatedToken)
throws IOException {
super(connection, 4, manuallyCreatedToken);
}

Expand Down
Loading