diff --git a/decode.go b/decode.go index 5f34af4..5acc6d8 100644 --- a/decode.go +++ b/decode.go @@ -677,7 +677,6 @@ func (d *decodeState) object(v reflect.Value) error { return nil } - var mapElem reflect.Value origErrorContext := d.errorContext for { @@ -701,17 +700,66 @@ func (d *decodeState) object(v reflect.Value) error { } // Figure out field corresponding to key. - var subv reflect.Value + var kv, subv reflect.Value destring := false // whether the value is wrapped in a string to be decoded first if v.Kind() == reflect.Map { - elemType := t.Elem() - if !mapElem.IsValid() { - mapElem = reflect.New(elemType).Elem() - } else { - mapElem.Set(reflect.Zero(elemType)) + // First, figure out the key value from the input. + kt := t.Key() + switch { + case reflect.PtrTo(kt).Implements(textUnmarshalerType): + kv = reflect.New(kt) + if err := d.literalStore(item, kv, true); err != nil { + return err + } + kv = kv.Elem() + case kt.Kind() == reflect.String: + kv = reflect.ValueOf(key).Convert(kt) + default: + switch kt.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := string(key) + n, err := strconv.ParseInt(s, 10, 64) + if err != nil || reflect.Zero(kt).OverflowInt(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.ValueOf(n).Convert(kt) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + s := string(key) + n, err := strconv.ParseUint(s, 10, 64) + if err != nil || reflect.Zero(kt).OverflowUint(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.ValueOf(n).Convert(kt) + default: + panic("json: Unexpected key type") // should never occur + } + } + + // Then, decide what element value we'll decode into. + et := t.Elem() + if kv.IsValid() { + if existing := v.MapIndex(kv); !existing.IsValid() { + // Nothing to reuse. + } else if et.Kind() == reflect.Ptr { + // Pointer; decode directly into it if non-nil. + if !existing.IsNil() { + subv = existing + } + } else { + // Non-pointer. Make a copy and decode into the + // addressable copy. Don't just use a new/zero + // value, as that would lose existing data. + subv = reflect.New(et).Elem() + subv.Set(existing) + } + } + if !subv.IsValid() { + // We couldn't reuse an existing value. + subv = reflect.New(et).Elem() } - subv = mapElem } else { var f *field if i, ok := fields.nameIndex[string(key)]; ok { @@ -790,43 +838,8 @@ func (d *decodeState) object(v reflect.Value) error { // Write value back to map; // if using struct, subv points into struct already. - if v.Kind() == reflect.Map { - kt := t.Key() - var kv reflect.Value - switch { - case reflect.PtrTo(kt).Implements(textUnmarshalerType): - kv = reflect.New(kt) - if err := d.literalStore(item, kv, true); err != nil { - return err - } - kv = kv.Elem() - case kt.Kind() == reflect.String: - kv = reflect.ValueOf(key).Convert(kt) - default: - switch kt.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - s := string(key) - n, err := strconv.ParseInt(s, 10, 64) - if err != nil || reflect.Zero(kt).OverflowInt(n) { - d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) - break - } - kv = reflect.ValueOf(n).Convert(kt) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - s := string(key) - n, err := strconv.ParseUint(s, 10, 64) - if err != nil || reflect.Zero(kt).OverflowUint(n) { - d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) - break - } - kv = reflect.ValueOf(n).Convert(kt) - default: - panic("json: Unexpected key type") // should never occur - } - } - if kv.IsValid() { - v.SetMapIndex(kv, subv) - } + if v.Kind() == reflect.Map && kv.IsValid() { + v.SetMapIndex(kv, subv) } // Next token must be , or }. diff --git a/decode_test.go b/decode_test.go index 5ac1022..a62488d 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2569,3 +2569,57 @@ func TestUnmarshalMaxDepth(t *testing.T) { } } } + +func TestUnmarshalMapPointerElem(t *testing.T) { + type S struct{ Unchanged, Changed int } + input := []byte(`{"S":{"Changed":5}}`) + want := S{1, 5} + + // First, a map with struct pointer elements. The key-value pair exists, + // so reuse the existing value. + s := &S{1, 2} + ptrMap := map[string]*S{"S": s} + if err := Unmarshal(input, &ptrMap); err != nil { + t.Fatal(err) + } + if s != ptrMap["S"] { + t.Fatal("struct pointer element in map was completely replaced") + } + if got := *s; got != want { + t.Fatalf("want %#v, got %#v", want, got) + } + + // Second, a map with struct elements. The key-value pair exists, but + // the value isn't addresable, so make a copy and use that. + s = &S{1, 2} + strMap := map[string]S{"S": *s} + if err := Unmarshal(input, &strMap); err != nil { + t.Fatal(err) + } + if *s == strMap["S"] { + t.Fatal("struct element in map wasn't copied") + } + if got := strMap["S"]; got != want { + t.Fatalf("want %#v, got %#v", want, got) + } + + // Finally, check the cases where the key-value pair exists, but the + // value is zero. + want = S{0, 5} + + ptrMap = map[string]*S{"S": nil} + if err := Unmarshal(input, &ptrMap); err != nil { + t.Fatal(err) + } + if got := *ptrMap["S"]; got != want { + t.Fatalf("want %#v, got %#v", want, got) + } + + strMap = map[string]S{"S": {}} + if err := Unmarshal(input, &strMap); err != nil { + t.Fatal(err) + } + if got := strMap["S"]; got != want { + t.Fatalf("want %#v, got %#v", want, got) + } +}