Skip to content

Commit

Permalink
GODRIVER-2896 Add IsZero to BSON RawValue (#1332)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com>
  • Loading branch information
prestonvasquez and matthewdale authored Jul 28, 2023
1 parent 6c59389 commit 215d683
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 3 deletions.
11 changes: 11 additions & 0 deletions bson/bsontype/bsontype.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,14 @@ func (bt Type) String() string {
return "invalid"
}
}

// IsValid will return true if the Type is valid.
func (bt Type) IsValid() bool {
switch bt {
case Double, String, EmbeddedDocument, Array, Binary, Undefined, ObjectID, Boolean, DateTime, Null, Regex,
DBPointer, JavaScript, Symbol, CodeWithScope, Int32, Timestamp, Int64, Decimal128, MinKey, MaxKey:
return true
default:
return false
}
}
18 changes: 15 additions & 3 deletions bson/primitive_codecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bson

import (
"errors"
"fmt"
"reflect"

"go.mongodb.org/mongo-driver/bson/bsoncodec"
Expand Down Expand Up @@ -45,15 +46,26 @@ func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder)

// RawValueEncodeValue is the ValueEncoderFunc for RawValue.
//
// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders
// registered.
// If the RawValue's Type is "invalid" and the RawValue's Value is not empty or
// nil, then this method will return an error.
//
// Deprecated: Use bson.NewRegistry to get a registry with all primitive
// encoders and decoders registered.
func (PrimitiveCodecs) RawValueEncodeValue(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRawValue {
return bsoncodec.ValueEncoderError{Name: "RawValueEncodeValue", Types: []reflect.Type{tRawValue}, Received: val}
return bsoncodec.ValueEncoderError{
Name: "RawValueEncodeValue",
Types: []reflect.Type{tRawValue},
Received: val,
}
}

rawvalue := val.Interface().(RawValue)

if !rawvalue.Type.IsValid() {
return fmt.Errorf("the RawValue Type specifies an invalid BSON type: %#x", byte(rawvalue.Type))
}

return bsonrw.Copier{}.CopyValueFromBytes(vw, rawvalue.Type, rawvalue.Value)
}

Expand Down
38 changes: 38 additions & 0 deletions bson/primitive_codecs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ func compareErrors(err1, err2 error) bool {
}

func TestDefaultValueEncoders(t *testing.T) {
t.Parallel()

var pc PrimitiveCodecs

var wrong = func(string, string) string { return "wrong" }
Expand Down Expand Up @@ -107,6 +109,28 @@ func TestDefaultValueEncoders(t *testing.T) {
bsonrwtest.WriteDouble,
nil,
},
{
"RawValue Type is zero with non-zero value",
RawValue{
Type: 0x00,
Value: bsoncore.AppendDouble(nil, 3.14159),
},
nil,
nil,
bsonrwtest.Nothing,
fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x0"),
},
{
"RawValue Type is invalid",
RawValue{
Type: 0x8F,
Value: bsoncore.AppendDouble(nil, 3.14159),
},
nil,
nil,
bsonrwtest.Nothing,
fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x8f"),
},
},
},
{
Expand Down Expand Up @@ -166,9 +190,17 @@ func TestDefaultValueEncoders(t *testing.T) {
}

for _, tc := range testCases {
tc := tc // Capture the range variable

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

for _, subtest := range tc.subtests {
subtest := subtest // Capture the range variable

t.Run(subtest.name, func(t *testing.T) {
t.Parallel()

var ec bsoncodec.EncodeContext
if subtest.ectx != nil {
ec = *subtest.ectx
Expand All @@ -192,6 +224,8 @@ func TestDefaultValueEncoders(t *testing.T) {
}

t.Run("success path", func(t *testing.T) {
t.Parallel()

oid := primitive.NewObjectID()
oids := []primitive.ObjectID{primitive.NewObjectID(), primitive.NewObjectID(), primitive.NewObjectID()}
var str = new(string)
Expand Down Expand Up @@ -426,7 +460,11 @@ func TestDefaultValueEncoders(t *testing.T) {
}

for _, tc := range testCases {
tc := tc // Capture the range variable

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

b := make(bsonrw.SliceWriter, 0, 512)
vw, err := bsonrw.NewBSONValueWriter(&b)
noerr(t, err)
Expand Down
6 changes: 6 additions & 0 deletions bson/raw_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ type RawValue struct {
r *bsoncodec.Registry
}

// IsZero reports whether the RawValue is zero, i.e. no data is present on
// the RawValue. It returns true if Type is 0 and Value is empty or nil.
func (rv RawValue) IsZero() bool {
return rv.Type == 0x00 && len(rv.Value) == 0
}

// Unmarshal deserializes BSON into the provided val. If RawValue cannot be unmarshaled into val, an
// error is returned. This method will use the registry used to create the RawValue, if the RawValue
// was created from partial BSON processing, or it will use the default registry. Users wishing to
Expand Down
84 changes: 84 additions & 0 deletions bson/raw_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@ import (

"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)

func TestRawValue(t *testing.T) {
t.Parallel()

t.Run("Unmarshal", func(t *testing.T) {
t.Parallel()

t.Run("Uses registry attached to value", func(t *testing.T) {
t.Parallel()

reg := bsoncodec.NewRegistryBuilder().Build()
val := RawValue{Type: bsontype.String, Value: bsoncore.AppendString(nil, "foobar"), r: reg}
var s string
Expand All @@ -29,6 +36,8 @@ func TestRawValue(t *testing.T) {
}
})
t.Run("Uses default registry if no registry attached", func(t *testing.T) {
t.Parallel()

want := "foobar"
val := RawValue{Type: bsontype.String, Value: bsoncore.AppendString(nil, want)}
var got string
Expand All @@ -40,7 +49,11 @@ func TestRawValue(t *testing.T) {
})
})
t.Run("UnmarshalWithRegistry", func(t *testing.T) {
t.Parallel()

t.Run("Returns error when registry is nil", func(t *testing.T) {
t.Parallel()

want := ErrNilRegistry
var val RawValue
got := val.UnmarshalWithRegistry(nil, &D{})
Expand All @@ -49,6 +62,8 @@ func TestRawValue(t *testing.T) {
}
})
t.Run("Returns lookup error", func(t *testing.T) {
t.Parallel()

reg := bsoncodec.NewRegistryBuilder().Build()
var val RawValue
var s string
Expand All @@ -59,6 +74,8 @@ func TestRawValue(t *testing.T) {
}
})
t.Run("Returns DecodeValue error", func(t *testing.T) {
t.Parallel()

reg := NewRegistryBuilder().Build()
val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, 3.14159)}
var s string
Expand All @@ -69,6 +86,8 @@ func TestRawValue(t *testing.T) {
}
})
t.Run("Success", func(t *testing.T) {
t.Parallel()

reg := NewRegistryBuilder().Build()
want := float64(3.14159)
val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, want)}
Expand All @@ -81,7 +100,11 @@ func TestRawValue(t *testing.T) {
})
})
t.Run("UnmarshalWithContext", func(t *testing.T) {
t.Parallel()

t.Run("Returns error when DecodeContext is nil", func(t *testing.T) {
t.Parallel()

want := ErrNilContext
var val RawValue
got := val.UnmarshalWithContext(nil, &D{})
Expand All @@ -90,6 +113,8 @@ func TestRawValue(t *testing.T) {
}
})
t.Run("Returns lookup error", func(t *testing.T) {
t.Parallel()

dc := bsoncodec.DecodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()}
var val RawValue
var s string
Expand All @@ -100,6 +125,8 @@ func TestRawValue(t *testing.T) {
}
})
t.Run("Returns DecodeValue error", func(t *testing.T) {
t.Parallel()

dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()}
val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, 3.14159)}
var s string
Expand All @@ -110,6 +137,8 @@ func TestRawValue(t *testing.T) {
}
})
t.Run("Success", func(t *testing.T) {
t.Parallel()

dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()}
want := float64(3.14159)
val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, want)}
Expand All @@ -121,4 +150,59 @@ func TestRawValue(t *testing.T) {
}
})
})

t.Run("IsZero", func(t *testing.T) {
t.Parallel()

tests := []struct {
name string
val RawValue
want bool
}{
{
name: "empty",
val: RawValue{},
want: true,
},
{
name: "zero type but non-zero value",
val: RawValue{
Type: 0x00,
Value: bsoncore.AppendInt32(nil, 0),
},
want: false,
},
{
name: "zero type and zero value",
val: RawValue{
Type: 0x00,
Value: bsoncore.AppendInt32(nil, 0),
},
},
{
name: "non-zero type and non-zero value",
val: RawValue{
Type: bsontype.String,
Value: bsoncore.AppendString(nil, "foobar"),
},
want: false,
},
{
name: "non-zero type and zero value",
val: RawValue{
Type: bsontype.String,
Value: bsoncore.AppendString(nil, "foobar"),
},
},
}

for _, tt := range tests {
tt := tt // Capture the range variable
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

assert.Equal(t, tt.want, tt.val.IsZero())
})
}
})
}

0 comments on commit 215d683

Please sign in to comment.