From 511b7f3f2624e2d3c4bf2697da4ab69ed5359aac Mon Sep 17 00:00:00 2001 From: Olav Loite Date: Wed, 21 Aug 2019 11:16:34 +0200 Subject: [PATCH] Spanner: move mocked test server to testutil The readily mocked inmem Spanner server was included in the normal build, instead of only being included in test builds. The test server has been moved to the internal testutil package, the package has been renamed as testutil_test, and the setup for a test client has been added to the client_test.go file to prevent a circular import. Fixes #1539 Change-Id: I56dc57345ef65dd2bebd57742961848e7abe1818 Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/44176 Reviewed-by: kokoro Reviewed-by: Jean de Klerk --- spanner/client_test.go | 171 ++++++++++-------- .../internal/testutil/inmem_spanner_server.go | 77 +++++++- .../testutil/inmem_spanner_server_test.go | 4 +- spanner/internal/testutil/mockclient.go | 2 +- .../testutil}/mocked_inmem_server.go | 117 ++++++------ spanner/internal/testutil/mockserver.go | 2 +- spanner/pdml_test.go | 15 +- spanner/read_test.go | 70 +++---- spanner/session_test.go | 137 +++++++------- spanner/transaction_test.go | 80 ++++---- 10 files changed, 374 insertions(+), 301 deletions(-) rename spanner/{ => internal/testutil}/mocked_inmem_server.go (50%) diff --git a/spanner/client_test.go b/spanner/client_test.go index 912e159f0204..ec69e89a98fc 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -18,20 +18,45 @@ package spanner import ( "context" + "fmt" "io" "os" "strings" "testing" "cloud.google.com/go/spanner/internal/benchserver" - "cloud.google.com/go/spanner/internal/testutil" + . "cloud.google.com/go/spanner/internal/testutil" "google.golang.org/api/iterator" + "google.golang.org/api/option" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" gstatus "google.golang.org/grpc/status" ) +func setupMockedTestServer(t *testing.T) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) { + return setupMockedTestServerWithConfig(t, ClientConfig{}) +} + +func setupMockedTestServerWithConfig(t *testing.T, config ClientConfig) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) { + return setupMockedTestServerWithConfigAndClientOptions(t, config, []option.ClientOption{}) +} + +func setupMockedTestServerWithConfigAndClientOptions(t *testing.T, config ClientConfig, clientOptions []option.ClientOption) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) { + server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) + opts = append(opts, clientOptions...) + ctx := context.Background() + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + client, err := NewClientWithConfig(ctx, formattedDatabase, config, opts...) + if err != nil { + t.Fatal(err) + } + return server, client, func() { + client.Close() + serverTeardown() + } +} + // Test validDatabaseName() func TestValidDatabaseName(t *testing.T) { validDbURI := "projects/spanner-cloud-test/instances/foo/databases/foodb" @@ -87,14 +112,13 @@ func TestClient_Single_InvalidArgument(t *testing.T) { } func testSingleQuery(t *testing.T, serverError error) error { - config := ClientConfig{} - server, client := newSpannerInMemTestServerWithConfig(t, config) - defer server.teardown(client) + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() if serverError != nil { - server.testSpanner.SetError(serverError) + server.TestSpanner.SetError(serverError) } - ctx := context.Background() - iter := client.Single().Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() for { row, err := iter.Next() @@ -113,12 +137,12 @@ func testSingleQuery(t *testing.T, serverError error) error { return nil } -func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[string]testutil.SimulatedExecutionTime { +func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[string]SimulatedExecutionTime { errors := make([]error, 2) errors[0] = gstatus.Error(codes.Unavailable, "Temporary unavailable") errors[1] = gstatus.Error(codes.Unavailable, "Temporary unavailable") - executionTimes := make(map[string]testutil.SimulatedExecutionTime) - executionTimes[method] = testutil.SimulatedExecutionTime{ + executionTimes := make(map[string]SimulatedExecutionTime) + executionTimes[method] = SimulatedExecutionTime{ Errors: errors, } return executionTimes @@ -126,37 +150,37 @@ func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[str func TestClient_ReadOnlyTransaction(t *testing.T) { t.Parallel() - if err := testReadOnlyTransaction(t, make(map[string]testutil.SimulatedExecutionTime)); err != nil { + if err := testReadOnlyTransaction(t, make(map[string]SimulatedExecutionTime)); err != nil { t.Fatal(err) } } func TestClient_ReadOnlyTransaction_UnavailableOnSessionCreate(t *testing.T) { t.Parallel() - if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodCreateSession)); err != nil { + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodCreateSession)); err != nil { t.Fatal(err) } } func TestClient_ReadOnlyTransaction_UnavailableOnBeginTransaction(t *testing.T) { t.Parallel() - if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodBeginTransaction)); err != nil { + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodBeginTransaction)); err != nil { t.Fatal(err) } } func TestClient_ReadOnlyTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) { t.Parallel() - if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodExecuteStreamingSql)); err != nil { + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodExecuteStreamingSql)); err != nil { t.Fatal(err) } } func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndBeginTransaction(t *testing.T) { t.Parallel() - exec := map[string]testutil.SimulatedExecutionTime{ - testutil.MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, - testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + exec := map[string]SimulatedExecutionTime{ + MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, } if err := testReadOnlyTransaction(t, exec); err != nil { t.Fatal(err) @@ -165,9 +189,9 @@ func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndBeginTransactio func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndInvalidArgumentOnBeginTransaction(t *testing.T) { t.Parallel() - exec := map[string]testutil.SimulatedExecutionTime{ - testutil.MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, - testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.InvalidArgument, "Invalid argument")}}, + exec := map[string]SimulatedExecutionTime{ + MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.InvalidArgument, "Invalid argument")}}, } if err := testReadOnlyTransaction(t, exec); err == nil { t.Fatalf("Missing expected exception") @@ -176,16 +200,16 @@ func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndInvalidArgument } } -func testReadOnlyTransaction(t *testing.T, executionTimes map[string]testutil.SimulatedExecutionTime) error { - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) +func testReadOnlyTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime) error { + server, client, teardown := setupMockedTestServer(t) + defer teardown() for method, exec := range executionTimes { - server.testSpanner.PutExecutionTime(method, exec) + server.TestSpanner.PutExecutionTime(method, exec) } - ctx := context.Background() tx := client.ReadOnlyTransaction() defer tx.Close() - iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + ctx := context.Background() + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() for { row, err := iter.Next() @@ -206,15 +230,15 @@ func testReadOnlyTransaction(t *testing.T, executionTimes map[string]testutil.Si func TestClient_ReadWriteTransaction(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, make(map[string]testutil.SimulatedExecutionTime), 1); err != nil { + if err := testReadWriteTransaction(t, make(map[string]SimulatedExecutionTime), 1); err != nil { t.Fatal(err) } } func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, }, 2); err != nil { t.Fatal(err) } @@ -222,8 +246,8 @@ func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) { func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, }, 2); err != nil { t.Fatal(err) } @@ -231,17 +255,17 @@ func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) { func TestClient_ReadWriteTransaction_UnavailableOnBeginTransaction(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, }, 1); err != nil { t.Fatal(err) } } func TestClient_ReadWriteTransaction_UnavailableOnBeginAndAbortOnCommit(t *testing.T) { - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, - testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, }, 2); err != nil { t.Fatal(err) } @@ -249,8 +273,8 @@ func TestClient_ReadWriteTransaction_UnavailableOnBeginAndAbortOnCommit(t *testi func TestClient_ReadWriteTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, }, 1); err != nil { t.Fatal(err) } @@ -258,10 +282,10 @@ func TestClient_ReadWriteTransaction_UnavailableOnExecuteStreamingSql(t *testing func TestClient_ReadWriteTransaction_UnavailableOnBeginAndExecuteStreamingSqlAndTwiceAbortOnCommit(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, - testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, - testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, }, 3); err != nil { t.Fatal(err) } @@ -269,9 +293,9 @@ func TestClient_ReadWriteTransaction_UnavailableOnBeginAndExecuteStreamingSqlAnd func TestClient_ReadWriteTransaction_AbortedOnExecuteStreamingSqlAndCommit(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, - testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, + MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, }, 4); err != nil { t.Fatal(err) } @@ -279,8 +303,8 @@ func TestClient_ReadWriteTransaction_AbortedOnExecuteStreamingSqlAndCommit(t *te func TestClient_ReadWriteTransactionCommitAbortedAndUnavailable(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodCommitTransaction: { + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodCommitTransaction: { Errors: []error{ gstatus.Error(codes.Aborted, "Transaction aborted"), gstatus.Error(codes.Unavailable, "Unavailable"), @@ -293,8 +317,8 @@ func TestClient_ReadWriteTransactionCommitAbortedAndUnavailable(t *testing.T) { func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) { t.Parallel() - if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ - testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.AlreadyExists, "A row with this key already exists")}}, + if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ + MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.AlreadyExists, "A row with this key already exists")}}, }, 1); err != nil { if gstatus.Code(err) != codes.AlreadyExists { t.Fatalf("Got unexpected error %v, expected %v", err, codes.AlreadyExists) @@ -304,17 +328,17 @@ func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) { } } -func testReadWriteTransaction(t *testing.T, executionTimes map[string]testutil.SimulatedExecutionTime, expectedAttempts int) error { - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) +func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error { + server, client, teardown := setupMockedTestServer(t) + defer teardown() for method, exec := range executionTimes { - server.testSpanner.PutExecutionTime(method, exec) + server.TestSpanner.PutExecutionTime(method, exec) } - var attempts int ctx := context.Background() + var attempts int _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { attempts++ - iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() for { row, err := iter.Next() @@ -343,14 +367,14 @@ func testReadWriteTransaction(t *testing.T, executionTimes map[string]testutil.S func TestClient_ApplyAtLeastOnce(t *testing.T) { t.Parallel() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() ms := []*Mutation{ Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}), } - server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, + SimulatedExecutionTime{ Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}, }) _, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce()) @@ -362,14 +386,14 @@ func TestClient_ApplyAtLeastOnce(t *testing.T) { // PartitionedUpdate should not retry on aborted. func TestClient_PartitionedUpdate(t *testing.T) { t.Parallel() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() // PartitionedDML transactions are not committed. - server.testSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, + SimulatedExecutionTime{ Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}, }) - _, err := client.PartitionedUpdate(context.Background(), NewStatement(updateBarSetFoo)) + _, err := client.PartitionedUpdate(context.Background(), NewStatement(UpdateBarSetFoo)) if err == nil { t.Fatalf("Missing expected Aborted exception") } else { @@ -380,13 +404,13 @@ func TestClient_PartitionedUpdate(t *testing.T) { } func TestReadWriteTransaction_ErrUnexpectedEOF(t *testing.T) { - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) - var attempts int + _, client, teardown := setupMockedTestServer(t) + defer teardown() ctx := context.Background() + var attempts int _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { attempts++ - iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() for { row, err := iter.Next() @@ -440,7 +464,7 @@ func TestNewClient_ConnectToEmulator(t *testing.T) { func TestClient_ApiClientHeader(t *testing.T) { t.Parallel() - server, client := newSpannerInMemTestServerWithInterceptor(t, func( + interceptor := func( ctx context.Context, method string, req interface{}, @@ -468,11 +492,12 @@ func TestClient_ApiClientHeader(t *testing.T) { return spannerErrorf(codes.Internal, "unexpected api client token: %v", token[0]) } return invoker(ctx, method, req, reply, cc, opts...) - }, - ) - defer server.teardown(client) + } + opts := []option.ClientOption{option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptor))} + _, client, teardown := setupMockedTestServerWithConfigAndClientOptions(t, ClientConfig{}, opts) + defer teardown() ctx := context.Background() - iter := client.Single().Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() for { _, err := iter.Next() diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 837b862cc742..c23e480f5315 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testutil +package testutil_test import ( emptypb "github.com/golang/protobuf/ptypes/empty" @@ -179,7 +179,9 @@ type inMemSpannerServer struct { spannerpb.SpannerServer mu sync.Mutex - + // Set to true when this server been stopped. This is the end state of a + // server, a stopped server cannot be restarted. + stopped bool // If set, all calls return this error. err error // The mock server creates session IDs using this counter. @@ -188,7 +190,6 @@ type inMemSpannerServer struct { sessions map[string]*spannerpb.Session // Last use times per session. sessionLastUseTime map[string]time.Time - // The mock server creates transaction IDs per session using these // counters. transactionCounters map[string]*uint64 @@ -198,19 +199,18 @@ type inMemSpannerServer struct { abortedTransactions map[string]bool // The transactions that are marked as PartitionedDMLTransaction partitionedDmlTransactions map[string]bool - // The mocked results for this server. statementResults map[string]*StatementResult // The simulated execution times per method. - executionTimes map[string]*SimulatedExecutionTime - // Server will stall on any requests. - freezed chan struct{} - + executionTimes map[string]*SimulatedExecutionTime totalSessionsCreated uint totalSessionsDeleted uint receivedRequests chan interface{} // Session ping history. pings []string + + // Server will stall on any requests. + freezed chan struct{} } // NewInMemSpannerServer creates a new in-mem test server. @@ -227,6 +227,9 @@ func NewInMemSpannerServer() InMemSpannerServer { } func (s *inMemSpannerServer) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + s.stopped = true close(s.receivedRequests) } @@ -234,12 +237,16 @@ func (s *inMemSpannerServer) Stop() { // transactions that have been created on the server. This method will not // remove mocked results. func (s *inMemSpannerServer) Reset() { + s.mu.Lock() + defer s.mu.Unlock() close(s.receivedRequests) s.receivedRequests = make(chan interface{}, 1000000) s.initDefaults() } func (s *inMemSpannerServer) SetError(err error) { + s.mu.Lock() + defer s.mu.Unlock() s.err = err } @@ -442,7 +449,13 @@ func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, e } func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() s.ready() s.mu.Lock() if s.err != nil { @@ -506,7 +519,13 @@ func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetS } func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() if req.Database == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing database") } @@ -544,7 +563,13 @@ func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.D } func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } @@ -624,7 +649,13 @@ func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlReques } func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } @@ -664,12 +695,24 @@ func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb } func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") } func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() return gstatus.Error(codes.Unimplemented, "Method not yet implemented") } @@ -717,7 +760,13 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe } func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() if req.Session == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") } @@ -735,11 +784,23 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba } func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") } func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) { + s.mu.Lock() + if s.stopped { + s.mu.Unlock() + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") + } s.receivedRequests <- req + s.mu.Unlock() return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") } diff --git a/spanner/internal/testutil/inmem_spanner_server_test.go b/spanner/internal/testutil/inmem_spanner_server_test.go index d563ff4b6175..3074c9330c3d 100644 --- a/spanner/internal/testutil/inmem_spanner_server_test.go +++ b/spanner/internal/testutil/inmem_spanner_server_test.go @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testutil +package testutil_test_test import ( "strconv" + . "cloud.google.com/go/spanner/internal/testutil" + structpb "github.com/golang/protobuf/ptypes/struct" spannerpb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" diff --git a/spanner/internal/testutil/mockclient.go b/spanner/internal/testutil/mockclient.go index 2a3b918353f9..f24816f9fb27 100644 --- a/spanner/internal/testutil/mockclient.go +++ b/spanner/internal/testutil/mockclient.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package testutil +package testutil_test import ( "context" diff --git a/spanner/mocked_inmem_server.go b/spanner/internal/testutil/mocked_inmem_server.go similarity index 50% rename from spanner/mocked_inmem_server.go rename to spanner/internal/testutil/mocked_inmem_server.go index ab96f71b9fe0..403df729928c 100644 --- a/spanner/mocked_inmem_server.go +++ b/spanner/internal/testutil/mocked_inmem_server.go @@ -12,74 +12,73 @@ // See the License for the specific language governing permissions and // limitations under the License. -package spanner +package testutil_test import ( - "context" "fmt" "net" "strconv" "testing" - "cloud.google.com/go/spanner/internal/testutil" structpb "github.com/golang/protobuf/ptypes/struct" "google.golang.org/api/option" spannerpb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc" ) -// The SQL statements and results that are already mocked for this test server. -const selectFooFromBar = "SELECT FOO FROM BAR" +// SelectFooFromBar is a SELECT statement that is added to the mocked test +// server and will return a one-col-two-rows result set containing the INT64 +// values 1 and 2. +const SelectFooFromBar = "SELECT FOO FROM BAR" const selectFooFromBarRowCount int64 = 2 const selectFooFromBarColCount int = 1 var selectFooFromBarResults = [...]int64{1, 2} -const selectSingerIDAlbumIDAlbumTitleFromAlbums = "SELECT SingerId, AlbumId, AlbumTitle FROM Albums" -const selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount int64 = 3 -const selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount int = 3 +// SelectSingerIDAlbumIDAlbumTitleFromAlbums i a SELECT statement that is added +// to the mocked test server and will return a 3-cols-3-rows result set. +const SelectSingerIDAlbumIDAlbumTitleFromAlbums = "SELECT SingerId, AlbumId, AlbumTitle FROM Albums" -const updateBarSetFoo = "UPDATE FOO SET BAR=1 WHERE BAZ=2" -const updateBarSetFooRowCount = 5 +// SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount is the number of rows +// returned by the SelectSingerIDAlbumIDAlbumTitleFromAlbums statement. +const SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount int64 = 3 -// An InMemSpannerServer with results for a number of SQL statements readily -// mocked. -type spannerInMemTestServer struct { - testSpanner testutil.InMemSpannerServer - server *grpc.Server -} +// SelectSingerIDAlbumIDAlbumTitleFromAlbumsColCount is the number of cols +// returned by the SelectSingerIDAlbumIDAlbumTitleFromAlbums statement. +const SelectSingerIDAlbumIDAlbumTitleFromAlbumsColCount int = 3 -// Create a spannerInMemTestServer with default configuration. -func newSpannerInMemTestServer(t *testing.T) (*spannerInMemTestServer, *Client) { - s := &spannerInMemTestServer{} - client := s.setup(t) - return s, client -} +// UpdateBarSetFoo is an UPDATE statement that is added to the mocked test +// server that will return an update count of 5. +const UpdateBarSetFoo = "UPDATE FOO SET BAR=1 WHERE BAZ=2" -// Create a spannerInMemTestServer with default configuration and a client interceptor. -func newSpannerInMemTestServerWithInterceptor(t *testing.T, interceptor grpc.UnaryClientInterceptor) (*spannerInMemTestServer, *Client) { - s := &spannerInMemTestServer{} - client := s.setupWithConfig(t, ClientConfig{}, interceptor) - return s, client -} +// UpdateBarSetFooRowCount is the constant update count value returned by the +// statement defined in UpdateBarSetFoo. +const UpdateBarSetFooRowCount = 5 -// Create a spannerInMemTestServer with the specified configuration. -func newSpannerInMemTestServerWithConfig(t *testing.T, config ClientConfig) (*spannerInMemTestServer, *Client) { - s := &spannerInMemTestServer{} - client := s.setupWithConfig(t, config, nil) - return s, client +// MockedSpannerInMemTestServer is an InMemSpannerServer with results for a +// number of SQL statements readily mocked. +type MockedSpannerInMemTestServer struct { + TestSpanner InMemSpannerServer + server *grpc.Server } -func (s *spannerInMemTestServer) setup(t *testing.T) *Client { - return s.setupWithConfig(t, ClientConfig{}, nil) +// NewMockedSpannerInMemTestServer creates a MockedSpannerInMemTestServer and +// returns client options that can be used to connect to it. +func NewMockedSpannerInMemTestServer(t *testing.T) (mockedServer *MockedSpannerInMemTestServer, opts []option.ClientOption, teardown func()) { + mockedServer = &MockedSpannerInMemTestServer{} + opts = mockedServer.setupMockedServer(t) + return mockedServer, opts, func() { + mockedServer.TestSpanner.Stop() + mockedServer.server.Stop() + } } -func (s *spannerInMemTestServer) setupWithConfig(t *testing.T, config ClientConfig, interceptor grpc.UnaryClientInterceptor) *Client { - s.testSpanner = testutil.NewInMemSpannerServer() +func (s *MockedSpannerInMemTestServer) setupMockedServer(t *testing.T) []option.ClientOption { + s.TestSpanner = NewInMemSpannerServer() s.setupFooResults() s.setupSingersResults() s.server = grpc.NewServer() - spannerpb.RegisterSpannerServer(s.server, s.testSpanner) + spannerpb.RegisterSpannerServer(s.server, s.TestSpanner) lis, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -88,24 +87,15 @@ func (s *spannerInMemTestServer) setupWithConfig(t *testing.T, config ClientConf go s.server.Serve(lis) serverAddress := lis.Addr().String() - ctx := context.Background() - var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") opts := []option.ClientOption{ option.WithEndpoint(serverAddress), option.WithGRPCDialOption(grpc.WithInsecure()), option.WithoutAuthentication(), } - if interceptor != nil { - opts = append(opts, option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptor))) - } - client, err := NewClientWithConfig(ctx, formattedDatabase, config, opts...) - if err != nil { - t.Fatal(err) - } - return client + return opts } -func (s *spannerInMemTestServer) setupFooResults() { +func (s *MockedSpannerInMemTestServer) setupFooResults() { fields := make([]*spannerpb.StructType_Field, selectFooFromBarColCount) fields[0] = &spannerpb.StructType_Field{ Name: "FOO", @@ -131,16 +121,16 @@ func (s *spannerInMemTestServer) setupFooResults() { Metadata: metadata, Rows: rows, } - result := &testutil.StatementResult{Type: testutil.StatementResultResultSet, ResultSet: resultSet} - s.testSpanner.PutStatementResult(selectFooFromBar, result) - s.testSpanner.PutStatementResult(updateBarSetFoo, &testutil.StatementResult{ - Type: testutil.StatementResultUpdateCount, - UpdateCount: updateBarSetFooRowCount, + result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet} + s.TestSpanner.PutStatementResult(SelectFooFromBar, result) + s.TestSpanner.PutStatementResult(UpdateBarSetFoo, &StatementResult{ + Type: StatementResultUpdateCount, + UpdateCount: UpdateBarSetFooRowCount, }) } -func (s *spannerInMemTestServer) setupSingersResults() { - fields := make([]*spannerpb.StructType_Field, selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) +func (s *MockedSpannerInMemTestServer) setupSingersResults() { + fields := make([]*spannerpb.StructType_Field, SelectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) fields[0] = &spannerpb.StructType_Field{ Name: "SingerId", Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64}, @@ -159,10 +149,10 @@ func (s *spannerInMemTestServer) setupSingersResults() { metadata := &spannerpb.ResultSetMetadata{ RowType: rowType, } - rows := make([]*structpb.ListValue, selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) + rows := make([]*structpb.ListValue, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) var idx int64 - for idx = 0; idx < selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount; idx++ { - rowValue := make([]*structpb.Value, selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) + for idx = 0; idx < SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount; idx++ { + rowValue := make([]*structpb.Value, SelectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) rowValue[0] = &structpb.Value{ Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(idx+1, 10)}, } @@ -180,11 +170,6 @@ func (s *spannerInMemTestServer) setupSingersResults() { Metadata: metadata, Rows: rows, } - result := &testutil.StatementResult{Type: testutil.StatementResultResultSet, ResultSet: resultSet} - s.testSpanner.PutStatementResult(selectSingerIDAlbumIDAlbumTitleFromAlbums, result) -} - -func (s *spannerInMemTestServer) teardown(client *Client) { - client.Close() - s.server.Stop() + result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet} + s.TestSpanner.PutStatementResult(SelectSingerIDAlbumIDAlbumTitleFromAlbums, result) } diff --git a/spanner/internal/testutil/mockserver.go b/spanner/internal/testutil/mockserver.go index b5a6f5e505fd..333d0e989bb7 100644 --- a/spanner/internal/testutil/mockserver.go +++ b/spanner/internal/testutil/mockserver.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package testutil +package testutil_test import ( "context" diff --git a/spanner/pdml_test.go b/spanner/pdml_test.go index 743f5113cc10..9f479bb6189e 100644 --- a/spanner/pdml_test.go +++ b/spanner/pdml_test.go @@ -18,21 +18,22 @@ import ( "context" "testing" + . "cloud.google.com/go/spanner/internal/testutil" "google.golang.org/grpc/codes" ) func TestMockPartitionedUpdate(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + _, client, teardown := setupMockedTestServer(t) + defer teardown() - stmt := NewStatement(updateBarSetFoo) + stmt := NewStatement(UpdateBarSetFoo) rowCount, err := client.PartitionedUpdate(ctx, stmt) if err != nil { t.Fatal(err) } - want := int64(updateBarSetFooRowCount) + want := int64(UpdateBarSetFooRowCount) if rowCount != want { t.Errorf("got %d, want %d", rowCount, want) } @@ -41,10 +42,10 @@ func TestMockPartitionedUpdate(t *testing.T) { func TestMockPartitionedUpdateWithQuery(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + _, client, teardown := setupMockedTestServer(t) + defer teardown() - stmt := NewStatement(selectFooFromBar) + stmt := NewStatement(SelectFooFromBar) _, err := client.PartitionedUpdate(ctx, stmt) wantCode := codes.InvalidArgument if serr, ok := err.(*Error); !ok || serr.Code != wantCode { diff --git a/spanner/read_test.go b/spanner/read_test.go index 5a660c252525..b18b0938d7d1 100644 --- a/spanner/read_test.go +++ b/spanner/read_test.go @@ -26,7 +26,7 @@ import ( "time" "cloud.google.com/go/spanner/internal/backoff" - "cloud.google.com/go/spanner/internal/testutil" + . "cloud.google.com/go/spanner/internal/testutil" "github.com/golang/protobuf/proto" proto3 "github.com/golang/protobuf/ptypes/struct" "google.golang.org/api/iterator" @@ -42,7 +42,7 @@ var ( // Metadata for mocked KV table, its rows are returned by SingleUse // transactions. kvMeta = func() *sppb.ResultSetMetadata { - meta := testutil.KvMeta + meta := KvMeta meta.Transaction = &sppb.Transaction{ ReadTimestamp: timestampProto(trxTs), } @@ -642,7 +642,7 @@ func TestRsdNonblockingStates(t *testing.T) { defer restore() tests := []struct { name string - msgs []testutil.MockCtlMsg + msgs []MockCtlMsg rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error) sql string // Expected values @@ -655,7 +655,7 @@ func TestRsdNonblockingStates(t *testing.T) { { // unConnected->queueingRetryable->finished name: "unConnected->queueingRetryable->finished", - msgs: []testutil.MockCtlMsg{ + msgs: []MockCtlMsg{ {}, {}, {Err: io.EOF, ResumeToken: false}, @@ -689,7 +689,7 @@ func TestRsdNonblockingStates(t *testing.T) { { // unConnected->queueingRetryable->aborted name: "unConnected->queueingRetryable->aborted", - msgs: []testutil.MockCtlMsg{ + msgs: []MockCtlMsg{ {}, {Err: nil, ResumeToken: true}, {}, @@ -710,7 +710,7 @@ func TestRsdNonblockingStates(t *testing.T) { {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, }, - ResumeToken: testutil.EncodeResumeToken(1), + ResumeToken: EncodeResumeToken(1), }, }, stateHistory: []resumableStreamDecoderState{ @@ -726,9 +726,9 @@ func TestRsdNonblockingStates(t *testing.T) { { // unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable name: "unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable", - msgs: func() (m []testutil.MockCtlMsg) { + msgs: func() (m []MockCtlMsg) { for i := 0; i < maxBuffers+1; i++ { - m = append(m, testutil.MockCtlMsg{}) + m = append(m, MockCtlMsg{}) } return m }(), @@ -760,11 +760,11 @@ func TestRsdNonblockingStates(t *testing.T) { { // unConnected->queueingRetryable->queueingUnretryable->aborted name: "unConnected->queueingRetryable->queueingUnretryable->aborted", - msgs: func() (m []testutil.MockCtlMsg) { + msgs: func() (m []MockCtlMsg) { for i := 0; i < maxBuffers; i++ { - m = append(m, testutil.MockCtlMsg{}) + m = append(m, MockCtlMsg{}) } - m = append(m, testutil.MockCtlMsg{Err: errors.New("Just Abort It"), ResumeToken: false}) + m = append(m, MockCtlMsg{Err: errors.New("Just Abort It"), ResumeToken: false}) return m }(), sql: "SELECT t.key key, t.value value FROM t_mock t", @@ -794,7 +794,7 @@ func TestRsdNonblockingStates(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() mc := sppb.NewSpannerClient(dialMock(t, ms)) if test.rpc == nil { @@ -890,7 +890,7 @@ func TestRsdBlockingStates(t *testing.T) { defer restore() for _, test := range []struct { name string - msgs []testutil.MockCtlMsg + msgs []MockCtlMsg rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error) sql string // Expected values @@ -919,7 +919,7 @@ func TestRsdBlockingStates(t *testing.T) { { // unConnected->queueingRetryable->queueingRetryable name: "unConnected->queueingRetryable->queueingRetryable", - msgs: []testutil.MockCtlMsg{ + msgs: []MockCtlMsg{ {}, {Err: nil, ResumeToken: true}, {Err: nil, ResumeToken: true}, @@ -940,7 +940,7 @@ func TestRsdBlockingStates(t *testing.T) { {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}}, {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}}, }, - ResumeToken: testutil.EncodeResumeToken(1), + ResumeToken: EncodeResumeToken(1), }, { Metadata: kvMeta, @@ -948,7 +948,7 @@ func TestRsdBlockingStates(t *testing.T) { {Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}}, {Kind: &proto3.Value_StringValue{StringValue: valStr(2)}}, }, - ResumeToken: testutil.EncodeResumeToken(2), + ResumeToken: EncodeResumeToken(2), }, }, queue: []*sppb.PartialResultSet{ @@ -960,7 +960,7 @@ func TestRsdBlockingStates(t *testing.T) { }, }, }, - resumeToken: testutil.EncodeResumeToken(2), + resumeToken: EncodeResumeToken(2), stateHistory: []resumableStreamDecoderState{ queueingRetryable, // do RPC queueingRetryable, // got foo-00 @@ -974,12 +974,12 @@ func TestRsdBlockingStates(t *testing.T) { { // unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable name: "unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable", - msgs: func() (m []testutil.MockCtlMsg) { + msgs: func() (m []MockCtlMsg) { for i := 0; i < maxBuffers+1; i++ { - m = append(m, testutil.MockCtlMsg{}) + m = append(m, MockCtlMsg{}) } - m = append(m, testutil.MockCtlMsg{Err: nil, ResumeToken: true}) - m = append(m, testutil.MockCtlMsg{}) + m = append(m, MockCtlMsg{Err: nil, ResumeToken: true}) + m = append(m, MockCtlMsg{}) return m }(), sql: "SELECT t.key key, t.value value FROM t_mock t", @@ -993,10 +993,10 @@ func TestRsdBlockingStates(t *testing.T) { }, }) } - s[maxBuffers+1].ResumeToken = testutil.EncodeResumeToken(maxBuffers + 1) + s[maxBuffers+1].ResumeToken = EncodeResumeToken(maxBuffers + 1) return s }(), - resumeToken: testutil.EncodeResumeToken(maxBuffers + 1), + resumeToken: EncodeResumeToken(maxBuffers + 1), queue: []*sppb.PartialResultSet{ { Metadata: kvMeta, @@ -1026,11 +1026,11 @@ func TestRsdBlockingStates(t *testing.T) { { // unConnected->queueingRetryable->queueingUnretryable->finished name: "unConnected->queueingRetryable->queueingUnretryable->finished", - msgs: func() (m []testutil.MockCtlMsg) { + msgs: func() (m []MockCtlMsg) { for i := 0; i < maxBuffers; i++ { - m = append(m, testutil.MockCtlMsg{}) + m = append(m, MockCtlMsg{}) } - m = append(m, testutil.MockCtlMsg{Err: io.EOF, ResumeToken: false}) + m = append(m, MockCtlMsg{Err: io.EOF, ResumeToken: false}) return m }(), sql: "SELECT t.key key, t.value value FROM t_mock t", @@ -1058,7 +1058,7 @@ func TestRsdBlockingStates(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() cc := dialMock(t, ms) mc := sppb.NewSpannerClient(cc) @@ -1194,7 +1194,7 @@ func (sr *sReceiver) waitn(n int) error { func TestQueueBytes(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1279,7 +1279,7 @@ func TestQueueBytes(t *testing.T) { func TestResumeToken(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1476,7 +1476,7 @@ func TestResumeToken(t *testing.T) { func TestGrpcReconnect(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1543,7 +1543,7 @@ func TestGrpcReconnect(t *testing.T) { func TestCancelTimeout(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1630,7 +1630,7 @@ func TestCancelTimeout(t *testing.T) { func TestRowIteratorDo(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1663,7 +1663,7 @@ func TestRowIteratorDo(t *testing.T) { func TestRowIteratorDoWithError(t *testing.T) { restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1694,7 +1694,7 @@ func TestIteratorStopEarly(t *testing.T) { ctx := context.Background() restore := setMaxBytesBetweenResumeTokens() defer restore() - ms := testutil.NewMockCloudSpanner(t, trxTs) + ms := NewMockCloudSpanner(t, trxTs) ms.Serve() defer ms.Stop() cc := dialMock(t, ms) @@ -1738,7 +1738,7 @@ func TestIteratorWithError(t *testing.T) { } } -func dialMock(t *testing.T, ms *testutil.MockCloudSpanner) *grpc.ClientConn { +func dialMock(t *testing.T, ms *MockCloudSpanner) *grpc.ClientConn { cc, err := grpc.Dial(ms.Addr(), grpc.WithInsecure(), grpc.WithBlock()) if err != nil { t.Fatalf("Dial(%q) = %v", ms.Addr(), err) diff --git a/spanner/session_test.go b/spanner/session_test.go index f3f779bc8b11..7e0392556bb9 100644 --- a/spanner/session_test.go +++ b/spanner/session_test.go @@ -26,7 +26,7 @@ import ( "time" vkit "cloud.google.com/go/spanner/apiv1" - "cloud.google.com/go/spanner/internal/testutil" + . "cloud.google.com/go/spanner/internal/testutil" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -34,8 +34,8 @@ import ( // TestSessionPoolConfigValidation tests session pool config validation. func TestSessionPoolConfigValidation(t *testing.T) { t.Parallel() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + _, client, teardown := setupMockedTestServer(t) + defer teardown() for _, test := range []struct { spc SessionPoolConfig @@ -66,8 +66,8 @@ func TestSessionPoolConfigValidation(t *testing.T) { func TestSessionCreation(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() sp := client.idleSessions // Take three sessions from session pool, this should trigger session pool @@ -86,7 +86,7 @@ func TestSessionCreation(t *testing.T) { if len(gotDs) != len(shs) { t.Fatalf("session pool created %v sessions, want %v", len(gotDs), len(shs)) } - if wantDs := server.testSpanner.DumpSessions(); !testEqual(gotDs, wantDs) { + if wantDs := server.TestSpanner.DumpSessions(); !testEqual(gotDs, wantDs) { t.Fatalf("session pool creates sessions %v, want %v", gotDs, wantDs) } // Verify that created sessions are recorded correctly in session pool. @@ -118,11 +118,11 @@ func TestTakeFromIdleList(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{MaxIdle: 10}, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Take ten sessions from session pool and recycle them. @@ -142,7 +142,7 @@ func TestTakeFromIdleList(t *testing.T) { } // Further session requests from session pool won't cause mockclient to // create more sessions. - wantSessions := server.testSpanner.DumpSessions() + wantSessions := server.TestSpanner.DumpSessions() // Take ten sessions from session pool again, this time all sessions should // come from idle list. gotSessions := map[string]bool{} @@ -168,11 +168,11 @@ func TestTakeWriteSessionFromIdleList(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{MaxIdle: 20}, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Take ten sessions from session pool and recycle them. @@ -192,7 +192,7 @@ func TestTakeWriteSessionFromIdleList(t *testing.T) { } // Further session requests from session pool won't cause mockclient to // create more sessions. - wantSessions := server.testSpanner.DumpSessions() + wantSessions := server.TestSpanner.DumpSessions() // Take ten sessions from session pool again, this time all sessions should // come from idle list. gotSessions := map[string]bool{} @@ -218,7 +218,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MaxIdle: 1, @@ -226,7 +226,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { healthCheckSampleInterval: 10 * time.Millisecond, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Stop healthcheck workers to simulate slow pings. @@ -262,7 +262,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { // The two back-to-back session requests shouldn't trigger any session // pings because sessionPool.Take // reschedules the next healthcheck. - if got, want := server.testSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { + if got, want := server.TestSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { t.Fatalf("%v - got ping session requests: %v, want %v", i, got, want) } sh.recycle() @@ -271,8 +271,8 @@ func TestTakeFromIdleListChecked(t *testing.T) { // Inject session error to server stub, and take the session from the // session pool, the old session should be destroyed and the session pool // will create a new session. - server.testSpanner.PutExecutionTime(testutil.MethodGetSession, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodGetSession, + SimulatedExecutionTime{ Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, }) @@ -287,7 +287,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { if err != nil { t.Fatalf("failed to get session: %v", err) } - ds := server.testSpanner.DumpSessions() + ds := server.TestSpanner.DumpSessions() if len(ds) != 1 { t.Fatalf("dumped sessions from mockclient: %v, want %v", ds, sh.getID()) } @@ -303,7 +303,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MaxIdle: 1, @@ -311,7 +311,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { healthCheckSampleInterval: 10 * time.Millisecond, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Stop healthcheck workers to simulate slow pings. @@ -345,7 +345,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { } // The two back-to-back session requests shouldn't trigger any session // pings because sessionPool.Take reschedules the next healthcheck. - if got, want := server.testSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { + if got, want := server.TestSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { t.Fatalf("%v - got ping session requests: %v, want %v", i, got, want) } sh.recycle() @@ -354,8 +354,8 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { // Inject session error to mockclient, and take the session from the // session pool, the old session should be destroyed and the session pool // will create a new session. - server.testSpanner.PutExecutionTime(testutil.MethodGetSession, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodGetSession, + SimulatedExecutionTime{ Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, }) @@ -367,7 +367,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { if err != nil { t.Fatalf("failed to get session: %v", err) } - ds := server.testSpanner.DumpSessions() + ds := server.TestSpanner.DumpSessions() if len(ds) != 1 { t.Fatalf("dumped sessions from mockclient: %v, want %v", ds, sh.getID()) } @@ -380,13 +380,13 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { func TestMaxOpenedSessions(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MaxOpened: 1, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions sh1, err := sp.take(ctx) @@ -425,13 +425,13 @@ func TestMaxOpenedSessions(t *testing.T) { func TestMinOpenedSessions(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MinOpened: 1, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Take ten sessions from session pool and recycle them. @@ -468,18 +468,18 @@ func TestMinOpenedSessions(t *testing.T) { func TestMaxBurst(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MaxBurst: 1, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Will cause session creation RPC to be retried forever. - server.testSpanner.PutExecutionTime(testutil.MethodCreateSession, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodCreateSession, + SimulatedExecutionTime{ Errors: []error{status.Errorf(codes.Unavailable, "try later")}, KeepError: true, }) @@ -511,10 +511,10 @@ func TestMaxBurst(t *testing.T) { } // Let the first session request succeed. - server.testSpanner.Freeze() - server.testSpanner.PutExecutionTime(testutil.MethodCreateSession, testutil.SimulatedExecutionTime{}) + server.TestSpanner.Freeze() + server.TestSpanner.PutExecutionTime(MethodCreateSession, SimulatedExecutionTime{}) //close(allowRequests) - server.testSpanner.Unfreeze() + server.TestSpanner.Unfreeze() // Now new session request can proceed because the first session request will eventually succeed. sh, err := sp.take(ctx) @@ -530,14 +530,14 @@ func TestMaxBurst(t *testing.T) { func TestSessionRecycle(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MinOpened: 1, MaxIdle: 5, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Test session is correctly recycled and reused. @@ -568,13 +568,13 @@ func TestSessionDestroy(t *testing.T) { t.Skip("s.destroy(true) is flakey") t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MinOpened: 1, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions <-time.After(10 * time.Millisecond) // maintainer will create one session, we wait for it create session to avoid flakiness in test @@ -631,14 +631,14 @@ func TestHcHeap(t *testing.T) { func TestHealthCheckScheduler(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ HealthCheckInterval: 50 * time.Millisecond, healthCheckSampleInterval: 10 * time.Millisecond, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Create 50 sessions. @@ -652,13 +652,13 @@ func TestHealthCheckScheduler(t *testing.T) { // Make sure we start with a ping history to avoid that the first // sessions that were created have not already exceeded the maximum // number of pings. - server.testSpanner.ClearPings() + server.TestSpanner.ClearPings() // Wait for 10-30 pings per session. waitFor(t, func() error { // Only check actually live sessions and ignore any sessions the // session pool may have deleted in the meantime. - liveSessions := server.testSpanner.DumpSessions() - dp := server.testSpanner.DumpPings() + liveSessions := server.TestSpanner.DumpSessions() + dp := server.TestSpanner.DumpPings() gotPings := map[string]int64{} for _, p := range dp { gotPings[p]++ @@ -678,14 +678,14 @@ func TestHealthCheckScheduler(t *testing.T) { func TestWriteSessionsPrepared(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ WriteSessions: 0.5, MaxIdle: 20, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions shs := make([]*sessionHandle, 10) @@ -748,7 +748,7 @@ func TestWriteSessionsPrepared(t *testing.T) { func TestTakeFromWriteQueue(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MaxOpened: 1, @@ -756,7 +756,7 @@ func TestTakeFromWriteQueue(t *testing.T) { MaxIdle: 1, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions sh, err := sp.take(ctx) @@ -786,14 +786,14 @@ func TestTakeFromWriteQueue(t *testing.T) { func TestSessionHealthCheck(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServerWithConfig(t, + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ HealthCheckInterval: 50 * time.Millisecond, healthCheckSampleInterval: 10 * time.Millisecond, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions // Test pinging sessions. @@ -804,7 +804,7 @@ func TestSessionHealthCheck(t *testing.T) { // Wait for healthchecker to send pings to session. waitFor(t, func() error { - pings := server.testSpanner.DumpPings() + pings := server.TestSpanner.DumpPings() if len(pings) == 0 || pings[0] != sh.getID() { return fmt.Errorf("healthchecker didn't send any ping to session %v", sh.getID()) } @@ -816,13 +816,13 @@ func TestSessionHealthCheck(t *testing.T) { t.Fatalf("cannot get session from session pool: %v", err) } - server.testSpanner.Freeze() - server.testSpanner.PutExecutionTime(testutil.MethodGetSession, - testutil.SimulatedExecutionTime{ + server.TestSpanner.Freeze() + server.TestSpanner.PutExecutionTime(MethodGetSession, + SimulatedExecutionTime{ Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, KeepError: true, }) - server.testSpanner.Unfreeze() + server.TestSpanner.Unfreeze() //atomic.SwapInt64(&requestShouldErr, 1) // Wait for healthcheck workers to find the broken session and tear it down. @@ -834,9 +834,9 @@ func TestSessionHealthCheck(t *testing.T) { t.Fatalf("session(%v) is still alive, want it to be dropped by healthcheck workers", s) } - server.testSpanner.Freeze() - server.testSpanner.PutExecutionTime(testutil.MethodGetSession, testutil.SimulatedExecutionTime{}) - server.testSpanner.Unfreeze() + server.TestSpanner.Freeze() + server.TestSpanner.PutExecutionTime(MethodGetSession, SimulatedExecutionTime{}) + server.TestSpanner.Unfreeze() // Test garbage collection. sh, err = sp.take(ctx) @@ -876,10 +876,9 @@ func TestStressSessionPool(t *testing.T) { cfg.healthCheckSampleInterval = 10 * time.Millisecond cfg.HealthCheckWorkers = 50 - server, client := newSpannerInMemTestServerWithConfig(t, - ClientConfig{ - SessionPoolConfig: cfg, - }) + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: cfg, + }) sp := client.idleSessions // Create a test group for this configuration and schedule 100 sub @@ -898,7 +897,7 @@ func TestStressSessionPool(t *testing.T) { // stable. idleSessions := map[string]bool{} hcSessions := map[string]bool{} - mockSessions := server.testSpanner.DumpSessions() + mockSessions := server.TestSpanner.DumpSessions() // Dump session pool's idle list. for sl := sp.idleList.Front(); sl != nil; sl = sl.Next() { s := sl.Value.(*session) @@ -944,13 +943,13 @@ func TestStressSessionPool(t *testing.T) { } } sp.close() - mockSessions = server.testSpanner.DumpSessions() + mockSessions = server.TestSpanner.DumpSessions() for id, b := range hcSessions { if b && mockSessions[id] { t.Fatalf("Found session from pool still live on server: %v", id) } } - server.teardown(client) + teardown() } } @@ -1034,14 +1033,14 @@ func TestMaintainer(t *testing.T) { minOpened := uint64(5) maxIdle := uint64(4) - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MinOpened: minOpened, MaxIdle: maxIdle, }, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions sampleInterval := sp.SessionPoolConfig.healthCheckSampleInterval @@ -1114,11 +1113,11 @@ func TestMaintainer_CreatesSessions(t *testing.T) { MaxIdle: 10, healthCheckSampleInterval: 20 * time.Millisecond, } - server, client := newSpannerInMemTestServerWithConfig(t, + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: spc, }) - defer server.teardown(client) + defer teardown() sp := client.idleSessions timeoutAmt := 4 * time.Second diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 9c10ac4d8195..2976c96b595d 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -25,7 +25,7 @@ import ( "testing" "time" - "cloud.google.com/go/spanner/internal/testutil" + . "cloud.google.com/go/spanner/internal/testutil" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" @@ -35,8 +35,8 @@ import ( func TestSingle(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() txn := client.Single() defer txn.Close() @@ -50,7 +50,7 @@ func TestSingle(t *testing.T) { } // Only one CreateSessionRequest is sent. - if _, err := shouldHaveReceived(server.testSpanner, []interface{}{&sppb.CreateSessionRequest{}}); err != nil { + if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{&sppb.CreateSessionRequest{}}); err != nil { t.Fatal(err) } } @@ -59,16 +59,16 @@ func TestSingle(t *testing.T) { func TestReadOnlyTransaction_RecoverFromFailure(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() txn := client.ReadOnlyTransaction() defer txn.Close() // First request will fail. errUsr := gstatus.Error(codes.Unknown, "error") - server.testSpanner.PutExecutionTime(testutil.MethodBeginTransaction, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodBeginTransaction, + SimulatedExecutionTime{ Errors: []error{errUsr}, }) @@ -86,8 +86,8 @@ func TestReadOnlyTransaction_RecoverFromFailure(t *testing.T) { func TestReadOnlyTransaction_UseAfterClose(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + _, client, teardown := setupMockedTestServer(t) + defer teardown() txn := client.ReadOnlyTransaction() txn.Close() @@ -102,12 +102,12 @@ func TestReadOnlyTransaction_UseAfterClose(t *testing.T) { func TestReadOnlyTransaction_Concurrent(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() txn := client.ReadOnlyTransaction() defer txn.Close() - server.testSpanner.Freeze() + server.TestSpanner.Freeze() var ( sh1 *sessionHandle sh2 *sessionHandle @@ -130,7 +130,7 @@ func TestReadOnlyTransaction_Concurrent(t *testing.T) { // TODO(deklerk): Get rid of this. <-time.After(100 * time.Millisecond) - server.testSpanner.Unfreeze() + server.TestSpanner.Unfreeze() wg.Wait() if sh1.session.id != sh2.session.id { t.Fatalf("Expected acquire to get same session handle, got %v and %v.", sh1, sh2) @@ -143,8 +143,8 @@ func TestReadOnlyTransaction_Concurrent(t *testing.T) { func TestApply_Single(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() ms := []*Mutation{ Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), @@ -154,7 +154,7 @@ func TestApply_Single(t *testing.T) { t.Fatalf("applyAtLeastOnce retry on abort, got %v, want nil.", e) } - if _, err := shouldHaveReceived(server.testSpanner, []interface{}{ + if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.CommitRequest{}, }); err != nil { @@ -166,13 +166,13 @@ func TestApply_Single(t *testing.T) { func TestApply_RetryOnAbort(t *testing.T) { ctx := context.Background() t.Parallel() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() // First commit will fail, and the retry will begin a new transaction. errAbrt := spannerErrorf(codes.Aborted, "") - server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, + SimulatedExecutionTime{ Errors: []error{errAbrt}, }) @@ -184,7 +184,7 @@ func TestApply_RetryOnAbort(t *testing.T) { t.Fatalf("ReadWriteTransaction retry on abort, got %v, want nil.", e) } - if _, err := shouldHaveReceived(server.testSpanner, []interface{}{ + if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, &sppb.CommitRequest{}, // First commit fails. @@ -199,16 +199,16 @@ func TestApply_RetryOnAbort(t *testing.T) { func TestTransaction_NotFound(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() wantErr := spannerErrorf(codes.NotFound, "Session not found") - server.testSpanner.PutExecutionTime(testutil.MethodBeginTransaction, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodBeginTransaction, + SimulatedExecutionTime{ Errors: []error{wantErr, wantErr, wantErr}, }) - server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, - testutil.SimulatedExecutionTime{ + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, + SimulatedExecutionTime{ Errors: []error{wantErr, wantErr, wantErr}, }) @@ -243,8 +243,8 @@ func TestTransaction_NotFound(t *testing.T) { func TestReadWriteTransaction_ErrorReturned(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() want := errors.New("an error") _, got := client.ReadWriteTransaction(ctx, func(context.Context, *ReadWriteTransaction) error { @@ -253,7 +253,7 @@ func TestReadWriteTransaction_ErrorReturned(t *testing.T) { if got != want { t.Fatalf("got %+v, want %+v", got, want) } - requests := drainRequestsFromServer(server.testSpanner) + requests := drainRequestsFromServer(server.TestSpanner) if err := compareRequests([]interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, @@ -277,27 +277,27 @@ func TestReadWriteTransaction_ErrorReturned(t *testing.T) { func TestBatchDML_WithMultipleDML(t *testing.T) { t.Parallel() ctx := context.Background() - server, client := newSpannerInMemTestServer(t) - defer server.teardown(client) + server, client, teardown := setupMockedTestServer(t) + defer teardown() _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { - if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { + if _, err = tx.Update(ctx, Statement{SQL: UpdateBarSetFoo}); err != nil { return err } - if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}, {SQL: updateBarSetFoo}}); err != nil { + if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}, {SQL: UpdateBarSetFoo}}); err != nil { return err } - if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { + if _, err = tx.Update(ctx, Statement{SQL: UpdateBarSetFoo}); err != nil { return err } - _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}}) + _, err = tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}}) return err }) if err != nil { t.Fatal(err) } - gotReqs, err := shouldHaveReceived(server.testSpanner, []interface{}{ + gotReqs, err := shouldHaveReceived(server.TestSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, &sppb.ExecuteSqlRequest{}, @@ -329,7 +329,7 @@ func TestBatchDML_WithMultipleDML(t *testing.T) { // // Note: this in-place modifies serverClientMock by popping items off the // ReceivedRequests channel. -func shouldHaveReceived(server testutil.InMemSpannerServer, want []interface{}) ([]interface{}, error) { +func shouldHaveReceived(server InMemSpannerServer, want []interface{}) ([]interface{}, error) { got := drainRequestsFromServer(server) return got, compareRequests(want, got) } @@ -358,7 +358,7 @@ func compareRequests(want []interface{}, got []interface{}) error { return nil } -func drainRequestsFromServer(server testutil.InMemSpannerServer) []interface{} { +func drainRequestsFromServer(server InMemSpannerServer) []interface{} { var reqs []interface{} loop: for {