diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java index b2de2c63a..7511efff0 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java @@ -7537,7 +7537,7 @@ protected void setInterruptsEnabled(boolean interruptsEnabled) { // Flag set to indicate that an interrupt has happened. private volatile boolean wasInterrupted = false; - private boolean wasInterrupted() { + boolean wasInterrupted() { return wasInterrupted; } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLJdbcVersion.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLJdbcVersion.java index 5c677a301..33159702c 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLJdbcVersion.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLJdbcVersion.java @@ -8,7 +8,7 @@ final class SQLJdbcVersion { static final int major = 10; static final int minor = 2; - static final int patch = 1; + static final int patch = 2; static final int build = 0; /* * Used to load mssql-jdbc_auth DLL. diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java index 17bd4a5fe..b14966001 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java @@ -2804,6 +2804,21 @@ final void doExecutePreparedStatementBatch(PrepStmtBatchExecCmd batchCommand) th for (int attempt = 1; attempt <= 2; ++attempt) { try { + // If the command was interrupted, that means the TDS.PKT_CANCEL_REQ was sent to the server. + // Since the cancelation request was sent, stop processing the batch query and process the + // cancelation request and then return. + // + // Otherwise, if we do continue processing the batch query, in the case where a query requires + // prepexec/sp_prepare, the TDS request for prepexec/sp_prepare will be sent regardless of + // query cancelation. This will cause a TDS token error in the post processing when we + // close the query. + if (batchCommand.wasInterrupted()) { + ensureExecuteResultsReader(batchCommand.startResponse(getIsResponseBufferingAdaptive())); + startResults(); + getNextResult(true); + return; + } + // Re-use handle if available, requires parameter definitions which are not available until here. if (reuseCachedHandle(hasNewTypeDefinitions, 1 < attempt)) { hasNewTypeDefinitions = false; diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java index 1bbb29384..9cf8e78b1 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java @@ -48,7 +48,7 @@ protected Object[][] getContents() { {"R_noServerResponse", "SQL Server did not return a response. The connection has been closed."}, {"R_truncatedServerResponse", "SQL Server returned an incomplete response. The connection has been closed."}, {"R_queryTimedOut", "The query has timed out."}, - {"R_queryCancelled", "The query was canceled."}, + {"R_queryCanceled", "The query was canceled."}, {"R_errorReadingStream", "An error occurred while reading the value from the stream object. Error: \"{0}\""}, {"R_streamReadReturnedInvalidValue", "The stream read operation returned an invalid value for the amount of data read."}, {"R_mismatchedStreamLength", "The stream value is not the specified length. The specified length was {0}, the actual length is {1}."}, diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java b/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java index da6773116..dd664bc72 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java @@ -161,7 +161,6 @@ protected Object[][] getContents() { {"R_cancellationFailed", "Cancellation failed."}, {"R_executionNotTimeout", "Execution did not timeout."}, {"R_executionTooLong", "Execution took too long."}, {"R_executionNotLong", "Execution did not take long enough."}, - {"R_queryCancelled", "The query was canceled."}, {"R_statementShouldBeClosed", "statement should be closed since resultset is closed."}, {"R_statementShouldBeOpened", "statement should be opened since resultset is opened."}, {"R_shouldBeWrapper", "{0} should be a wrapper for {1}."}, @@ -201,5 +200,6 @@ protected Object[][] getContents() { {"R_objectNullOrEmpty", "The {0} is null or empty."}, {"R_cekDecryptionFailed", "Failed to decrypt a column encryption key using key store provider: {0}."}, {"R_connectTimedOut", "connect timed out"}, + {"R_queryCanceled", "The query was canceled."}, {"R_sessionKilled", "Cannot continue the execution because the session is in the kill state"}}; } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/callablestatement/CallableStatementTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/callablestatement/CallableStatementTest.java index 74cab8dba..5ca86eeb7 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/callablestatement/CallableStatementTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/callablestatement/CallableStatementTest.java @@ -152,7 +152,7 @@ public void testCallableStatementManyParameters() throws SQLException { @Test public void getStringGUIDTest() throws SQLException { - String sql = "{call " + AbstractSQLGenerator.escapeIdentifier(outputProcedureNameGUID) + "(?)}"; + String sql = "{call " + outputProcedureNameGUID + "(?)}"; try (SQLServerCallableStatement callableStatement = (SQLServerCallableStatement) connection.prepareCall(sql)) { @@ -181,7 +181,7 @@ public void getSetNullWithTypeVarchar() throws SQLException { SQLServerDataSource ds = new SQLServerDataSource(); ds.setURL(connectionString); ds.setSendStringParametersAsUnicode(true); - String sql = "{? = call " + AbstractSQLGenerator.escapeIdentifier(setNullProcedureName) + " (?,?)}"; + String sql = "{? = call " + setNullProcedureName + " (?,?)}"; try (Connection connection = ds.getConnection(); SQLServerCallableStatement cs = (SQLServerCallableStatement) connection.prepareCall(sql); SQLServerCallableStatement cs2 = (SQLServerCallableStatement) connection.prepareCall(sql)) { @@ -213,7 +213,7 @@ public void getSetNullWithTypeVarchar() throws SQLException { */ @Test public void testGetObjectAsLocalDateTime() throws SQLException { - String sql = "{CALL " + AbstractSQLGenerator.escapeIdentifier(getObjectLocalDateTimeProcedureName) + " (?)}"; + String sql = "{CALL " + getObjectLocalDateTimeProcedureName + " (?)}"; try (Connection con = DriverManager.getConnection(connectionString); CallableStatement cs = con.prepareCall(sql)) { cs.registerOutParameter(1, Types.TIMESTAMP); @@ -253,7 +253,7 @@ public void testGetObjectAsLocalDateTime() throws SQLException { @Test @Tag(Constants.xAzureSQLDW) public void testGetObjectAsOffsetDateTime() throws SQLException { - String sql = "{CALL " + AbstractSQLGenerator.escapeIdentifier(getObjectOffsetDateTimeProcedureName) + String sql = "{CALL " + getObjectOffsetDateTimeProcedureName + " (?, ?)}"; try (Connection con = DriverManager.getConnection(connectionString); CallableStatement cs = con.prepareCall(sql)) { @@ -283,7 +283,7 @@ public void testGetObjectAsOffsetDateTime() throws SQLException { */ @Test public void inputParamsTest() throws SQLException { - String call = "{CALL " + AbstractSQLGenerator.escapeIdentifier(inputParamsProcedureName) + " (?,?)}"; + String call = "{CALL " + inputParamsProcedureName + " (?,?)}"; // the historical way: no leading '@', parameter names respected (not positional) try (CallableStatement cs = connection.prepareCall(call)) { @@ -338,25 +338,25 @@ public static void cleanup() throws SQLException { } private static void createGUIDStoredProcedure(Statement stmt) throws SQLException { - String sql = "CREATE PROCEDURE " + AbstractSQLGenerator.escapeIdentifier(outputProcedureNameGUID) + String sql = "CREATE PROCEDURE " + outputProcedureNameGUID + "(@p1 uniqueidentifier OUTPUT) AS SELECT @p1 = c1 FROM " - + AbstractSQLGenerator.escapeIdentifier(tableNameGUID) + Constants.SEMI_COLON; + + tableNameGUID + Constants.SEMI_COLON; stmt.execute(sql); } private static void createGUIDTable(Statement stmt) throws SQLException { - String sql = "CREATE TABLE " + AbstractSQLGenerator.escapeIdentifier(tableNameGUID) + String sql = "CREATE TABLE " + tableNameGUID + " (c1 uniqueidentifier null)"; stmt.execute(sql); } private static void createSetNullProcedure(Statement stmt) throws SQLException { - stmt.execute("create procedure " + AbstractSQLGenerator.escapeIdentifier(setNullProcedureName) + stmt.execute("create procedure " + setNullProcedureName + " (@p1 nvarchar(255), @p2 nvarchar(255) output) as select @p2=@p1 return 0"); } private static void createInputParamsProcedure(Statement stmt) throws SQLException { - String sql = "CREATE PROCEDURE " + AbstractSQLGenerator.escapeIdentifier(inputParamsProcedureName) + String sql = "CREATE PROCEDURE " + inputParamsProcedureName + " @p1 nvarchar(max) = N'parameter1', " + " @p2 nvarchar(max) = N'parameter2' " + "AS " + "BEGIN " + " SET NOCOUNT ON; " + " SELECT @p1 + @p2 AS result; " + "END "; @@ -364,13 +364,13 @@ private static void createInputParamsProcedure(Statement stmt) throws SQLExcepti } private static void createGetObjectLocalDateTimeProcedure(Statement stmt) throws SQLException { - String sql = "CREATE PROCEDURE " + AbstractSQLGenerator.escapeIdentifier(getObjectLocalDateTimeProcedureName) + String sql = "CREATE PROCEDURE " + getObjectLocalDateTimeProcedureName + "(@p1 datetime2(7) OUTPUT) AS " + "SELECT @p1 = '2018-03-11T02:00:00.1234567'"; stmt.execute(sql); } private static void createGetObjectOffsetDateTimeProcedure(Statement stmt) throws SQLException { - String sql = "CREATE PROCEDURE " + AbstractSQLGenerator.escapeIdentifier(getObjectOffsetDateTimeProcedureName) + String sql = "CREATE PROCEDURE " + getObjectOffsetDateTimeProcedureName + "(@p1 DATETIMEOFFSET OUTPUT, @p2 DATETIMEOFFSET OUTPUT) AS " + "SELECT @p1 = '2018-01-02T11:22:33.123456700+12:34', @p2 = NULL"; stmt.execute(sql); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java index 8cb258a03..90a3b0a9a 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java @@ -5,6 +5,7 @@ package com.microsoft.sqlserver.jdbc.unit.statement; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; import java.lang.reflect.Field; @@ -60,6 +61,51 @@ public void testBatchExceptionAEOn() throws Exception { testExecuteBatch1UseBulkCopyAPI(); } + @Test + public void testBatchStatementCancellation() throws Exception { + try (Connection connection = PrepUtil.getConnection(connectionString)) { + connection.setAutoCommit(false); + + try (PreparedStatement statement = connection.prepareStatement( + "if object_id('test_table') is not null drop table test_table")) { + statement.execute(); + } + connection.commit(); + + try (PreparedStatement statement = connection.prepareStatement( + "create table test_table (column_name bit)")) { + statement.execute(); + } + connection.commit(); + + for (long delayInMilliseconds : new long[] { 1, 2, 4, 8, 16, 32, 64, 128 }) { + for (int numberOfCommands : new int[] { 1, 2, 4, 8, 16, 32, 64 }) { + int parameterCount = 512; + + try (PreparedStatement statement = connection.prepareStatement( + "insert into test_table values (?)" + repeat(",(?)", parameterCount - 1))) { + + for (int i = 0; i < numberOfCommands; i++) { + for (int j = 0; j < parameterCount; j++) { + statement.setBoolean(j + 1, true); + } + statement.addBatch(); + } + + Thread cancelThread = cancelAsync(statement, delayInMilliseconds); + try { + statement.executeBatch(); + } catch (SQLException e) { + assertEquals(TestResource.getResource("R_queryCancelled"), e.getMessage()); + } + cancelThread.join(); + } + connection.commit(); + } + } + } + } + /** * Get a PreparedStatement object and call the addBatch() method with 3 SQL statements and call the executeBatch() * method and it should return array of Integer values of length 3 @@ -231,6 +277,29 @@ private void modifyConnectionForBulkCopyAPI(SQLServerConnection con) throws Exce con.setUseBulkCopyForBatchInsert(true); } + private static String repeat(String string, int count) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < count; i++) { + sb.append(string); + } + return sb.toString(); + } + + private static Thread cancelAsync(Statement statement, long delayInMilliseconds) { + Thread thread = new Thread(() -> { + try { + Thread.sleep(delayInMilliseconds); + statement.cancel(); + } catch (SQLException | InterruptedException e) { + // does not/must not happen + e.printStackTrace(); + throw new IllegalStateException(e); + } + }); + thread.start(); + return thread; + } + @BeforeAll public static void testSetup() throws TestAbortedException, Exception { connectionString = TestUtils.addOrOverrideProperty(connectionString,"trustServerCertificate", "true"); diff --git a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java index 496cfa697..2bacd129b 100644 --- a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java +++ b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java @@ -71,6 +71,7 @@ public abstract class AbstractTest { protected static String trustStorePath = ""; + protected static String trustServerCertificate = ""; protected static String windowsKeyPath = null; protected static String javaKeyPath = null; protected static String javaKeyAliases = null; @@ -134,6 +135,10 @@ public static void setup() throws Exception { applicationKey = getConfiguredProperty("applicationKey"); tenantID = getConfiguredProperty("tenantID"); + trustServerCertificate = getConfiguredProperty("trustServerCertificate", "true"); + connectionString = TestUtils.addOrOverrideProperty(connectionString, "trustServerCertificate", + trustServerCertificate); + javaKeyPath = TestUtils.getCurrentClassPath() + Constants.JKS_NAME; keyIDs = getConfiguredProperty("keyID", "").split(Constants.SEMI_COLON);