From 92db2833d6543a105f4168f0e15d66ce61a4fa60 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Tue, 16 Jun 2026 10:55:00 +0200 Subject: [PATCH 1/4] feat(oauth): add stdio OAuth 2.1 login core library Introduce internal/oauth, a self-contained library that performs the user-facing GitHub OAuth login the stdio server uses to obtain a token without a pre-provisioned PAT. It is independent of MCP: client concerns (elicitation) sit behind the Prompter interface so the flows are testable without a live session. What it provides: - Authorization-code + PKCE flow with a local loopback callback server, state/CSRF validation, and XSS-safe result pages. - Device-authorization flow as a fallback (headless, containers). - A Manager that selects the most secure available channel (browser auto-open -> URL elicitation -> last-resort user action), runs a single flow at a time, and exposes a refreshing token source. Both GitHub OAuth Apps and GitHub Apps are supported without special casing: the token is modeled as an x/oauth2 refreshing TokenSource, so expiring GitHub App user tokens are renewed transparently (the gap that made a stored-token approach silently die after ~8h). When a client lacks secure URL elicitation and the flow falls back to a tool-response message, the message advises the user that their agent/CLI/ IDE does not appear to support URL elicitation and suggests requesting it for improved security. Tests exercise real protocol behavior against an httptest GitHub stand-in: PKCE challenge/verifier, GitHub App refresh-on-expiry, device polling, URL elicitation, declined prompts, the last-resort action with advisory, and single-flight concurrency. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- go.mod | 2 +- internal/oauth/callback.go | 156 +++++++++++++++++ internal/oauth/callback_test.go | 82 +++++++++ internal/oauth/env.go | 56 ++++++ internal/oauth/flow.go | 164 ++++++++++++++++++ internal/oauth/manager.go | 241 ++++++++++++++++++++++++++ internal/oauth/manager_test.go | 230 ++++++++++++++++++++++++ internal/oauth/oauth.go | 100 +++++++++++ internal/oauth/oauth_test.go | 64 +++++++ internal/oauth/prompter.go | 55 ++++++ internal/oauth/templates/error.html | 60 +++++++ internal/oauth/templates/success.html | 56 ++++++ internal/oauth/testutil_test.go | 216 +++++++++++++++++++++++ 13 files changed, 1481 insertions(+), 1 deletion(-) create mode 100644 internal/oauth/callback.go create mode 100644 internal/oauth/callback_test.go create mode 100644 internal/oauth/env.go create mode 100644 internal/oauth/flow.go create mode 100644 internal/oauth/manager.go create mode 100644 internal/oauth/manager_test.go create mode 100644 internal/oauth/oauth.go create mode 100644 internal/oauth/oauth_test.go create mode 100644 internal/oauth/prompter.go create mode 100644 internal/oauth/templates/error.html create mode 100644 internal/oauth/templates/success.html create mode 100644 internal/oauth/testutil_test.go diff --git a/go.mod b/go.mod index 080cdcfd8e..66c7a974ad 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 github.com/yosida95/uritemplate/v3 v3.0.2 + golang.org/x/oauth2 v0.35.0 ) require ( @@ -40,7 +41,6 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/net v0.38.0 // indirect - golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.28.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/internal/oauth/callback.go b/internal/oauth/callback.go new file mode 100644 index 0000000000..6c2dfb0631 --- /dev/null +++ b/internal/oauth/callback.go @@ -0,0 +1,156 @@ +package oauth + +import ( + "context" + "embed" + "fmt" + "html/template" + "net" + "net/http" + "time" +) + +//go:embed templates/*.html +var templateFS embed.FS + +var ( + errorTemplate = template.Must(template.ParseFS(templateFS, "templates/error.html")) + successTemplate = template.Must(template.ParseFS(templateFS, "templates/success.html")) +) + +// callbackResult is delivered by the callback server once the browser redirect +// arrives. Exactly one of code or err is set. +type callbackResult struct { + code string + err error +} + +// callbackServer is a short-lived local HTTP server that captures the +// authorization code from the OAuth redirect. +type callbackServer struct { + server *http.Server + listener net.Listener + redirect string + results chan callbackResult +} + +// listenCallback binds the local callback listener. +// +// A random port (port == 0) binds to 127.0.0.1 only: the redirect target is +// loopback and never reachable off-host. A fixed port binds to all interfaces +// because Docker's published-port DNAT delivers traffic to the container's eth0 +// rather than to loopback; exposure is still constrained by the host-side +// publish (e.g. -p 127.0.0.1:8085:8085). +func listenCallback(port int) (net.Listener, error) { + host := "127.0.0.1" + if port > 0 { + host = "0.0.0.0" + } + addr := fmt.Sprintf("%s:%d", host, port) + listener, err := net.Listen("tcp", addr) + if err != nil { + return nil, fmt.Errorf("starting callback listener on %s: %w", addr, err) + } + return listener, nil +} + +// newCallbackServer starts a callback server on listener that validates state +// and reports the result on a buffered channel. The redirect URI always uses +// localhost so it matches the value registered on the OAuth/GitHub App. +func newCallbackServer(listener net.Listener, expectedState string) *callbackServer { + cs := &callbackServer{ + server: &http.Server{ReadHeaderTimeout: 10 * time.Second}, // ReadHeaderTimeout guards against Slowloris. + listener: listener, + redirect: fmt.Sprintf("http://localhost:%d/callback", listener.Addr().(*net.TCPAddr).Port), + results: make(chan callbackResult, 1), + } + cs.server.Handler = cs.handler(expectedState) + + go func() { + if err := cs.server.Serve(listener); err != nil && err != http.ErrServerClosed { + cs.report(callbackResult{err: fmt.Errorf("callback server: %w", err)}) + } + }() + + return cs +} + +// handler renders the callback endpoint. It reports the outcome exactly once and +// always shows the user a friendly page. +func (cs *callbackServer) handler(expectedState string) http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + if errCode := q.Get("error"); errCode != "" { + msg := errCode + if desc := q.Get("error_description"); desc != "" { + msg = fmt.Sprintf("%s: %s", errCode, desc) + } + cs.report(callbackResult{err: fmt.Errorf("authorization failed: %s", msg)}) + renderError(w, msg) + return + } + + if q.Get("state") != expectedState { + cs.report(callbackResult{err: fmt.Errorf("state mismatch (possible CSRF)")}) + renderError(w, "state mismatch") + return + } + + code := q.Get("code") + if code == "" { + cs.report(callbackResult{err: fmt.Errorf("no authorization code in callback")}) + renderError(w, "no authorization code received") + return + } + + cs.report(callbackResult{code: code}) + renderSuccess(w) + }) + return mux +} + +// report delivers the first outcome and drops later ones (the channel is +// buffered for one; subsequent redirect retries must not block the handler). +func (cs *callbackServer) report(res callbackResult) { + select { + case cs.results <- res: + default: + } +} + +// wait blocks for the callback outcome or ctx cancellation, then shuts the +// server down. It is safe to call once per server. +func (cs *callbackServer) wait(ctx context.Context) (string, error) { + defer cs.close() + select { + case res := <-cs.results: + return res.code, res.err + case <-ctx.Done(): + return "", ctx.Err() + } +} + +func (cs *callbackServer) close() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = cs.server.Shutdown(shutdownCtx) + _ = cs.listener.Close() +} + +func renderSuccess(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := successTemplate.Execute(w, nil); err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + } +} + +// renderError shows the failure page. html/template auto-escapes msg, so a +// hostile error_description cannot inject markup. +func renderError(w http.ResponseWriter, msg string) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := errorTemplate.Execute(w, struct{ ErrorMessage string }{ErrorMessage: msg}); err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + } +} diff --git a/internal/oauth/callback_test.go b/internal/oauth/callback_test.go new file mode 100644 index 0000000000..df517dff50 --- /dev/null +++ b/internal/oauth/callback_test.go @@ -0,0 +1,82 @@ +package oauth + +import ( + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// serveCallback drives the callback handler with the given query string and +// returns the recorded response and the single reported result. +func serveCallback(t *testing.T, expectedState, query string) (*httptest.ResponseRecorder, callbackResult) { + t.Helper() + cs := &callbackServer{results: make(chan callbackResult, 1)} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/callback?"+query, nil) + + cs.handler(expectedState).ServeHTTP(rec, req) + + select { + case res := <-cs.results: + return rec, res + default: + t.Fatal("handler did not report a result") + return nil, callbackResult{} + } +} + +func TestCallbackHandlerSuccess(t *testing.T) { + rec, res := serveCallback(t, "state123", "code=the-code&state=state123") + + require.NoError(t, res.err) + assert.Equal(t, "the-code", res.code) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Body.String(), "Authorization Successful") +} + +func TestCallbackHandlerStateMismatch(t *testing.T) { + rec, res := serveCallback(t, "expected", "code=the-code&state=attacker") + + require.Error(t, res.err) + assert.Empty(t, res.code) + assert.Contains(t, res.err.Error(), "state mismatch") + assert.Contains(t, rec.Body.String(), "state mismatch") +} + +func TestCallbackHandlerMissingCode(t *testing.T) { + _, res := serveCallback(t, "state123", "state=state123") + + require.Error(t, res.err) + assert.Contains(t, res.err.Error(), "no authorization code") +} + +func TestCallbackHandlerOAuthError(t *testing.T) { + _, res := serveCallback(t, "state123", "error=access_denied&error_description=user+said+no") + + require.Error(t, res.err) + assert.Contains(t, res.err.Error(), "access_denied") + assert.Contains(t, res.err.Error(), "user said no") +} + +func TestCallbackHandlerEscapesError(t *testing.T) { + rec, _ := serveCallback(t, "state123", "error=evil&error_description=%3Cscript%3Ealert(1)%3C%2Fscript%3E") + + body := rec.Body.String() + assert.NotContains(t, body, "