From 215d683afcf2d95b8414f46f82f6a8e7dac688b1 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 28 Jul 2023 10:50:04 -0600 Subject: [PATCH] GODRIVER-2896 Add IsZero to BSON RawValue (#1332) Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com> --- bson/bsontype/bsontype.go | 11 +++++ bson/primitive_codecs.go | 18 ++++++-- bson/primitive_codecs_test.go | 38 ++++++++++++++++ bson/raw_value.go | 6 +++ bson/raw_value_test.go | 84 +++++++++++++++++++++++++++++++++++ 5 files changed, 154 insertions(+), 3 deletions(-) diff --git a/bson/bsontype/bsontype.go b/bson/bsontype/bsontype.go index f38c263a4c..8cff5492d1 100644 --- a/bson/bsontype/bsontype.go +++ b/bson/bsontype/bsontype.go @@ -102,3 +102,14 @@ func (bt Type) String() string { return "invalid" } } + +// IsValid will return true if the Type is valid. +func (bt Type) IsValid() bool { + switch bt { + case Double, String, EmbeddedDocument, Array, Binary, Undefined, ObjectID, Boolean, DateTime, Null, Regex, + DBPointer, JavaScript, Symbol, CodeWithScope, Int32, Timestamp, Int64, Decimal128, MinKey, MaxKey: + return true + default: + return false + } +} diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index 6b9602589c..ff32a87a79 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -8,6 +8,7 @@ package bson import ( "errors" + "fmt" "reflect" "go.mongodb.org/mongo-driver/bson/bsoncodec" @@ -45,15 +46,26 @@ func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) // RawValueEncodeValue is the ValueEncoderFunc for RawValue. // -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. +// If the RawValue's Type is "invalid" and the RawValue's Value is not empty or +// nil, then this method will return an error. +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive +// encoders and decoders registered. func (PrimitiveCodecs) RawValueEncodeValue(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRawValue { - return bsoncodec.ValueEncoderError{Name: "RawValueEncodeValue", Types: []reflect.Type{tRawValue}, Received: val} + return bsoncodec.ValueEncoderError{ + Name: "RawValueEncodeValue", + Types: []reflect.Type{tRawValue}, + Received: val, + } } rawvalue := val.Interface().(RawValue) + if !rawvalue.Type.IsValid() { + return fmt.Errorf("the RawValue Type specifies an invalid BSON type: %#x", byte(rawvalue.Type)) + } + return bsonrw.Copier{}.CopyValueFromBytes(vw, rawvalue.Type, rawvalue.Value) } diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index 3fb606d2f4..466f135e83 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -65,6 +65,8 @@ func compareErrors(err1, err2 error) bool { } func TestDefaultValueEncoders(t *testing.T) { + t.Parallel() + var pc PrimitiveCodecs var wrong = func(string, string) string { return "wrong" } @@ -107,6 +109,28 @@ func TestDefaultValueEncoders(t *testing.T) { bsonrwtest.WriteDouble, nil, }, + { + "RawValue Type is zero with non-zero value", + RawValue{ + Type: 0x00, + Value: bsoncore.AppendDouble(nil, 3.14159), + }, + nil, + nil, + bsonrwtest.Nothing, + fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x0"), + }, + { + "RawValue Type is invalid", + RawValue{ + Type: 0x8F, + Value: bsoncore.AppendDouble(nil, 3.14159), + }, + nil, + nil, + bsonrwtest.Nothing, + fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x8f"), + }, }, }, { @@ -166,9 +190,17 @@ func TestDefaultValueEncoders(t *testing.T) { } for _, tc := range testCases { + tc := tc // Capture the range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + for _, subtest := range tc.subtests { + subtest := subtest // Capture the range variable + t.Run(subtest.name, func(t *testing.T) { + t.Parallel() + var ec bsoncodec.EncodeContext if subtest.ectx != nil { ec = *subtest.ectx @@ -192,6 +224,8 @@ func TestDefaultValueEncoders(t *testing.T) { } t.Run("success path", func(t *testing.T) { + t.Parallel() + oid := primitive.NewObjectID() oids := []primitive.ObjectID{primitive.NewObjectID(), primitive.NewObjectID(), primitive.NewObjectID()} var str = new(string) @@ -426,7 +460,11 @@ func TestDefaultValueEncoders(t *testing.T) { } for _, tc := range testCases { + tc := tc // Capture the range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + b := make(bsonrw.SliceWriter, 0, 512) vw, err := bsonrw.NewBSONValueWriter(&b) noerr(t, err) diff --git a/bson/raw_value.go b/bson/raw_value.go index 6627294c4d..4d1bfb3160 100644 --- a/bson/raw_value.go +++ b/bson/raw_value.go @@ -37,6 +37,12 @@ type RawValue struct { r *bsoncodec.Registry } +// IsZero reports whether the RawValue is zero, i.e. no data is present on +// the RawValue. It returns true if Type is 0 and Value is empty or nil. +func (rv RawValue) IsZero() bool { + return rv.Type == 0x00 && len(rv.Value) == 0 +} + // Unmarshal deserializes BSON into the provided val. If RawValue cannot be unmarshaled into val, an // error is returned. This method will use the registry used to create the RawValue, if the RawValue // was created from partial BSON processing, or it will use the default registry. Users wishing to diff --git a/bson/raw_value_test.go b/bson/raw_value_test.go index fbc0715600..87f08c4a55 100644 --- a/bson/raw_value_test.go +++ b/bson/raw_value_test.go @@ -13,12 +13,19 @@ import ( "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) func TestRawValue(t *testing.T) { + t.Parallel() + t.Run("Unmarshal", func(t *testing.T) { + t.Parallel() + t.Run("Uses registry attached to value", func(t *testing.T) { + t.Parallel() + reg := bsoncodec.NewRegistryBuilder().Build() val := RawValue{Type: bsontype.String, Value: bsoncore.AppendString(nil, "foobar"), r: reg} var s string @@ -29,6 +36,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Uses default registry if no registry attached", func(t *testing.T) { + t.Parallel() + want := "foobar" val := RawValue{Type: bsontype.String, Value: bsoncore.AppendString(nil, want)} var got string @@ -40,7 +49,11 @@ func TestRawValue(t *testing.T) { }) }) t.Run("UnmarshalWithRegistry", func(t *testing.T) { + t.Parallel() + t.Run("Returns error when registry is nil", func(t *testing.T) { + t.Parallel() + want := ErrNilRegistry var val RawValue got := val.UnmarshalWithRegistry(nil, &D{}) @@ -49,6 +62,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Returns lookup error", func(t *testing.T) { + t.Parallel() + reg := bsoncodec.NewRegistryBuilder().Build() var val RawValue var s string @@ -59,6 +74,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Returns DecodeValue error", func(t *testing.T) { + t.Parallel() + reg := NewRegistryBuilder().Build() val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string @@ -69,6 +86,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Success", func(t *testing.T) { + t.Parallel() + reg := NewRegistryBuilder().Build() want := float64(3.14159) val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, want)} @@ -81,7 +100,11 @@ func TestRawValue(t *testing.T) { }) }) t.Run("UnmarshalWithContext", func(t *testing.T) { + t.Parallel() + t.Run("Returns error when DecodeContext is nil", func(t *testing.T) { + t.Parallel() + want := ErrNilContext var val RawValue got := val.UnmarshalWithContext(nil, &D{}) @@ -90,6 +113,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Returns lookup error", func(t *testing.T) { + t.Parallel() + dc := bsoncodec.DecodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()} var val RawValue var s string @@ -100,6 +125,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Returns DecodeValue error", func(t *testing.T) { + t.Parallel() + dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()} val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string @@ -110,6 +137,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Success", func(t *testing.T) { + t.Parallel() + dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()} want := float64(3.14159) val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, want)} @@ -121,4 +150,59 @@ func TestRawValue(t *testing.T) { } }) }) + + t.Run("IsZero", func(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + val RawValue + want bool + }{ + { + name: "empty", + val: RawValue{}, + want: true, + }, + { + name: "zero type but non-zero value", + val: RawValue{ + Type: 0x00, + Value: bsoncore.AppendInt32(nil, 0), + }, + want: false, + }, + { + name: "zero type and zero value", + val: RawValue{ + Type: 0x00, + Value: bsoncore.AppendInt32(nil, 0), + }, + }, + { + name: "non-zero type and non-zero value", + val: RawValue{ + Type: bsontype.String, + Value: bsoncore.AppendString(nil, "foobar"), + }, + want: false, + }, + { + name: "non-zero type and zero value", + val: RawValue{ + Type: bsontype.String, + Value: bsoncore.AppendString(nil, "foobar"), + }, + }, + } + + for _, tt := range tests { + tt := tt // Capture the range variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.want, tt.val.IsZero()) + }) + } + }) }