Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sso): support active session account switching #1207

Merged
merged 12 commits into from
Jun 10, 2024
10 changes: 10 additions & 0 deletions pkg/app/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/fastly/cli/pkg/auth"
"github.com/fastly/cli/pkg/commands"
"github.com/fastly/cli/pkg/commands/compute"
"github.com/fastly/cli/pkg/commands/sso"
"github.com/fastly/cli/pkg/commands/update"
"github.com/fastly/cli/pkg/commands/version"
"github.com/fastly/cli/pkg/config"
Expand Down Expand Up @@ -356,6 +357,12 @@ func processToken(cmds []argparser.Command, data *global.Data) (token string, to
// Otherwise, for an existing SSO token, check its freshness.
reauth, err := checkAndRefreshSSOToken(profileData, profileName, data)
if err != nil {
// The following scenario is when the user wants to switch to another SSO
// profile that exists under a different auth session.
if errors.Is(err, auth.ErrInvalidGrant) {
sso.ForceReAuth = true
return ssoAuthentication("We can't refresh your token", cmds, data)
}
return token, tokenSource, fmt.Errorf("failed to check access/refresh token: %w", err)
}
if reauth {
Expand Down Expand Up @@ -394,6 +401,9 @@ func checkAndRefreshSSOToken(profileData *config.Profile, profileName string, da

updatedJWT, err := data.AuthServer.RefreshAccessToken(profileData.RefreshToken)
if err != nil {
if errors.Is(err, auth.ErrInvalidGrant) {
return false, err
}
return false, fmt.Errorf("failed to refresh access token: %w", err)
}

Expand Down
57 changes: 47 additions & 10 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"time"
Expand All @@ -34,6 +35,9 @@ const RedirectURL = "http://localhost:8080/callback"
// https://swagger.io/docs/specification/authentication/openid-connect-discovery/
const OIDCMetadata = "%s/realms/fastly/.well-known/openid-configuration"

// ErrInvalidGrant represents an error refreshing the user's token.
var ErrInvalidGrant = errors.New("failed to refresh token: invalid grant")

// WellKnownEndpoints represents the OpenID Connect metadata.
type WellKnownEndpoints struct {
// Auth is the authorization_endpoint.
Expand All @@ -54,6 +58,9 @@ type Runner interface {
// RefreshAccessToken constructs and calls the token_endpoint with the
// refresh token so we can refresh and return the access token.
RefreshAccessToken(refreshToken string) (JWT, error)
// SetParam sets the specified parameter for the authorization_endpoint.
// https://openid.net/specs/openid-connect-basic-1_0.html#rfc.section.2.1.1.1
SetParam(field, value string)
// Start starts a local server for handling authentication processing.
Start() error
// ValidateAndRetrieveAPIToken verifies the signature and the claims and
Expand All @@ -71,6 +78,8 @@ type Server struct {
DebugMode string
// HTTPClient is a HTTP client used to call the API to exchange the access token for a session token.
HTTPClient api.HTTPClient
// Params are additional parameters for the authorization_endpoint.
Params []Param
// Result is a channel that reports the result of authorization.
Result chan AuthorizationResult
// Router is an HTTP request multiplexer.
Expand All @@ -81,23 +90,36 @@ type Server struct {
WellKnownEndpoints WellKnownEndpoints
}

// Param is an individual parameter set on the authorization_endpoint.
type Param struct {
Field string
Value string
}

// AuthURL returns a fully qualified authorization_endpoint.
// i.e. path + audience + scope + code_challenge etc.
func (s Server) AuthURL() (string, error) {
challenge, err := oidc.CreateCodeChallenge(s.Verifier)
if err != nil {
return "", err
}
params := url.Values{}
params.Add("audience", s.APIEndpoint)
params.Add("scope", "openid")
params.Add("response_type", "code")
params.Add("client_id", ClientID)
params.Add("code_challenge", challenge)
params.Add("code_challenge_method", "S256")
params.Add("redirect_uri", RedirectURL)
for _, p := range s.Params {
params.Add(p.Field, p.Value)
}
return fmt.Sprintf("%s?%s", s.WellKnownEndpoints.Auth, params.Encode()), nil
}

authorizationURL := fmt.Sprintf(
"%s?audience=%s"+
"&scope=openid"+
"&response_type=code&client_id=%s"+
"&code_challenge=%s"+
"&code_challenge_method=S256&redirect_uri=%s",
s.WellKnownEndpoints.Auth, s.APIEndpoint, ClientID, challenge, RedirectURL)

return authorizationURL, nil
// SetParam sets the specified parameter for the authorization_endpoint.
func (s *Server) SetParam(field, value string) {
s.Params = append(s.Params, Param{field, value})
}

// GetResult returns the result channel.
Expand Down Expand Up @@ -367,7 +389,16 @@ func (s *Server) RefreshAccessToken(refreshToken string) (JWT, error) {
}

if res.StatusCode != http.StatusOK {
return JWT{}, fmt.Errorf("failed to refresh the access token (status: %s)", res.Status)
var re RefreshError
err = json.Unmarshal(body, &re)
if err != nil {
return JWT{}, err
}

if re.Error == "invalid_grant" {
return JWT{}, ErrInvalidGrant
}
return JWT{}, fmt.Errorf("non-2xx status: %s", res.Status)
}

var j JWT
Expand All @@ -379,6 +410,12 @@ func (s *Server) RefreshAccessToken(refreshToken string) (JWT, error) {
return j, nil
}

// RefreshError represents an error when refreshing the user's token.
type RefreshError struct {
Error string `json:"error"`
Description string `json:"error_description"`
}

// APIToken is returned from the /login-enhanced endpoint.
type APIToken struct {
// AccessToken is used to access the Fastly API.
Expand Down
11 changes: 8 additions & 3 deletions pkg/commands/profile/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ func (c *CreateCommand) Exec(in io.Reader, out io.Writer) (err error) {
return fsterr.ErrInvalidProfileSSOCombo
}

if c.Globals.Verbose() {
text.Break(out)
}
text.Output(out, "Creating profile '%s'", c.profile)

if profile.Exist(c.profile, c.Globals.Config.Profiles) {
return fsterr.RemediationError{
Inner: fmt.Errorf("profile '%s' already exists", c.profile),
Expand All @@ -72,8 +77,8 @@ func (c *CreateCommand) Exec(in io.Reader, out io.Writer) (err error) {
if err != nil {
return err
}
text.Break(out)
}
text.Break(out)

if c.sso {
// IMPORTANT: We need to set profile fields for `sso` command.
Expand Down Expand Up @@ -134,7 +139,7 @@ func (c *CreateCommand) staticTokenFlow(makeDefault bool, in io.Reader, out io.W
}

func promptForToken(in io.Reader, out io.Writer, errLog fsterr.LogInterface) (string, error) {
text.Output(out, "\nAn API token is used to authenticate requests to the Fastly API. To create a token, visit https://manage.fastly.com/account/personal/tokens\n\n")
text.Output(out, "An API token is used to authenticate requests to the Fastly API. To create a token, visit https://manage.fastly.com/account/personal/tokens\n\n")
token, err := text.InputSecure(out, text.Prompt("Fastly API token: "), in, validateTokenNotEmpty)
if err != nil {
errLog.Add(err)
Expand Down Expand Up @@ -268,7 +273,7 @@ func displayCfgPath(path string, out io.Writer) {
}

func (c *CreateCommand) promptForDefault(in io.Reader, out io.Writer) (bool, error) {
cont, err := text.AskYesNo(out, "Set this profile to be your default? [y/N] ", in)
cont, err := text.AskYesNo(out, "\nSet this profile to be your default? [y/N] ", in)
if err != nil {
c.Globals.ErrLog.Add(err)
return false, err
Expand Down
3 changes: 3 additions & 0 deletions pkg/commands/profile/switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ func (c *SwitchCommand) Exec(_ io.Reader, out io.Writer) error {
return fmt.Errorf("error saving config file: %w", err)
}

if c.Globals.Verbose() {
text.Break(out)
}
text.Success(out, "Profile switched to '%s'", c.profile)
return nil
}
3 changes: 3 additions & 0 deletions pkg/commands/profile/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ func (c *UpdateCommand) Exec(in io.Reader, out io.Writer) error {
if err != nil {
return fmt.Errorf("failed to identify the profile to update: %w", err)
}
if c.Globals.Verbose() {
text.Break(out)
}
text.Info(out, "Profile being updated: '%s'.\n\n", profileName)

err = c.updateToken(profileName, p, in, out)
Expand Down
21 changes: 18 additions & 3 deletions pkg/commands/sso/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ import (
"github.com/fastly/cli/pkg/text"
)

// ForceReAuth indicates we want to force a re-auth of the user's session.
// This variable is overridden by ../../app/run.go to force a re-auth.
var ForceReAuth = false

// RootCommand is the parent command for all subcommands in this package.
// It should be installed under the primary root command.
type RootCommand struct {
Expand Down Expand Up @@ -47,13 +51,14 @@ func NewRootCommand(parent argparser.Registerer, g *global.Data) *RootCommand {

// Exec implements the command interface.
func (c *RootCommand) Exec(in io.Reader, out io.Writer) error {
profileName, _ := c.identifyProfileAndFlow()

// We need to prompt the user, so they know we're about to open their web
// browser, but we also need to handle the scenario where the `sso` command is
// invoked indirectly via ../../app/run.go as that package will have its own
// (similar) prompt before invoking this command. So to avoid a double prompt,
// the app package will set `SkipAuthPrompt: true`.
if !c.Globals.SkipAuthPrompt && !c.Globals.Flags.AutoYes && !c.Globals.Flags.NonInteractive {
profileName, _ := c.identifyProfileAndFlow()
msg := fmt.Sprintf("We're going to authenticate the '%s' profile", profileName)
text.Important(out, "%s. We need to open your browser to authenticate you.", msg)
text.Break(out)
Expand Down Expand Up @@ -83,6 +88,13 @@ func (c *RootCommand) Exec(in io.Reader, out io.Writer) error {

text.Info(out, "Starting a local server to handle the authentication flow.")

// For creating/updating a profile we set `prompt` because we want to ensure
// that another session (from a different profile) doesn't cause unexpected
// errors for the user flow. This forces a re-auth.
if c.InvokedFromProfileCreate || c.InvokedFromProfileUpdate || ForceReAuth {
c.Globals.AuthServer.SetParam("prompt", "login")
}

authorizationURL, err := c.Globals.AuthServer.AuthURL()
if err != nil {
return fsterr.RemediationError{
Expand Down Expand Up @@ -215,8 +227,8 @@ func (c *RootCommand) processCreateProfile(ar auth.AuthorizationResult, profileN
// we'll call Set for its side effect of resetting all other profiles to have
// their Default field set to false.
if c.ProfileDefault { // this is set by the `profile create` command.
if p, ok := profile.SetDefault(c.ProfileCreateName, c.Globals.Config.Profiles); ok {
c.Globals.Config.Profiles = p
if ps, ok := profile.SetDefault(c.ProfileCreateName, c.Globals.Config.Profiles); ok {
c.Globals.Config.Profiles = ps
}
}
}
Expand Down Expand Up @@ -259,6 +271,9 @@ func createNewProfile(profileName string, makeDefault bool, p config.Profiles, a
return p
}

// editProfile mutates the given profile with JWT details returned from the SSO
// authentication process.
//
// IMPORTANT: Mutates the config.Profiles map type.
// We need to return the modified type so it can be safely reassigned.
func editProfile(profileName string, makeDefault bool, p config.Profiles, ar auth.AuthorizationResult) (config.Profiles, error) {
Expand Down
5 changes: 5 additions & 0 deletions pkg/testutil/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ type MockAuthServer struct {
Result chan auth.AuthorizationResult
}

// SetParam sets the specified parameter for the authorization_endpoint.
func (s MockAuthServer) SetParam(_, _ string) {
// no-op
}

// AuthURL returns a fully qualified authorization_endpoint.
// i.e. path + audience + scope + code_challenge etc.
func (s MockAuthServer) AuthURL() (string, error) {
Expand Down
Loading