Skip to content

Commit

Permalink
Add SortFlags option (spf13#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
n10v authored and eparis committed Mar 25, 2017
1 parent 9ff6c69 commit d90f37a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 11 deletions.
52 changes: 41 additions & 11 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,16 @@ type FlagSet struct {
// a custom error handler.
Usage func()

// SortFlags is used to indicate, if user wants to have sorted flags in
// help/usage messages.
SortFlags bool

name string
parsed bool
actual map[NormalizedName]*Flag
orderedActual []*Flag
formal map[NormalizedName]*Flag
orderedFormal []*Flag
shorthands map[byte]*Flag
args []string // arguments after flags
argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no --
Expand All @@ -156,7 +162,7 @@ type Flag struct {
Value Value // value as set
DefValue string // default value (as text); for usage message
Changed bool // If the user set the value (or if left to default)
NoOptDefVal string //default value (as text); if the flag is on the command line without any options
NoOptDefVal string // default value (as text); if the flag is on the command line without any options
Deprecated string // If this flag is deprecated, this string is the new or now thing to use
Hidden bool // used by cobra.Command to allow flags to be hidden from help/usage text
ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use
Expand Down Expand Up @@ -194,10 +200,12 @@ func sortFlags(flags map[NormalizedName]*Flag) []*Flag {
// "--getUrl" which may also be translated to "geturl" and everything will work.
func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) {
f.normalizeNameFunc = n
f.orderedFormal = f.orderedFormal[:0]
for k, v := range f.formal {
delete(f.formal, k)
nname := f.normalizeFlagName(string(k))
f.formal[nname] = v
f.orderedFormal = append(f.orderedFormal, v)
v.Name = string(nname)
}
}
Expand Down Expand Up @@ -229,10 +237,18 @@ func (f *FlagSet) SetOutput(output io.Writer) {
f.output = output
}

// VisitAll visits the flags in lexicographical order, calling fn for each.
// VisitAll visits the flags in lexicographical order or
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits all flags, even those not set.
func (f *FlagSet) VisitAll(fn func(*Flag)) {
for _, flag := range sortFlags(f.formal) {
var flags []*Flag
if f.SortFlags {
flags = sortFlags(f.formal)
} else {
flags = f.orderedFormal
}

for _, flag := range flags {
fn(flag)
}
}
Expand All @@ -253,22 +269,32 @@ func (f *FlagSet) HasAvailableFlags() bool {
return false
}

// VisitAll visits the command-line flags in lexicographical order, calling
// fn for each. It visits all flags, even those not set.
// VisitAll visits the command-line flags in lexicographical order or
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits all flags, even those not set.
func VisitAll(fn func(*Flag)) {
CommandLine.VisitAll(fn)
}

// Visit visits the flags in lexicographical order, calling fn for each.
// Visit visits the flags in lexicographical order or
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits only those flags that have been set.
func (f *FlagSet) Visit(fn func(*Flag)) {
for _, flag := range sortFlags(f.actual) {
var flags []*Flag
if f.SortFlags {
flags = sortFlags(f.actual)
} else {
flags = f.orderedActual
}

for _, flag := range flags {
fn(flag)
}
}

// Visit visits the command-line flags in lexicographical order, calling fn
// for each. It visits only those flags that have been set.
// Visit visits the command-line flags in lexicographical order or
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits only those flags that have been set.
func Visit(fn func(*Flag)) {
CommandLine.Visit(fn)
}
Expand Down Expand Up @@ -373,6 +399,7 @@ func (f *FlagSet) Set(name, value string) error {
f.actual = make(map[NormalizedName]*Flag)
}
f.actual[normalName] = flag
f.orderedActual = append(f.orderedActual, flag)
flag.Changed = true
if len(flag.Deprecated) > 0 {
fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
Expand Down Expand Up @@ -729,6 +756,7 @@ func (f *FlagSet) AddFlag(flag *Flag) {

flag.Name = string(normalizedFlagName)
f.formal[normalizedFlagName] = flag
f.orderedFormal = append(f.orderedFormal, flag)

if len(flag.Shorthand) == 0 {
return
Expand Down Expand Up @@ -807,6 +835,7 @@ func (f *FlagSet) setFlag(flag *Flag, value string, origArg string) error {
f.actual = make(map[NormalizedName]*Flag)
}
f.actual[f.normalizeFlagName(flag.Name)] = flag
f.orderedActual = append(f.orderedActual, flag)
flag.Changed = true
if len(flag.Deprecated) > 0 {
fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
Expand Down Expand Up @@ -1036,14 +1065,15 @@ func Parsed() bool {
// CommandLine is the default set of command-line flags, parsed from os.Args.
var CommandLine = NewFlagSet(os.Args[0], ExitOnError)

// NewFlagSet returns a new, empty flag set with the specified name and
// error handling property.
// NewFlagSet returns a new, empty flag set with the specified name,
// error handling property and SortFlags set to true.
func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet {
f := &FlagSet{
name: name,
errorHandling: errorHandling,
argsLenAtDash: -1,
interspersed: true,
SortFlags: true,
}
return f
}
Expand Down
35 changes: 35 additions & 0 deletions flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1004,3 +1004,38 @@ func TestPrintDefaults(t *testing.T) {
t.Errorf("got %q want %q\n", got, defaultOutput)
}
}

func TestVisitAllFlagOrder(t *testing.T) {
fs := NewFlagSet("TestVisitAllFlagOrder", ContinueOnError)
fs.SortFlags = false
names := []string{"C", "B", "A", "D"}
for _, name := range names {
fs.Bool(name, false, "")
}

i := 0
fs.VisitAll(func(f *Flag) {
if names[i] != f.Name {
t.Errorf("Incorrect order. Expected %v, got %v", names[i], f.Name)
}
i++
})
}

func TestVisitFlagOrder(t *testing.T) {
fs := NewFlagSet("TestVisitFlagOrder", ContinueOnError)
fs.SortFlags = false
names := []string{"C", "B", "A", "D"}
for _, name := range names {
fs.Bool(name, false, "")
fs.Set(name, "true")
}

i := 0
fs.Visit(func(f *Flag) {
if names[i] != f.Name {
t.Errorf("Incorrect order. Expected %v, got %v", names[i], f.Name)
}
i++
})
}

0 comments on commit d90f37a

Please sign in to comment.