Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ type Command struct {
SliceFlagSeparator string `json:"sliceFlagSeparator"`
// DisableSliceFlagSeparator is used to disable SliceFlagSeparator, the default is false
DisableSliceFlagSeparator bool `json:"disableSliceFlagSeparator"`
// MapFlagKeyValueSeparator is used to customize the separator for MapFlag, the default is "="
MapFlagKeyValueSeparator string `json:"mapFlagKeyValueSeparator"`
// Boolean to enable short-option handling so user can combine several
// single-character bool arguments into one
// i.e. foobar -o -v -> foobar -ov
Expand Down Expand Up @@ -155,6 +157,10 @@ type Command struct {
didSetupDefaults bool
// whether in shell completion mode
shellCompletion bool
// whether global help flag was added
globaHelpFlagAdded bool
// whether global version flag was added
globaVersionFlagAdded bool
}

// FullName returns the full name of the command.
Expand Down Expand Up @@ -349,6 +355,7 @@ func (cmd *Command) Root() *Command {

func (cmd *Command) set(fName string, f Flag, val string) error {
cmd.setFlags[f] = struct{}{}
cmd.setMultiValueParsingConfig(f)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strictly speaking this call should be moved into PreParse phase.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dearchap
I tried to avoid breaking things, and there are two set methods set and Set

Initially I thought to call setMultiValueParsingConfig from inside parseFlags(args Args) (Args, error) { or run(... but found that Set can be called without Run , at least in tests , so I moved it where setMultiValueParsingConfig will be called for sure

if err := f.Set(fName, val); err != nil {
return fmt.Errorf("invalid value %q for flag -%s: %v", val, fName, err)
}
Expand Down Expand Up @@ -440,9 +447,21 @@ func (cmd *Command) NumFlags() int {
return count // cmd.flagSet.NFlag()
}

func (cmd *Command) setMultiValueParsingConfig(f Flag) {
tracef("setMultiValueParsingConfig %T, %+v", f, f)
if cf, ok := f.(multiValueParsingConfigSetter); ok {
cf.setMultiValueParsingConfig(multiValueParsingConfig{
SliceFlagSeparator: cmd.SliceFlagSeparator,
DisableSliceFlagSeparator: cmd.DisableSliceFlagSeparator,
MapFlagKeyValueSeparator: cmd.MapFlagKeyValueSeparator,
})
}
}

// Set sets a context flag to a value.
func (cmd *Command) Set(name, value string) error {
if f := cmd.lookupFlag(name); f != nil {
cmd.setMultiValueParsingConfig(f)
return f.Set(name, value)
}

Expand Down
43 changes: 26 additions & 17 deletions command_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,18 @@ func (cmd *Command) setupDefaults(osArgs []string) {

if !cmd.HideVersion && isRoot {
tracef("appending version flag (cmd=%[1]q)", cmd.Name)
cmd.appendFlag(VersionFlag)
if !cmd.globaVersionFlagAdded {
var localVersionFlag Flag
if globalVersionFlag, ok := VersionFlag.(*BoolFlag); ok {
flag := *globalVersionFlag
localVersionFlag = &flag
} else {
localVersionFlag = VersionFlag
}

cmd.appendFlag(localVersionFlag)
cmd.globaVersionFlagAdded = true
}
}

if cmd.PrefixMatchCommands && cmd.SuggestCommandFunc == nil {
Expand Down Expand Up @@ -130,14 +141,6 @@ func (cmd *Command) setupDefaults(osArgs []string) {
cmd.Metadata = map[string]any{}
}

if len(cmd.SliceFlagSeparator) != 0 {
tracef("setting defaultSliceFlagSeparator from cmd.SliceFlagSeparator (cmd=%[1]q)", cmd.Name)
defaultSliceFlagSeparator = cmd.SliceFlagSeparator
}

tracef("setting disableSliceFlagSeparator from cmd.DisableSliceFlagSeparator (cmd=%[1]q)", cmd.Name)
disableSliceFlagSeparator = cmd.DisableSliceFlagSeparator

cmd.setFlags = map[Flag]struct{}{}
}

Expand Down Expand Up @@ -200,15 +203,21 @@ func (cmd *Command) ensureHelp() {
}

if HelpFlag != nil {
// TODO need to remove hack
if hf, ok := HelpFlag.(*BoolFlag); ok {
hf.applied = false
hf.hasBeenSet = false
hf.Value = false
hf.value = nil
if !cmd.globaHelpFlagAdded {
var localHelpFlag Flag
if globalHelpFlag, ok := HelpFlag.(*BoolFlag); ok {
flag := *globalHelpFlag
localHelpFlag = &flag
} else {
localHelpFlag = HelpFlag
}

tracef("appending HelpFlag (cmd=%[1]q)", cmd.Name)
cmd.appendFlag(localHelpFlag)
cmd.globaHelpFlagAdded = true
} else {
tracef("HelpFlag already added, skip (cmd=%[1]q)", cmd.Name)
}
tracef("appending HelpFlag (cmd=%[1]q)", cmd.Name)
cmd.appendFlag(HelpFlag)
}
}
}
61 changes: 56 additions & 5 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4386,11 +4386,6 @@ func TestCommandCategories(t *testing.T) {
}

func TestCommandSliceFlagSeparator(t *testing.T) {
oldSep := defaultSliceFlagSeparator
defer func() {
defaultSliceFlagSeparator = oldSep
}()

cmd := &Command{
SliceFlagSeparator: ";",
Flags: []Flag{
Expand All @@ -4405,6 +4400,26 @@ func TestCommandSliceFlagSeparator(t *testing.T) {
r.Equal([]string{"ff", "dd", "gg", "t,u"}, cmd.Value("foo"))
}

func TestCommandMapKeyValueFlagSeparator(t *testing.T) {
cmd := &Command{
MapFlagKeyValueSeparator: ":",
Flags: []Flag{
&StringMapFlag{
Name: "f_string_map",
},
},
}

r := require.New(t)
r.NoError(cmd.Run(buildTestContext(t), []string{"app", "--f_string_map", "s1:s2,s3:", "--f_string_map", "s4:s5"}))
exp := map[string]string{
"s1": "s2",
"s3": "",
"s4": "s5",
}
r.Equal(exp, cmd.Value("f_string_map"))
}

// TestStringFlagTerminator tests the string flag "--flag" with "--" terminator.
func TestStringFlagTerminator(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -4754,6 +4769,7 @@ func TestJSONExportCommand(t *testing.T) {
"metadata": null,
"sliceFlagSeparator": "",
"disableSliceFlagSeparator": false,
"mapFlagKeyValueSeparator": "",
"useShortOptionHandling": false,
"suggest": false,
"allowExtFlags": false,
Expand Down Expand Up @@ -4817,6 +4833,7 @@ func TestJSONExportCommand(t *testing.T) {
"metadata": null,
"sliceFlagSeparator": "",
"disableSliceFlagSeparator": false,
"mapFlagKeyValueSeparator": "",
"useShortOptionHandling": false,
"suggest": false,
"allowExtFlags": false,
Expand Down Expand Up @@ -4851,6 +4868,7 @@ func TestJSONExportCommand(t *testing.T) {
"metadata": null,
"sliceFlagSeparator": "",
"disableSliceFlagSeparator": false,
"mapFlagKeyValueSeparator": "",
"useShortOptionHandling": false,
"suggest": false,
"allowExtFlags": false,
Expand Down Expand Up @@ -4882,6 +4900,7 @@ func TestJSONExportCommand(t *testing.T) {
"metadata": null,
"sliceFlagSeparator": "",
"disableSliceFlagSeparator": false,
"mapFlagKeyValueSeparator": "",
"useShortOptionHandling": false,
"suggest": false,
"allowExtFlags": false,
Expand Down Expand Up @@ -4932,6 +4951,7 @@ func TestJSONExportCommand(t *testing.T) {
"metadata": null,
"sliceFlagSeparator": "",
"disableSliceFlagSeparator": false,
"mapFlagKeyValueSeparator": "",
"useShortOptionHandling": false,
"suggest": false,
"allowExtFlags": false,
Expand Down Expand Up @@ -4999,6 +5019,7 @@ func TestJSONExportCommand(t *testing.T) {
"metadata": null,
"sliceFlagSeparator": "",
"disableSliceFlagSeparator": false,
"mapFlagKeyValueSeparator": "",
"useShortOptionHandling": false,
"suggest": false,
"allowExtFlags": false,
Expand Down Expand Up @@ -5062,6 +5083,7 @@ func TestJSONExportCommand(t *testing.T) {
"metadata": null,
"sliceFlagSeparator": "",
"disableSliceFlagSeparator": false,
"mapFlagKeyValueSeparator": "",
"useShortOptionHandling": false,
"suggest": false,
"allowExtFlags": false,
Expand Down Expand Up @@ -5169,6 +5191,7 @@ func TestJSONExportCommand(t *testing.T) {
"metadata": null,
"sliceFlagSeparator": "",
"disableSliceFlagSeparator": false,
"mapFlagKeyValueSeparator": "",
"useShortOptionHandling": false,
"suggest": false,
"allowExtFlags": false,
Expand Down Expand Up @@ -5284,3 +5307,31 @@ func TestCommand_ExclusiveFlagsWithAfter(t *testing.T) {
}))
require.True(t, called)
}

func TestCommand_ParallelRun(t *testing.T) {
t.Parallel()

for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("run_%d", i), func(t *testing.T) {
t.Parallel()

defer func() {
if r := recover(); r != nil {
t.Errorf("unexpected panic - '%s'", r)
}
}()

cmd := &Command{
Name: "debug",
Usage: "make an explosive entrance",
Action: func(_ context.Context, cmd *Command) error {
return nil
},
}

if err := cmd.Run(context.Background(), nil); err != nil {
fmt.Printf("%s\n", err)
}
})
}
}
11 changes: 7 additions & 4 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

const defaultPlaceholder = "value"

var (
const (
defaultSliceFlagSeparator = ","
defaultMapFlagKeyValueSeparator = "="
disableSliceFlagSeparator = false
Expand Down Expand Up @@ -222,10 +222,13 @@ func hasFlag(flags []Flag, fl Flag) bool {
return false
}

func flagSplitMultiValues(val string) []string {
if disableSliceFlagSeparator {
func flagSplitMultiValues(val string, sliceSeparator string, disableSliceSeparator bool) []string {
if disableSliceSeparator {
return []string{val}
}

return strings.Split(val, defaultSliceFlagSeparator)
if len(sliceSeparator) == 0 {
sliceSeparator = defaultSliceFlagSeparator
}
return strings.Split(val, sliceSeparator)
}
22 changes: 22 additions & 0 deletions flag_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@ type boolFlag interface {
IsBoolFlag() bool
}

type multiValueParsingConfig struct {
// SliceFlagSeparator is used to customize the separator for SliceFlag, the default is ","
SliceFlagSeparator string
// DisableSliceFlagSeparator is used to disable SliceFlagSeparator, the default is false
DisableSliceFlagSeparator bool
// MapFlagKeyValueSeparator is used to customize the separator for MapFlag, the default is "="
MapFlagKeyValueSeparator string
}

type multiValueParsingConfigSetter interface {
// configuration of parsing
setMultiValueParsingConfig(c multiValueParsingConfig)
}

// ValueCreator is responsible for creating a flag.Value emulation
// as well as custom formatting
//
Expand Down Expand Up @@ -134,6 +148,14 @@ func (f *FlagBase[T, C, V]) PostParse() error {
return nil
}

// pass configuration of parsing to value
func (f *FlagBase[T, C, V]) setMultiValueParsingConfig(c multiValueParsingConfig) {
tracef("setMultiValueParsingConfig %T, %+v", f.value, f.value)
if cf, ok := f.value.(multiValueParsingConfigSetter); ok {
cf.setMultiValueParsingConfig(c)
}
}

func (f *FlagBase[T, C, V]) PreParse() error {
newVal := f.Value

Expand Down
38 changes: 32 additions & 6 deletions flag_map_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (

// MapBase wraps map[string]T to satisfy flag.Value
type MapBase[T any, C any, VC ValueCreator[T, C]] struct {
dict *map[string]T
hasBeenSet bool
value Value
dict *map[string]T
hasBeenSet bool
value Value
multiValueConfig multiValueParsingConfig
}

func (i MapBase[T, C, VC]) Create(val map[string]T, p *map[string]T, c C) Value {
Expand All @@ -36,6 +37,18 @@ func NewMapBase[T any, C any, VC ValueCreator[T, C]](defaults map[string]T) *Map
}
}

// configuration of slicing
func (i *MapBase[T, C, VC]) setMultiValueParsingConfig(c multiValueParsingConfig) {
i.multiValueConfig = c
mvc := &i.multiValueConfig
tracef(
"set map parsing config - keyValueSeparator '%s', slice separator '%s', disable separator:%v",
mvc.MapFlagKeyValueSeparator,
mvc.SliceFlagSeparator,
mvc.DisableSliceFlagSeparator,
)
}

// Set parses the value and appends it to the list of values
func (i *MapBase[T, C, VC]) Set(value string) error {
if !i.hasBeenSet {
Expand All @@ -50,10 +63,23 @@ func (i *MapBase[T, C, VC]) Set(value string) error {
return nil
}

for _, item := range flagSplitMultiValues(value) {
key, value, ok := strings.Cut(item, defaultMapFlagKeyValueSeparator)
mvc := &i.multiValueConfig
keyValueSeparator := mvc.MapFlagKeyValueSeparator
if len(keyValueSeparator) == 0 {
keyValueSeparator = defaultMapFlagKeyValueSeparator
}

tracef(
"splitting map value '%s', keyValueSeparator '%s', slice separator '%s', disable separator:%v",
value,
keyValueSeparator,
mvc.SliceFlagSeparator,
mvc.DisableSliceFlagSeparator,
)
for _, item := range flagSplitMultiValues(value, mvc.SliceFlagSeparator, mvc.DisableSliceFlagSeparator) {
key, value, ok := strings.Cut(item, keyValueSeparator)
if !ok {
return fmt.Errorf("item %q is missing separator %q", item, defaultMapFlagKeyValueSeparator)
return fmt.Errorf("item %q is missing separator %q", item, keyValueSeparator)
}
if err := i.value.Set(value); err != nil {
return err
Expand Down
Loading