Skip to content

Commit

Permalink
addressing pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
renan-campos committed Jun 29, 2022
1 parent d7e73ea commit ca547e1
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 34 deletions.
21 changes: 8 additions & 13 deletions cmd/edit/service/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"os"

"github.com/spf13/cobra"
"github.com/spf13/pflag"

cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
"github.com/openshift/rosa/pkg/arguments"
Expand Down Expand Up @@ -54,13 +53,11 @@ func init() {
}

func run(cmd *cobra.Command, argv []string) {
r := rosa.NewRuntime().WithOCM().WithFlagChecker()
r := rosa.NewRuntime().WithOCM()
defer r.Cleanup()

// Adding known flags to flag checker before parsing the unknown flags
cmd.Flags().VisitAll(func(flag *pflag.Flag) {
r.FlagChecker.AddValidFlag(flag)
})
flagChecker := arguments.NewFlagCheck(cmd.Flags())

err := arguments.ParseUnknownFlags(cmd, argv)
if err != nil {
Expand All @@ -87,7 +84,6 @@ func run(cmd *cobra.Command, argv []string) {
os.Exit(1)
}

// Setting parameter flags as valid
addOn, err := r.OCMClient.GetAddOn(service.Service())
if err != nil {
r.Reporter.Errorf("Failed to get add-on %q: %s", service.Service(), err)
Expand All @@ -96,18 +92,17 @@ func run(cmd *cobra.Command, argv []string) {

addonParameters := addOn.Parameters()
addonParameters.Each(func(param *cmv1.AddOnParameter) bool {
r.FlagChecker.AddValidParameter(param.ID())
flagChecker.AddValidFlag(param.ID())
return true
})

// Now that rosa knows the expected fields to validate,
// Validate that all of the user-specified flags are valid.
cmd.Flags().VisitAll(func(flag *pflag.Flag) {
if !r.FlagChecker.IsValidFlag(flag) {
r.Reporter.Errorf("%q is not a valid flag", flag.Name)
os.Exit(1)
}
})
err = flagChecker.ValidateFlags(cmd.Flags())
if err != nil {
r.Reporter.Errorf(err.Error())
os.Exit(1)
}

args.Parameters = map[string]string{}
addonParameters.Each(func(param *cmv1.AddOnParameter) bool {
Expand Down
40 changes: 26 additions & 14 deletions pkg/arguments/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,46 @@

package arguments

import "github.com/spf13/pflag"
import (
"fmt"

"github.com/spf13/pflag"
)

type FlagCheck struct {
validFlags map[string]struct{}
}

func NewFlagCheck() *FlagCheck {
return &FlagCheck{
func NewFlagCheck(flags *pflag.FlagSet) *FlagCheck {
flagCheck := FlagCheck{
validFlags: map[string]struct{}{},
}
}

func (f *FlagCheck) AddValidFlag(flag *pflag.Flag) *FlagCheck {
f.validFlags[flag.Name] = struct{}{}
return f
flags.VisitAll(func(flag *pflag.Flag) {
flagCheck.AddValidFlag(flag.Name)
})
return &flagCheck
}

func (f *FlagCheck) AddValidParameter(parameterName string) *FlagCheck {
f.validFlags[parameterName] = struct{}{}
func (f *FlagCheck) AddValidFlag(flagName string) *FlagCheck {
f.validFlags[flagName] = struct{}{}
return f
}

func (f *FlagCheck) IsValidFlag(flag *pflag.Flag) bool {
_, found := f.validFlags[flag.Name]
func (f *FlagCheck) IsValidFlag(flagName string) bool {
_, found := f.validFlags[flagName]
return found
}

func (f *FlagCheck) IsValidParameter(parameterName string) bool {
_, found := f.validFlags[parameterName]
return found
func (f *FlagCheck) ValidateFlags(flags *pflag.FlagSet) error {
var invalidFlags string
flags.VisitAll(func(flag *pflag.Flag) {
if !f.IsValidFlag(flag.Name) {
invalidFlags += fmt.Sprintf("%q, ", flag.Name)
}
})
if invalidFlags != "" {
return fmt.Errorf("Invalid flags: %s", invalidFlags[:len(invalidFlags)-2])
}
return nil
}
7 changes: 0 additions & 7 deletions pkg/rosa/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,6 @@ func (r *Runtime) WithAWS() *Runtime {
return r
}

func (r *Runtime) WithFlagChecker() *Runtime {
if r.FlagChecker == nil {
r.FlagChecker = arguments.NewFlagCheck()
}
return r
}

func (r *Runtime) Cleanup() {
if r.OCMClient != nil {
if err := r.OCMClient.Close(); err != nil {
Expand Down

0 comments on commit ca547e1

Please sign in to comment.