Skip to content

Commit

Permalink
fix(spanner): json null handling (#10660)
Browse files Browse the repository at this point in the history
* perf(spanner): compare faster with json null value

By using a string comparison the compiler can significantly optimize
the check.

* fix(spanner): avoid reusing []byte("null")

The caller can change the result from these calls and this may
cause weird bugs down the line. Instead allocate a new slice.

---------

Co-authored-by: rahul2393 <irahul@google.com>
  • Loading branch information
egonelbre and rahul2393 authored Aug 20, 2024
1 parent 95ae207 commit 4c519e3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
39 changes: 21 additions & 18 deletions spanner/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ var (
CommitTimestamp = commitTimestamp
commitTimestamp = time.Unix(0, 0).In(time.FixedZone("CommitTimestamp placeholder", 0xDB))

jsonNullBytes = []byte("null")

jsonUseNumber bool

protoMsgReflectType = reflect.TypeOf((*proto.Message)(nil)).Elem()
Expand All @@ -140,6 +138,11 @@ func jsonUnmarshal(data []byte, v any) error {
return dec.Decode(v)
}

// jsonIsNull returns whether v matches JSON null literal
func jsonIsNull(v []byte) bool {
return string(v) == "null"
}

// Encoder is the interface implemented by a custom type that can be encoded to
// a supported type by Spanner. A code example:
//
Expand Down Expand Up @@ -220,7 +223,7 @@ func (n *NullInt64) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.Int64 = int64(0)
n.Valid = false
return nil
Expand Down Expand Up @@ -300,7 +303,7 @@ func (n *NullString) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.StringVal = ""
n.Valid = false
return nil
Expand Down Expand Up @@ -385,7 +388,7 @@ func (n *NullFloat64) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.Float64 = float64(0)
n.Valid = false
return nil
Expand Down Expand Up @@ -465,7 +468,7 @@ func (n *NullFloat32) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.Float32 = float32(0)
n.Valid = false
return nil
Expand Down Expand Up @@ -545,7 +548,7 @@ func (n *NullBool) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.Bool = false
n.Valid = false
return nil
Expand Down Expand Up @@ -625,7 +628,7 @@ func (n *NullTime) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.Time = time.Time{}
n.Valid = false
return nil
Expand Down Expand Up @@ -710,7 +713,7 @@ func (n *NullDate) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.Date = civil.Date{}
n.Valid = false
return nil
Expand Down Expand Up @@ -795,7 +798,7 @@ func (n *NullNumeric) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.Numeric = big.Rat{}
n.Valid = false
return nil
Expand Down Expand Up @@ -892,7 +895,7 @@ func (n *NullJSON) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.Valid = false
return nil
}
Expand Down Expand Up @@ -940,7 +943,7 @@ func (n *PGNumeric) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.Numeric = ""
n.Valid = false
return nil
Expand Down Expand Up @@ -979,15 +982,15 @@ func (n NullProtoMessage) MarshalJSON() ([]byte, error) {
if n.Valid {
return json.Marshal(n.ProtoMessageVal)
}
return jsonNullBytes, nil
return []byte("null"), nil
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullProtoMessage.
func (n *NullProtoMessage) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.ProtoMessageVal = nil
n.Valid = false
return nil
Expand Down Expand Up @@ -1025,15 +1028,15 @@ func (n NullProtoEnum) MarshalJSON() ([]byte, error) {
if n.Valid && n.ProtoEnumVal != nil {
return []byte(fmt.Sprintf("%v", n.ProtoEnumVal.Number())), nil
}
return jsonNullBytes, nil
return []byte("null"), nil
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullProtoEnum.
func (n *NullProtoEnum) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.ProtoEnumVal = nil
n.Valid = false
return nil
Expand Down Expand Up @@ -1094,7 +1097,7 @@ func (n *PGJsonB) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
}
if bytes.Equal(payload, jsonNullBytes) {
if jsonIsNull(payload) {
n.Valid = false
return nil
}
Expand All @@ -1110,7 +1113,7 @@ func (n *PGJsonB) UnmarshalJSON(payload []byte) error {

func nulljson(valid bool, v interface{}) ([]byte, error) {
if !valid {
return jsonNullBytes, nil
return []byte("null"), nil
}
return json.Marshal(v)
}
Expand Down
9 changes: 9 additions & 0 deletions spanner/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3266,3 +3266,12 @@ func expectUnmarshalNullableTypes(t *testing.T, err error, v interface{}, isNull
t.Fatalf("Incorrect unmarshalling a json string to nullable types: got %q, want %q", v, expect)
}
}

func TestNullJson(t *testing.T) {
v, _ := nulljson(false, nil)
v[0] = 'X'
v, _ = nulljson(false, nil)
if string(v) != "null" {
t.Fatalf("expected null, got %s", v)
}
}

0 comments on commit 4c519e3

Please sign in to comment.