diff --git a/go.mod b/go.mod index 080cdcfd8e..66c7a974ad 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 github.com/yosida95/uritemplate/v3 v3.0.2 + golang.org/x/oauth2 v0.35.0 ) require ( @@ -40,7 +41,6 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/net v0.38.0 // indirect - golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.28.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/internal/oauth/callback.go b/internal/oauth/callback.go new file mode 100644 index 0000000000..1e643e207d --- /dev/null +++ b/internal/oauth/callback.go @@ -0,0 +1,157 @@ +package oauth + +import ( + "context" + "embed" + "fmt" + "html/template" + "net" + "net/http" + "time" +) + +//go:embed templates/*.html +var templateFS embed.FS + +var ( + errorTemplate = template.Must(template.ParseFS(templateFS, "templates/error.html")) + successTemplate = template.Must(template.ParseFS(templateFS, "templates/success.html")) +) + +// callbackResult is delivered by the callback server once the browser redirect +// arrives. Exactly one of code or err is set. +type callbackResult struct { + code string + err error +} + +// callbackServer is a short-lived local HTTP server that captures the +// authorization code from the OAuth redirect. +type callbackServer struct { + server *http.Server + listener net.Listener + redirect string + results chan callbackResult +} + +// listenCallback binds the local callback listener. +// +// It binds to loopback (127.0.0.1) by default so the callback server is never +// exposed on other interfaces. bindAll is set only inside a container, where +// Docker's published-port DNAT delivers traffic to the container's eth0 rather +// than to loopback; host-side exposure is still constrained by the publish +// (e.g. -p 127.0.0.1:8085:8085). A native run — even with a fixed port — stays +// on loopback. +func listenCallback(port int, bindAll bool) (net.Listener, error) { + host := "127.0.0.1" + if bindAll { + host = "0.0.0.0" + } + addr := fmt.Sprintf("%s:%d", host, port) + listener, err := net.Listen("tcp", addr) + if err != nil { + return nil, fmt.Errorf("starting callback listener on %s: %w", addr, err) + } + return listener, nil +} + +// newCallbackServer starts a callback server on listener that validates state +// and reports the result on a buffered channel. The redirect URI always uses +// localhost so it matches the value registered on the OAuth/GitHub App. +func newCallbackServer(listener net.Listener, expectedState string) *callbackServer { + cs := &callbackServer{ + server: &http.Server{ReadHeaderTimeout: 10 * time.Second}, // ReadHeaderTimeout guards against Slowloris. + listener: listener, + redirect: fmt.Sprintf("http://localhost:%d/callback", listener.Addr().(*net.TCPAddr).Port), + results: make(chan callbackResult, 1), + } + cs.server.Handler = cs.handler(expectedState) + + go func() { + if err := cs.server.Serve(listener); err != nil && err != http.ErrServerClosed { + cs.report(callbackResult{err: fmt.Errorf("callback server: %w", err)}) + } + }() + + return cs +} + +// handler renders the callback endpoint. It reports the outcome exactly once and +// always shows the user a friendly page. +func (cs *callbackServer) handler(expectedState string) http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + if errCode := q.Get("error"); errCode != "" { + msg := errCode + if desc := q.Get("error_description"); desc != "" { + msg = fmt.Sprintf("%s: %s", errCode, desc) + } + cs.report(callbackResult{err: fmt.Errorf("authorization failed: %s", msg)}) + renderError(w, msg) + return + } + + if q.Get("state") != expectedState { + cs.report(callbackResult{err: fmt.Errorf("state mismatch (possible CSRF)")}) + renderError(w, "state mismatch") + return + } + + code := q.Get("code") + if code == "" { + cs.report(callbackResult{err: fmt.Errorf("no authorization code in callback")}) + renderError(w, "no authorization code received") + return + } + + cs.report(callbackResult{code: code}) + renderSuccess(w) + }) + return mux +} + +// report delivers the first outcome and drops later ones (the channel is +// buffered for one; subsequent redirect retries must not block the handler). +func (cs *callbackServer) report(res callbackResult) { + select { + case cs.results <- res: + default: + } +} + +// wait blocks for the callback outcome or ctx cancellation, then shuts the +// server down. It is safe to call once per server. +func (cs *callbackServer) wait(ctx context.Context) (string, error) { + defer cs.close() + select { + case res := <-cs.results: + return res.code, res.err + case <-ctx.Done(): + return "", ctx.Err() + } +} + +func (cs *callbackServer) close() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = cs.server.Shutdown(shutdownCtx) + _ = cs.listener.Close() +} + +func renderSuccess(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := successTemplate.Execute(w, nil); err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + } +} + +// renderError shows the failure page. html/template auto-escapes msg, so a +// hostile error_description cannot inject markup. +func renderError(w http.ResponseWriter, msg string) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := errorTemplate.Execute(w, struct{ ErrorMessage string }{ErrorMessage: msg}); err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + } +} diff --git a/internal/oauth/callback_test.go b/internal/oauth/callback_test.go new file mode 100644 index 0000000000..45a8fa71c4 --- /dev/null +++ b/internal/oauth/callback_test.go @@ -0,0 +1,92 @@ +package oauth + +import ( + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// serveCallback drives the callback handler with the given query string and +// returns the recorded response and the single reported result. +func serveCallback(t *testing.T, expectedState, query string) (*httptest.ResponseRecorder, callbackResult) { + t.Helper() + cs := &callbackServer{results: make(chan callbackResult, 1)} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/callback?"+query, nil) + + cs.handler(expectedState).ServeHTTP(rec, req) + + select { + case res := <-cs.results: + return rec, res + default: + t.Fatal("handler did not report a result") + return nil, callbackResult{} + } +} + +func TestCallbackHandlerSuccess(t *testing.T) { + rec, res := serveCallback(t, "state123", "code=the-code&state=state123") + + require.NoError(t, res.err) + assert.Equal(t, "the-code", res.code) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Body.String(), "Authorization Successful") +} + +func TestCallbackHandlerStateMismatch(t *testing.T) { + rec, res := serveCallback(t, "expected", "code=the-code&state=attacker") + + require.Error(t, res.err) + assert.Empty(t, res.code) + assert.Contains(t, res.err.Error(), "state mismatch") + assert.Contains(t, rec.Body.String(), "state mismatch") +} + +func TestCallbackHandlerMissingCode(t *testing.T) { + _, res := serveCallback(t, "state123", "state=state123") + + require.Error(t, res.err) + assert.Contains(t, res.err.Error(), "no authorization code") +} + +func TestCallbackHandlerOAuthError(t *testing.T) { + _, res := serveCallback(t, "state123", "error=access_denied&error_description=user+said+no") + + require.Error(t, res.err) + assert.Contains(t, res.err.Error(), "access_denied") + assert.Contains(t, res.err.Error(), "user said no") +} + +func TestCallbackHandlerEscapesError(t *testing.T) { + rec, _ := serveCallback(t, "state123", "error=evil&error_description=%3Cscript%3Ealert(1)%3C%2Fscript%3E") + + body := rec.Body.String() + assert.NotContains(t, body, "