From 94c1bee8d0211fbcb01826ef58ff9dfbfbfc7ae6 Mon Sep 17 00:00:00 2001 From: Vincent Cheung <71457708+vcheung-stripe@users.noreply.github.com> Date: Fri, 18 Oct 2024 13:44:33 -0700 Subject: [PATCH] Validate base URLs (#1264) * Validate base URLs * Fix lint error * allow 127.0.0.1 for tests * cleanup --- pkg/cmd/fixtures.go | 4 +++ pkg/cmd/listen.go | 4 +++ pkg/cmd/login.go | 4 +++ pkg/cmd/logs/tail.go | 4 +++ pkg/cmd/plugin/install.go | 4 +++ pkg/cmd/plugin/upgrade.go | 4 +++ pkg/cmd/resource/operation.go | 5 +++ pkg/cmd/trigger.go | 4 +++ pkg/requests/base.go | 4 +++ pkg/stripe/client.go | 9 ----- pkg/stripe/url.go | 64 +++++++++++++++++++++++++++++++++++ pkg/stripe/url_test.go | 35 +++++++++++++++++++ 12 files changed, 136 insertions(+), 9 deletions(-) create mode 100644 pkg/stripe/url.go create mode 100644 pkg/stripe/url_test.go diff --git a/pkg/cmd/fixtures.go b/pkg/cmd/fixtures.go index 73b1f066..d0d0ec90 100644 --- a/pkg/cmd/fixtures.go +++ b/pkg/cmd/fixtures.go @@ -58,6 +58,10 @@ func newFixturesCmd(cfg *config.Config) *FixturesCmd { func (fc *FixturesCmd) runFixturesCmd(cmd *cobra.Command, args []string) error { version.CheckLatestVersion() + if err := stripe.ValidateAPIBaseURL(fc.apiBaseURL); err != nil { + return err + } + apiKey, err := fc.Cfg.Profile.GetAPIKey(false) if err != nil { return err diff --git a/pkg/cmd/listen.go b/pkg/cmd/listen.go index 37c5e313..c6539afa 100644 --- a/pkg/cmd/listen.go +++ b/pkg/cmd/listen.go @@ -116,6 +116,10 @@ Stripe account.`, // Normally, this function would be listed alphabetically with the others declared in this file, // but since it's acting as the core functionality for the cmd above, I'm keeping it close. func (lc *listenCmd) runListenCmd(cmd *cobra.Command, args []string) error { + if err := stripe.ValidateAPIBaseURL(lc.apiBaseURL); err != nil { + return err + } + if !lc.printJSON && !lc.onlyPrintSecret && !lc.skipUpdate { version.CheckLatestVersion() } diff --git a/pkg/cmd/login.go b/pkg/cmd/login.go index e5fd503a..4b70ad3b 100644 --- a/pkg/cmd/login.go +++ b/pkg/cmd/login.go @@ -34,6 +34,10 @@ func newLoginCmd() *loginCmd { } func (lc *loginCmd) runLoginCmd(cmd *cobra.Command, args []string) error { + if err := stripe.ValidateDashboardBaseURL(lc.dashboardBaseURL); err != nil { + return err + } + if lc.interactive { return login.InteractiveLogin(cmd.Context(), &Config) } diff --git a/pkg/cmd/logs/tail.go b/pkg/cmd/logs/tail.go index e414ffec..9501a202 100644 --- a/pkg/cmd/logs/tail.go +++ b/pkg/cmd/logs/tail.go @@ -146,6 +146,10 @@ func withSIGTERMCancel(ctx context.Context, onCancel func()) context.Context { } func (tailCmd *TailCmd) runTailCmd(cmd *cobra.Command, args []string) error { + if err := stripe.ValidateAPIBaseURL(tailCmd.apiBaseURL); err != nil { + return err + } + err := tailCmd.validateArgs() if err != nil { return err diff --git a/pkg/cmd/plugin/install.go b/pkg/cmd/plugin/install.go index bf094663..e3dd7b46 100644 --- a/pkg/cmd/plugin/install.go +++ b/pkg/cmd/plugin/install.go @@ -90,6 +90,10 @@ func (ic *InstallCmd) installPluginByName(cmd *cobra.Command, arg string) error } func (ic *InstallCmd) runInstallCmd(cmd *cobra.Command, args []string) error { + if err := stripe.ValidateAPIBaseURL(ic.apiBaseURL); err != nil { + return err + } + var err error color := ansi.Color(os.Stdout) diff --git a/pkg/cmd/plugin/upgrade.go b/pkg/cmd/plugin/upgrade.go index 39b6ecbe..a63c7c44 100644 --- a/pkg/cmd/plugin/upgrade.go +++ b/pkg/cmd/plugin/upgrade.go @@ -46,6 +46,10 @@ func NewUpgradeCmd(config *config.Config) *UpgradeCmd { } func (uc *UpgradeCmd) runUpgradeCmd(cmd *cobra.Command, args []string) error { + if err := stripe.ValidateAPIBaseURL(uc.apiBaseURL); err != nil { + return err + } + ctx := withSIGTERMCancel(cmd.Context(), func() { log.WithFields(log.Fields{ "prefix": "cmd.upgradeCmd.runUpgradeCmd", diff --git a/pkg/cmd/resource/operation.go b/pkg/cmd/resource/operation.go index 6d347d9e..e2bab8fc 100644 --- a/pkg/cmd/resource/operation.go +++ b/pkg/cmd/resource/operation.go @@ -12,6 +12,7 @@ import ( "github.com/stripe/stripe-cli/pkg/ansi" "github.com/stripe/stripe-cli/pkg/config" "github.com/stripe/stripe-cli/pkg/requests" + "github.com/stripe/stripe-cli/pkg/stripe" "github.com/stripe/stripe-cli/pkg/validators" ) @@ -40,6 +41,10 @@ type OperationCmd struct { } func (oc *OperationCmd) runOperationCmd(cmd *cobra.Command, args []string) error { + if err := stripe.ValidateAPIBaseURL(oc.APIBaseURL); err != nil { + return err + } + apiKey, err := oc.Profile.GetAPIKey(oc.Livemode) if err != nil { return err diff --git a/pkg/cmd/trigger.go b/pkg/cmd/trigger.go index f4fa1dae..cfc4298f 100644 --- a/pkg/cmd/trigger.go +++ b/pkg/cmd/trigger.go @@ -69,6 +69,10 @@ needed to create the triggered event as well as the corresponding API objects. func (tc *triggerCmd) runTriggerCmd(cmd *cobra.Command, args []string) error { version.CheckLatestVersion() + if err := stripe.ValidateAPIBaseURL(tc.apiBaseURL); err != nil { + return err + } + if len(args) == 0 { cmd.Help() diff --git a/pkg/requests/base.go b/pkg/requests/base.go index 4d17ab61..907b93bb 100644 --- a/pkg/requests/base.go +++ b/pkg/requests/base.go @@ -109,6 +109,10 @@ var confirmationCommands = map[string]bool{http.MethodDelete: true} // RunRequestsCmd is the interface exposed for the CLI to run network requests through func (rb *Base) RunRequestsCmd(cmd *cobra.Command, args []string) error { + if err := stripe.ValidateAPIBaseURL(rb.APIBaseURL); err != nil { + return err + } + if len(args) > 1 { return fmt.Errorf("this command only supports one argument. Run with the --help flag to see usage and examples") } diff --git a/pkg/stripe/client.go b/pkg/stripe/client.go index 9708df73..9577ae48 100644 --- a/pkg/stripe/client.go +++ b/pkg/stripe/client.go @@ -15,15 +15,6 @@ import ( "github.com/stripe/stripe-cli/pkg/useragent" ) -// DefaultAPIBaseURL is the default base URL for API requests -const DefaultAPIBaseURL = "https://api.stripe.com" - -// DefaultFilesAPIBaseURL is the default base URL for Files API requsts -const DefaultFilesAPIBaseURL = "https://files.stripe.com" - -// DefaultDashboardBaseURL is the default base URL for dashboard requests -const DefaultDashboardBaseURL = "https://dashboard.stripe.com" - // APIVersion is API version used in CLI const APIVersion = "2019-03-14" diff --git a/pkg/stripe/url.go b/pkg/stripe/url.go new file mode 100644 index 00000000..8d7a9165 --- /dev/null +++ b/pkg/stripe/url.go @@ -0,0 +1,64 @@ +package stripe + +import ( + "errors" + "regexp" +) + +const ( + // DefaultAPIBaseURL is the default base URL for API requests + DefaultAPIBaseURL = "https://api.stripe.com" + qaAPIBaseURL = "https://qa-api.stripe.com" + devAPIBaseURLRegexp = `http(s)?:\/\/[A-Za-z0-9\-]+api-mydev.dev.stripe.me` + + // DefaultFilesAPIBaseURL is the default base URL for Files API requsts + DefaultFilesAPIBaseURL = "https://files.stripe.com" + + // DefaultDashboardBaseURL is the default base URL for dashboard requests + DefaultDashboardBaseURL = "https://dashboard.stripe.com" + qaDashboardBaseURL = "https://qa-dashboard.stripe.com" + devDashboardBaseURLRegexp = `http(s)?:\/\/[A-Za-z0-9\-]+manage-mydev\.dev\.stripe\.me` + + // localhostURLRegexp is used in tests + localhostURLRegexp = `http:\/\/127\.0\.0\.1(:[0-9]+)?` +) + +var ( + errInvalidAPIBaseURL = errors.New("invalid API base URL") + errInvalidDashboardBaseURL = errors.New("invalid dashboard base URL") +) + +func isValid(url string, exactStrings []string, regexpStrings []string) bool { + for _, s := range exactStrings { + if url == s { + return true + } + } + for _, r := range regexpStrings { + matched, err := regexp.Match(r, []byte(url)) + if err == nil && matched { + return true + } + } + return false +} + +// ValidateAPIBaseURL returns an error if apiBaseURL isn't allowed +func ValidateAPIBaseURL(apiBaseURL string) error { + exactStrings := []string{DefaultAPIBaseURL, qaAPIBaseURL} + regexpStrings := []string{devAPIBaseURLRegexp, localhostURLRegexp} + if isValid(apiBaseURL, exactStrings, regexpStrings) { + return nil + } + return errInvalidAPIBaseURL +} + +// ValidateDashboardBaseURL returns an error if dashboardBaseURL isn't allowed +func ValidateDashboardBaseURL(dashboardBaseURL string) error { + exactStrings := []string{DefaultDashboardBaseURL, qaDashboardBaseURL} + regexpStrings := []string{devDashboardBaseURLRegexp, localhostURLRegexp} + if isValid(dashboardBaseURL, exactStrings, regexpStrings) { + return nil + } + return errInvalidDashboardBaseURL +} diff --git a/pkg/stripe/url_test.go b/pkg/stripe/url_test.go new file mode 100644 index 00000000..91df1dcd --- /dev/null +++ b/pkg/stripe/url_test.go @@ -0,0 +1,35 @@ +package stripe + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateAPIBaseURLWorks(t *testing.T) { + assert.Nil(t, ValidateAPIBaseURL("https://api.stripe.com")) + assert.Nil(t, ValidateAPIBaseURL("https://qa-api.stripe.com")) + assert.Nil(t, ValidateAPIBaseURL("http://foo-api-mydev.dev.stripe.me")) + assert.Nil(t, ValidateAPIBaseURL("https://foo-lv5r9y--api-mydev.dev.stripe.me/")) + assert.Nil(t, ValidateAPIBaseURL("http://127.0.0.1")) + assert.Nil(t, ValidateAPIBaseURL("http://127.0.0.1:1337")) + + assert.ErrorIs(t, ValidateAPIBaseURL("https://example.com"), errInvalidAPIBaseURL) + assert.ErrorIs(t, ValidateAPIBaseURL("https://unknowndomain"), errInvalidAPIBaseURL) + assert.ErrorIs(t, ValidateAPIBaseURL("localhost"), errInvalidAPIBaseURL) + assert.ErrorIs(t, ValidateAPIBaseURL("anything_else"), errInvalidAPIBaseURL) +} + +func TestValidateDashboardBaseURLWorks(t *testing.T) { + assert.Nil(t, ValidateDashboardBaseURL("https://dashboard.stripe.com")) + assert.Nil(t, ValidateDashboardBaseURL("https://qa-dashboard.stripe.com")) + assert.Nil(t, ValidateDashboardBaseURL("http://foo-manage-mydev.dev.stripe.me")) + assert.Nil(t, ValidateDashboardBaseURL("https://foo-lv5r9y--manage-mydev.dev.stripe.me/")) + assert.Nil(t, ValidateDashboardBaseURL("http://127.0.0.1")) + assert.Nil(t, ValidateDashboardBaseURL("http://127.0.0.1:1337")) + + assert.ErrorIs(t, ValidateDashboardBaseURL("https://example.com"), errInvalidDashboardBaseURL) + assert.ErrorIs(t, ValidateDashboardBaseURL("https://unknowndomain"), errInvalidDashboardBaseURL) + assert.ErrorIs(t, ValidateDashboardBaseURL("localhost"), errInvalidDashboardBaseURL) + assert.ErrorIs(t, ValidateDashboardBaseURL("anything_else"), errInvalidDashboardBaseURL) +}