Skip to content

Commit a0a6ae0

Browse files
authored
Improve API to get flag completion function (#2063)
The new API is simpler and matches the `c.RegisterFlagCompletionFunc()` API. By removing the global function `GetFlagCompletion()` we are more future proof if we ever move from a global map of flag completion functions to something associated with the command. The commit also makes this API work with persistent flags by using `c.Flag(flagName)` instead of `c.Flags().Lookup(flagName)`. The commit also adds unit tests. Signed-off-by: Marc Khouzam <marc.khouzam@gmail.com>
1 parent 890302a commit a0a6ae0

File tree

2 files changed

+97
-12
lines changed

2 files changed

+97
-12
lines changed

completions.go

+7-12
Original file line numberDiff line numberDiff line change
@@ -145,25 +145,20 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman
145145
return nil
146146
}
147147

148-
// GetFlagCompletion returns the completion function for the given flag, if available.
149-
func GetFlagCompletion(flag *pflag.Flag) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) {
148+
// GetFlagCompletionFunc returns the completion function for the given flag of the command, if available.
149+
func (c *Command) GetFlagCompletionFunc(flagName string) (func(*Command, []string, string) ([]string, ShellCompDirective), bool) {
150+
flag := c.Flag(flagName)
151+
if flag == nil {
152+
return nil, false
153+
}
154+
150155
flagCompletionMutex.RLock()
151156
defer flagCompletionMutex.RUnlock()
152157

153158
completionFunc, exists := flagCompletionFunctions[flag]
154159
return completionFunc, exists
155160
}
156161

157-
// GetFlagCompletionByName returns the completion function for the given flag in the command by name, if available.
158-
func (c *Command) GetFlagCompletionByName(flagName string) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) {
159-
flag := c.Flags().Lookup(flagName)
160-
if flag == nil {
161-
return nil, false
162-
}
163-
164-
return GetFlagCompletion(flag)
165-
}
166-
167162
// Returns a string listing the different directive enabled in the specified parameter
168163
func (d ShellCompDirective) string() string {
169164
var directives []string

completions_test.go

+90
Original file line numberDiff line numberDiff line change
@@ -3427,3 +3427,93 @@ Completion ended with directive: ShellCompDirectiveNoFileComp
34273427
})
34283428
}
34293429
}
3430+
3431+
func TestGetFlagCompletion(t *testing.T) {
3432+
rootCmd := &Command{Use: "root", Run: emptyRun}
3433+
3434+
rootCmd.Flags().String("rootflag", "", "root flag")
3435+
_ = rootCmd.RegisterFlagCompletionFunc("rootflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
3436+
return []string{"rootvalue"}, ShellCompDirectiveKeepOrder
3437+
})
3438+
3439+
rootCmd.PersistentFlags().String("persistentflag", "", "persistent flag")
3440+
_ = rootCmd.RegisterFlagCompletionFunc("persistentflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
3441+
return []string{"persistentvalue"}, ShellCompDirectiveDefault
3442+
})
3443+
3444+
childCmd := &Command{Use: "child", Run: emptyRun}
3445+
3446+
childCmd.Flags().String("childflag", "", "child flag")
3447+
_ = childCmd.RegisterFlagCompletionFunc("childflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
3448+
return []string{"childvalue"}, ShellCompDirectiveNoFileComp | ShellCompDirectiveNoSpace
3449+
})
3450+
3451+
rootCmd.AddCommand(childCmd)
3452+
3453+
testcases := []struct {
3454+
desc string
3455+
cmd *Command
3456+
flagName string
3457+
exists bool
3458+
comps []string
3459+
directive ShellCompDirective
3460+
}{
3461+
{
3462+
desc: "get flag completion function for command",
3463+
cmd: rootCmd,
3464+
flagName: "rootflag",
3465+
exists: true,
3466+
comps: []string{"rootvalue"},
3467+
directive: ShellCompDirectiveKeepOrder,
3468+
},
3469+
{
3470+
desc: "get persistent flag completion function for command",
3471+
cmd: rootCmd,
3472+
flagName: "persistentflag",
3473+
exists: true,
3474+
comps: []string{"persistentvalue"},
3475+
directive: ShellCompDirectiveDefault,
3476+
},
3477+
{
3478+
desc: "get flag completion function for child command",
3479+
cmd: childCmd,
3480+
flagName: "childflag",
3481+
exists: true,
3482+
comps: []string{"childvalue"},
3483+
directive: ShellCompDirectiveNoFileComp | ShellCompDirectiveNoSpace,
3484+
},
3485+
{
3486+
desc: "get persistent flag completion function for child command",
3487+
cmd: childCmd,
3488+
flagName: "persistentflag",
3489+
exists: true,
3490+
comps: []string{"persistentvalue"},
3491+
directive: ShellCompDirectiveDefault,
3492+
},
3493+
{
3494+
desc: "cannot get flag completion function for local parent flag",
3495+
cmd: childCmd,
3496+
flagName: "rootflag",
3497+
exists: false,
3498+
},
3499+
}
3500+
3501+
for _, tc := range testcases {
3502+
t.Run(tc.desc, func(t *testing.T) {
3503+
compFunc, exists := tc.cmd.GetFlagCompletionFunc(tc.flagName)
3504+
if tc.exists != exists {
3505+
t.Errorf("Unexpected result looking for flag completion function")
3506+
}
3507+
3508+
if exists {
3509+
comps, directive := compFunc(tc.cmd, []string{}, "")
3510+
if strings.Join(tc.comps, " ") != strings.Join(comps, " ") {
3511+
t.Errorf("Unexpected completions %q", comps)
3512+
}
3513+
if tc.directive != directive {
3514+
t.Errorf("Unexpected directive %q", directive)
3515+
}
3516+
}
3517+
})
3518+
}
3519+
}

0 commit comments

Comments
 (0)