diff --git a/debug.go b/debug.go index 6b3f42d..3b8806b 100644 --- a/debug.go +++ b/debug.go @@ -12,7 +12,11 @@ import ( "text/tabwriter" ) -// TODO: delete me +// debugging tools for dynamo developers +// disabled by default, use `go build -tags debug` to enable +// warning: not covered by API stability guarantees + +// DumpType dumps a description of x's typedef to stdout. func DumpType(x any) { plan, err := typedefOf(reflect.TypeOf(x)) if err != nil { diff --git a/decodefunc.go b/decodefunc.go index 0a5b7b3..140fb63 100644 --- a/decodefunc.go +++ b/decodefunc.go @@ -182,7 +182,7 @@ func decodeMap(decodeKey func(reflect.Value, string) error) func(plan *typedef, for name, av := range item { vp := new(V) decodeKey(kp, name) - decodeAttr(av, vp) // TODO fix order + decodeAttr(av, vp) // TODO: make argument order consistent out[*kp] = *vp } */ diff --git a/encode.go b/encode.go index 7340163..227a69d 100644 --- a/encode.go +++ b/encode.go @@ -72,7 +72,7 @@ func marshalSliceNoOmit(values []interface{}) ([]*dynamodb.AttributeValue, error return avs, nil } -func encodeItem(fields []fieldMeta, rv reflect.Value) (Item, error) { +func encodeItem(fields []structField, rv reflect.Value) (Item, error) { item := make(Item, len(fields)) for _, field := range fields { fv := dig(rv, field.index) diff --git a/encodefunc.go b/encodefunc.go index f4c9c90..190ee87 100644 --- a/encodefunc.go +++ b/encodefunc.go @@ -106,6 +106,22 @@ func encodeType(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { return nil, fmt.Errorf("dynamo marshal: unsupported type %s", rt.String()) } +func encodePtr(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { + elem, err := encodeType(rt.Elem(), flags) + if err != nil { + return nil, err + } + return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { + if rv.IsNil() { + if flags&flagNull != 0 { + return nullAV, nil + } + return nil, nil + } + return elem(rv.Elem(), flags) + }, nil +} + func encode2[T any](fn func(T, encodeFlags) (*dynamodb.AttributeValue, error)) func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { target := reflect.TypeOf((*T)(nil)).Elem() interfacing := target.Kind() == reflect.Interface @@ -133,7 +149,7 @@ func encodeString(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue s := rv.String() if len(s) == 0 { if flags&flagAllowEmpty != 0 { - return &dynamodb.AttributeValue{S: &s}, nil + return emptyS, nil } if flags&flagNull != 0 { return nullAV, nil @@ -158,11 +174,25 @@ var encodeTextMarshaler = encode2[encoding.TextMarshaler](func(x encoding.TextMa return &dynamodb.AttributeValue{S: &str}, nil }) -func encodePtr(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { - elem, err := encodeType(rt.Elem(), flags) - if err != nil { - return nil, err +func encodeBytes(rt reflect.Type, flags encodeFlags) encodeFunc { + if rt.Kind() == reflect.Array { + size := rt.Len() + return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { + if rv.IsZero() { + switch { + case flags&flagNull != 0: + return nullAV, nil + case flags&flagAllowEmpty != 0: + return emptyB, nil + } + return nil, nil + } + data := make([]byte, size) + reflect.Copy(reflect.ValueOf(data), rv) + return &dynamodb.AttributeValue{B: data}, nil + } } + return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { if rv.IsNil() { if flags&flagNull != 0 { @@ -170,8 +200,14 @@ func encodePtr(rt reflect.Type, flags encodeFlags) (encodeFunc, error) { } return nil, nil } - return elem(rv.Elem(), flags) - }, nil + if rv.Len() == 0 { + if flags&flagAllowEmpty != 0 { + return emptyB, nil + } + return nil, nil + } + return &dynamodb.AttributeValue{B: rv.Bytes()}, nil + } } func encodeStruct(rt reflect.Type) (encodeFunc, error) { diff --git a/encoding.go b/encoding.go index c9e9d11..2f7877e 100644 --- a/encoding.go +++ b/encoding.go @@ -11,38 +11,22 @@ import ( "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" ) -var planCache sync.Map // unmarshalKey → *decodePlan - -type unmarshalKey struct { - gotype reflect.Type - shape shapeKey -} - -func (key unmarshalKey) GoString() string { - return fmt.Sprintf("%s:%v", key.shape.GoString(), key.gotype.String()) -} - -func (key unmarshalKey) Less(other unmarshalKey) bool { - if key.gotype == other.gotype { - return key.shape < other.shape - } - return key.gotype.String() < other.gotype.String() -} +var typeCache sync.Map // unmarshalKey → *typedef type typedef struct { decoders map[unmarshalKey]decodeFunc - fields []fieldMeta + fields []structField } -type fieldMeta struct { - index []int - name string - flags encodeFlags - enc encodeFunc - isZero func(reflect.Value) bool +func newTypedef(rt reflect.Type) (*typedef, error) { + def := &typedef{ + decoders: make(map[unmarshalKey]decodeFunc), + } + err := def.init(rt) + return def, err } -func (def *typedef) analyze(rt reflect.Type) { +func (def *typedef) init(rt reflect.Type) error { for rt.Kind() == reflect.Pointer { rt = rt.Elem() } @@ -50,82 +34,46 @@ func (def *typedef) analyze(rt reflect.Type) { def.learn(rt) if rt.Kind() != reflect.Struct { - return + return nil } var err error def.fields, err = structFields(rt) - if err != nil { - panic(err) // TODO - } + return err } -func structFields(rt reflect.Type) ([]fieldMeta, error) { - var fields []fieldMeta - err := visitTypeFields(rt, nil, nil, func(name string, index []int, flags encodeFlags, vt reflect.Type) error { - enc, err := encodeType(vt, flags) - if err != nil { - return err - } - field := fieldMeta{ - index: index, - name: name, - flags: flags, - enc: enc, - isZero: isZeroFunc(vt), - } - // if flags&flagOmitEmpty != 0 { - // field.isZero = isZeroFunc(rt) - // } - fields = append(fields, field) - return nil - }) - return fields, err -} - -func newTypedef(rt reflect.Type) (*typedef, error) { - plan := &typedef{ - decoders: make(map[unmarshalKey]decodeFunc), - } - - plan.analyze(rt) - - return plan, nil -} - -func registerTypedef(gotype reflect.Type, r *typedef) *typedef { - plan, _ := planCache.LoadOrStore(gotype, r) - return plan.(*typedef) +func registerTypedef(gotype reflect.Type, def *typedef) *typedef { + canon, _ := typeCache.LoadOrStore(gotype, def) + return canon.(*typedef) } func typedefOf(rt reflect.Type) (*typedef, error) { - v, ok := planCache.Load(rt) + v, ok := typeCache.Load(rt) if ok { return v.(*typedef), nil } - plan, err := newTypedef(rt) + def, err := newTypedef(rt) if err != nil { return nil, err } - plan = registerTypedef(rt, plan) - return plan, nil + def = registerTypedef(rt, def) + return def, nil } -func (plan *typedef) seen(gotype reflect.Type) bool { - _, ok := plan.decoders[unmarshalKey{gotype: gotype, shape: '0'}] +func (def *typedef) seen(gotype reflect.Type) bool { + _, ok := def.decoders[unmarshalKey{gotype: gotype, shape: shapeNULL}] return ok } -func (plan *typedef) handle(key unmarshalKey, fn decodeFunc) { - if _, ok := plan.decoders[key]; ok { +func (def *typedef) handle(key unmarshalKey, fn decodeFunc) { + if _, ok := def.decoders[key]; ok { return } - plan.decoders[key] = fn + def.decoders[key] = fn // debugf("handle %#v -> %s", key, runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()) } func (def *typedef) encodeItem(rv reflect.Value) (Item, error) { - // out := rv rv = indirectPtrNoAlloc(rv) if shouldBypassEncodeItem(rv.Type()) { return def.encodeItemBypass(rv.Interface()) @@ -153,7 +101,7 @@ func (def *typedef) encodeItem(rv reflect.Value) (Item, error) { return encodeItem(def.fields, rv) } -func (plan *typedef) encodeItemBypass(in any) (item map[string]*dynamodb.AttributeValue, err error) { +func (def *typedef) encodeItemBypass(in any) (item map[string]*dynamodb.AttributeValue, err error) { switch x := in.(type) { case map[string]*dynamodb.AttributeValue: item = x @@ -170,15 +118,15 @@ func (plan *typedef) encodeItemBypass(in any) (item map[string]*dynamodb.Attribu return } -func (plan *typedef) decodeItem(item map[string]*dynamodb.AttributeValue, outv reflect.Value) error { +func (def *typedef) decodeItem(item map[string]*dynamodb.AttributeValue, outv reflect.Value) error { out := outv outv = indirectPtr(outv) if shouldBypassDecodeItem(outv.Type()) { - return plan.decodeItemBypass(item, outv.Interface()) + return def.decodeItemBypass(item, outv.Interface()) } outv = indirect(outv) if shouldBypassDecodeItem(outv.Type()) { - return plan.decodeItemBypass(item, outv.Interface()) + return def.decodeItemBypass(item, outv.Interface()) } if !outv.CanSet() { @@ -188,16 +136,16 @@ func (plan *typedef) decodeItem(item map[string]*dynamodb.AttributeValue, outv r // debugf("decode item: %v -> %T(%v)", item, out, out) switch outv.Kind() { case reflect.Struct: - return decodeStruct(plan, flagNone, &dynamodb.AttributeValue{M: item}, outv) + return decodeStruct(def, flagNone, &dynamodb.AttributeValue{M: item}, outv) case reflect.Map: - return plan.decodeAttr(flagNone, &dynamodb.AttributeValue{M: item}, outv) + return def.decodeAttr(flagNone, &dynamodb.AttributeValue{M: item}, outv) } bad: return fmt.Errorf("dynamo: cannot unmarshal item into type %v (must be a pointer to a map or struct, or a supported interface)", out.Type()) } -func (plan *typedef) decodeItemBypass(item map[string]*dynamodb.AttributeValue, out any) error { +func (def *typedef) decodeItemBypass(item map[string]*dynamodb.AttributeValue, out any) error { switch x := out.(type) { case *map[string]*dynamodb.AttributeValue: *x = item @@ -210,7 +158,7 @@ func (plan *typedef) decodeItemBypass(item map[string]*dynamodb.AttributeValue, return nil } -func (plan *typedef) decodeAttr(flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { +func (def *typedef) decodeAttr(flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { if !rv.IsValid() || av == nil { return nil } @@ -218,14 +166,14 @@ func (plan *typedef) decodeAttr(flags encodeFlags, av *dynamodb.AttributeValue, // debugf("decodeAttr: %v(%v) <- %v", rv.Type(), rv, av) if av.NULL != nil { - return decodeNull(plan, flags, av, rv) + return decodeNull(def, flags, av, rv) } rv = indirectPtr(rv) retry: gotype := rv.Type() - ok, err := plan.decodeType(unmarshalKey{gotype: gotype, shape: shapeOf(av)}, flags, av, rv) + ok, err := def.decodeType(unmarshalKey{gotype: gotype, shape: shapeOf(av)}, flags, av, rv) if err != nil { return err } @@ -233,12 +181,12 @@ retry: // debugf("lookup1 %#v -> %v", unmarshalKey{gotype: gotype, shape: shapeOf(av)}, rv) return nil } - ok, err = plan.decodeType(unmarshalKey{gotype: gotype, shape: '_'}, flags, av, rv) + ok, err = def.decodeType(unmarshalKey{gotype: gotype, shape: shapeAny}, flags, av, rv) if err != nil { return err } if ok { - // debugf("lookup2 %#v -> %v", unmarshalKey{gotype: gotype, shape: '_'}, rv) + // debugf("lookup2 %#v -> %v", unmarshalKey{gotype: gotype, shape: shapeAny}, rv) return nil } @@ -251,18 +199,18 @@ retry: return fmt.Errorf("dynamo: cannot unmarshal %s attribute value into type %s", avTypeName(av), rv.Type().String()) } -func (plan *typedef) decodeType(key unmarshalKey, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) (bool, error) { - do, ok := plan.decoders[key] +func (def *typedef) decodeType(key unmarshalKey, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) (bool, error) { + do, ok := def.decoders[key] if !ok { return false, nil } - err := do(plan, flags, av, rv) + err := do(def, flags, av, rv) return true, err } -func (plan *typedef) learn(rt reflect.Type) { - if plan.decoders == nil { - plan.decoders = make(map[unmarshalKey]decodeFunc) +func (def *typedef) learn(rt reflect.Type) { + if def.decoders == nil { + def.decoders = make(map[unmarshalKey]decodeFunc) } this := func(db shapeKey) unmarshalKey { @@ -270,11 +218,11 @@ func (plan *typedef) learn(rt reflect.Type) { } switch { - case plan.seen(rt): + case def.seen(rt): return } - plan.handle(this(shapeNULL), decodeNull) + def.handle(this(shapeNULL), decodeNull) try := rt if try.Kind() != reflect.Pointer { @@ -283,31 +231,31 @@ func (plan *typedef) learn(rt reflect.Type) { for { switch try { case rtypeAttr: - plan.handle(this(shapeAny), decode2(func(dst *dynamodb.AttributeValue, src *dynamodb.AttributeValue) error { + def.handle(this(shapeAny), decode2(func(dst *dynamodb.AttributeValue, src *dynamodb.AttributeValue) error { *dst = *src return nil })) return case rtypeTimePtr, rtypeTime: - plan.handle(this(shapeN), decodeUnixTime) - plan.handle(this(shapeS), decode2(func(t encoding.TextUnmarshaler, av *dynamodb.AttributeValue) error { + def.handle(this(shapeN), decodeUnixTime) + def.handle(this(shapeS), decode2(func(t encoding.TextUnmarshaler, av *dynamodb.AttributeValue) error { return t.UnmarshalText([]byte(*av.S)) })) return } switch { case try.Implements(rtypeUnmarshaler): - plan.handle(this(shapeAny), decode2(func(t Unmarshaler, av *dynamodb.AttributeValue) error { + def.handle(this(shapeAny), decode2(func(t Unmarshaler, av *dynamodb.AttributeValue) error { return t.UnmarshalDynamo(av) })) return case try.Implements(rtypeAWSUnmarshaler): - plan.handle(this(shapeAny), decode2(func(t dynamodbattribute.Unmarshaler, av *dynamodb.AttributeValue) error { + def.handle(this(shapeAny), decode2(func(t dynamodbattribute.Unmarshaler, av *dynamodb.AttributeValue) error { return t.UnmarshalDynamoDBAttributeValue(av) })) return case try.Implements(rtypeTextUnmarshaler): - plan.handle(this(shapeS), decode2(func(t encoding.TextUnmarshaler, av *dynamodb.AttributeValue) error { + def.handle(this(shapeS), decode2(func(t encoding.TextUnmarshaler, av *dynamodb.AttributeValue) error { return t.UnmarshalText([]byte(*av.S)) })) return @@ -322,74 +270,74 @@ func (plan *typedef) learn(rt reflect.Type) { switch rt.Kind() { case reflect.Ptr: - plan.learn(rt.Elem()) - plan.handle(this(shapeAny), decodePtr) + def.learn(rt.Elem()) + def.handle(this(shapeAny), decodePtr) case reflect.Bool: - plan.handle(this(shapeBOOL), decodeBool) + def.handle(this(shapeBOOL), decodeBool) case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - plan.handle(this(shapeN), decodeInt) + def.handle(this(shapeN), decodeInt) case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - plan.handle(this(shapeN), decodeUint) + def.handle(this(shapeN), decodeUint) case reflect.Float64, reflect.Float32: - plan.handle(this(shapeN), decodeFloat) + def.handle(this(shapeN), decodeFloat) case reflect.String: - plan.handle(this(shapeS), decodeString) + def.handle(this(shapeS), decodeString) case reflect.Struct: visitTypeFields(rt, nil, nil, func(_ string, _ []int, flags encodeFlags, vt reflect.Type) error { - plan.learn(vt) + def.learn(vt) return nil }) - plan.handle(this(shapeM), decodeStruct) + def.handle(this(shapeM), decodeStruct) case reflect.Map: - plan.learn(rt.Key()) - plan.learn(rt.Elem()) + def.learn(rt.Key()) + def.learn(rt.Elem()) decodeKey := decodeMapKeyFunc(rt) - plan.handle(this(shapeM), decodeMap(decodeKey)) + def.handle(this(shapeM), decodeMap(decodeKey)) truthy := truthy(rt) if !truthy.IsValid() { - bad := func(plan *typedef, flags encodeFlags, av *dynamodb.AttributeValue, rv reflect.Value) error { + bad := func(_ *typedef, _ encodeFlags, _ *dynamodb.AttributeValue, _ reflect.Value) error { return fmt.Errorf("dynamo: unmarshal map set: value type must be struct{} or bool, got %v", rt) } - plan.handle(this(shapeSS), bad) - plan.handle(this(shapeNS), bad) - plan.handle(this(shapeBS), bad) + def.handle(this(shapeSS), bad) + def.handle(this(shapeNS), bad) + def.handle(this(shapeBS), bad) return } - plan.handle(this(shapeSS), decodeMapSS(decodeKey, truthy)) - plan.handle(this(shapeNS), decodeMapNS(decodeKey, truthy)) - plan.handle(this(shapeBS), decodeMapBS(decodeKey, truthy)) + def.handle(this(shapeSS), decodeMapSS(decodeKey, truthy)) + def.handle(this(shapeNS), decodeMapNS(decodeKey, truthy)) + def.handle(this(shapeBS), decodeMapBS(decodeKey, truthy)) case reflect.Slice: - plan.learn(rt.Elem()) + def.learn(rt.Elem()) if rt.Elem().Kind() == reflect.Uint8 { - plan.handle(this(shapeB), decodeBytes) + def.handle(this(shapeB), decodeBytes) } /* else { - plan.handle(this(shapeB), decodeSliceB) + def.handle(this(shapeB), decodeSliceB) } */ - plan.handle(this(shapeL), decodeSliceL) - plan.handle(this(shapeBS), decodeSliceBS) - plan.handle(this(shapeSS), decodeSliceSS) - plan.handle(this(shapeNS), decodeSliceNS) + def.handle(this(shapeL), decodeSliceL) + def.handle(this(shapeBS), decodeSliceBS) + def.handle(this(shapeSS), decodeSliceSS) + def.handle(this(shapeNS), decodeSliceNS) case reflect.Array: - plan.learn(rt.Elem()) - plan.handle(this(shapeB), decodeArrayB) - plan.handle(this(shapeL), decodeArrayL) + def.learn(rt.Elem()) + def.handle(this(shapeB), decodeArrayB) + def.handle(this(shapeL), decodeArrayL) case reflect.Interface: // interface{} if rt.NumMethod() == 0 { - plan.handle(this(shapeAny), decodeAny) + def.handle(this(shapeAny), decodeAny) } } } @@ -416,43 +364,52 @@ func shouldBypassEncodeItem(rt reflect.Type) bool { return false } -func encodeBytes(rt reflect.Type, flags encodeFlags) encodeFunc { - if rt.Kind() == reflect.Array { - size := rt.Len() - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { - if rv.IsZero() { - switch { - case flags&flagNull != 0: - return nullAV, nil - case flags&flagAllowEmpty != 0: - return emptyB, nil - } - return nil, nil - } - data := make([]byte, size) - reflect.Copy(reflect.ValueOf(data), rv) - return &dynamodb.AttributeValue{B: data}, nil - } +type unmarshalKey struct { + gotype reflect.Type + shape shapeKey +} + +func (key unmarshalKey) GoString() string { + return fmt.Sprintf("%s:%v", key.shape.GoString(), key.gotype.String()) +} + +func (key unmarshalKey) Less(other unmarshalKey) bool { + if key.gotype == other.gotype { + return key.shape < other.shape } + return key.gotype.String() < other.gotype.String() +} - return func(rv reflect.Value, flags encodeFlags) (*dynamodb.AttributeValue, error) { - if rv.IsNil() { - if flags&flagNull != 0 { - return nullAV, nil - } - return nil, nil +type structField struct { + index []int + name string + flags encodeFlags + enc encodeFunc + isZero func(reflect.Value) bool +} + +func structFields(rt reflect.Type) ([]structField, error) { + var fields []structField + err := visitTypeFields(rt, nil, nil, func(name string, index []int, flags encodeFlags, vt reflect.Type) error { + enc, err := encodeType(vt, flags) + if err != nil { + return err } - if rv.Len() == 0 { - if flags&flagAllowEmpty != 0 { - return emptyB, nil - } - return nil, nil + field := structField{ + index: index, + name: name, + flags: flags, + enc: enc, + isZero: isZeroFunc(vt), } - return &dynamodb.AttributeValue{B: rv.Bytes()}, nil - } + fields = append(fields, field) + return nil + }) + return fields, err } var ( nullAV = &dynamodb.AttributeValue{NULL: aws.Bool(true)} emptyB = &dynamodb.AttributeValue{B: []byte("")} + emptyS = &dynamodb.AttributeValue{S: new(string)} )