Skip to content

Commit

Permalink
feat(autocli): add map support (#15770)
Browse files Browse the repository at this point in the history
  • Loading branch information
JeancarloBarrios authored May 30, 2023
1 parent f358214 commit 62f0c6f
Show file tree
Hide file tree
Showing 14 changed files with 2,306 additions and 80 deletions.
70 changes: 67 additions & 3 deletions client/v2/autocli/flag/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package flag

import (
"context"
"strconv"

autocliv1 "cosmossdk.io/api/cosmos/autocli/v1"
cosmos_proto "github.com/cosmos/cosmos-proto"
"github.com/spf13/pflag"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"

autocliv1 "cosmossdk.io/api/cosmos/autocli/v1"
cosmos_proto "github.com/cosmos/cosmos-proto"

"cosmossdk.io/client/v2/internal/util"
)

Expand Down Expand Up @@ -60,8 +62,13 @@ func (b *Builder) addFieldFlag(ctx context.Context, flagSet *pflag.FlagSet, fiel

// use the built-in pflag StringP, Int32P, etc. functions
var val HasValue

if field.IsList() {
val = bindSimpleListFlag(flagSet, field.Kind(), name, shorthand, usage)
} else if field.IsMap() {
keyKind := field.MapKey().Kind()
valKind := field.MapValue().Kind()
val = bindSimpleMapFlag(flagSet, keyKind, valKind, name, shorthand, usage)
} else {
val = bindSimpleFlag(flagSet, field.Kind(), name, shorthand, usage)
}
Expand All @@ -81,7 +88,65 @@ func (b *Builder) resolveFlagType(field protoreflect.FieldDescriptor) Type {
if typ != nil {
return compositeListType{simpleType: typ}
}
return nil
}
if field.IsMap() {
keyKind := field.MapKey().Kind()
valType := b.resolveFlagType(field.MapValue())
if valType != nil {
switch keyKind {
case protoreflect.StringKind:
ct := new(compositeMapType[string])
ct.keyValueResolver = func(s string) (string, error) { return s, nil }
ct.valueType = valType
ct.keyType = "string"
return ct
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
ct := new(compositeMapType[int32])
ct.keyValueResolver = func(s string) (int32, error) {
i, err := strconv.ParseInt(s, 10, 32)
return int32(i), err
}
ct.valueType = valType
ct.keyType = "int32"
return ct
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
ct := new(compositeMapType[int64])
ct.keyValueResolver = func(s string) (int64, error) {
i, err := strconv.ParseInt(s, 10, 64)
return i, err
}
ct.valueType = valType
ct.keyType = "int64"
return ct
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
ct := new(compositeMapType[uint32])
ct.keyValueResolver = func(s string) (uint32, error) {
i, err := strconv.ParseUint(s, 10, 32)
return uint32(i), err
}
ct.valueType = valType
ct.keyType = "uint32"
return ct
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
ct := new(compositeMapType[uint64])
ct.keyValueResolver = func(s string) (uint64, error) {
i, err := strconv.ParseUint(s, 10, 64)
return i, err
}
ct.valueType = valType
ct.keyType = "uint64"
return ct
case protoreflect.BoolKind:
ct := new(compositeMapType[bool])
ct.keyValueResolver = strconv.ParseBool
ct.valueType = valType
ct.keyType = "bool"
return ct
}
return nil

}
return nil
}

Expand All @@ -107,7 +172,6 @@ func (b *Builder) resolveFlagTypeBasic(field protoreflect.FieldDescriptor) Type
if flagType, ok := b.messageFlagTypes[field.Message().FullName()]; ok {
return flagType
}

return jsonMessageFlagType{
messageDesc: field.Message(),
}
Expand Down
252 changes: 252 additions & 0 deletions client/v2/autocli/flag/map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
package flag

import (
"context"
"fmt"
"strings"

"github.com/cockroachdb/errors"
"github.com/spf13/pflag"
"google.golang.org/protobuf/reflect/protoreflect"
)

func bindSimpleMapFlag(flagSet *pflag.FlagSet, keyKind, valueKind protoreflect.Kind, name, shorthand, usage string) HasValue {
switch keyKind {
case protoreflect.StringKind:
switch valueKind {
case protoreflect.StringKind:
val := flagSet.StringToStringP(name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfString)
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
val := StringToInt32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt32)
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
val := flagSet.StringToInt64P(name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt64)
case protoreflect.Uint32Kind:
val := StringToUint32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint32)
case protoreflect.Uint64Kind:
val := StringToUint64P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint64)
case protoreflect.BoolKind:
val := StringToBoolP(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfBool)
}

case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
switch valueKind {
case protoreflect.StringKind:
val := Int32ToStringP(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfString)
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
val := Int32ToInt32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt32)
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
val := Int32ToInt64P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt64)
case protoreflect.Uint32Kind:
val := Int32ToUint32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint32)
case protoreflect.Uint64Kind:
val := Int32ToUint64P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint64)
case protoreflect.BoolKind:
val := Int32ToBoolP(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfBool)
}

case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
switch valueKind {
case protoreflect.StringKind:
val := Int64ToStringP(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfString)
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
val := Int64ToInt32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt32)
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
val := Int64ToInt64P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt64)
case protoreflect.Uint32Kind:
val := Int64ToUint32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint32)
case protoreflect.Uint64Kind:
val := Int64ToUint64P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint64)
case protoreflect.BoolKind:
val := Int64ToBoolP(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfBool)
}
case protoreflect.Uint32Kind:
switch valueKind {
case protoreflect.StringKind:
val := Uint32ToStringP(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfString)
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
val := Uint32ToInt32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt32)
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
val := Uint32ToInt64P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt64)
case protoreflect.Uint32Kind:
val := Uint32ToUint32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint32)
case protoreflect.Uint64Kind:
val := Uint32ToUint64P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint64)
case protoreflect.BoolKind:
val := Uint32ToBoolP(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfBool)
}
case protoreflect.Uint64Kind:
switch valueKind {
case protoreflect.StringKind:
val := Uint64ToStringP(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfString)
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
val := Uint64ToInt32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt32)
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
val := Uint64ToInt64P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt64)
case protoreflect.Uint32Kind:
val := Uint64ToUint32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint32)
case protoreflect.Uint64Kind:
val := Uint64ToUint64P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint64)
case protoreflect.BoolKind:
val := Uint64ToBoolP(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfBool)
}
case protoreflect.BoolKind:
switch valueKind {
case protoreflect.StringKind:
val := BoolToStringP(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfString)
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
val := BoolToInt32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt32)
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
val := BoolToInt64P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfInt64)
case protoreflect.Uint32Kind:
val := BoolToUint32P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint32)
case protoreflect.Uint64Kind:
val := BoolToUint64P(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfUint64)
case protoreflect.BoolKind:
val := BoolToBoolP(flagSet, name, shorthand, nil, usage)
return newMapValue(val, protoreflect.ValueOfBool)
}

}
return nil
}

type mapValue[K comparable, V any] struct {
value *map[K]V
toProtoreflectValue func(V) protoreflect.Value
}

func newMapValue[K comparable, V any](mapV *map[K]V, toProtoreflectValue func(V) protoreflect.Value) mapValue[K, V] {
return mapValue[K, V]{value: mapV, toProtoreflectValue: toProtoreflectValue}
}

func (v mapValue[K, V]) Get(mutable protoreflect.Value) (protoreflect.Value, error) {
protoMap := mutable.Map()
for k, val := range *v.value {
protoMap.Set(protoreflect.MapKey(protoreflect.ValueOf(k)), v.toProtoreflectValue(val))
}
return mutable, nil
}

// keyValueResolver is a function that converts a string to a key that is primitive Type T
type keyValueResolver[T comparable] func(string) (T, error)

// compositeMapType is a map type that is composed of a key and value type that are both primitive types
type compositeMapType[T comparable] struct {
keyValueResolver keyValueResolver[T]
keyType string
valueType Type
}

// compositeMapValue is a map value that is composed of a key and value type that are both primitive types
type compositeMapValue[T comparable] struct {
keyValueResolver keyValueResolver[T]
keyType string
valueType Type
values map[T]protoreflect.Value
ctx context.Context
opts *Builder
}

func (m compositeMapType[T]) DefaultValue() string {
return ""
}

func (m compositeMapType[T]) NewValue(ctx context.Context, opts *Builder) Value {
return &compositeMapValue[T]{
keyValueResolver: m.keyValueResolver,
valueType: m.valueType,
keyType: m.keyType,
ctx: ctx,
opts: opts,
values: nil,
}
}

func (m *compositeMapValue[T]) Set(s string) error {
comaArgs := strings.Split(s, ",")
for _, arg := range comaArgs {
parts := strings.SplitN(arg, "=", 2)
if len(parts) != 2 {
return errors.New("invalid format, expected key=value")
}
key, val := parts[0], parts[1]

keyValue, err := m.keyValueResolver(key)
if err != nil {
return err
}

simpleVal := m.valueType.NewValue(m.ctx, m.opts)
err = simpleVal.Set(val)
if err != nil {
return err
}
protoValue, err := simpleVal.Get(protoreflect.Value{})
if err != nil {
return err
}
if m.values == nil {
m.values = make(map[T]protoreflect.Value)
}

m.values[keyValue] = protoValue
}

return nil
}

func (m *compositeMapValue[T]) Get(mutable protoreflect.Value) (protoreflect.Value, error) {
protoMap := mutable.Map()
for key, value := range m.values {
keyVal := protoreflect.ValueOf(key)
protoMap.Set(keyVal.MapKey(), value)
}
return protoreflect.ValueOfMap(protoMap), nil
}

func (m *compositeMapValue[T]) String() string {
if m.values == nil {
return ""
}

return fmt.Sprintf("%+v", m.values)
}

func (m *compositeMapValue[T]) Type() string {
return fmt.Sprintf("map[%s]%s", m.keyType, m.valueType.NewValue(m.ctx, m.opts).Type())
}
Loading

0 comments on commit 62f0c6f

Please sign in to comment.