Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 109 additions & 5 deletions cmd/src/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@ import (
"fmt"
"io"
"os"
"os/exec"
"runtime"
"strings"
"time"

"github.com/sourcegraph/src-cli/internal/api"
"github.com/sourcegraph/src-cli/internal/cmderrors"
"github.com/sourcegraph/src-cli/internal/oauth"
)

func init() {
usage := `'src login' helps you authenticate 'src' to access a Sourcegraph instance with your user credentials.

Usage:

src login SOURCEGRAPH_URL
src login [flags] SOURCEGRAPH_URL

Examples:

Expand All @@ -28,6 +32,15 @@ Examples:
Authenticate to Sourcegraph.com:

$ src login https://sourcegraph.com

Use OAuth device flow to authenticate:

$ src login --oauth https://sourcegraph.com


Override the default client id used during device flow when authenticating:

$ src login --oauth https://sourcegraph.com
`

flagSet := flag.NewFlagSet("login", flag.ExitOnError)
Expand All @@ -38,6 +51,7 @@ Examples:

var (
apiFlags = api.NewFlags(flagSet)
useOAuth = flagSet.Bool("oauth", false, "Use OAuth device flow to obtain an access token interactively")
)

handler := func(args []string) error {
Expand All @@ -54,7 +68,15 @@ Examples:

client := cfg.apiClient(apiFlags, io.Discard)

return loginCmd(context.Background(), cfg, client, endpoint, os.Stdout)
return loginCmd(context.Background(), loginParams{
cfg: cfg,
client: client,
endpoint: endpoint,
out: os.Stdout,
useOAuth: *useOAuth,
apiFlags: apiFlags,
deviceFlowClient: oauth.NewClient(oauth.DefaultClientID),
})
}

commands = append(commands, &command{
Expand All @@ -64,8 +86,21 @@ Examples:
})
}

func loginCmd(ctx context.Context, cfg *config, client api.Client, endpointArg string, out io.Writer) error {
endpointArg = cleanEndpoint(endpointArg)
type loginParams struct {
cfg *config
client api.Client
endpoint string
out io.Writer
useOAuth bool
apiFlags *api.Flags
deviceFlowClient oauth.Client
}

func loginCmd(ctx context.Context, p loginParams) error {
endpointArg := cleanEndpoint(p.endpoint)
cfg := p.cfg
client := p.client
out := p.out

printProblem := func(problem string) {
fmt.Fprintf(out, "❌ Problem: %s\n", problem)
Expand All @@ -86,7 +121,19 @@ func loginCmd(ctx context.Context, cfg *config, client api.Client, endpointArg s

noToken := cfg.AccessToken == ""
endpointConflict := endpointArg != cfg.Endpoint
if noToken || endpointConflict {

if p.useOAuth {
token, err := runOAuthDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient)
if err != nil {
printProblem(fmt.Sprintf("OAuth Device flow authentication failed: %s", err))
fmt.Fprintln(out, createAccessTokenMessage)
return cmderrors.ExitCode1
}

cfg.AccessToken = token
cfg.Endpoint = endpointArg
client = cfg.apiClient(p.apiFlags, out)
} else if noToken || endpointConflict {
fmt.Fprintln(out)
switch {
case noToken:
Expand Down Expand Up @@ -122,6 +169,63 @@ func loginCmd(ctx context.Context, cfg *config, client api.Client, endpointArg s
}
fmt.Fprintln(out)
fmt.Fprintf(out, "✔️ Authenticated as %s on %s\n", result.CurrentUser.Username, endpointArg)

if p.useOAuth {
fmt.Fprintln(out)
fmt.Fprintf(out, "Authenticated with OAuth credentials")
}

fmt.Fprintln(out)
return nil
}

func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (string, error) {
authResp, err := client.Start(ctx, endpoint, nil)
if err != nil {
return "", err
}

authURL := authResp.VerificationURIComplete
msg := fmt.Sprintf("If your browser did not open automatically, visit %s.", authURL)
if authURL == "" {
authURL = authResp.VerificationURI
msg = fmt.Sprintf("If your browser did not open automatically, visit %s and enter the user code %s", authURL, authResp.DeviceCode)
}
_ = openInBrowser(authURL)
fmt.Fprintln(out)
fmt.Fprint(out, msg)

fmt.Fprintln(out)
fmt.Fprint(out, "Waiting for authorization...")
defer fmt.Fprintf(out, "DONE\n\n")

interval := time.Duration(authResp.Interval) * time.Second
if interval <= 0 {
interval = 5 * time.Second
}

tokenResp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn)
if err != nil {
return "", err
}

return tokenResp.AccessToken, nil
}

func openInBrowser(url string) error {
if url == "" {
return nil
}

var cmd *exec.Cmd
switch runtime.GOOS {
case "darwin":
cmd = exec.Command("open", url)
case "windows":
// "start" is a cmd.exe built-in; the empty string is the window title.
cmd = exec.Command("cmd", "/c", "start", "", url)
default:
cmd = exec.Command("xdg-open", url)
}
return cmd.Run()
}
2 changes: 1 addition & 1 deletion cmd/src/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestLogin(t *testing.T) {
t.Helper()

var out bytes.Buffer
err = loginCmd(context.Background(), cfg, cfg.apiClient(nil, io.Discard), endpointArg, &out)
err = loginCmd(context.Background(), loginParams{cfg: cfg, client: cfg.apiClient(nil, io.Discard), endpoint: endpointArg, out: &out})
return strings.TrimSpace(out.String()), err
}

Expand Down
Loading
Loading