Skip to content

Commit

Permalink
chore: add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kevwan committed Aug 4, 2024
1 parent 3eaa672 commit c0000c7
Show file tree
Hide file tree
Showing 14 changed files with 296 additions and 144 deletions.
21 changes: 11 additions & 10 deletions core/metainfo/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,29 @@ import (

var _ propagation.TextMapCarrier = (*GrpcHeaderCarrier)(nil)

// GrpcHeaderCarrier impl propagation.TextMapCarrier for grpc metadata.MD.
// GrpcHeaderCarrier implements propagation.TextMapCarrier for grpc metadata.MD.
type GrpcHeaderCarrier metadata.MD

// Get returns the value associated with the passed key.
func (mc GrpcHeaderCarrier) Get(key string) string {
vals := metadata.MD(mc).Get(key)
if len(vals) > 0 {
return vals[0]
vals := mc[key]
if len(vals) == 0 {
return ""
}
return ""
}

// Set stores the key-value pair.
func (mc GrpcHeaderCarrier) Set(key string, value string) {
metadata.MD(mc).Set(key, value)
return vals[0]
}

// Keys lists the keys stored in this carrier.
func (mc GrpcHeaderCarrier) Keys() []string {
keys := make([]string, 0, len(mc))
for k := range metadata.MD(mc) {
for k := range mc {
keys = append(keys, k)
}
return keys
}

// Set stores the key-value pair.
func (mc GrpcHeaderCarrier) Set(key, value string) {
metadata.MD(mc).Set(key, value)
}
54 changes: 54 additions & 0 deletions core/metainfo/grpc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package metainfo

import (
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
)

func TestGrpcHeaderCarrier_Get(t *testing.T) {
md := metadata.MD{
"key1": []string{"value1"},
"key2": []string{"value2", "value3"},
}
carrier := GrpcHeaderCarrier(md)

tests := []struct {
key string
expected string
}{
{"key1", "value1"},
{"key2", "value2"},
{"key3", ""},
}

for _, tt := range tests {
t.Run(tt.key, func(t *testing.T) {
assert.Equal(t, tt.expected, carrier.Get(tt.key))
})
}
}

func TestGrpcHeaderCarrier_Set(t *testing.T) {
md := metadata.MD{}
carrier := GrpcHeaderCarrier(md)

carrier.Set("key1", "value1")
carrier.Set("key2", "value2")

assert.Equal(t, metadata.MD{
"key1": []string{"value1"},
"key2": []string{"value2"},
}, md)
}

func TestGrpcHeaderCarrier_Keys(t *testing.T) {
md := metadata.MD{
"key1": []string{"value1"},
"key2": []string{"value2"},
}
carrier := GrpcHeaderCarrier(md)

assert.ElementsMatch(t, []string{"key1", "key2"}, carrier.Keys())
}
121 changes: 57 additions & 64 deletions core/metainfo/metainfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,23 @@ import (
"context"
"strings"

"go.opentelemetry.io/otel/propagation"

"github.com/zeromicro/go-zero/core/collection"
"go.opentelemetry.io/otel/propagation"
)

const (
// LogKey is the key for custom keys in logger fields.
LogKey = "custom_keys"
// PrefixPass means that header/metadata key with this prefix will be passed to the other servers.
PrefixPass = "x-pass-"

lenPP = len(PrefixPass)
)

var (
// CustomKeysMapPropagator impl propagation.TextMapPropagator for custom keys passing.
// CustomKeysMapPropagator implements propagation.TextMapPropagator for custom keys passing.
CustomKeysMapPropagator propagation.TextMapPropagator = (*customKeysPropagator)(nil)

ctxKey ctxKeyType
customKeyStore = contextKeyStore{
keyArr: make([]string, 0),
keySet: collection.NewSet(),
}
customKeyStore = newContextKeyStore()
)

type (
Expand All @@ -37,78 +33,75 @@ type (
}
)

// RegisterCustomKeys register custom keys globally.
// GetMapFromContext retrieves all custom keys and values from the context.
func GetMapFromContext(ctx context.Context) map[string]string {
mp := getMap(ctx)
if len(mp) == 0 {
return mp
}

m := make(map[string]string, len(mp))
for k, v := range mp {
m[k] = v
}
return m
}

// GetMapFromPropagator retrieves all custom keys and values from the propagation carrier.
func GetMapFromPropagator(carrier propagation.TextMapCarrier) map[string]string {
mp := make(map[string]string)
for _, k := range carrier.Keys() {
lowerKey := strings.ToLower(k)
if customKeyStore.keySet.Contains(lowerKey) || (len(lowerKey) > len(PrefixPass) &&
strings.HasPrefix(lowerKey, PrefixPass)) {
v := carrier.Get(lowerKey)
if len(v) > 0 {
mp[lowerKey] = v
}
}
}
return mp
}

// RegisterCustomKeys registers custom keys globally.
// Key must be lowercase.
// Should only be called once before application start.
func RegisterCustomKeys(keys []string) {
for _, k := range keys {
kk := strings.ToLower(k)
if k != kk {
panic("custom key only support lowercase")
lowerKey := strings.ToLower(k)
if k != lowerKey {
panic("custom keys must be lowercase")
}
customKeyStore.keySet.AddStr(k)
customKeyStore.keySet.AddStr(lowerKey)
}
customKeyStore.keyArr = customKeyStore.keySet.KeysStr()
}

// for test only
func reset() {
customKeyStore = contextKeyStore{
keyArr: make([]string, 0),
keySet: collection.NewSet(),
}
}

func getMap(ctx context.Context) map[string]string {
if ctx != nil {
if val, ok := ctx.Value(ctxKey).(map[string]string); ok {
return val
}
if val, ok := ctx.Value(ctxKey).(map[string]string); ok {
return val
}

return make(map[string]string, 0)
return make(map[string]string)
}

func setMap(ctx context.Context, m map[string]string) context.Context {
if ctx == nil {
return nil
func newContextKeyStore() contextKeyStore {
return contextKeyStore{
keyArr: make([]string, 0),
keySet: collection.NewSet(),
}

return context.WithValue(ctx, ctxKey, m)
}

// GetMapFromContext retrieves all custom keys and values from context.
func GetMapFromContext(ctx context.Context) map[string]string {
mp := getMap(ctx)

if len(mp) > 0 {
m := make(map[string]string, len(mp))
for k, v := range mp {
m[k] = v
}
return m
}

return mp
// for test only
func reset() {
customKeyStore = newContextKeyStore()
}

// GetMapFromPropagator retrieves all custom keys and values from propagation carrier.
func GetMapFromPropagator(carrier propagation.TextMapCarrier) map[string]string {
mp := make(map[string]string)
for _, k := range carrier.Keys() {
kk := strings.ToLower(k)

if customKeyStore.keySet.Contains(kk) || (len(kk) > lenPP && strings.HasPrefix(kk, PrefixPass)) {
v := carrier.Get(kk)
if len(v) > 0 {
mp[kk] = v
}
}
}
return mp
func setMap(ctx context.Context, m map[string]string) context.Context {
return context.WithValue(ctx, ctxKey, m)
}

// Inject impl TextMapPropagator for customKeysPropagator.
// Inject implements TextMapPropagator for customKeysPropagator.
func (c *customKeysPropagator) Inject(ctx context.Context, carrier propagation.TextMapCarrier) {
mp := getMap(ctx)
for k, v := range mp {
Expand All @@ -118,10 +111,10 @@ func (c *customKeysPropagator) Inject(ctx context.Context, carrier propagation.T
}
}

// Extract impl TextMapPropagator for customKeysPropagator.
func (c *customKeysPropagator) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context {
// Extract implements TextMapPropagator for customKeysPropagator.
func (c *customKeysPropagator) Extract(ctx context.Context,
carrier propagation.TextMapCarrier) context.Context {
mp := getMap(ctx)

cmp := GetMapFromPropagator(carrier)
if len(cmp) == 0 {
return ctx
Expand All @@ -134,7 +127,7 @@ func (c *customKeysPropagator) Extract(ctx context.Context, carrier propagation.
return setMap(ctx, mp)
}

// Fields not used
// Fields returns nil as it's not used.
func (c *customKeysPropagator) Fields() []string {
return nil
}
101 changes: 80 additions & 21 deletions core/metainfo/metainfo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -229,27 +230,85 @@ func TestRegisterCustomKeys_Mix(t *testing.T) {
assert.Equal(t, customMap, GetMapFromPropagator(propagation.HeaderCarrier(header)))
}

//goos: darwin
//goarch: arm64
//pkg: code.bydev.io/cht/fiat/backend/lib.git/pkg/transport
//BenchmarkCustomKeys
//BenchmarkCustomKeys-8 1414234 762.8 ns/op
//BenchmarkCustomKeys_10
//BenchmarkCustomKeys_10-8 615555 1849 ns/op
//BenchmarkCustomKeys_50
//BenchmarkCustomKeys_50-8 104818 11497 ns/op
//BenchmarkCustomKeysAutoPass
//BenchmarkCustomKeysAutoPass-8 861883 1333 ns/op
//BenchmarkCustomKeysAutoPass_10
//BenchmarkCustomKeysAutoPass_10-8 392179 3085 ns/op
//BenchmarkCustomKeysAutoPass_50
//BenchmarkCustomKeysAutoPass_50-8 75937 15628 ns/op
//BenchmarkCustomKeysMix
//BenchmarkCustomKeysMix-8 1201923 972.0 ns/op
//BenchmarkCustomKeysMix_10
//BenchmarkCustomKeysMix_10-8 450882 2786 ns/op
//BenchmarkCustomKeysMix_50
//BenchmarkCustomKeysMix_50-8 82384 14543 ns/op
func TestGetMap(t *testing.T) {
// Test when context has no map
ctx := context.Background()
assert.Empty(t, getMap(ctx))

// Test when context has a map
expectedMap := map[string]string{"key": "value"}
ctx = setMap(ctx, expectedMap)
assert.Equal(t, expectedMap, getMap(ctx))
}

func TestGetMapFromContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
want map[string]string
}{
{
name: "Empty context",
ctx: context.Background(),
want: map[string]string{},
},
{
name: "Context with custom map",
ctx: context.WithValue(context.Background(), ctxKey,
map[string]string{"key1": "value1", "key2": "value2"}),
want: map[string]string{"key1": "value1", "key2": "value2"},
},
{
name: "Context with empty custom map",
ctx: context.WithValue(context.Background(), ctxKey, map[string]string{}),
want: map[string]string{},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetMapFromContext(tt.ctx)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetMapFromContext() = %v, want %v", got, tt.want)
}
})
}
}

func TestSetMap(t *testing.T) {
// Test when context is not nil
ctx := context.Background()
expectedMap := map[string]string{"key": "value"}
newCtx := setMap(ctx, expectedMap)
assert.NotNil(t, newCtx)
assert.Equal(t, expectedMap, getMap(newCtx))
}

func TestFields(t *testing.T) {
ck := &customKeysPropagator{}
assert.Nil(t, ck.Fields())
}

// Benchmark tests (remaining the same)
// goos: darwin
// goarch: arm64
// pkg: code.bydev.io/cht/fiat/backend/lib.git/pkg/transport
// BenchmarkCustomKeys
// BenchmarkCustomKeys-8 1414234 762.8 ns/op
// BenchmarkCustomKeys_10
// BenchmarkCustomKeys_10-8 615555 1849 ns/op
// BenchmarkCustomKeys_50
// BenchmarkCustomKeys_50-8 104818 11497 ns/op
// BenchmarkCustomKeysAutoPass
// BenchmarkCustomKeysAutoPass-8 861883 1333 ns/op
// BenchmarkCustomKeysAutoPass_10
// BenchmarkCustomKeysAutoPass_10-8 392179 3085 ns/op
// BenchmarkCustomKeysAutoPass_50-8 75937 15628 ns/op
// BenchmarkCustomKeysMix
// BenchmarkCustomKeysMix-8 1201923 972.0 ns/op
// BenchmarkCustomKeysMix_10
// BenchmarkCustomKeysMix_10-8 450882 2786 ns/op
// BenchmarkCustomKeysMix_50-8 82384 14543 ns/op

func benchmarkLen(b *testing.B, l int) {
reset()
Expand Down
Loading

0 comments on commit c0000c7

Please sign in to comment.