Skip to content

Commit 98af5b4

Browse files
iwysiuDivjot Arora
andauthored
GODRIVER-1923 Error if BSON cstrings contain null bytes (#622) (#684)
Co-authored-by: Divjot Arora <divjot.arora@10gen.com>
1 parent 575a442 commit 98af5b4

File tree

5 files changed

+118
-6
lines changed

5 files changed

+118
-6
lines changed

bson/bsonrw/value_writer.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"io"
1313
"math"
1414
"strconv"
15+
"strings"
1516
"sync"
1617

1718
"go.mongodb.org/mongo-driver/bson/bsontype"
@@ -247,7 +248,12 @@ func (vw *valueWriter) invalidTransitionError(destination mode, name string, mod
247248
func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error {
248249
switch vw.stack[vw.frame].mode {
249250
case mElement:
250-
vw.buf = bsoncore.AppendHeader(vw.buf, t, vw.stack[vw.frame].key)
251+
key := vw.stack[vw.frame].key
252+
if !isValidCString(key) {
253+
return errors.New("BSON element key cannot contain null bytes")
254+
}
255+
256+
vw.buf = bsoncore.AppendHeader(vw.buf, t, key)
251257
case mValue:
252258
// TODO: Do this with a cache of the first 1000 or so array keys.
253259
vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey))
@@ -430,6 +436,9 @@ func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error {
430436
}
431437

432438
func (vw *valueWriter) WriteRegex(pattern string, options string) error {
439+
if !isValidCString(pattern) || !isValidCString(options) {
440+
return errors.New("BSON regex values cannot contain null bytes")
441+
}
433442
if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil {
434443
return err
435444
}
@@ -602,3 +611,7 @@ func (vw *valueWriter) writeLength() error {
602611
vw.buf[start+3] = byte(length >> 24)
603612
return nil
604613
}
614+
615+
func isValidCString(cs string) bool {
616+
return !strings.ContainsRune(cs, '\x00')
617+
}

bson/marshal_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package bson
88

99
import (
1010
"bytes"
11+
"errors"
1112
"testing"
1213

1314
"github.com/google/go-cmp/cmp"
@@ -207,3 +208,35 @@ func TestMarshal_roundtripFromDoc(t *testing.T) {
207208
t.Errorf("Documents to not match. got %v; want %v", after, before)
208209
}
209210
}
211+
212+
func TestNullBytes(t *testing.T) {
213+
t.Run("element keys", func(t *testing.T) {
214+
doc := D{{"a\x00", "foobar"}}
215+
res, err := Marshal(doc)
216+
want := errors.New("BSON element key cannot contain null bytes")
217+
require.Equal(t, want, err, "expected Marshal error %v, got error %v with result %q", want, err, Raw(res))
218+
})
219+
220+
t.Run("regex values", func(t *testing.T) {
221+
wantErr := errors.New("BSON regex values cannot contain null bytes")
222+
223+
testCases := []struct {
224+
name string
225+
pattern string
226+
options string
227+
}{
228+
{"null bytes in pattern", "a\x00", "i"},
229+
{"null bytes in options", "pattern", "i\x00"},
230+
}
231+
for _, tc := range testCases {
232+
t.Run(tc.name, func(t *testing.T) {
233+
regex := primitive.Regex{
234+
Pattern: tc.pattern,
235+
Options: tc.options,
236+
}
237+
res, err := Marshal(D{{"foo", regex}})
238+
require.Equal(t, wantErr, err, "expected Marshal error %v, got error %v with result %q", wantErr, err, Raw(res))
239+
})
240+
}
241+
})
242+
}

mongo/mongo_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ func TestMongoHelpers(t *testing.T) {
116116
got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc)
117117
assert.Nil(t, err, "transformAndEnsureID error: %v", err)
118118
_, ok := id.(string)
119-
assert.True(t, ok, "expected returned id type %T, got %T", string(0), id)
119+
assert.True(t, ok, "expected returned id type string, got %T", id)
120120
assert.Equal(t, got, want, "expected document %v, got %v", got, want)
121121
})
122122
})

x/bsonx/bsoncore/bsoncore.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,35 @@ import (
3030
"fmt"
3131
"math"
3232
"strconv"
33+
"strings"
3334
"time"
3435

3536
"go.mongodb.org/mongo-driver/bson/bsontype"
3637
"go.mongodb.org/mongo-driver/bson/primitive"
3738
)
3839

39-
// EmptyDocumentLength is the length of a document that has been started/ended but has no elements.
40-
const EmptyDocumentLength = 5
40+
const (
41+
// EmptyDocumentLength is the length of a document that has been started/ended but has no elements.
42+
EmptyDocumentLength = 5
43+
// nullTerminator is a string version of the 0 byte that is appended at the end of cstrings.
44+
nullTerminator = string(byte(0))
45+
invalidKeyPanicMsg = "BSON element keys cannot contain null bytes"
46+
invalidRegexPanicMsg = "BSON regex values cannot contain null bytes"
47+
)
4148

4249
// AppendType will append t to dst and return the extended buffer.
4350
func AppendType(dst []byte, t bsontype.Type) []byte { return append(dst, byte(t)) }
4451

4552
// AppendKey will append key to dst and return the extended buffer.
46-
func AppendKey(dst []byte, key string) []byte { return append(dst, key+string(0x00)...) }
53+
func AppendKey(dst []byte, key string) []byte { return append(dst, key+nullTerminator...) }
4754

4855
// AppendHeader will append Type t and key to dst and return the extended
4956
// buffer.
5057
func AppendHeader(dst []byte, t bsontype.Type, key string) []byte {
58+
if !isValidCString(key) {
59+
panic(invalidKeyPanicMsg)
60+
}
61+
5162
dst = AppendType(dst, t)
5263
dst = append(dst, key...)
5364
return append(dst, 0x00)
@@ -427,7 +438,11 @@ func AppendNullElement(dst []byte, key string) []byte { return AppendHeader(dst,
427438

428439
// AppendRegex will append pattern and options to dst and return the extended buffer.
429440
func AppendRegex(dst []byte, pattern, options string) []byte {
430-
return append(dst, pattern+string(0x00)+options+string(0x00)...)
441+
if !isValidCString(pattern) || !isValidCString(options) {
442+
panic(invalidRegexPanicMsg)
443+
}
444+
445+
return append(dst, pattern+nullTerminator+options+nullTerminator...)
431446
}
432447

433448
// AppendRegexElement will append a BSON regex element using key, pattern, and
@@ -841,3 +856,7 @@ func appendBinarySubtype2(dst []byte, subtype byte, b []byte) []byte {
841856
dst = appendLength(dst, int32(len(b)))
842857
return append(dst, b...)
843858
}
859+
860+
func isValidCString(cs string) bool {
861+
return !strings.ContainsRune(cs, '\x00')
862+
}

x/bsonx/bsoncore/bsoncore_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/google/go-cmp/cmp"
1717
"go.mongodb.org/mongo-driver/bson/bsontype"
1818
"go.mongodb.org/mongo-driver/bson/primitive"
19+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1920
)
2021

2122
func noerr(t *testing.T, err error) {
@@ -899,6 +900,52 @@ func TestBuild(t *testing.T) {
899900
}
900901
}
901902

903+
func TestNullBytes(t *testing.T) {
904+
// Helper function to execute the provided callback and assert that it panics with the expected message. The
905+
// createBSONFn callback should create a BSON document/array/value and return the stringified version.
906+
assertBSONCreationPanics := func(t *testing.T, createBSONFn func(), expected string) {
907+
t.Helper()
908+
909+
defer func() {
910+
got := recover()
911+
assert.Equal(t, expected, got, "expected panic with error %v, got error %v", expected, got)
912+
}()
913+
createBSONFn()
914+
}
915+
916+
t.Run("element keys", func(t *testing.T) {
917+
createDocFn := func() {
918+
BuildDocument(nil, AppendStringElement(nil, "a\x00", "foo"))
919+
}
920+
921+
assertBSONCreationPanics(t, createDocFn, invalidKeyPanicMsg)
922+
})
923+
t.Run("regex values", func(t *testing.T) {
924+
testCases := []struct {
925+
name string
926+
pattern string
927+
options string
928+
}{
929+
{"null bytes in pattern", "a\x00", "i"},
930+
{"null bytes in options", "pattern", "i\x00"},
931+
}
932+
for _, tc := range testCases {
933+
t.Run(tc.name+"-AppendRegexElement", func(t *testing.T) {
934+
createDocFn := func() {
935+
AppendRegexElement(nil, "foo", tc.pattern, tc.options)
936+
}
937+
assertBSONCreationPanics(t, createDocFn, invalidRegexPanicMsg)
938+
})
939+
t.Run(tc.name+"-AppendRegex", func(t *testing.T) {
940+
createValFn := func() {
941+
AppendRegex(nil, tc.pattern, tc.options)
942+
}
943+
assertBSONCreationPanics(t, createValFn, invalidRegexPanicMsg)
944+
})
945+
}
946+
})
947+
}
948+
902949
func compareDecimal128(d1, d2 primitive.Decimal128) bool {
903950
d1H, d1L := d1.GetBytes()
904951
d2H, d2L := d2.GetBytes()

0 commit comments

Comments
 (0)