Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions internal/guest/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,31 @@ func GenerateClaudeInitScript(mounts []session.VMMount, projectDir string, polic
sb.WriteString("XSEL_EOF\n")
sb.WriteString("chmod +x /usr/local/bin/xsel\n\n")

// xdg-open shim — signals the host to open a URL in the browser via VirtioFS
sb.WriteString("# Install browser-open shim (xdg-open)\n")
sb.WriteString("cat > /usr/local/bin/xdg-open << 'XDGOPEN_EOF'\n")
sb.WriteString("#!/bin/sh\n")
sb.WriteString("# Signals the host to open a URL in the default browser.\n")
sb.WriteString("# Writes the URL to a VirtioFS file; the host polls and opens it.\n")
sb.WriteString("URL=\"$1\"\n")
sb.WriteString("if [ -z \"$URL\" ]; then\n")
sb.WriteString(" exit 0\n")
sb.WriteString("fi\n")
sb.WriteString("# Atomic write via temp file + mv\n")
sb.WriteString("TMPFILE=$(mktemp /mnt/bootstrap/.open-url.XXXXXX 2>/dev/null) || exit 0\n")
sb.WriteString("printf '%s' \"$URL\" > \"$TMPFILE\"\n")
sb.WriteString("mv \"$TMPFILE\" /mnt/bootstrap/open-url\n")
sb.WriteString("# Wait up to 5s for host to acknowledge (remove the file)\n")
sb.WriteString("i=0\n")
sb.WriteString("while [ $i -lt 10 ] && [ -f /mnt/bootstrap/open-url ]; do\n")
sb.WriteString(" sleep 0.5\n")
sb.WriteString(" i=$((i + 1))\n")
sb.WriteString("done\n")
sb.WriteString("exit 0\n")
sb.WriteString("XDGOPEN_EOF\n")
sb.WriteString("chmod +x /usr/local/bin/xdg-open\n")
sb.WriteString("ln -sf /usr/local/bin/xdg-open /usr/local/bin/open\n\n")

// Create Claude config directory
sb.WriteString("# Create Claude configuration directory\n")
sb.WriteString("mkdir -p /home/claude/.claude\n")
Expand Down Expand Up @@ -508,6 +533,24 @@ func GenerateClaudeInitScript(mounts []session.VMMount, projectDir string, polic
sb.WriteString(" fi\n")
sb.WriteString("fi\n\n")

// Background OAuth callback relay poller
sb.WriteString("# Background OAuth callback relay poller\n")
sb.WriteString("(\n")
sb.WriteString(" while true; do\n")
sb.WriteString(" if [ -f /mnt/bootstrap/auth-callback ]; then\n")
sb.WriteString(" mv /mnt/bootstrap/auth-callback /tmp/auth-callback-$$ 2>/dev/null || { sleep 1; continue; }\n")
sb.WriteString(" CALLBACK_URL=$(cat /tmp/auth-callback-$$ 2>/dev/null) || true\n")
sb.WriteString(" rm -f /tmp/auth-callback-$$\n")
sb.WriteString(" case \"$CALLBACK_URL\" in\n")
sb.WriteString(" http://localhost:[0-9]*/*) \n")
sb.WriteString(" wget -q -O /dev/null \"$CALLBACK_URL\" 2>/dev/null || true\n")
sb.WriteString(" ;;\n")
sb.WriteString(" esac\n")
sb.WriteString(" fi\n")
sb.WriteString(" sleep 1\n")
sb.WriteString(" done\n")
sb.WriteString(") &\n\n")

// Background terminal resize watcher — polls VirtioFS termsize file and
// resizes PTYs when the host terminal dimensions change.
sb.WriteString("# Background terminal resize watcher\n")
Expand Down
24 changes: 19 additions & 5 deletions internal/vm/console_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ const escapeHelp = "\r\nSupported escape sequences:\r\n ~. Disconnect from ses
// EscapeWriter is not safe for concurrent use from multiple goroutines.
// It expects sequential Write() calls from a single source (stdin).
type EscapeWriter struct {
w io.Writer // underlying writer to forward bytes to
afterNewline bool // true if last byte was newline or at start
pendingTilde bool // true if we saw ~ and waiting for next char
detachCh chan struct{} // closed when ~. detected
stdout io.Writer // for printing help message
w io.Writer // underlying writer to forward bytes to
afterNewline bool // true if last byte was newline or at start
pendingTilde bool // true if we saw ~ and waiting for next char
detachCh chan struct{} // closed when ~. detected
stdout io.Writer // for printing help message
}

// NewEscapeWriter creates a new EscapeWriter that wraps w
Expand Down Expand Up @@ -110,6 +110,7 @@ type ConsoleClient struct {
conn net.Conn
termsizePath string
clipboardDir string
openURLDir string
}

// SetTermsizePath sets the path to the termsize file used for propagating
Expand All @@ -124,6 +125,12 @@ func (c *ConsoleClient) SetClipboardDir(path string) {
c.clipboardDir = path
}

// SetOpenURLDir sets the path to the bootstrap directory used for watching
// URL open requests from the VM guest via VirtioFS.
func (c *ConsoleClient) SetOpenURLDir(path string) {
c.openURLDir = path
}

// NewConsoleClient connects to a VM console Unix socket
func NewConsoleClient(socketPath string) (*ConsoleClient, error) {
conn, err := net.Dial("unix", socketPath)
Expand Down Expand Up @@ -185,6 +192,13 @@ func (c *ConsoleClient) Attach(stdin io.Reader, stdout io.Writer) error {
return fmt.Errorf("failed to read from console: %w", err)
}

// Start URL open watcher to handle guest browser-open requests via VirtioFS
openURLDone := make(chan struct{})
defer close(openURLDone)
if c.openURLDir != "" {
go watchOpenURL(openURLDone, c.openURLDir)
}

// Create escape writer for detecting ~. sequence
escapeWriter := NewEscapeWriter(c.conn, stdout)

Expand Down
106 changes: 106 additions & 0 deletions internal/vm/oauth_relay_darwin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
//go:build darwin

package vm

import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"sync"
"time"
)

// parseOAuthRedirect extracts the localhost port from an OAuth authorization URL's
// redirect_uri parameter. Returns the port and true if redirect_uri is
// http://localhost:<port>/..., otherwise returns ("", false).
func parseOAuthRedirect(rawURL string) (string, bool) {
u, err := url.Parse(rawURL)
if err != nil {
return "", false
}

redirectURI := u.Query().Get("redirect_uri")
if redirectURI == "" {
return "", false
}

r, err := url.Parse(redirectURI)
if err != nil {
return "", false
}

if r.Scheme != "http" {
return "", false
}

host := r.Hostname()
port := r.Port()
if host != "localhost" || port == "" {
return "", false
}

n, err := strconv.Atoi(port)
if err != nil || n < 1024 || n > 65535 {
return "", false
}

return port, true
}

// startOAuthRelay starts an HTTP server on 127.0.0.1:<port> that captures a single
// OAuth callback request, writes the full reconstructed URL to bootstrapDir/auth-callback,
// and responds with a success page. Shuts down after one request, on done channel close,
// or after a 5-minute timeout.
func startOAuthRelay(done <-chan struct{}, bootstrapDir string, port string) error {
ln, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", port))
if err != nil {
return err
}

mux := http.NewServeMux()

handled := make(chan struct{})
var once sync.Once

mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fired := false
once.Do(func() { fired = true })
if !fired {
http.Error(w, "already handled", http.StatusGone)
return
}

reconstructed := "http://localhost:" + port + r.URL.RequestURI()

callbackFile := filepath.Join(bootstrapDir, "auth-callback")
_ = os.WriteFile(callbackFile, []byte(reconstructed), 0o600)

debugLog("OAuth callback received, relaying to VM")

w.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = fmt.Fprint(w, "<!DOCTYPE html><html><body><p>Authentication successful. You can close this tab.</p></body></html>")

close(handled)
})

srv := &http.Server{Handler: mux}

go func() {
select {
case <-handled:
case <-done:
case <-time.After(5 * time.Minute):
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = srv.Shutdown(ctx)
}()

go srv.Serve(ln)
return nil
}
6 changes: 6 additions & 0 deletions internal/vm/oauth_relay_stub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
//go:build !darwin

package vm

func parseOAuthRedirect(rawURL string) (string, bool) { return "", false }

Check failure on line 5 in internal/vm/oauth_relay_stub.go

View workflow job for this annotation

GitHub Actions / Lint

func parseOAuthRedirect is unused (unused)
func startOAuthRelay(done <-chan struct{}, bootstrapDir string, port string) error { return nil }

Check failure on line 6 in internal/vm/oauth_relay_stub.go

View workflow job for this annotation

GitHub Actions / Lint

func startOAuthRelay is unused (unused)
173 changes: 173 additions & 0 deletions internal/vm/oauth_relay_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
//go:build darwin

package vm

import (
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"testing"
)

func TestParseOAuthRedirect(t *testing.T) {
tests := []struct {
name string
rawURL string
wantPort string
wantMatch bool
}{
{
name: "standard OAuth URL",
rawURL: "https://auth.example.com/authorize?client_id=abc&redirect_uri=http%3A%2F%2Flocalhost%3A38449%2Fcallback&state=xyz",
wantPort: "38449",
wantMatch: true,
},
{
name: "different port",
rawURL: "https://auth.example.com/authorize?redirect_uri=http%3A%2F%2Flocalhost%3A12345%2Fcallback",
wantPort: "12345",
wantMatch: true,
},
{
name: "redirect_uri with path and query",
rawURL: "https://auth.example.com/authorize?redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Foauth%2Fcallback%3Ffoo%3Dbar",
wantPort: "8080",
wantMatch: true,
},
{
name: "no redirect_uri param",
rawURL: "https://auth.example.com/authorize?client_id=abc",
wantPort: "",
wantMatch: false,
},
{
name: "non-localhost redirect",
rawURL: "https://auth.example.com/authorize?redirect_uri=http%3A%2F%2Fexample.com%3A8080%2Fcallback",
wantPort: "",
wantMatch: false,
},
{
name: "HTTPS redirect_uri",
rawURL: "https://auth.example.com/authorize?redirect_uri=https%3A%2F%2Flocalhost%3A8080%2Fcallback",
wantPort: "",
wantMatch: false,
},
{
name: "localhost without port",
rawURL: "https://auth.example.com/authorize?redirect_uri=http%3A%2F%2Flocalhost%2Fcallback",
wantPort: "",
wantMatch: false,
},
{
name: "empty URL",
rawURL: "",
wantPort: "",
wantMatch: false,
},
{
name: "malformed URL",
rawURL: "://not-a-url",
wantPort: "",
wantMatch: false,
},
{
name: "127.0.0.1 instead of localhost",
rawURL: "https://auth.example.com/authorize?redirect_uri=http%3A%2F%2F127.0.0.1%3A8080%2Fcallback",
wantPort: "",
wantMatch: false,
},
{
name: "privileged port",
rawURL: "https://auth.example.com/authorize?redirect_uri=http%3A%2F%2Flocalhost%3A80%2Fcallback",
wantPort: "",
wantMatch: false,
},
{
name: "port zero",
rawURL: "https://auth.example.com/authorize?redirect_uri=http%3A%2F%2Flocalhost%3A0%2Fcallback",
wantPort: "",
wantMatch: false,
},
{
name: "port overflow",
rawURL: "https://auth.example.com/authorize?redirect_uri=http%3A%2F%2Flocalhost%3A99999%2Fcallback",
wantPort: "",
wantMatch: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotPort, gotMatch := parseOAuthRedirect(tt.rawURL)
if gotMatch != tt.wantMatch {
t.Errorf("parseOAuthRedirect(%q) match = %v, want %v", tt.rawURL, gotMatch, tt.wantMatch)
}
if gotPort != tt.wantPort {
t.Errorf("parseOAuthRedirect(%q) port = %q, want %q", tt.rawURL, gotPort, tt.wantPort)
}
})
}
}

func TestStartOAuthRelay(t *testing.T) {
// Pick a free port
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
port := ln.Addr().(*net.TCPAddr).Port
ln.Close()

tmpDir := t.TempDir()
done := make(chan struct{})
defer close(done)

portStr := fmt.Sprintf("%d", port)
if err := startOAuthRelay(done, tmpDir, portStr); err != nil {
t.Fatalf("startOAuthRelay: %v", err)
}

// Hit the relay
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/callback?code=abc", port))
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Errorf("status = %d, want 200", resp.StatusCode)
}

// Check the callback file was written
data, err := os.ReadFile(filepath.Join(tmpDir, "auth-callback"))
if err != nil {
t.Fatalf("read auth-callback: %v", err)
}

want := "http://localhost:" + portStr + "/callback?code=abc"
if string(data) != want {
t.Errorf("auth-callback = %q, want %q", string(data), want)
}
}

func TestStartOAuthRelayPortConflict(t *testing.T) {
// Bind a port
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()

port := ln.Addr().(*net.TCPAddr).Port
portStr := fmt.Sprintf("%d", port)

done := make(chan struct{})
defer close(done)

// Should fail because port is already bound
if err := startOAuthRelay(done, t.TempDir(), portStr); err == nil {
t.Error("expected error for occupied port, got nil")
}
}
Loading
Loading