diff --git a/spanner/client_test.go b/spanner/client_test.go index 912e159f0204..ec69e89a98fc 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -18,20 +18,45 @@ package spanner import ( "context" + "fmt" "io" "os" "strings" "testing" "cloud.google.com/go/spanner/internal/benchserver" - "cloud.google.com/go/spanner/internal/testutil" + . "cloud.google.com/go/spanner/internal/testutil" "google.golang.org/api/iterator" + "google.golang.org/api/option" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" gstatus "google.golang.org/grpc/status" ) +func setupMockedTestServer(t *testing.T) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) { + return setupMockedTestServerWithConfig(t, ClientConfig{}) +} + +func setupMockedTestServerWithConfig(t *testing.T, config ClientConfig) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) { + return setupMockedTestServerWithConfigAndClientOptions(t, config, []option.ClientOption{}) +} + +func setupMockedTestServerWithConfigAndClientOptions(t *testing.T, config ClientConfig, clientOptions []option.ClientOption) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) { + server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) + opts = append(opts, clientOptions...) + ctx := context.Background() + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + client, err := NewClientWithConfig(ctx, formattedDatabase, config, opts...) + if err != nil { + t.Fatal(err) + } + return server, client, func() { + client.Close() + serverTeardown() + } +} + // Test validDatabaseName() func TestValidDatabaseName(t *testing.T) { validDbURI := "projects/spanner-cloud-test/instances/foo/databases/foodb" @@ -87,14 +112,13 @@ func TestClient_Single_InvalidArgument(t *testing.T) { } func testSingleQuery(t *testing.T, serverError error) error { - config := ClientConfig{} - server, client := newSpannerInMemTestServerWithConfig(t, config) - defer server.teardown(client) + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() if serverError != nil { - server.testSpanner.SetError(serverError) + server.TestSpanner.SetError(serverError) } - ctx := context.Background() - iter := client.Single().Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() for { row, err := iter.Next() @@ -113,12 +137,12 @@ func testSingleQuery(t *testing.T, serverError error) error { return nil } -func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[string]testutil.SimulatedExecutionTime { +func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[string]SimulatedExecutionTime { errors := make([]error, 2) errors[0] = gstatus.Error(codes.Unavailable, "Temporary unavailable") errors[1] = gstatus.Error(codes.Unavailable, "Temporary unavailable") - executionTimes := make(map[string]testutil.SimulatedExecutionTime) - executionTimes[method] = testutil.SimulatedExecutionTime{ + executionTimes := make(map[string]SimulatedExecutionTime) + executionTimes[method] = SimulatedExecutionTime{ Errors: errors, } return executionTimes @@ -126,37 +150,37 @@ func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[str func TestClient_ReadOnlyTransaction(t *testing.T) { t.Parallel() - if err := testReadOnlyTransaction(t, make(map[string]testutil.SimulatedExecutionTime)); err != nil { + if err := testReadOnlyTransaction(t, make(map[string]SimulatedExecutionTime)); err != nil { t.Fatal(err) } } func TestClient_ReadOnlyTransaction_UnavailableOnSessionCreate(t *testing.T) { t.Parallel() - if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodCreateSession)); err != nil { + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodCreateSession)); err != nil { t.Fatal(err) } } func TestClient_ReadOnlyTransaction_UnavailableOnBeginTransaction(t *testing.T) { t.Parallel() - if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodBeginTransaction)); err != nil { + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodBeginTransaction)); err != nil { t.Fatal(err) } } func TestClient_ReadOnlyTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) { t.Parallel() - if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodExecuteStreamingSql)); err != nil { + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodExecuteStreamingSql)); err != nil { t.Fatal(err) } } func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndBeginTransaction(t *testing.T) { t.Parallel() - exec := map[string]testutil.SimulatedExecutionTime{ - testutil.MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, - testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + exec := map[string]SimulatedExecutionTime{ + MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, } if err := testReadOnlyTransaction(t, exec); err != nil { t.Fatal(err) @@ -165,9 +189,9 @@ func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndBeginTransactio func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndInvalidArgumentOnBeginTransaction(t *testing.T) { t.Parallel() - exec := map[string]testutil.SimulatedExecutionTime{ - testutil.MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, - testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.InvalidArgument, "Invalid argument")}}, + exec := map[string]SimulatedExecutionTime{ + MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.InvalidArgument, "Invalid argument")}}, } if err := testReadOnlyTransaction(t, exec); err == nil { t.Fatalf("Missing expected exception") @@ -176,16 +200,16 @@ func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndInvalidArgument } } -func testReadOnlyTransaction(t *testing.T, executionTimes map[string]testutil.SimulatedExecutionTime) error { - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) +func testReadOnlyTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime) error { + server, client, teardown := setupMockedTestServer(t) + defer teardown() for method, exec := range executionTimes { - server.testSpanner.PutExecutionTime(method, exec) + server.TestSpanner.PutExecutionTime(method, exec) } - ctx := context.Background() tx := client.ReadOnlyTransaction() defer tx.Close() - iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + ctx := context.Background() + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() for { row, err := iter.Next() @@ -206,15 +230,15 @@ func testReadOnlyTransaction(t *testing.T, executionTimes map[string]testutil.Si func TestClient_ReadWriteTransaction(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, make(map[string]testutil.SimulatedExecutionTime), 1); err != nil { + if err := testReadWriteTransaction(t, make(map[string]SimulatedExecutionTime), 1); err != nil { t.Fatal(err) } } func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, }, 2); err != nil { t.Fatal(err) } @@ -222,8 +246,8 @@ func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) { func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, }, 2); err != nil { t.Fatal(err) } @@ -231,17 +255,17 @@ func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) { func TestClient_ReadWriteTransaction_UnavailableOnBeginTransaction(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, }, 1); err != nil { t.Fatal(err) } } func TestClient_ReadWriteTransaction_UnavailableOnBeginAndAbortOnCommit(t *testing.T) { - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, - testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, }, 2); err != nil { t.Fatal(err) } @@ -249,8 +273,8 @@ func TestClient_ReadWriteTransaction_UnavailableOnBeginAndAbortOnCommit(t *testi func TestClient_ReadWriteTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, }, 1); err != nil { t.Fatal(err) } @@ -258,10 +282,10 @@ func TestClient_ReadWriteTransaction_UnavailableOnExecuteStreamingSql(t *testing func TestClient_ReadWriteTransaction_UnavailableOnBeginAndExecuteStreamingSqlAndTwiceAbortOnCommit(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, - testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, - testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, }, 3); err != nil { t.Fatal(err) } @@ -269,9 +293,9 @@ func TestClient_ReadWriteTransaction_UnavailableOnBeginAndExecuteStreamingSqlAnd func TestClient_ReadWriteTransaction_AbortedOnExecuteStreamingSqlAndCommit(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, - testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, + MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, }, 4); err != nil { t.Fatal(err) } @@ -279,8 +303,8 @@ func TestClient_ReadWriteTransaction_AbortedOnExecuteStreamingSqlAndCommit(t *te func TestClient_ReadWriteTransactionCommitAbortedAndUnavailable(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodCommitTransaction: { + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodCommitTransaction: { Errors: []error{ gstatus.Error(codes.Aborted, "Transaction aborted"), gstatus.Error(codes.Unavailable, "Unavailable"), @@ -293,8 +317,8 @@ func TestClient_ReadWriteTransactionCommitAbortedAndUnavailable(t *testing.T) { func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.AlreadyExists, "A row with this key already exists")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.AlreadyExists, "A row with this key already exists")}}, }, 1); err != nil { if gstatus.Code(err) != codes.AlreadyExists { t.Fatalf("Got unexpected error %v, expected %v", err, codes.AlreadyExists) @@ -304,17 +328,17 @@ func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) { } } -func testReadWriteTransaction(t *testing.T, executionTimes map[string]testutil.SimulatedExecutionTime, expectedAttempts int) error { - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) +func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error { + server, client, teardown := setupMockedTestServer(t) + defer teardown() for method, exec := range executionTimes { - server.testSpanner.PutExecutionTime(method, exec) + server.TestSpanner.PutExecutionTime(method, exec) } - var attempts int ctx := context.Background() + var attempts int _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { attempts++ - iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() for { row, err := iter.Next() @@ -343,14 +367,14 @@ func testReadWriteTransaction(t *testing.T, executionTimes map[string]testutil.S func TestClient_ApplyAtLeastOnce(t *testing.T) { t.Parallel() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() ms := []*Mutation{ Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}), } - server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, + SimulatedExecutionTime{ Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}, }) _, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce()) @@ -362,14 +386,14 @@ func TestClient_ApplyAtLeastOnce(t *testing.T) { // PartitionedUpdate should not retry on aborted. func TestClient_PartitionedUpdate(t *testing.T) { t.Parallel() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() // PartitionedDML transactions are not committed. - server.testSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, + SimulatedExecutionTime{ Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}, }) - _, err := client.PartitionedUpdate(context.Background(), NewStatement(updateBarSetFoo)) + _, err := client.PartitionedUpdate(context.Background(), NewStatement(UpdateBarSetFoo)) if err == nil { t.Fatalf("Missing expected Aborted exception") } else { @@ -380,13 +404,13 @@ func TestClient_PartitionedUpdate(t *testing.T) { } func TestReadWriteTransaction_ErrUnexpectedEOF(t *testing.T) { - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) - var attempts int + _, client, teardown := setupMockedTestServer(t) + defer teardown() ctx := context.Background() + var attempts int _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { attempts++ - iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() for { row, err := iter.Next() @@ -440,7 +464,7 @@ func TestNewClient_ConnectToEmulator(t *testing.T) { func TestClient_ApiClientHeader(t *testing.T) { t.Parallel() - server, client := newSpannerInMemTestServerWithInterceptor(t, func( + interceptor := func( ctx context.Context, method string, req interface{}, @@ -468,11 +492,12 @@ func TestClient_ApiClientHeader(t *testing.T) { return spannerErrorf(codes.Internal, "unexpected api client token: %v", token[0]) } return invoker(ctx, method, req, reply, cc, opts...) - }, - ) - defer server.teardown(client) + } + opts := []option.ClientOption{option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptor))} + _, client, teardown := setupMockedTestServerWithConfigAndClientOptions(t, ClientConfig{}, opts) + defer teardown() ctx := context.Background() - iter := client.Single().Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() for { _, err := iter.Next() diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 837b862cc742..c23e480f5315 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testutil +package testutil_test import ( emptypb "github.com/golang/protobuf/ptypes/empty" @@ -179,7 +179,9 @@ type inMemSpannerServer struct { spannerpb.SpannerServer mu sync.Mutex - + // Set to true when this server been stopped. This is the end state of a + // server, a stopped server cannot be restarted. + stopped bool // If set, all calls return this error. err error // The mock server creates session IDs using this counter. @@ -188,7 +190,6 @@ type inMemSpannerServer struct { sessions map[string]*spannerpb.Session // Last use times per session. sessionLastUseTime map[string]time.Time - // The mock server creates transaction IDs per session using these // counters. transactionCounters map[string]*uint64 @@ -198,19 +199,18 @@ type inMemSpannerServer struct { abortedTransactions map[string]bool // The transactions that are marked as PartitionedDMLTransaction partitionedDmlTransactions map[string]bool - // The mocked results for this server. statementResults map[string]*StatementResult // The simulated execution times per method. - executionTimes map[string]*SimulatedExecutionTime - // Server will stall on any requests. - freezed chan struct{} - + executionTimes map[string]*SimulatedExecutionTime totalSessionsCreated uint totalSessionsDeleted uint receivedRequests chan interface{} // Session ping history. pings []string + + // Server will stall on any requests. + freezed chan struct{} } // NewInMemSpannerServer creates a new in-mem test server. @@ -227,6 +227,9 @@ func NewInMemSpannerServer() InMemSpannerServer { } func (s *inMemSpannerServer) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + s.stopped = true close(s.receivedRequests) } @@ -234,12 +237,16 @@ func (s *inMemSpannerServer) Stop() { // transactions that have been created on the server. This method will not // remove mocked results. func (s *inMemSpannerServer) Reset() { + s.mu.Lock() + defer s.mu.Unlock() close(s.receivedRequests) s.receivedRequests = make(chan interface{}, 1000000) s.initDefaults() } func (s *inMemSpannerServer) SetError(err error) { + s.mu.Lock() + defer s.mu.Unlock() s.err = err } @@ -442,7 +449,13 @@ func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, e } func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() s.ready() s.mu.Lock() if s.err != nil { @@ -506,7 +519,13 @@ func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetS } func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() if req.Database == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing database") } @@ -544,7 +563,13 @@ func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.D } func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } @@ -624,7 +649,13 @@ func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlReques } func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } @@ -664,12 +695,24 @@ func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb } func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") } func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() return gstatus.Error(codes.Unimplemented, "Method not yet implemented") } @@ -717,7 +760,13 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe } func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } @@ -735,11 +784,23 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba } func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") } func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") } diff --git a/spanner/internal/testutil/inmem_spanner_server_test.go b/spanner/internal/testutil/inmem_spanner_server_test.go index d563ff4b6175..3074c9330c3d 100644 --- a/spanner/internal/testutil/inmem_spanner_server_test.go +++ b/spanner/internal/testutil/inmem_spanner_server_test.go @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testutil +package testutil_test_test import ( "strconv" + . "cloud.google.com/go/spanner/internal/testutil" + structpb "github.com/golang/protobuf/ptypes/struct" spannerpb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" diff --git a/spanner/internal/testutil/mockclient.go b/spanner/internal/testutil/mockclient.go index 2a3b918353f9..f24816f9fb27 100644 --- a/spanner/internal/testutil/mockclient.go +++ b/spanner/internal/testutil/mockclient.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package testutil +package testutil_test import ( "context" diff --git a/spanner/mocked_inmem_server.go b/spanner/internal/testutil/mocked_inmem_server.go similarity index 50% rename from spanner/mocked_inmem_server.go rename to spanner/internal/testutil/mocked_inmem_server.go index ab96f71b9fe0..403df729928c 100644 --- a/spanner/mocked_inmem_server.go +++ b/spanner/internal/testutil/mocked_inmem_server.go @@ -12,74 +12,73 @@ // See the License for the specific language governing permissions and // limitations under the License. -package spanner +package testutil_test import ( - "context" "fmt" "net" "strconv" "testing" - "cloud.google.com/go/spanner/internal/testutil" structpb "github.com/golang/protobuf/ptypes/struct" "google.golang.org/api/option" spannerpb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc" ) -// The SQL statements and results that are already mocked for this test server. -const selectFooFromBar = "SELECT FOO FROM BAR" +// SelectFooFromBar is a SELECT statement that is added to the mocked test +// server and will return a one-col-two-rows result set containing the INT64 +// values 1 and 2. +const SelectFooFromBar = "SELECT FOO FROM BAR" const selectFooFromBarRowCount int64 = 2 const selectFooFromBarColCount int = 1 var selectFooFromBarResults = [...]int64{1, 2} -const selectSingerIDAlbumIDAlbumTitleFromAlbums = "SELECT SingerId, AlbumId, AlbumTitle FROM Albums" -const selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount int64 = 3 -const selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount int = 3 +// SelectSingerIDAlbumIDAlbumTitleFromAlbums i a SELECT statement that is added +// to the mocked test server and will return a 3-cols-3-rows result set. +const SelectSingerIDAlbumIDAlbumTitleFromAlbums = "SELECT SingerId, AlbumId, AlbumTitle FROM Albums" -const updateBarSetFoo = "UPDATE FOO SET BAR=1 WHERE BAZ=2" -const updateBarSetFooRowCount = 5 +// SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount is the number of rows +// returned by the SelectSingerIDAlbumIDAlbumTitleFromAlbums statement. +const SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount int64 = 3 -// An InMemSpannerServer with results for a number of SQL statements readily -// mocked. -type spannerInMemTestServer struct { - testSpanner testutil.InMemSpannerServer - server *grpc.Server -} +// SelectSingerIDAlbumIDAlbumTitleFromAlbumsColCount is the number of cols +// returned by the SelectSingerIDAlbumIDAlbumTitleFromAlbums statement. +const SelectSingerIDAlbumIDAlbumTitleFromAlbumsColCount int = 3 -// Create a spannerInMemTestServer with default configuration. -func newSpannerInMemTestServer(t *testing.T) (*spannerInMemTestServer, *Client) { - s := &spannerInMemTestServer{} - client := s.setup(t) - return s, client -} +// UpdateBarSetFoo is an UPDATE statement that is added to the mocked test +// server that will return an update count of 5. +const UpdateBarSetFoo = "UPDATE FOO SET BAR=1 WHERE BAZ=2" -// Create a spannerInMemTestServer with default configuration and a client interceptor. -func newSpannerInMemTestServerWithInterceptor(t *testing.T, interceptor grpc.UnaryClientInterceptor) (*spannerInMemTestServer, *Client) { - s := &spannerInMemTestServer{} - client := s.setupWithConfig(t, ClientConfig{}, interceptor) - return s, client -} +// UpdateBarSetFooRowCount is the constant update count value returned by the +// statement defined in UpdateBarSetFoo. +const UpdateBarSetFooRowCount = 5 -// Create a spannerInMemTestServer with the specified configuration. -func newSpannerInMemTestServerWithConfig(t *testing.T, config ClientConfig) (*spannerInMemTestServer, *Client) { - s := &spannerInMemTestServer{} - client := s.setupWithConfig(t, config, nil) - return s, client +// MockedSpannerInMemTestServer is an InMemSpannerServer with results for a +// number of SQL statements readily mocked. +type MockedSpannerInMemTestServer struct { + TestSpanner InMemSpannerServer + server *grpc.Server } -func (s *spannerInMemTestServer) setup(t *testing.T) *Client { - return s.setupWithConfig(t, ClientConfig{}, nil) +// NewMockedSpannerInMemTestServer creates a MockedSpannerInMemTestServer and +// returns client options that can be used to connect to it. +func NewMockedSpannerInMemTestServer(t *testing.T) (mockedServer *MockedSpannerInMemTestServer, opts []option.ClientOption, teardown func()) { + mockedServer = &MockedSpannerInMemTestServer{} + opts = mockedServer.setupMockedServer(t) + return mockedServer, opts, func() { + mockedServer.TestSpanner.Stop() + mockedServer.server.Stop() + } } -func (s *spannerInMemTestServer) setupWithConfig(t *testing.T, config ClientConfig, interceptor grpc.UnaryClientInterceptor) *Client { - s.testSpanner = testutil.NewInMemSpannerServer() +func (s *MockedSpannerInMemTestServer) setupMockedServer(t *testing.T) []option.ClientOption { + s.TestSpanner = NewInMemSpannerServer() s.setupFooResults() s.setupSingersResults() s.server = grpc.NewServer() - spannerpb.RegisterSpannerServer(s.server, s.testSpanner) + spannerpb.RegisterSpannerServer(s.server, s.TestSpanner) lis, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -88,24 +87,15 @@ func (s *spannerInMemTestServer) setupWithConfig(t *testing.T, config ClientConf go s.server.Serve(lis) serverAddress := lis.Addr().String() - ctx := context.Background() - var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") opts := []option.ClientOption{ option.WithEndpoint(serverAddress), option.WithGRPCDialOption(grpc.WithInsecure()), option.WithoutAuthentication(), } - if interceptor != nil { - opts = append(opts, option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptor))) - } - client, err := NewClientWithConfig(ctx, formattedDatabase, config, opts...) - if err != nil { - t.Fatal(err) - } - return client + return opts } -func (s *spannerInMemTestServer) setupFooResults() { +func (s *MockedSpannerInMemTestServer) setupFooResults() { fields := make([]*spannerpb.StructType_Field, selectFooFromBarColCount) fields[0] = &spannerpb.StructType_Field{ Name: "FOO", @@ -131,16 +121,16 @@ func (s *spannerInMemTestServer) setupFooResults() { Metadata: metadata, Rows: rows, } - result := &testutil.StatementResult{Type: testutil.StatementResultResultSet, ResultSet: resultSet} - s.testSpanner.PutStatementResult(selectFooFromBar, result) - s.testSpanner.PutStatementResult(updateBarSetFoo, &testutil.StatementResult{ - Type: testutil.StatementResultUpdateCount, - UpdateCount: updateBarSetFooRowCount, + result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet} + s.TestSpanner.PutStatementResult(SelectFooFromBar, result) + s.TestSpanner.PutStatementResult(UpdateBarSetFoo, &StatementResult{ + Type: StatementResultUpdateCount, + UpdateCount: UpdateBarSetFooRowCount, }) } -func (s *spannerInMemTestServer) setupSingersResults() { - fields := make([]*spannerpb.StructType_Field, selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) +func (s *MockedSpannerInMemTestServer) setupSingersResults() { + fields := make([]*spannerpb.StructType_Field, SelectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) fields[0] = &spannerpb.StructType_Field{ Name: "SingerId", Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64}, @@ -159,10 +149,10 @@ func (s *spannerInMemTestServer) setupSingersResults() { metadata := &spannerpb.ResultSetMetadata{ RowType: rowType, } - rows := make([]*structpb.ListValue, selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) + rows := make([]*structpb.ListValue, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) var idx int64 - for idx = 0; idx < selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount; idx++ { - rowValue := make([]*structpb.Value, selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) + for idx = 0; idx < SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount; idx++ { + rowValue := make([]*structpb.Value, SelectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) rowValue[0] = &structpb.Value{ Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(idx+1, 10)}, } @@ -180,11 +170,6 @@ func (s *spannerInMemTestServer) setupSingersResults() { Metadata: metadata, Rows: rows, } - result := &testutil.StatementResult{Type: testutil.StatementResultResultSet, ResultSet: resultSet} - s.testSpanner.PutStatementResult(selectSingerIDAlbumIDAlbumTitleFromAlbums, result) -} - -func (s *spannerInMemTestServer) teardown(client *Client) { - client.Close() - s.server.Stop() + result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet} + s.TestSpanner.PutStatementResult(SelectSingerIDAlbumIDAlbumTitleFromAlbums, result) } diff --git a/spanner/internal/testutil/mockserver.go b/spanner/internal/testutil/mockserver.go index b5a6f5e505fd..333d0e989bb7 100644 --- a/spanner/internal/testutil/mockserver.go +++ b/spanner/internal/testutil/mockserver.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package testutil +package testutil_test import ( "context" diff --git a/spanner/pdml_test.go b/spanner/pdml_test.go index 743f5113cc10..9f479bb6189e 100644 --- a/spanner/pdml_test.go +++ b/spanner/pdml_test.go @@ -18,21 +18,22 @@ import ( "context" "testing" + . "cloud.google.com/go/spanner/internal/testutil" "google.golang.org/grpc/codes" ) func TestMockPartitionedUpdate(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + _, client, teardown := setupMockedTestServer(t) + defer teardown() - stmt := NewStatement(updateBarSetFoo) + stmt := NewStatement(UpdateBarSetFoo) rowCount, err := client.PartitionedUpdate(ctx, stmt) if err != nil { t.Fatal(err) } - want := int64(updateBarSetFooRowCount) + want := int64(UpdateBarSetFooRowCount) if rowCount != want { t.Errorf("got %d, want %d", rowCount, want) } @@ -41,10 +42,10 @@ func TestMockPartitionedUpdate(t *testing.T) { func TestMockPartitionedUpdateWithQuery(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + _, client, teardown := setupMockedTestServer(t) + defer teardown() - stmt := NewStatement(selectFooFromBar) + stmt := NewStatement(SelectFooFromBar) _, err := client.PartitionedUpdate(ctx, stmt) wantCode := codes.InvalidArgument if serr, ok := err.(*Error); !ok || serr.Code != wantCode { diff --git a/spanner/read_test.go b/spanner/read_test.go index 5a660c252525..b18b0938d7d1 100644 --- a/spanner/read_test.go +++ b/spanner/read_test.go @@ -26,7 +26,7 @@ import ( "time" "cloud.google.com/go/spanner/internal/backoff" - "cloud.google.com/go/spanner/internal/testutil" + . "cloud.google.com/go/spanner/internal/testutil" "github.com/golang/protobuf/proto" proto3 "github.com/golang/protobuf/ptypes/struct" "google.golang.org/api/iterator" @@ -42,7 +42,7 @@ var ( // Metadata for mocked KV table, its rows are returned by SingleUse // transactions. kvMeta = func() *sppb.ResultSetMetadata { - meta := testutil.KvMeta + meta := KvMeta meta.Transaction = &sppb.Transaction{ ReadTimestamp: timestampProto(trxTs), } @@ -642,7 +642,7 @@ func TestRsdNonblockingStates(t *testing.T) { defer restore() tests := []struct { name string - msgs []testutil.MockCtlMsg + msgs []MockCtlMsg rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error) sql string // Expected values @@ -655,7 +655,7 @@ func TestRsdNonblockingStates(t *testing.T) { { // unConnected->queueingRetryable->finished name: "unConnected->queueingRetryable->finished", - msgs: []testutil.MockCtlMsg{ + msgs: []MockCtlMsg{ {}, {}, {Err: io.EOF, ResumeToken: false}, @@ -689,7 +689,7 @@ func TestRsdNonblockingStates(t *testing.T) { { // unConnected->queueingRetryable->aborted name: "unConnected->queueingRetryable->aborted", - msgs: []testutil.MockCtlMsg{ + msgs: []MockCtlMsg{ {}, {Err: nil, ResumeToken: true}, {}, @@ -710,7 +710,7 @@ func TestRsdNonblockingStates(t *testing.T) { {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, }, - ResumeToken: testutil.EncodeResumeToken(1), + ResumeToken: EncodeResumeToken(1), }, }, stateHistory: []resumableStreamDecoderState{ @@ -726,9 +726,9 @@ func TestRsdNonblockingStates(t *testing.T) { { // unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable name: "unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable", - msgs: func() (m []testutil.MockCtlMsg) { + msgs: func() (m []MockCtlMsg) { for i := 0; i < maxBuffers+1; i++ { - m = append(m, testutil.MockCtlMsg{}) + m = append(m, MockCtlMsg{}) } return m }(), @@ -760,11 +760,11 @@ func TestRsdNonblockingStates(t *testing.T) { { // unConnected->queueingRetryable->queueingUnretryable->aborted name: "unConnected->queueingRetryable->queueingUnretryable->aborted", - msgs: func() (m []testutil.MockCtlMsg) { + msgs: func() (m []MockCtlMsg) { for i := 0; i < maxBuffers; i++ { - m = append(m, testutil.MockCtlMsg{}) + m = append(m, MockCtlMsg{}) } - m = append(m, testutil.MockCtlMsg{Err: errors.New("Just Abort It"), ResumeToken: false}) + m = append(m, MockCtlMsg{Err: errors.New("Just Abort It"), ResumeToken: false}) return m }(), sql: "SELECT t.key key, t.value value FROM t_mock t", @@ -794,7 +794,7 @@ func TestRsdNonblockingStates(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() mc := sppb.NewSpannerClient(dialMock(t, ms)) if test.rpc == nil { @@ -890,7 +890,7 @@ func TestRsdBlockingStates(t *testing.T) { defer restore() for _, test := range []struct { name string - msgs []testutil.MockCtlMsg + msgs []MockCtlMsg rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error) sql string // Expected values @@ -919,7 +919,7 @@ func TestRsdBlockingStates(t *testing.T) { { // unConnected->queueingRetryable->queueingRetryable name: "unConnected->queueingRetryable->queueingRetryable", - msgs: []testutil.MockCtlMsg{ + msgs: []MockCtlMsg{ {}, {Err: nil, ResumeToken: true}, {Err: nil, ResumeToken: true}, @@ -940,7 +940,7 @@ func TestRsdBlockingStates(t *testing.T) { {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, }, - ResumeToken: testutil.EncodeResumeToken(1), + ResumeToken: EncodeResumeToken(1), }, { Metadata: kvMeta, @@ -948,7 +948,7 @@ func TestRsdBlockingStates(t *testing.T) { {Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}}, {Kind: &proto3.Value_StringValue{StringValue: valStr(2)}}, }, - ResumeToken: testutil.EncodeResumeToken(2), + ResumeToken: EncodeResumeToken(2), }, }, queue: []*sppb.PartialResultSet{ @@ -960,7 +960,7 @@ func TestRsdBlockingStates(t *testing.T) { }, }, }, - resumeToken: testutil.EncodeResumeToken(2), + resumeToken: EncodeResumeToken(2), stateHistory: []resumableStreamDecoderState{ queueingRetryable, // do RPC queueingRetryable, // got foo-00 @@ -974,12 +974,12 @@ func TestRsdBlockingStates(t *testing.T) { { // unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable name: "unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable", - msgs: func() (m []testutil.MockCtlMsg) { + msgs: func() (m []MockCtlMsg) { for i := 0; i < maxBuffers+1; i++ { - m = append(m, testutil.MockCtlMsg{}) + m = append(m, MockCtlMsg{}) } - m = append(m, testutil.MockCtlMsg{Err: nil, ResumeToken: true}) - m = append(m, testutil.MockCtlMsg{}) + m = append(m, MockCtlMsg{Err: nil, ResumeToken: true}) + m = append(m, MockCtlMsg{}) return m }(), sql: "SELECT t.key key, t.value value FROM t_mock t", @@ -993,10 +993,10 @@ func TestRsdBlockingStates(t *testing.T) { }, }) } - s[maxBuffers+1].ResumeToken = testutil.EncodeResumeToken(maxBuffers + 1) + s[maxBuffers+1].ResumeToken = EncodeResumeToken(maxBuffers + 1) return s }(), - resumeToken: testutil.EncodeResumeToken(maxBuffers + 1), + resumeToken: EncodeResumeToken(maxBuffers + 1), queue: []*sppb.PartialResultSet{ { Metadata: kvMeta, @@ -1026,11 +1026,11 @@ func TestRsdBlockingStates(t *testing.T) { { // unConnected->queueingRetryable->queueingUnretryable->finished name: "unConnected->queueingRetryable->queueingUnretryable->finished", - msgs: func() (m []testutil.MockCtlMsg) { + msgs: func() (m []MockCtlMsg) { for i := 0; i < maxBuffers; i++ { - m = append(m, testutil.MockCtlMsg{}) + m = append(m, MockCtlMsg{}) } - m = append(m, testutil.MockCtlMsg{Err: io.EOF, ResumeToken: false}) + m = append(m, MockCtlMsg{Err: io.EOF, ResumeToken: false}) return m }(), sql: "SELECT t.key key, t.value value FROM t_mock t", @@ -1058,7 +1058,7 @@ func TestRsdBlockingStates(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() cc := dialMock(t, ms) mc := sppb.NewSpannerClient(cc) @@ -1194,7 +1194,7 @@ func (sr *sReceiver) waitn(n int) error { func TestQueueBytes(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1279,7 +1279,7 @@ func TestQueueBytes(t *testing.T) { func TestResumeToken(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1476,7 +1476,7 @@ func TestResumeToken(t *testing.T) { func TestGrpcReconnect(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1543,7 +1543,7 @@ func TestGrpcReconnect(t *testing.T) { func TestCancelTimeout(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1630,7 +1630,7 @@ func TestCancelTimeout(t *testing.T) { func TestRowIteratorDo(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1663,7 +1663,7 @@ func TestRowIteratorDo(t *testing.T) { func TestRowIteratorDoWithError(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1694,7 +1694,7 @@ func TestIteratorStopEarly(t *testing.T) { ctx := context.Background() restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1738,7 +1738,7 @@ func TestIteratorWithError(t *testing.T) { } } -func dialMock(t *testing.T, ms *testutil.MockCloudSpanner) *grpc.ClientConn { +func dialMock(t *testing.T, ms *MockCloudSpanner) *grpc.ClientConn { cc, err := grpc.Dial(ms.Addr(), grpc.WithInsecure(), grpc.WithBlock()) if err != nil { t.Fatalf("Dial(%q) = %v", ms.Addr(), err) diff --git a/spanner/session_test.go b/spanner/session_test.go index f3f779bc8b11..7e0392556bb9 100644 --- a/spanner/session_test.go +++ b/spanner/session_test.go @@ -26,7 +26,7 @@ import ( "time" vkit "cloud.google.com/go/spanner/apiv1" - "cloud.google.com/go/spanner/internal/testutil" + . "cloud.google.com/go/spanner/internal/testutil" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -34,8 +34,8 @@ import ( // TestSessionPoolConfigValidation tests session pool config validation. func TestSessionPoolConfigValidation(t *testing.T) { t.Parallel() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + _, client, teardown := setupMockedTestServer(t) + defer teardown() for _, test := range []struct { spc SessionPoolConfig @@ -66,8 +66,8 @@ func TestSessionPoolConfigValidation(t *testing.T) { func TestSessionCreation(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() sp := client.idleSessions // Take three sessions from session pool, this should trigger session pool @@ -86,7 +86,7 @@ func TestSessionCreation(t *testing.T) { if len(gotDs) != len(shs) { t.Fatalf("session pool created %v sessions, want %v", len(gotDs), len(shs)) } - if wantDs := server.testSpanner.DumpSessions(); !testEqual(gotDs, wantDs) { + if wantDs := server.TestSpanner.DumpSessions(); !testEqual(gotDs, wantDs) { t.Fatalf("session pool creates sessions %v, want %v", gotDs, wantDs) } // Verify that created sessions are recorded correctly in session pool. @@ -118,11 +118,11 @@ func TestTakeFromIdleList(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{MaxIdle: 10}, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Take ten sessions from session pool and recycle them. @@ -142,7 +142,7 @@ func TestTakeFromIdleList(t *testing.T) { } // Further session requests from session pool won't cause mockclient to // create more sessions. - wantSessions := server.testSpanner.DumpSessions() + wantSessions := server.TestSpanner.DumpSessions() // Take ten sessions from session pool again, this time all sessions should // come from idle list. gotSessions := map[string]bool{} @@ -168,11 +168,11 @@ func TestTakeWriteSessionFromIdleList(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{MaxIdle: 20}, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Take ten sessions from session pool and recycle them. @@ -192,7 +192,7 @@ func TestTakeWriteSessionFromIdleList(t *testing.T) { } // Further session requests from session pool won't cause mockclient to // create more sessions. - wantSessions := server.testSpanner.DumpSessions() + wantSessions := server.TestSpanner.DumpSessions() // Take ten sessions from session pool again, this time all sessions should // come from idle list. gotSessions := map[string]bool{} @@ -218,7 +218,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MaxIdle: 1, @@ -226,7 +226,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { healthCheckSampleInterval: 10 * time.Millisecond, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Stop healthcheck workers to simulate slow pings. @@ -262,7 +262,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { // The two back-to-back session requests shouldn't trigger any session // pings because sessionPool.Take // reschedules the next healthcheck. - if got, want := server.testSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { + if got, want := server.TestSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { t.Fatalf("%v - got ping session requests: %v, want %v", i, got, want) } sh.recycle() @@ -271,8 +271,8 @@ func TestTakeFromIdleListChecked(t *testing.T) { // Inject session error to server stub, and take the session from the // session pool, the old session should be destroyed and the session pool // will create a new session. - server.testSpanner.PutExecutionTime(testutil.MethodGetSession, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodGetSession, + SimulatedExecutionTime{ Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, }) @@ -287,7 +287,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { if err != nil { t.Fatalf("failed to get session: %v", err) } - ds := server.testSpanner.DumpSessions() + ds := server.TestSpanner.DumpSessions() if len(ds) != 1 { t.Fatalf("dumped sessions from mockclient: %v, want %v", ds, sh.getID()) } @@ -303,7 +303,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MaxIdle: 1, @@ -311,7 +311,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { healthCheckSampleInterval: 10 * time.Millisecond, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Stop healthcheck workers to simulate slow pings. @@ -345,7 +345,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { } // The two back-to-back session requests shouldn't trigger any session // pings because sessionPool.Take reschedules the next healthcheck. - if got, want := server.testSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { + if got, want := server.TestSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { t.Fatalf("%v - got ping session requests: %v, want %v", i, got, want) } sh.recycle() @@ -354,8 +354,8 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { // Inject session error to mockclient, and take the session from the // session pool, the old session should be destroyed and the session pool // will create a new session. - server.testSpanner.PutExecutionTime(testutil.MethodGetSession, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodGetSession, + SimulatedExecutionTime{ Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, }) @@ -367,7 +367,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { if err != nil { t.Fatalf("failed to get session: %v", err) } - ds := server.testSpanner.DumpSessions() + ds := server.TestSpanner.DumpSessions() if len(ds) != 1 { t.Fatalf("dumped sessions from mockclient: %v, want %v", ds, sh.getID()) } @@ -380,13 +380,13 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { func TestMaxOpenedSessions(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MaxOpened: 1, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions sh1, err := sp.take(ctx) @@ -425,13 +425,13 @@ func TestMaxOpenedSessions(t *testing.T) { func TestMinOpenedSessions(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MinOpened: 1, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Take ten sessions from session pool and recycle them. @@ -468,18 +468,18 @@ func TestMinOpenedSessions(t *testing.T) { func TestMaxBurst(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MaxBurst: 1, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Will cause session creation RPC to be retried forever. - server.testSpanner.PutExecutionTime(testutil.MethodCreateSession, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodCreateSession, + SimulatedExecutionTime{ Errors: []error{status.Errorf(codes.Unavailable, "try later")}, KeepError: true, }) @@ -511,10 +511,10 @@ func TestMaxBurst(t *testing.T) { } // Let the first session request succeed. - server.testSpanner.Freeze() - server.testSpanner.PutExecutionTime(testutil.MethodCreateSession, testutil.SimulatedExecutionTime{}) + server.TestSpanner.Freeze() + server.TestSpanner.PutExecutionTime(MethodCreateSession, SimulatedExecutionTime{}) //close(allowRequests) - server.testSpanner.Unfreeze() + server.TestSpanner.Unfreeze() // Now new session request can proceed because the first session request will eventually succeed. sh, err := sp.take(ctx) @@ -530,14 +530,14 @@ func TestMaxBurst(t *testing.T) { func TestSessionRecycle(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MinOpened: 1, MaxIdle: 5, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Test session is correctly recycled and reused. @@ -568,13 +568,13 @@ func TestSessionDestroy(t *testing.T) { t.Skip("s.destroy(true) is flakey") t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MinOpened: 1, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions <-time.After(10 * time.Millisecond) // maintainer will create one session, we wait for it create session to avoid flakiness in test @@ -631,14 +631,14 @@ func TestHcHeap(t *testing.T) { func TestHealthCheckScheduler(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ HealthCheckInterval: 50 * time.Millisecond, healthCheckSampleInterval: 10 * time.Millisecond, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Create 50 sessions. @@ -652,13 +652,13 @@ func TestHealthCheckScheduler(t *testing.T) { // Make sure we start with a ping history to avoid that the first // sessions that were created have not already exceeded the maximum // number of pings. - server.testSpanner.ClearPings() + server.TestSpanner.ClearPings() // Wait for 10-30 pings per session. waitFor(t, func() error { // Only check actually live sessions and ignore any sessions the // session pool may have deleted in the meantime. - liveSessions := server.testSpanner.DumpSessions() - dp := server.testSpanner.DumpPings() + liveSessions := server.TestSpanner.DumpSessions() + dp := server.TestSpanner.DumpPings() gotPings := map[string]int64{} for _, p := range dp { gotPings[p]++ @@ -678,14 +678,14 @@ func TestHealthCheckScheduler(t *testing.T) { func TestWriteSessionsPrepared(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ WriteSessions: 0.5, MaxIdle: 20, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions shs := make([]*sessionHandle, 10) @@ -748,7 +748,7 @@ func TestWriteSessionsPrepared(t *testing.T) { func TestTakeFromWriteQueue(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MaxOpened: 1, @@ -756,7 +756,7 @@ func TestTakeFromWriteQueue(t *testing.T) { MaxIdle: 1, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions sh, err := sp.take(ctx) @@ -786,14 +786,14 @@ func TestTakeFromWriteQueue(t *testing.T) { func TestSessionHealthCheck(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ HealthCheckInterval: 50 * time.Millisecond, healthCheckSampleInterval: 10 * time.Millisecond, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Test pinging sessions. @@ -804,7 +804,7 @@ func TestSessionHealthCheck(t *testing.T) { // Wait for healthchecker to send pings to session. waitFor(t, func() error { - pings := server.testSpanner.DumpPings() + pings := server.TestSpanner.DumpPings() if len(pings) == 0 || pings[0] != sh.getID() { return fmt.Errorf("healthchecker didn't send any ping to session %v", sh.getID()) } @@ -816,13 +816,13 @@ func TestSessionHealthCheck(t *testing.T) { t.Fatalf("cannot get session from session pool: %v", err) } - server.testSpanner.Freeze() - server.testSpanner.PutExecutionTime(testutil.MethodGetSession, - testutil.SimulatedExecutionTime{ + server.TestSpanner.Freeze() + server.TestSpanner.PutExecutionTime(MethodGetSession, + SimulatedExecutionTime{ Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, KeepError: true, }) - server.testSpanner.Unfreeze() + server.TestSpanner.Unfreeze() //atomic.SwapInt64(&requestShouldErr, 1) // Wait for healthcheck workers to find the broken session and tear it down. @@ -834,9 +834,9 @@ func TestSessionHealthCheck(t *testing.T) { t.Fatalf("session(%v) is still alive, want it to be dropped by healthcheck workers", s) } - server.testSpanner.Freeze() - server.testSpanner.PutExecutionTime(testutil.MethodGetSession, testutil.SimulatedExecutionTime{}) - server.testSpanner.Unfreeze() + server.TestSpanner.Freeze() + server.TestSpanner.PutExecutionTime(MethodGetSession, SimulatedExecutionTime{}) + server.TestSpanner.Unfreeze() // Test garbage collection. sh, err = sp.take(ctx) @@ -876,10 +876,9 @@ func TestStressSessionPool(t *testing.T) { cfg.healthCheckSampleInterval = 10 * time.Millisecond cfg.HealthCheckWorkers = 50 - server, client := newSpannerInMemTestServerWithConfig(t, - ClientConfig{ - SessionPoolConfig: cfg, - }) + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: cfg, + }) sp := client.idleSessions // Create a test group for this configuration and schedule 100 sub @@ -898,7 +897,7 @@ func TestStressSessionPool(t *testing.T) { // stable. idleSessions := map[string]bool{} hcSessions := map[string]bool{} - mockSessions := server.testSpanner.DumpSessions() + mockSessions := server.TestSpanner.DumpSessions() // Dump session pool's idle list. for sl := sp.idleList.Front(); sl != nil; sl = sl.Next() { s := sl.Value.(*session) @@ -944,13 +943,13 @@ func TestStressSessionPool(t *testing.T) { } } sp.close() - mockSessions = server.testSpanner.DumpSessions() + mockSessions = server.TestSpanner.DumpSessions() for id, b := range hcSessions { if b && mockSessions[id] { t.Fatalf("Found session from pool still live on server: %v", id) } } - server.teardown(client) + teardown() } } @@ -1034,14 +1033,14 @@ func TestMaintainer(t *testing.T) { minOpened := uint64(5) maxIdle := uint64(4) - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MinOpened: minOpened, MaxIdle: maxIdle, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions sampleInterval := sp.SessionPoolConfig.healthCheckSampleInterval @@ -1114,11 +1113,11 @@ func TestMaintainer_CreatesSessions(t *testing.T) { MaxIdle: 10, healthCheckSampleInterval: 20 * time.Millisecond, } - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: spc, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions timeoutAmt := 4 * time.Second diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 9c10ac4d8195..2976c96b595d 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -25,7 +25,7 @@ import ( "testing" "time" - "cloud.google.com/go/spanner/internal/testutil" + . "cloud.google.com/go/spanner/internal/testutil" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" @@ -35,8 +35,8 @@ import ( func TestSingle(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() txn := client.Single() defer txn.Close() @@ -50,7 +50,7 @@ func TestSingle(t *testing.T) { } // Only one CreateSessionRequest is sent. - if _, err := shouldHaveReceived(server.testSpanner, []interface{}{&sppb.CreateSessionRequest{}}); err != nil { + if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{&sppb.CreateSessionRequest{}}); err != nil { t.Fatal(err) } } @@ -59,16 +59,16 @@ func TestSingle(t *testing.T) { func TestReadOnlyTransaction_RecoverFromFailure(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() txn := client.ReadOnlyTransaction() defer txn.Close() // First request will fail. errUsr := gstatus.Error(codes.Unknown, "error") - server.testSpanner.PutExecutionTime(testutil.MethodBeginTransaction, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodBeginTransaction, + SimulatedExecutionTime{ Errors: []error{errUsr}, }) @@ -86,8 +86,8 @@ func TestReadOnlyTransaction_RecoverFromFailure(t *testing.T) { func TestReadOnlyTransaction_UseAfterClose(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + _, client, teardown := setupMockedTestServer(t) + defer teardown() txn := client.ReadOnlyTransaction() txn.Close() @@ -102,12 +102,12 @@ func TestReadOnlyTransaction_UseAfterClose(t *testing.T) { func TestReadOnlyTransaction_Concurrent(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() txn := client.ReadOnlyTransaction() defer txn.Close() - server.testSpanner.Freeze() + server.TestSpanner.Freeze() var ( sh1 *sessionHandle sh2 *sessionHandle @@ -130,7 +130,7 @@ func TestReadOnlyTransaction_Concurrent(t *testing.T) { // TODO(deklerk): Get rid of this. <-time.After(100 * time.Millisecond) - server.testSpanner.Unfreeze() + server.TestSpanner.Unfreeze() wg.Wait() if sh1.session.id != sh2.session.id { t.Fatalf("Expected acquire to get same session handle, got %v and %v.", sh1, sh2) @@ -143,8 +143,8 @@ func TestReadOnlyTransaction_Concurrent(t *testing.T) { func TestApply_Single(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() ms := []*Mutation{ Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), @@ -154,7 +154,7 @@ func TestApply_Single(t *testing.T) { t.Fatalf("applyAtLeastOnce retry on abort, got %v, want nil.", e) } - if _, err := shouldHaveReceived(server.testSpanner, []interface{}{ + if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.CommitRequest{}, }); err != nil { @@ -166,13 +166,13 @@ func TestApply_Single(t *testing.T) { func TestApply_RetryOnAbort(t *testing.T) { ctx := context.Background() t.Parallel() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() // First commit will fail, and the retry will begin a new transaction. errAbrt := spannerErrorf(codes.Aborted, "") - server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, + SimulatedExecutionTime{ Errors: []error{errAbrt}, }) @@ -184,7 +184,7 @@ func TestApply_RetryOnAbort(t *testing.T) { t.Fatalf("ReadWriteTransaction retry on abort, got %v, want nil.", e) } - if _, err := shouldHaveReceived(server.testSpanner, []interface{}{ + if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, &sppb.CommitRequest{}, // First commit fails. @@ -199,16 +199,16 @@ func TestApply_RetryOnAbort(t *testing.T) { func TestTransaction_NotFound(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() wantErr := spannerErrorf(codes.NotFound, "Session not found") - server.testSpanner.PutExecutionTime(testutil.MethodBeginTransaction, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodBeginTransaction, + SimulatedExecutionTime{ Errors: []error{wantErr, wantErr, wantErr}, }) - server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, + SimulatedExecutionTime{ Errors: []error{wantErr, wantErr, wantErr}, }) @@ -243,8 +243,8 @@ func TestTransaction_NotFound(t *testing.T) { func TestReadWriteTransaction_ErrorReturned(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() want := errors.New("an error") _, got := client.ReadWriteTransaction(ctx, func(context.Context, *ReadWriteTransaction) error { @@ -253,7 +253,7 @@ func TestReadWriteTransaction_ErrorReturned(t *testing.T) { if got != want { t.Fatalf("got %+v, want %+v", got, want) } - requests := drainRequestsFromServer(server.testSpanner) + requests := drainRequestsFromServer(server.TestSpanner) if err := compareRequests([]interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, @@ -277,27 +277,27 @@ func TestReadWriteTransaction_ErrorReturned(t *testing.T) { func TestBatchDML_WithMultipleDML(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { - if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { + if _, err = tx.Update(ctx, Statement{SQL: UpdateBarSetFoo}); err != nil { return err } - if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}, {SQL: updateBarSetFoo}}); err != nil { + if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}, {SQL: UpdateBarSetFoo}}); err != nil { return err } - if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { + if _, err = tx.Update(ctx, Statement{SQL: UpdateBarSetFoo}); err != nil { return err } - _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}}) + _, err = tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}}) return err }) if err != nil { t.Fatal(err) } - gotReqs, err := shouldHaveReceived(server.testSpanner, []interface{}{ + gotReqs, err := shouldHaveReceived(server.TestSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, &sppb.ExecuteSqlRequest{}, @@ -329,7 +329,7 @@ func TestBatchDML_WithMultipleDML(t *testing.T) { // // Note: this in-place modifies serverClientMock by popping items off the // ReceivedRequests channel. -func shouldHaveReceived(server testutil.InMemSpannerServer, want []interface{}) ([]interface{}, error) { +func shouldHaveReceived(server InMemSpannerServer, want []interface{}) ([]interface{}, error) { got := drainRequestsFromServer(server) return got, compareRequests(want, got) } @@ -358,7 +358,7 @@ func compareRequests(want []interface{}, got []interface{}) error { return nil } -func drainRequestsFromServer(server testutil.InMemSpannerServer) []interface{} { +func drainRequestsFromServer(server InMemSpannerServer) []interface{} { var reqs []interface{} loop: for {