Skip to content

Commit 7003d7f

Browse files
authored
Merge branch 'v1' into dont-waste-id
2 parents 54ce4fd + 2f0917c commit 7003d7f

21 files changed

+171
-191
lines changed

bson/bsoncodec/registry.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,9 @@ func (r *Registry) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) {
388388
// If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for
389389
// concurrent use by multiple goroutines after all codecs and encoders are registered.
390390
func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) {
391+
if valueType == nil {
392+
return nil, ErrNoEncoder{Type: valueType}
393+
}
391394
enc, found := r.lookupTypeEncoder(valueType)
392395
if found {
393396
if enc == nil {
@@ -400,15 +403,10 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) {
400403
if found {
401404
return r.typeEncoders.LoadOrStore(valueType, enc), nil
402405
}
403-
if valueType == nil {
404-
r.storeTypeEncoder(valueType, nil)
405-
return nil, ErrNoEncoder{Type: valueType}
406-
}
407406

408407
if v, ok := r.kindEncoders.Load(valueType.Kind()); ok {
409408
return r.storeTypeEncoder(valueType, v), nil
410409
}
411-
r.storeTypeEncoder(valueType, nil)
412410
return nil, ErrNoEncoder{Type: valueType}
413411
}
414412

@@ -474,7 +472,6 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) {
474472
if v, ok := r.kindDecoders.Load(valueType.Kind()); ok {
475473
return r.storeTypeDecoder(valueType, v), nil
476474
}
477-
r.storeTypeDecoder(valueType, nil)
478475
return nil, ErrNoDecoder{Type: valueType}
479476
}
480477

bson/bsoncodec/registry_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,36 @@ func TestRegistry(t *testing.T) {
792792
})
793793
})
794794
}
795+
t.Run("nil type", func(t *testing.T) {
796+
t.Parallel()
797+
798+
t.Run("Encoder", func(t *testing.T) {
799+
t.Parallel()
800+
801+
wanterr := ErrNoEncoder{Type: reflect.TypeOf(nil)}
802+
803+
gotcodec, goterr := reg.LookupEncoder(nil)
804+
if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
805+
t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
806+
}
807+
if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) {
808+
t.Errorf("codecs did not match: got %#v, want nil", gotcodec)
809+
}
810+
})
811+
t.Run("Decoder", func(t *testing.T) {
812+
t.Parallel()
813+
814+
wanterr := ErrNilType
815+
816+
gotcodec, goterr := reg.LookupDecoder(nil)
817+
if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
818+
t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
819+
}
820+
if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) {
821+
t.Errorf("codecs did not match: got %v: want nil", gotcodec)
822+
}
823+
})
824+
})
795825
// lookup a type whose pointer implements an interface and expect that the registered hook is
796826
// returned
797827
t.Run("interface implementation with hook (pointer)", func(t *testing.T) {

bson/marshal_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"errors"
1212
"fmt"
1313
"reflect"
14+
"sync"
1415
"testing"
1516
"time"
1617

@@ -380,3 +381,19 @@ func TestMarshalExtJSONIndent(t *testing.T) {
380381
})
381382
}
382383
}
384+
385+
func TestMarshalConcurrently(t *testing.T) {
386+
t.Parallel()
387+
388+
const size = 10_000
389+
390+
wg := sync.WaitGroup{}
391+
wg.Add(size)
392+
for i := 0; i < size; i++ {
393+
go func() {
394+
defer wg.Done()
395+
_, _ = Marshal(struct{ LastError error }{})
396+
}()
397+
}
398+
wg.Wait()
399+
}

bson/unmarshal_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package bson
99
import (
1010
"math/rand"
1111
"reflect"
12+
"sync"
1213
"testing"
1314

1415
"go.mongodb.org/mongo-driver/bson/bsoncodec"
@@ -773,3 +774,21 @@ func TestUnmarshalByteSlicesUseDistinctArrays(t *testing.T) {
773774
})
774775
}
775776
}
777+
778+
func TestUnmarshalConcurrently(t *testing.T) {
779+
t.Parallel()
780+
781+
const size = 10_000
782+
783+
data := []byte{16, 0, 0, 0, 10, 108, 97, 115, 116, 101, 114, 114, 111, 114, 0, 0}
784+
wg := sync.WaitGroup{}
785+
wg.Add(size)
786+
for i := 0; i < size; i++ {
787+
go func() {
788+
defer wg.Done()
789+
var res struct{ LastError error }
790+
_ = Unmarshal(data, &res)
791+
}()
792+
}
793+
wg.Wait()
794+
}

mongo/integration/unified/client_entity.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ func newClientEntity(ctx context.Context, em *EntityMap, entityOptions *entityOp
107107
}
108108

109109
if olm := entityOptions.ObserveLogMessages; olm != nil {
110-
clientLogger := newLogger(olm, expectedLogMessageCount(ctx))
110+
expectedLogMessagesCount := expectedLogMessagesCount(ctx, entityOptions.ID)
111+
ignoreLogMessages := ignoreLogMessages(ctx, entityOptions.ID)
112+
113+
clientLogger := newLogger(olm, expectedLogMessagesCount, ignoreLogMessages)
111114

112115
wrap := func(str string) options.LogLevel {
113116
return options.LogLevel(logger.ParseLevel(str))

mongo/integration/unified/context.go

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,23 @@ const (
2828
failPointsKey ctxKey = "test-failpoints"
2929
// targetedFailPointsKey is used to store a map from a fail point name to the host on which the fail point is set.
3030
targetedFailPointsKey ctxKey = "test-targeted-failpoints"
31-
// expectedLogMessageCountKey is used to store the number of log messages expected to be received by the test runner.
32-
expectedLogMessageCountKey ctxKey = "test-expected-log-message-count"
31+
clientLogMessagesKey ctxKey = "test-expected-log-message-count"
32+
ignoreLogMessagesKey ctxKey = "test-ignore-log-message-count"
3333
)
3434

3535
// newTestContext creates a new Context derived from ctx with values initialized to store the state required for test
3636
// execution.
3737
func newTestContext(
3838
ctx context.Context,
3939
entityMap *EntityMap,
40-
expectedLogMessageCount int,
40+
clientLogMessages []*clientLogMessages,
4141
hasOperationalFailPoint bool,
4242
) context.Context {
4343
ctx = context.WithValue(ctx, operationalFailPointKey, hasOperationalFailPoint)
4444
ctx = context.WithValue(ctx, entitiesKey, entityMap)
4545
ctx = context.WithValue(ctx, failPointsKey, make(map[string]*mongo.Client))
4646
ctx = context.WithValue(ctx, targetedFailPointsKey, make(map[string]string))
47-
ctx = context.WithValue(ctx, expectedLogMessageCountKey, expectedLogMessageCount)
47+
ctx = context.WithValue(ctx, clientLogMessagesKey, clientLogMessages)
4848
return ctx
4949
}
5050

@@ -84,6 +84,28 @@ func entities(ctx context.Context) *EntityMap {
8484
return ctx.Value(entitiesKey).(*EntityMap)
8585
}
8686

87-
func expectedLogMessageCount(ctx context.Context) int {
88-
return ctx.Value(expectedLogMessageCountKey).(int)
87+
func expectedLogMessagesCount(ctx context.Context, clientID string) int {
88+
messages := ctx.Value(clientLogMessagesKey).([]*clientLogMessages)
89+
90+
count := 0
91+
for _, message := range messages {
92+
if message.Client == clientID {
93+
count += len(message.LogMessages)
94+
}
95+
}
96+
97+
return count
98+
}
99+
100+
func ignoreLogMessages(ctx context.Context, clientID string) []*logMessage {
101+
messages := ctx.Value(clientLogMessagesKey).([]*clientLogMessages)
102+
103+
ignoreMessages := []*logMessage{}
104+
for _, message := range messages {
105+
if message.Client == clientID {
106+
ignoreMessages = append(ignoreMessages, message.IgnoreMessages...)
107+
}
108+
}
109+
110+
return ignoreMessages
89111
}

mongo/integration/unified/logger.go

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package unified
88

99
import (
10+
"context"
1011
"sync"
1112

1213
"go.mongodb.org/mongo-driver/internal/logger"
@@ -33,19 +34,21 @@ type Logger struct {
3334
// orderMu guards the order value, which increments each time the "Info"
3435
// method is called. This is necessary since "Info" could be called from
3536
// multiple go routines, e.g. SDAM logs.
36-
orderMu sync.RWMutex
37-
logQueue chan orderedLogMessage
37+
orderMu sync.RWMutex
38+
logQueue chan orderedLogMessage
39+
ignoreMessages []*logMessage
3840
}
3941

40-
func newLogger(olm *observeLogMessages, bufSize int) *Logger {
42+
func newLogger(olm *observeLogMessages, bufSize int, ignoreMessages []*logMessage) *Logger {
4143
if olm == nil {
4244
return nil
4345
}
4446

4547
return &Logger{
46-
lastOrder: 1,
47-
logQueue: make(chan orderedLogMessage, bufSize),
48-
bufSize: bufSize,
48+
lastOrder: 1,
49+
logQueue: make(chan orderedLogMessage, bufSize),
50+
bufSize: bufSize,
51+
ignoreMessages: ignoreMessages,
4952
}
5053
}
5154

@@ -65,8 +68,6 @@ func (log *Logger) Info(level int, msg string, args ...interface{}) {
6568
return
6669
}
6770

68-
defer func() { log.lastOrder++ }()
69-
7071
// Add the Diff back to the level, as there is no need to create a
7172
// logging offset.
7273
level = level + logger.DiffToInfo
@@ -76,12 +77,19 @@ func (log *Logger) Info(level int, msg string, args ...interface{}) {
7677
panic(err)
7778
}
7879

80+
for _, ignoreMessage := range log.ignoreMessages {
81+
if err := verifyLogMatch(context.Background(), ignoreMessage, logMessage); err == nil {
82+
return
83+
}
84+
}
85+
86+
defer func() { log.lastOrder++ }()
87+
7988
// Send the log message to the "orderedLogMessage" channel for
8089
// validation.
8190
log.logQueue <- orderedLogMessage{
8291
order: log.lastOrder + 1,
83-
logMessage: logMessage,
84-
}
92+
logMessage: logMessage}
8593

8694
// If the order has reached the buffer size, then close the channel.
8795
if log.lastOrder == log.bufSize {

mongo/integration/unified/logger_verification.go

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,6 @@ type clientLogMessages struct {
7979
LogMessages []*logMessage `bson:"messages"`
8080
}
8181

82-
// ignore checks to see if the message is in the "IgnoreMessages" slice.
83-
func (clm clientLogMessages) ignore(ctx context.Context, msg *logMessage) bool {
84-
for _, ignoreMessage := range clm.IgnoreMessages {
85-
if err := verifyLogMatch(ctx, ignoreMessage, msg); err == nil {
86-
return true
87-
}
88-
}
89-
90-
return false
91-
}
92-
9382
// logMessageValidator defines the expectation for log messages across all
9483
// clients.
9584
type logMessageValidator struct {
@@ -191,8 +180,7 @@ type logQueues struct {
191180
}
192181

193182
// partitionLogQueue will partition the expected logs into "unordered" and
194-
// "ordered" log channels. This function will also remove any logs in the
195-
// "ignoreMessages" list for a client.
183+
// "ordered" log channels.
196184
func partitionLogQueue(ctx context.Context, exp *clientLogMessages) logQueues {
197185
orderedLogCh := make(chan *logMessage, len(exp.LogMessages))
198186
unorderedLogCh := make(chan *logMessage, len(exp.LogMessages))
@@ -241,12 +229,6 @@ func matchOrderedLogs(ctx context.Context, logs logQueues) <-chan error {
241229
defer close(errs)
242230

243231
for actual := range logs.ordered {
244-
// Ignore logs that are in the "IngoreMessages" slice of
245-
// the expected results.
246-
if logs.expected.ignore(ctx, actual) {
247-
continue
248-
}
249-
250232
expected := expLogMessages[0]
251233
if expected == nil {
252234
continue

mongo/integration/unified/logger_verification_test.go

Lines changed: 0 additions & 84 deletions
This file was deleted.

0 commit comments

Comments
 (0)