Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.google.cloud.spanner.SpannerOptions;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.TransactionRunner.TransactionCallable;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
Expand All @@ -55,6 +56,7 @@
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.it.common.ResourceManager;
import org.apache.beam.it.common.utils.ExceptionUtils;
Expand Down Expand Up @@ -370,6 +372,64 @@ public synchronized void write(Iterable<Mutation> tableRecords) throws IllegalSt
}
}

/**
* Writes a collection of mutations into one or more tables inside a ReadWriteTransaction. This
* method requires {@link SpannerResourceManager#executeDdlStatement(String)} to be called
* beforehand.
*
* @param mutations A collection of mutation objects.
*/
public void writeInTransaction(Iterable<Mutation> mutations) {
checkIsUsable();
checkHasInstanceAndDatabase();

LOG.info("Sending {} mutations to {}.{}", Iterables.size(mutations), instanceId, databaseId);
DatabaseClient databaseClient =
spanner.getDatabaseClient(DatabaseId.of(projectId, instanceId, databaseId));
databaseClient
.readWriteTransaction()
.run(
(TransactionCallable<Void>)
transaction -> {
transaction.buffer(mutations);
return null;
});
LOG.info("Successfully sent mutations to {}.{}", instanceId, databaseId);
}

/**
* Executes a list of DML statements. This method requires {@link
* SpannerResourceManager#executeDdlStatement(String)} to be called beforehand.
*
* @param statements The DML statements.
* @throws IllegalStateException if method is called after resources have been cleaned up.
*/
public synchronized void executeDMLStatements(List<String> statements)
throws IllegalStateException {
checkIsUsable();
checkHasInstanceAndDatabase();

LOG.info("Executing DML statements on database {}.", statements, databaseId);
List<Statement> statementsList =
statements.stream().map(s -> Statement.of(s)).collect(Collectors.toList());
try {
DatabaseClient databaseClient =
spanner.getDatabaseClient(DatabaseId.of(projectId, instanceId, databaseId));
databaseClient
.readWriteTransaction()
.run(
(TransactionCallable<Void>)
transaction -> {
transaction.batchUpdate(statementsList);
return null;
});
LOG.debug(
"Successfully executed DML statements '{}' on database {}.", statements, databaseId);
} catch (Exception e) {
throw new SpannerResourceManagerException("Failed to execute statement.", e);
}
}

/**
* Runs the specified query.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import com.google.cloud.spanner.Mutation;
import com.google.cloud.spanner.Struct;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.beam.it.common.utils.ResourceManagerUtils;
Expand Down Expand Up @@ -104,6 +105,73 @@ public void testResourceManagerE2E() {
Map.of("RowId", 2, "FirstName", "Jane", "LastName", "Doe", "Company", "Alphabet")));
}

@Test
public void testResourceManagerWriteInTransactionAndExecuteDML() {
// Arrange
spannerResourceManager.executeDdlStatement(
"CREATE TABLE "
+ TABLE_ID
+ " ("
+ "RowId INT64 NOT NULL,"
+ "FirstName STRING(1024),"
+ "LastName STRING(1024),"
+ "Company STRING(1024)"
+ ") PRIMARY KEY (RowId)");

List<Mutation> mutations =
List.of(
Mutation.newInsertBuilder(TABLE_ID)
.set("RowId")
.to(1)
.set("FirstName")
.to("John")
.set("LastName")
.to("Doe")
.set("Company")
.to("Google")
.build(),
Mutation.newInsertBuilder(TABLE_ID)
.set("RowId")
.to(2)
.set("FirstName")
.to("Jane")
.set("LastName")
.to("Doe")
.set("Company")
.to("Alphabet")
.build());

List<String> statements =
Arrays.asList(
"INSERT INTO "
+ TABLE_ID
+ " (RowId, FirstName, LastName, Company) values (3, 'Tester', 'Doe', 'Youtube')",
"INSERT INTO "
+ TABLE_ID
+ " (RowId, FirstName, LastName, Company) values (4, 'Jacob', 'Doe', 'DeepMind')");

// Act
spannerResourceManager.writeInTransaction(mutations);
spannerResourceManager.executeDMLStatements(statements);
long rowCount = spannerResourceManager.getRowCount(TABLE_ID);

List<Struct> fetchRecords =
spannerResourceManager.readTableRecords(
TABLE_ID, List.of("RowId", "FirstName", "LastName", "Company"));

// Assert
assertThat(rowCount).isEqualTo(4);
assertThat(fetchRecords).hasSize(4);
assertThatStructs(fetchRecords)
.hasRecordsUnorderedCaseInsensitiveColumns(
List.of(
Map.of("RowId", 1, "FirstName", "John", "LastName", "Doe", "Company", "Google"),
Map.of("RowId", 2, "FirstName", "Jane", "LastName", "Doe", "Company", "Alphabet"),
Map.of("RowId", 3, "FirstName", "Tester", "LastName", "Doe", "Company", "Youtube"),
Map.of(
"RowId", 4, "FirstName", "Jacob", "LastName", "Doe", "Company", "DeepMind")));
}

@After
public void tearDown() {
ResourceManagerUtils.cleanResources(spannerResourceManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -34,18 +35,26 @@
import com.google.cloud.Timestamp;
import com.google.cloud.spanner.Database;
import com.google.cloud.spanner.DatabaseAdminClient;
import com.google.cloud.spanner.DatabaseClient;
import com.google.cloud.spanner.Dialect;
import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.Instance;
import com.google.cloud.spanner.InstanceAdminClient;
import com.google.cloud.spanner.Mutation;
import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.TransactionContext;
import com.google.cloud.spanner.TransactionRunner;
import com.google.common.collect.ImmutableList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import org.apache.beam.it.gcp.monitoring.MonitoringClient;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -393,6 +402,192 @@ public void testWriteMultipleRecordsShouldThrowExceptionWhenSpannerWriteFails()
assertThrows(SpannerResourceManagerException.class, () -> testManager.write(testMutations));
}

@Test
public void testWriteInTransactionShouldWorkWhenSpannerWriteSucceeds()
throws ExecutionException, InterruptedException {
// arrange
prepareTable();
DatabaseClient databaseClientMock = mock(DatabaseClient.class);
TransactionRunner transactionCallableMock = mock(TransactionRunner.class);
TransactionContext transactionContext = mock(TransactionContext.class);
when(spanner.getDatabaseClient(any())).thenReturn(databaseClientMock);
when(databaseClientMock.readWriteTransaction()).thenReturn(transactionCallableMock);
when(transactionCallableMock.run(any()))
.thenAnswer(
invocation -> {
TransactionRunner.TransactionCallable<Void> callable = invocation.getArgument(0);
return callable.run(transactionContext);
});

ImmutableList<Mutation> testMutations =
ImmutableList.of(
Mutation.newInsertOrUpdateBuilder("SingerId")
.set("SingerId")
.to(1)
.set("FirstName")
.to("Marc")
.set("LastName")
.to("Richards")
.build(),
Mutation.newInsertOrUpdateBuilder("SingerId")
.set("SingerId")
.to(2)
.set("FirstName")
.to("Catalina")
.set("LastName")
.to("Smith")
.build());

// act
testManager.writeInTransaction(testMutations);

// assert
ArgumentCaptor<Iterable<Mutation>> argument = ArgumentCaptor.forClass(Iterable.class);
verify(transactionContext, times(1)).buffer(argument.capture());
Iterable<Mutation> capturedMutations = argument.getValue();

assertThat(capturedMutations).containsExactlyElementsIn(testMutations);
}

@Test
public void testWriteInTransactionShouldThrowExceptionWhenCalledBeforeExecuteDdlStatement() {
// arrange
ImmutableList<Mutation> testMutations =
ImmutableList.of(
Mutation.newInsertOrUpdateBuilder("SingerId")
.set("SingerId")
.to(1)
.set("FirstName")
.to("Marc")
.set("LastName")
.to("Richards")
.build(),
Mutation.newInsertOrUpdateBuilder("SingerId")
.set("SingerId")
.to(2)
.set("FirstName")
.to("Catalina")
.set("LastName")
.to("Smith")
.build());

// act & assert
assertThrows(IllegalStateException.class, () -> testManager.writeInTransaction(testMutations));
}

@Test
public void testWriteInTransactionShouldThrowExceptionWhenSpannerWriteFails()
throws ExecutionException, InterruptedException {
// arrange
prepareTable();
prepareTable();
DatabaseClient databaseClientMock = mock(DatabaseClient.class);
TransactionRunner transactionCallableMock = mock(TransactionRunner.class);
when(spanner.getDatabaseClient(any())).thenReturn(databaseClientMock);
when(databaseClientMock.readWriteTransaction()).thenReturn(transactionCallableMock);
when(transactionCallableMock.run(any()))
.thenAnswer(
invocation -> {
throw SpannerExceptionFactory.newSpannerException(ErrorCode.NOT_FOUND, "Not found");
});
ImmutableList<Mutation> testMutations =
ImmutableList.of(
Mutation.newInsertOrUpdateBuilder("SingerId")
.set("SingerId")
.to(1)
.set("FirstName")
.to("Marc")
.set("LastName")
.to("Richards")
.build(),
Mutation.newInsertOrUpdateBuilder("SingerId")
.set("SingerId")
.to(2)
.set("FirstName")
.to("Catalina")
.set("LastName")
.to("Smith")
.build());

// act & assert
assertThrows(SpannerException.class, () -> testManager.writeInTransaction(testMutations));
}

@Test
public void testExecuteDMLShouldWorkWhenSpannerWriteSucceeds()
throws ExecutionException, InterruptedException {
// arrange
prepareTable();
DatabaseClient databaseClientMock = mock(DatabaseClient.class);
TransactionRunner transactionCallableMock = mock(TransactionRunner.class);
TransactionContext transactionContext = mock(TransactionContext.class);
when(spanner.getDatabaseClient(any())).thenReturn(databaseClientMock);
when(databaseClientMock.readWriteTransaction()).thenReturn(transactionCallableMock);
when(transactionCallableMock.run(any()))
.thenAnswer(
invocation -> {
TransactionRunner.TransactionCallable<Void> callable = invocation.getArgument(0);
return callable.run(transactionContext);
});

ImmutableList<String> testStatements =
ImmutableList.of(
"INSERT INTO Singers (SingerId, FirstName, LastName) values (1, 'Marc', 'Richards')",
"INSERT INTO Singers (SingerId, FirstName, LastName) values (2, 'Catalina', 'Smith')");

// act
testManager.executeDMLStatements(testStatements);

// assert
ArgumentCaptor<Iterable<Statement>> argument = ArgumentCaptor.forClass(Iterable.class);
verify(transactionContext, times(1)).batchUpdate(argument.capture());
Iterable<Statement> capturedStatements = argument.getValue();

List<Statement> statementList =
testStatements.stream().map(s -> Statement.of(s)).collect(Collectors.toList());
assertThat(capturedStatements).containsExactlyElementsIn(statementList);
}

@Test
public void testExecuteDMLShouldThrowExceptionWhenCalledBeforeExecuteDdlStatement() {
// arrange
ImmutableList<String> testStatements =
ImmutableList.of(
"INSERT INTO Singers (SingerId, FirstName, LastName) values (1, 'Marc', 'Richards')",
"INSERT INTO Singers (SingerId, FirstName, LastName) values (2, 'Catalina', 'Smith')");

// act & assert
assertThrows(
IllegalStateException.class, () -> testManager.executeDMLStatements(testStatements));
}

@Test
public void testExecuteDMLShouldThrowExceptionWhenSpannerWriteFails()
throws ExecutionException, InterruptedException {
// arrange
prepareTable();
DatabaseClient databaseClientMock = mock(DatabaseClient.class);
TransactionRunner transactionCallableMock = mock(TransactionRunner.class);
when(spanner.getDatabaseClient(any())).thenReturn(databaseClientMock);
when(databaseClientMock.readWriteTransaction()).thenReturn(transactionCallableMock);
when(transactionCallableMock.run(any()))
.thenAnswer(
invocation -> {
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.DEADLINE_EXCEEDED, "Deadline exceeded while processing the request");
});

ImmutableList<String> testStatements =
ImmutableList.of(
"INSERT INTO Singers (SingerId, FirstName, LastName) values (1, 'Marc', 'Richards')",
"INSERT INTO Singers (SingerId, FirstName, LastName) values (2, 'Catalina', 'Smith')");

// act & assert
assertThrows(
SpannerResourceManagerException.class,
() -> testManager.executeDMLStatements(testStatements));
}

@Test
public void testReadRecordsShouldWorkWhenSpannerReadSucceeds()
throws ExecutionException, InterruptedException {
Expand Down
Loading
Loading