Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spanner): implement generation and propagation of "x-goog-spanner-request-id" Header #11048

Merged
merged 14 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions spanner/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
var (
sh *sessionHandle
err error
rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
rpc func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error)
)
if sh, _, err = t.acquire(ctx); err != nil {
return &RowIterator{err: err}
Expand All @@ -322,7 +322,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
sh.updateLastUseTime()
// Read or query partition.
if p.rreq != nil {
rpc = func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) {
rpc = func(ctx context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) {
client, err := client.StreamingRead(ctx, &sppb.ReadRequest{
Session: p.rreq.Session,
Transaction: p.rreq.Transaction,
Expand All @@ -335,7 +335,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
ResumeToken: resumeToken,
DataBoostEnabled: p.rreq.DataBoostEnabled,
DirectedReadOptions: p.rreq.DirectedReadOptions,
})
}, opts...)
if err != nil {
return client, err
}
Expand All @@ -351,7 +351,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
return client, err
}
} else {
rpc = func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) {
rpc = func(ctx context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) {
client, err := client.ExecuteStreamingSql(ctx, &sppb.ExecuteSqlRequest{
Session: p.qreq.Session,
Transaction: p.qreq.Transaction,
Expand All @@ -364,7 +364,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
ResumeToken: resumeToken,
DataBoostEnabled: p.qreq.DataBoostEnabled,
DirectedReadOptions: p.qreq.DirectedReadOptions,
})
}, opts...)
if err != nil {
return client, err
}
Expand All @@ -387,7 +387,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
t.sp.sc.metricsTracerFactory,
rpc,
t.setTimestamp,
t.release)
t.release, client.(*grpcSpannerClient))
}

// MarshalBinary implements BinaryMarshaler.
Expand Down
8 changes: 8 additions & 0 deletions spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,14 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
} else {
// Create gtransport ConnPool as usual if MultiEndpoint is not used.
// gRPC options.

// Add a unaryClientInterceptor and streamClientInterceptor.
reqIDInjector := new(requestIDHeaderInjector)
opts = append(opts,
option.WithGRPCDialOption(grpc.WithChainStreamInterceptor(reqIDInjector.interceptStream)),
option.WithGRPCDialOption(grpc.WithChainUnaryInterceptor(reqIDInjector.interceptUnary)),
)

allOpts := allClientOpts(config.NumChannels, config.Compression, opts...)
pool, err = gtransport.DialPool(ctx, allOpts...)
if err != nil {
Expand Down
8 changes: 6 additions & 2 deletions spanner/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4187,13 +4187,17 @@ func TestReadWriteTransaction_ContextTimeoutDuringCommit(t *testing.T) {
if se.GRPCStatus().Code() != w.GRPCStatus().Code() {
t.Fatalf("Error status mismatch:\nGot: %v\nWant: %v", se.GRPCStatus(), w.GRPCStatus())
}
if se.Error() != w.Error() {
t.Fatalf("Error message mismatch:\nGot %s\nWant: %s", se.Error(), w.Error())
if !testEqual(se, w) {
t.Fatalf("Error message mismatch:\nGot: %s\nWant: %s", se.Error(), w.Error())
}
var outcome *TransactionOutcomeUnknownError
if !errors.As(err, &outcome) {
t.Fatalf("Missing wrapped TransactionOutcomeUnknownError error")
}

if w.RequestID != "" {
t.Fatal("Missing .RequestID")
}
}

func TestFailedCommit_NoRollback(t *testing.T) {
Expand Down
3 changes: 3 additions & 0 deletions spanner/cmp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ func testEqual(a, b interface{}) bool {
if strings.Contains(path.GoString(), "{*spanner.Error}.err") {
return true
}
if strings.Contains(path.GoString(), "{*spanner.Error}.RequestID") {
return true
}
return false
}, cmp.Ignore()))
}
25 changes: 20 additions & 5 deletions spanner/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ type Error struct {
// additionalInformation optionally contains any additional information
// about the error.
additionalInformation string

// RequestID is the associated ID that was sent to Google Cloud Spanner's
// backend, as the value in the "x-goog-spanner-request-id" gRPC header.
RequestID string
}

// TransactionOutcomeUnknownError is wrapped in a Spanner error when the error
Expand Down Expand Up @@ -85,10 +89,17 @@ func (e *Error) Error() string {
return "spanner: OK"
}
code := ErrCode(e)

var s string
if e.additionalInformation == "" {
return fmt.Sprintf("spanner: code = %q, desc = %q", code, e.Desc)
s = fmt.Sprintf("spanner: code = %q, desc = %q", code, e.Desc)
} else {
s = fmt.Sprintf("spanner: code = %q, desc = %q, additional information = %s", code, e.Desc, e.additionalInformation)
}
return fmt.Sprintf("spanner: code = %q, desc = %q, additional information = %s", code, e.Desc, e.additionalInformation)
if e.RequestID != "" {
s = fmt.Sprintf("%s, requestID = %q", s, e.RequestID)
}
return s
}

// Unwrap returns the wrapped error (if any).
Expand Down Expand Up @@ -123,6 +134,10 @@ func (e *Error) decorate(info string) {
// APIError error having given error code as its status.
func spannerErrorf(code codes.Code, format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
return spannerError(code, msg)
}

func spannerError(code codes.Code, msg string) error {
wrapped, _ := apierror.FromError(status.Error(code, msg))
return &Error{
Code: code,
Expand Down Expand Up @@ -172,9 +187,9 @@ func toSpannerErrorWithCommitInfo(err error, errorDuringCommit bool) error {
desc = fmt.Sprintf("%s, %s", desc, transactionOutcomeUnknownMsg)
wrapped = &TransactionOutcomeUnknownError{err: wrapped}
}
return &Error{status.FromContextError(err).Code(), toAPIError(wrapped), desc, ""}
return &Error{status.FromContextError(err).Code(), toAPIError(wrapped), desc, "", ""}
case status.Code(err) == codes.Unknown:
return &Error{codes.Unknown, toAPIError(err), err.Error(), ""}
return &Error{codes.Unknown, toAPIError(err), err.Error(), "", ""}
olavloite marked this conversation as resolved.
Show resolved Hide resolved
default:
statusErr := status.Convert(err)
code, desc := statusErr.Code(), statusErr.Message()
Expand All @@ -183,7 +198,7 @@ func toSpannerErrorWithCommitInfo(err error, errorDuringCommit bool) error {
desc = fmt.Sprintf("%s, %s", desc, transactionOutcomeUnknownMsg)
wrapped = &TransactionOutcomeUnknownError{err: wrapped}
}
return &Error{code, toAPIError(wrapped), desc, ""}
return &Error{code, toAPIError(wrapped), desc, "", ""}
olavloite marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
47 changes: 32 additions & 15 deletions spanner/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package spanner
import (
"context"
"strings"
"sync/atomic"

vkit "cloud.google.com/go/spanner/apiv1"
"cloud.google.com/go/spanner/apiv1/spannerpb"
Expand Down Expand Up @@ -67,6 +68,15 @@ type spannerClient interface {
type grpcSpannerClient struct {
raw *vkit.Client
metricsTracerFactory *builtinMetricsTracerFactory

// These fields are used to uniquely track x-goog-spanner-request-id where:
// raw(*vkit.Client) is the channel, and channelID is derived from the ordinal
// count of unique *vkit.Client as retrieved from the session pool.
channelID uint64
// id is derived from the SpannerClient.
id int
// nthRequest is incremented for each new request (but not for retries of requests).
nthRequest *atomic.Uint32
}

var (
Expand All @@ -76,13 +86,16 @@ var (

// newGRPCSpannerClient initializes a new spannerClient that uses the gRPC
// Spanner API.
func newGRPCSpannerClient(ctx context.Context, sc *sessionClient, opts ...option.ClientOption) (spannerClient, error) {
func newGRPCSpannerClient(ctx context.Context, sc *sessionClient, channelID uint64, opts ...option.ClientOption) (spannerClient, error) {
raw, err := vkit.NewClient(ctx, opts...)
if err != nil {
return nil, err
}

g := &grpcSpannerClient{raw: raw, metricsTracerFactory: sc.metricsTracerFactory}
clientID := sc.nthClient
g.prepareRequestIDTrackers(clientID, channelID, sc.nthRequest)

clientInfo := []string{"gccl", internal.Version}
if sc.userAgent != "" {
agentWithVersion := strings.SplitN(sc.userAgent, "/", 2)
Expand Down Expand Up @@ -118,7 +131,7 @@ func (g *grpcSpannerClient) CreateSession(ctx context.Context, req *spannerpb.Cr
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.CreateSession(ctx, req, opts...)
resp, err := g.raw.CreateSession(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -128,7 +141,7 @@ func (g *grpcSpannerClient) BatchCreateSessions(ctx context.Context, req *spanne
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.BatchCreateSessions(ctx, req, opts...)
resp, err := g.raw.BatchCreateSessions(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -138,21 +151,21 @@ func (g *grpcSpannerClient) GetSession(ctx context.Context, req *spannerpb.GetSe
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.GetSession(ctx, req, opts...)
resp, err := g.raw.GetSession(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest, opts ...gax.CallOption) *vkit.SessionIterator {
return g.raw.ListSessions(ctx, req, opts...)
return g.raw.ListSessions(ctx, req, g.optsWithNextRequestID(opts)...)
}

func (g *grpcSpannerClient) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest, opts ...gax.CallOption) error {
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
err := g.raw.DeleteSession(ctx, req, opts...)
err := g.raw.DeleteSession(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return err
Expand All @@ -162,21 +175,23 @@ func (g *grpcSpannerClient) ExecuteSql(ctx context.Context, req *spannerpb.Execu
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.ExecuteSql(ctx, req, opts...)
resp, err := g.raw.ExecuteSql(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) ExecuteStreamingSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest, opts ...gax.CallOption) (spannerpb.Spanner_ExecuteStreamingSqlClient, error) {
// Note: This method does not add g.optsWithNextRequestID to inject x-goog-spanner-request-id
// as it is already manually added when creating Stream iterators for ExecuteStreamingSql.
return g.raw.ExecuteStreamingSql(peer.NewContext(ctx, &peer.Peer{}), req, opts...)
}

func (g *grpcSpannerClient) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest, opts ...gax.CallOption) (*spannerpb.ExecuteBatchDmlResponse, error) {
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.ExecuteBatchDml(ctx, req, opts...)
resp, err := g.raw.ExecuteBatchDml(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -186,21 +201,23 @@ func (g *grpcSpannerClient) Read(ctx context.Context, req *spannerpb.ReadRequest
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.Read(ctx, req, opts...)
resp, err := g.raw.Read(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) StreamingRead(ctx context.Context, req *spannerpb.ReadRequest, opts ...gax.CallOption) (spannerpb.Spanner_StreamingReadClient, error) {
// Note: This method does not add g.optsWithNextRequestID, as it is already
// manually added when creating Stream iterators for StreamingRead.
return g.raw.StreamingRead(peer.NewContext(ctx, &peer.Peer{}), req, opts...)
}

func (g *grpcSpannerClient) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest, opts ...gax.CallOption) (*spannerpb.Transaction, error) {
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.BeginTransaction(ctx, req, opts...)
resp, err := g.raw.BeginTransaction(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -210,7 +227,7 @@ func (g *grpcSpannerClient) Commit(ctx context.Context, req *spannerpb.CommitReq
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.Commit(ctx, req, opts...)
resp, err := g.raw.Commit(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -220,7 +237,7 @@ func (g *grpcSpannerClient) Rollback(ctx context.Context, req *spannerpb.Rollbac
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
err := g.raw.Rollback(ctx, req, opts...)
err := g.raw.Rollback(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return err
Expand All @@ -230,7 +247,7 @@ func (g *grpcSpannerClient) PartitionQuery(ctx context.Context, req *spannerpb.P
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.PartitionQuery(ctx, req, opts...)
resp, err := g.raw.PartitionQuery(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
Expand All @@ -240,12 +257,12 @@ func (g *grpcSpannerClient) PartitionRead(ctx context.Context, req *spannerpb.Pa
mt := g.newBuiltinMetricsTracer(ctx)
defer recordOperationCompletion(mt)
ctx = context.WithValue(ctx, metricsTracerKey, mt)
resp, err := g.raw.PartitionRead(ctx, req, opts...)
resp, err := g.raw.PartitionRead(ctx, req, g.optsWithNextRequestID(opts)...)
statusCode, _ := status.FromError(err)
mt.currOp.setStatus(statusCode.Code().String())
return resp, err
}

func (g *grpcSpannerClient) BatchWrite(ctx context.Context, req *spannerpb.BatchWriteRequest, opts ...gax.CallOption) (spannerpb.Spanner_BatchWriteClient, error) {
return g.raw.BatchWrite(peer.NewContext(ctx, &peer.Peer{}), req, opts...)
return g.raw.BatchWrite(peer.NewContext(ctx, &peer.Peer{}), req, g.optsWithNextRequestID(opts)...)
}
4 changes: 4 additions & 0 deletions spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ const (
MethodExecuteBatchDml string = "EXECUTE_BATCH_DML"
MethodStreamingRead string = "EXECUTE_STREAMING_READ"
MethodBatchWrite string = "BATCH_WRITE"
MethodPartitionQuery string = "PARTITION_QUERY"
)

// StatementResult represents a mocked result on the test server. The result is
Expand Down Expand Up @@ -1107,6 +1108,9 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba
}

func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) {
if err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil {
return nil, err
}
s.mu.Lock()
if s.stopped {
s.mu.Unlock()
Expand Down
Loading
Loading