Skip to content

Commit

Permalink
Add support for keys implement TextMarshaler & TextUnmarshaler from M…
Browse files Browse the repository at this point in the history
…apCodec (mongodb#946)
  • Loading branch information
aaronjheng authored May 18, 2022
1 parent 0cdb185 commit a80ae1d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
44 changes: 44 additions & 0 deletions bson/bson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"bytes"
"fmt"
"reflect"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -140,6 +141,29 @@ func (kb *keyBool) UnmarshalKey(key string) error {
return nil
}

type keyStruct struct {
val int64
}

func (k keyStruct) MarshalText() (text []byte, err error) {
str := strconv.FormatInt(k.val, 10)

return []byte(str), nil
}

func (k *keyStruct) UnmarshalText(text []byte) error {
val, err := strconv.ParseInt(string(text), 10, 64)
if err != nil {
return err
}

*k = keyStruct{
val: val,
}

return nil
}

func TestMapCodec(t *testing.T) {
t.Run("EncodeKeysWithStringer", func(t *testing.T) {
strstr := stringerString("foo")
Expand All @@ -163,6 +187,7 @@ func TestMapCodec(t *testing.T) {
})
}
})

t.Run("keys implements keyMarshaler and keyUnmarshaler", func(t *testing.T) {
mapObj := map[keyBool]int{keyBool(true): 1}

Expand All @@ -179,6 +204,25 @@ func TestMapCodec(t *testing.T) {
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)

})

t.Run("keys implements encoding.TextMarshaler and encoding.TextUnmarshaler", func(t *testing.T) {
mapObj := map[keyStruct]int{
{val: 10}: 100,
}

doc, err := Marshal(mapObj)
assert.Nil(t, err, "Marshal error: %v", err)
idx, want := bsoncore.AppendDocumentStart(nil)
want = bsoncore.AppendInt32Element(want, "10", 100)
want, _ = bsoncore.AppendDocumentEnd(want, idx)
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))

var got map[keyStruct]int
err = Unmarshal(doc, &got)
assert.Nil(t, err, "Unmarshal error: %v", err)
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)

})
}

func TestExtJSONEscapeKey(t *testing.T) {
Expand Down
21 changes: 21 additions & 0 deletions bson/bsoncodec/map_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package bsoncodec

import (
"encoding"
"fmt"
"reflect"
"strconv"
Expand Down Expand Up @@ -230,6 +231,19 @@ func (mc *MapCodec) encodeKey(val reflect.Value) (string, error) {
}
return "", err
}
// keys implement encoding.TextMarshaler are marshaled.
if km, ok := val.Interface().(encoding.TextMarshaler); ok {
if val.Kind() == reflect.Ptr && val.IsNil() {
return "", nil
}

buf, err := km.MarshalText()
if err != nil {
return "", err
}

return string(buf), nil
}

switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
Expand All @@ -241,6 +255,7 @@ func (mc *MapCodec) encodeKey(val reflect.Value) (string, error) {
}

var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()

func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
keyVal := reflect.ValueOf(key)
Expand All @@ -252,6 +267,12 @@ func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value,
v := keyVal.Interface().(KeyUnmarshaler)
err = v.UnmarshalKey(key)
keyVal = keyVal.Elem()
// Try to decode encoding.TextUnmarshalers.
case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
keyVal = reflect.New(keyType)
v := keyVal.Interface().(encoding.TextUnmarshaler)
err = v.UnmarshalText([]byte(key))
keyVal = keyVal.Elem()
// Otherwise, go to type specific behavior
default:
switch keyType.Kind() {
Expand Down

0 comments on commit a80ae1d

Please sign in to comment.