From 6250ccfa2ae053c2795d50d0f28c8d570db87feb Mon Sep 17 00:00:00 2001 From: Kenta Ishizaki Date: Thu, 23 Apr 2026 19:51:13 +0900 Subject: [PATCH] oauth2: join both probe errors when auth style is unknown When Endpoint.AuthStyle is AuthStyleUnknown (the default), RetrieveToken probes the token endpoint by trying credentials in the Authorization header first, then in form params if that fails. If both attempts failed, only the second error was returned, silently discarding the first. This is misleading when the first request fails for a reason unrelated to auth style (e.g. misconfiguration or an expired signing key) and the provider has already consumed the authorization code. The second request then fails with a different error (e.g. "code already redeemed") that hides the real cause. Join both errors with errors.Join so callers see the full picture, and update retrieveToken to convert every wrapped *internal.RetrieveError to the public *RetrieveError type so errors.As keeps working. Note: errors.As now unwraps to the first (header) probe error rather than the second (params) one. This is intentional, as the header error is typically the root cause. Fixes golang/oauth2#786 --- internal/token.go | 8 +++++ internal/token_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++ oauth2_test.go | 64 +++++++++++++++++++++++++++--------- token.go | 25 ++++++++++++--- 4 files changed, 150 insertions(+), 20 deletions(-) diff --git a/internal/token.go b/internal/token.go index 8389f2462..324ea1651 100644 --- a/internal/token.go +++ b/internal/token.go @@ -240,9 +240,17 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, // We used to maintain a big table in this code of all the sites and which way // they went, but maintaining it didn't scale & got annoying. // So just try both ways. + headerErr := err authStyle = AuthStyleInParams // the second way we'll try req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) token, err = doTokenRoundTrip(ctx, req) + if err != nil { + // Return both errors so callers can see the original + // failure from the header probe, which may be the real + // cause (e.g. misconfiguration) while the server already + // consumed the authorization code on the first attempt. + err = errors.Join(headerErr, err) + } } if needsAuthStyleProbe && err == nil { styleCache.setAuthStyle(tokenURL, clientID, authStyle) diff --git a/internal/token_test.go b/internal/token_test.go index ef28c1162..4870ae0df 100644 --- a/internal/token_test.go +++ b/internal/token_test.go @@ -6,12 +6,14 @@ package internal import ( "context" + "errors" "fmt" "io" "math" "net/http" "net/http/httptest" "net/url" + "strings" "testing" ) @@ -76,6 +78,77 @@ func TestExpiresInUpperBound(t *testing.T) { } } +// TestRetrieveToken_AuthStyleUnknown_BothProbesFail verifies that when +// the auth style is unknown and both the header-probe and the params-probe +// fail, RetrieveToken returns both errors joined together rather than only +// the second one. This matters because the first request may have already +// consumed the authorization code at the provider, so the second error is +// often a misleading consequence (e.g. "code already redeemed") of an +// unrelated first failure (e.g. misconfiguration). +func TestRetrieveToken_AuthStyleUnknown_BothProbesFail(t *testing.T) { + styleCache := new(AuthStyleCache) + const clientID = "client-id" + + const headerErrCode = "invalid_client" + const paramsErrCode = "invalid_grant" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if r.Header.Get("Authorization") != "" { + io.WriteString(w, `{"error":"`+headerErrCode+`"}`) + return + } + io.WriteString(w, `{"error":"`+paramsErrCode+`"}`) + })) + defer ts.Close() + + _, err := RetrieveToken(context.Background(), clientID, "secret", ts.URL, url.Values{}, AuthStyleUnknown, styleCache) + if err == nil { + t.Fatalf("RetrieveToken returned nil error; want joined error") + } + + msg := err.Error() + if !strings.Contains(msg, headerErrCode) || !strings.Contains(msg, paramsErrCode) { + t.Errorf("error = %q; want it to mention both %q and %q", msg, headerErrCode, paramsErrCode) + } + + var re *RetrieveError + if !errors.As(err, &re) { + t.Fatalf("errors.As(err, *RetrieveError) = false; want true") + } +} + +// TestRetrieveToken_AuthStyleUnknown_HeaderFailsParamsSucceeds verifies that +// the probe still succeeds transparently when the header attempt fails but +// the params attempt succeeds. This is the long-standing behavior we must +// not break with the error-joining change. +func TestRetrieveToken_AuthStyleUnknown_HeaderFailsParamsSucceeds(t *testing.T) { + styleCache := new(AuthStyleCache) + const clientID = "client-id" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Header.Get("Authorization") != "" { + w.WriteHeader(http.StatusBadRequest) + io.WriteString(w, `{"error":"invalid_client"}`) + return + } + if got, want := r.FormValue("client_id"), clientID; got != want { + t.Errorf("client_id = %q; want %q", got, want) + } + io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) + })) + defer ts.Close() + + tok, err := RetrieveToken(context.Background(), clientID, "secret", ts.URL, url.Values{}, AuthStyleUnknown, styleCache) + if err != nil { + t.Fatalf("RetrieveToken = %v; want no error", err) + } + if tok.AccessToken != "ACCESS_TOKEN" { + t.Errorf("AccessToken = %q; want %q", tok.AccessToken, "ACCESS_TOKEN") + } +} + func TestAuthStyleCache(t *testing.T) { var c LazyAuthStyleCache diff --git a/oauth2_test.go b/oauth2_test.go index e996b8013..9f23b7319 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -12,6 +12,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" ) @@ -490,15 +491,13 @@ func TestTokenRetrieveError(t *testing.T) { if err == nil { t.Fatalf("got no error, expected one") } - re, ok := err.(*RetrieveError) - if !ok { - t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err) + // AuthStyleUnknown probes the server twice; both probes fail here, so + // the returned error is a join of two *RetrieveError values. + var re *RetrieveError + if !errors.As(err, &re) { + t.Fatalf("got %T error, expected to unwrap *RetrieveError; error was: %v", err, err) } - expected := `oauth2: "invalid_grant"` - if errStr := err.Error(); errStr != expected { - t.Fatalf("got %#v, expected %#v", errStr, expected) - } - expected = "invalid_grant" + expected := "invalid_grant" if re.ErrorCode != expected { t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected) } @@ -519,20 +518,53 @@ func TestTokenRetrieveError200(t *testing.T) { if err == nil { t.Fatalf("got no error, expected one") } - re, ok := err.(*RetrieveError) - if !ok { - t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err) - } - expected := `oauth2: "invalid_grant"` - if errStr := err.Error(); errStr != expected { - t.Fatalf("got %#v, expected %#v", errStr, expected) + var re *RetrieveError + if !errors.As(err, &re) { + t.Fatalf("got %T error, expected to unwrap *RetrieveError; error was: %v", err, err) } - expected = "invalid_grant" + expected := "invalid_grant" if re.ErrorCode != expected { t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected) } } +// TestTokenRetrieveError_AuthStyleProbeJoinsErrors verifies that when +// Endpoint.AuthStyle is unset (AuthStyleUnknown) and both probes fail with +// distinct error codes, Exchange returns an error from which errors.As can +// still recover a public *RetrieveError and whose message mentions both +// failures. This prevents the first failure from being silently discarded +// (see golang/oauth2#786). +func TestTokenRetrieveError_AuthStyleProbeJoinsErrors(t *testing.T) { + const headerErrCode = "invalid_client" + const paramsErrCode = "invalid_grant" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if r.Header.Get("Authorization") != "" { + w.Write([]byte(`{"error":"` + headerErrCode + `"}`)) + return + } + w.Write([]byte(`{"error":"` + paramsErrCode + `"}`)) + })) + defer ts.Close() + + conf := newConf(ts.URL) + _, err := conf.Exchange(context.Background(), "exchange-code") + if err == nil { + t.Fatalf("got no error, expected one") + } + + var re *RetrieveError + if !errors.As(err, &re) { + t.Fatalf("got %T error, expected to unwrap *RetrieveError; error was: %v", err, err) + } + + msg := err.Error() + if !strings.Contains(msg, headerErrCode) || !strings.Contains(msg, paramsErrCode) { + t.Errorf("err.Error() = %q; want both %q and %q", msg, headerErrCode, paramsErrCode) + } +} + func TestRefreshToken_RefreshTokenReplacement(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") diff --git a/token.go b/token.go index e995eebb5..1641aec02 100644 --- a/token.go +++ b/token.go @@ -6,6 +6,7 @@ package oauth2 import ( "context" + "errors" "fmt" "net/http" "net/url" @@ -174,14 +175,30 @@ func tokenFromInternal(t *internal.Token) *Token { func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get()) if err != nil { - if rErr, ok := err.(*internal.RetrieveError); ok { - return nil, (*RetrieveError)(rErr) - } - return nil, err + return nil, convertRetrieveError(err) } return tokenFromInternal(tk), nil } +// convertRetrieveError rewrites any *internal.RetrieveError values inside err +// to the public *RetrieveError type, preserving any errors.Join wrapping. +// This keeps errors.As(err, *RetrieveError) working even when +// internal.RetrieveToken joined two failures from its AuthStyle probe. +func convertRetrieveError(err error) error { + if rErr, ok := err.(*internal.RetrieveError); ok { + return (*RetrieveError)(rErr) + } + if u, ok := err.(interface{ Unwrap() []error }); ok { + wrapped := u.Unwrap() + converted := make([]error, len(wrapped)) + for i, e := range wrapped { + converted[i] = convertRetrieveError(e) + } + return errors.Join(converted...) + } + return err +} + // RetrieveError is the error returned when the token endpoint returns a // non-2XX HTTP status code or populates RFC 6749's 'error' parameter. // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2