Skip to content

Commit

Permalink
gcp/observability: update method name validation (#5951)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasweq authored Jan 23, 2023
1 parent 4075ef0 commit 52a8392
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 23 deletions.
33 changes: 23 additions & 10 deletions gcp/observability/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,14 @@ import (
"errors"
"fmt"
"os"
"regexp"
"strings"

gcplogging "cloud.google.com/go/logging"
"golang.org/x/oauth2/google"
"google.golang.org/grpc/internal/envconfig"
)

const (
envProjectID = "GOOGLE_CLOUD_PROJECT"
methodStringRegexpStr = `^([\w./]+)/((?:\w+)|[*])$`
)

var methodStringRegexp = regexp.MustCompile(methodStringRegexpStr)
const envProjectID = "GOOGLE_CLOUD_PROJECT"

// fetchDefaultProjectID fetches the default GCP project id from environment.
func fetchDefaultProjectID(ctx context.Context) string {
Expand All @@ -59,6 +54,25 @@ func fetchDefaultProjectID(ctx context.Context) string {
return credentials.ProjectID
}

// validateMethodString validates whether the string passed in is a valid
// pattern.
func validateMethodString(method string) error {
if strings.HasPrefix(method, "/") {
return errors.New("cannot have a leading slash")
}
serviceMethod := strings.Split(method, "/")
if len(serviceMethod) != 2 {
return errors.New("/ must come in between service and method, only one /")
}
if serviceMethod[1] == "" {
return errors.New("method name must be non empty")
}
if serviceMethod[0] == "*" {
return errors.New("cannot have service wildcard * i.e. (*/m)")
}
return nil
}

func validateLogEventMethod(methods []string, exclude bool) error {
for _, method := range methods {
if method == "*" {
Expand All @@ -67,9 +81,8 @@ func validateLogEventMethod(methods []string, exclude bool) error {
}
continue
}
match := methodStringRegexp.FindStringSubmatch(method)
if match == nil {
return fmt.Errorf("invalid method string: %v", method)
if err := validateMethodString(method); err != nil {
return fmt.Errorf("invalid method string: %v, err: %v", method, err)
}
}
return nil
Expand Down
23 changes: 18 additions & 5 deletions gcp/observability/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -322,6 +323,7 @@ func (bml *binaryMethodLogger) Log(c iblog.LogEntryConfig) {
}

type eventConfig struct {
// ServiceMethod has /s/m syntax for fast matching.
ServiceMethod map[string]bool
Services map[string]bool
MatchAll bool
Expand Down Expand Up @@ -364,6 +366,17 @@ func (bl *binaryLogger) GetMethodLogger(methodName string) iblog.MethodLogger {
return nil
}

// parseMethod splits service and method from the input. It expects format
// "service/method".
func parseMethod(method string) (string, string, error) {
pos := strings.Index(method, "/")
if pos < 0 {
// Shouldn't happen, config already validated.
return "", "", errors.New("invalid method name: no / found")
}
return method[:pos], method[pos+1:], nil
}

func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter loggingExporter) {
if len(clientRPCEvents) == 0 {
return
Expand All @@ -382,15 +395,15 @@ func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter logging
eventConfig.MatchAll = true
continue
}
s, m, err := grpcutil.ParseMethod(method)
s, m, err := parseMethod(method)
if err != nil {
continue
}
if m == "*" {
eventConfig.Services[s] = true
continue
}
eventConfig.ServiceMethod[method] = true
eventConfig.ServiceMethod["/"+method] = true
}
eventConfigs = append(eventConfigs, eventConfig)
}
Expand Down Expand Up @@ -419,15 +432,15 @@ func registerServerRPCEvents(serverRPCEvents []serverRPCEvents, exporter logging
eventConfig.MatchAll = true
continue
}
s, m, err := grpcutil.ParseMethod(method)
if err != nil { // Shouldn't happen, already validated at this point.
s, m, err := parseMethod(method)
if err != nil {
continue
}
if m == "*" {
eventConfig.Services[s] = true
continue
}
eventConfig.ServiceMethod[method] = true
eventConfig.ServiceMethod["/"+method] = true
}
eventConfigs = append(eventConfigs, eventConfig)
}
Expand Down
126 changes: 118 additions & 8 deletions gcp/observability/logging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"testing"

Expand Down Expand Up @@ -99,13 +100,14 @@ func setupObservabilitySystemWithConfig(cfg *config) (func(), error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
err = Start(ctx)
if err != nil {
return nil, fmt.Errorf("error in Start: %v", err)
}
return func() {
cleanup := func() {
End()
envconfig.ObservabilityConfig = oldObservabilityConfig
}, nil
}
if err != nil {
return cleanup, fmt.Errorf("error in Start: %v", err)
}
return cleanup, nil
}

// TestClientRPCEventsLogAll tests the observability system configured with a
Expand Down Expand Up @@ -777,18 +779,18 @@ func (s) TestPrecedenceOrderingInConfiguration(t *testing.T) {
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"/grpc.testing.TestService/UnaryCall"},
Methods: []string{"grpc.testing.TestService/UnaryCall"},
MaxMetadataBytes: 30,
MaxMessageBytes: 30,
},
{
Methods: []string{"/grpc.testing.TestService/EmptyCall"},
Methods: []string{"grpc.testing.TestService/EmptyCall"},
Exclude: true,
MaxMetadataBytes: 30,
MaxMessageBytes: 30,
},
{
Methods: []string{"/grpc.testing.TestService/*"},
Methods: []string{"grpc.testing.TestService/*"},
MaxMetadataBytes: 30,
MaxMessageBytes: 30,
},
Expand Down Expand Up @@ -1273,3 +1275,111 @@ func (s) TestMetadataTruncationAccountsKey(t *testing.T) {
}
fle.mu.Unlock()
}

// TestMethodInConfiguration tests different method names with an expectation on
// whether they should error or not.
func (s) TestMethodInConfiguration(t *testing.T) {
// To skip creating a stackdriver exporter.
fle := &fakeLoggingExporter{
t: t,
}

defer func(ne func(ctx context.Context, config *config) (loggingExporter, error)) {
newLoggingExporter = ne
}(newLoggingExporter)

newLoggingExporter = func(ctx context.Context, config *config) (loggingExporter, error) {
return fle, nil
}

tests := []struct {
name string
config *config
wantErr string
}{
{
name: "leading-slash",
config: &config{
ProjectID: "fake",
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"/service/method"},
},
},
},
},
wantErr: "cannot have a leading slash",
},
{
name: "wildcard service/method",
config: &config{
ProjectID: "fake",
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"*/method"},
},
},
},
},
wantErr: "cannot have service wildcard *",
},
{
name: "/ in service name",
config: &config{
ProjectID: "fake",
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"ser/vice/method"},
},
},
},
},
wantErr: "only one /",
},
{
name: "empty method name",
config: &config{
ProjectID: "fake",
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"service/"},
},
},
},
},
wantErr: "method name must be non empty",
},
{
name: "normal",
config: &config{
ProjectID: "fake",
CloudLogging: &cloudLogging{
ClientRPCEvents: []clientRPCEvents{
{
Methods: []string{"service/method"},
},
},
},
},
wantErr: "",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
cleanup, gotErr := setupObservabilitySystemWithConfig(test.config)
if cleanup != nil {
defer cleanup()
}
if gotErr != nil && !strings.Contains(gotErr.Error(), test.wantErr) {
t.Fatalf("Start(%v) = %v, wantErr %v", test.config, gotErr, test.wantErr)
}
if (gotErr != nil) != (test.wantErr != "") {
t.Fatalf("Start(%v) = %v, wantErr %v", test.config, gotErr, test.wantErr)
}
})
}
}

0 comments on commit 52a8392

Please sign in to comment.