Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clearer errors when decoding to invalid types #332

Merged
merged 1 commit into from
Nov 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 40 additions & 19 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,34 @@ func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}

var (
unmarshalToml = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
unmarshalText = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
)

// Decode TOML data in to the pointer `v`.
func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return MetaData{}, e("Decode of non-pointer %s", reflect.TypeOf(v))
s := "%q"
if reflect.TypeOf(v) == nil {
s = "%v"
}

return MetaData{}, e("cannot decode to non-pointer "+s, reflect.TypeOf(v))
}
if rv.IsNil() {
return MetaData{}, e("Decode of nil %s", reflect.TypeOf(v))
return MetaData{}, e("cannot decode to nil value of %q", reflect.TypeOf(v))
}

// Check if this is a supported type: struct, map, interface{}, or something
// that implements UnmarshalTOML or UnmarshalText.
rv = indirect(rv)
rt := rv.Type()
if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map &&
!(rv.Kind() == reflect.Interface && rv.NumMethod() == 0) &&
!rt.Implements(unmarshalToml) && !rt.Implements(unmarshalText) {
return MetaData{}, e("cannot decode to type %s", rt)
}

// TODO: parser should read from io.Reader? Or at the very least, make it
Expand All @@ -135,7 +155,7 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
decoded: make(map[string]struct{}, len(p.ordered)),
context: nil,
}
return md, md.unify(p.mapping, indirect(rv))
return md, md.unify(p.mapping, rv)
}

// Decode the TOML data in to the pointer v.
Expand Down Expand Up @@ -291,7 +311,7 @@ func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
if tmap == nil {
return nil
}
return badtype("map", mapping)
return md.badtype("map", mapping)
}
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
Expand Down Expand Up @@ -319,7 +339,7 @@ func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
if !datav.IsValid() {
return nil
}
return badtype("slice", data)
return md.badtype("slice", data)
}
if l := datav.Len(); l != rv.Len() {
return e("expected array length %d; got TOML array of length %d", rv.Len(), l)
Expand All @@ -333,7 +353,7 @@ func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
if !datav.IsValid() {
return nil
}
return badtype("slice", data)
return md.badtype("slice", data)
}
n := datav.Len()
if rv.IsNil() || rv.Cap() < n {
Expand All @@ -359,7 +379,7 @@ func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
rv.SetString(s)
return nil
}
return badtype("string", data)
return md.badtype("string", data)
}

func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
Expand Down Expand Up @@ -396,7 +416,7 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
return nil
}

return badtype("float", data)
return md.badtype("float", data)
}

func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
Expand Down Expand Up @@ -443,15 +463,15 @@ func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
}
return nil
}
return badtype("integer", data)
return md.badtype("integer", data)
}

func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
if b, ok := data.(bool); ok {
rv.SetBool(b)
return nil
}
return badtype("boolean", data)
return md.badtype("boolean", data)
}

func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
Expand Down Expand Up @@ -485,25 +505,30 @@ func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) erro
case float64:
s = fmt.Sprintf("%f", sdata)
default:
return badtype("primitive (string-like)", data)
return md.badtype("primitive (string-like)", data)
}
if err := v.UnmarshalText([]byte(s)); err != nil {
return err
}
return nil
}

func (md *MetaData) badtype(dst string, data interface{}) error {
return e("incompatible types: TOML key %q has type %T; destination has type %s", md.context, data, dst)
}

// rvalue returns a reflect.Value of `v`. All pointers are resolved.
func rvalue(v interface{}) reflect.Value {
return indirect(reflect.ValueOf(v))
}

// indirect returns the value pointed to by a pointer.
// Pointers are followed until the value is not a pointer.
// New values are allocated for each nil pointer.
//
// An exception to this rule is if the value satisfies an interface of
// interest to us (like encoding.TextUnmarshaler).
// Pointers are followed until the value is not a pointer. New values are
// allocated for each nil pointer.
//
// An exception to this rule is if the value satisfies an interface of interest
// to us (like encoding.TextUnmarshaler).
func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
if v.CanSet() {
Expand Down Expand Up @@ -533,7 +558,3 @@ func isUnifiable(rv reflect.Value) bool {
func e(format string, args ...interface{}) error {
return fmt.Errorf("toml: "+format, args...)
}

func badtype(expected string, data interface{}) error {
return e("cannot load TOML value of type %T into a Go %s", data, expected)
}
25 changes: 20 additions & 5 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,23 +346,38 @@ func TestDecodeSizedInts(t *testing.T) {
}
}

type NopUnmarshalTOML int

func (NopUnmarshalTOML) UnmarshalTOML(p interface{}) error { return nil }

func TestDecodeTypes(t *testing.T) {
type mystr string

for _, tt := range []struct {
v interface{}
want string
}{
{new(map[string]int64), ""},
{new(map[mystr]int64), ""},
{new(map[string]bool), ""},
{new(map[mystr]bool), ""},
{new(NopUnmarshalTOML), ""},

{3, `toml: cannot decode to non-pointer "int"`},
{map[string]interface{}{}, `toml: cannot decode to non-pointer "map[string]interface {}"`},

{(*int)(nil), `toml: cannot decode to nil value of "*int"`},
{(*Unmarshaler)(nil), `toml: cannot decode to nil value of "*toml.Unmarshaler"`},
{nil, `toml: cannot decode to non-pointer <nil>`},

{3, "non-pointer int"},
{(*int)(nil), "nil"},
{new(map[int]string), "cannot decode to a map with non-string key type"},
{new(map[interface{}]string), "cannot decode to a map with non-string key type"},

{new(struct{ F int }), `toml: incompatible types: TOML key "F" has type bool; destination has type integer`},
{new(map[string]int), `toml: incompatible types: TOML key "F" has type bool; destination has type integer`},
{new(int), `toml: cannot decode to type int`},
{new([]int), "toml: cannot decode to type []int"},
} {
t.Run(fmt.Sprintf("%T", tt.v), func(t *testing.T) {
_, err := Decode(`x = 3`, tt.v)
_, err := Decode(`F = true`, tt.v)
if !errorContains(err, tt.want) {
t.Errorf("wrong error\nhave: %q\nwant: %q", err, tt.want)
}
Expand Down