diff --git a/internal/token.go b/internal/token.go index 8389f2462..324ea1651 100644 --- a/internal/token.go +++ b/internal/token.go @@ -240,9 +240,17 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, // We used to maintain a big table in this code of all the sites and which way // they went, but maintaining it didn't scale & got annoying. // So just try both ways. + headerErr := err authStyle = AuthStyleInParams // the second way we'll try req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) token, err = doTokenRoundTrip(ctx, req) + if err != nil { + // Return both errors so callers can see the original + // failure from the header probe, which may be the real + // cause (e.g. misconfiguration) while the server already + // consumed the authorization code on the first attempt. + err = errors.Join(headerErr, err) + } } if needsAuthStyleProbe && err == nil { styleCache.setAuthStyle(tokenURL, clientID, authStyle) diff --git a/internal/token_test.go b/internal/token_test.go index ef28c1162..4870ae0df 100644 --- a/internal/token_test.go +++ b/internal/token_test.go @@ -6,12 +6,14 @@ package internal import ( "context" + "errors" "fmt" "io" "math" "net/http" "net/http/httptest" "net/url" + "strings" "testing" ) @@ -76,6 +78,77 @@ func TestExpiresInUpperBound(t *testing.T) { } } +// TestRetrieveToken_AuthStyleUnknown_BothProbesFail verifies that when +// the auth style is unknown and both the header-probe and the params-probe +// fail, RetrieveToken returns both errors joined together rather than only +// the second one. This matters because the first request may have already +// consumed the authorization code at the provider, so the second error is +// often a misleading consequence (e.g. "code already redeemed") of an +// unrelated first failure (e.g. misconfiguration). +func TestRetrieveToken_AuthStyleUnknown_BothProbesFail(t *testing.T) { + styleCache := new(AuthStyleCache) + const clientID = "client-id" + + const headerErrCode = "invalid_client" + const paramsErrCode = "invalid_grant" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if r.Header.Get("Authorization") != "" { + io.WriteString(w, `{"error":"`+headerErrCode+`"}`) + return + } + io.WriteString(w, `{"error":"`+paramsErrCode+`"}`) + })) + defer ts.Close() + + _, err := RetrieveToken(context.Background(), clientID, "secret", ts.URL, url.Values{}, AuthStyleUnknown, styleCache) + if err == nil { + t.Fatalf("RetrieveToken returned nil error; want joined error") + } + + msg := err.Error() + if !strings.Contains(msg, headerErrCode) || !strings.Contains(msg, paramsErrCode) { + t.Errorf("error = %q; want it to mention both %q and %q", msg, headerErrCode, paramsErrCode) + } + + var re *RetrieveError + if !errors.As(err, &re) { + t.Fatalf("errors.As(err, *RetrieveError) = false; want true") + } +} + +// TestRetrieveToken_AuthStyleUnknown_HeaderFailsParamsSucceeds verifies that +// the probe still succeeds transparently when the header attempt fails but +// the params attempt succeeds. This is the long-standing behavior we must +// not break with the error-joining change. +func TestRetrieveToken_AuthStyleUnknown_HeaderFailsParamsSucceeds(t *testing.T) { + styleCache := new(AuthStyleCache) + const clientID = "client-id" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if r.Header.Get("Authorization") != "" { + w.WriteHeader(http.StatusBadRequest) + io.WriteString(w, `{"error":"invalid_client"}`) + return + } + if got, want := r.FormValue("client_id"), clientID; got != want { + t.Errorf("client_id = %q; want %q", got, want) + } + io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) + })) + defer ts.Close() + + tok, err := RetrieveToken(context.Background(), clientID, "secret", ts.URL, url.Values{}, AuthStyleUnknown, styleCache) + if err != nil { + t.Fatalf("RetrieveToken = %v; want no error", err) + } + if tok.AccessToken != "ACCESS_TOKEN" { + t.Errorf("AccessToken = %q; want %q", tok.AccessToken, "ACCESS_TOKEN") + } +} + func TestAuthStyleCache(t *testing.T) { var c LazyAuthStyleCache diff --git a/oauth2_test.go b/oauth2_test.go index e996b8013..9f23b7319 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -12,6 +12,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" ) @@ -490,15 +491,13 @@ func TestTokenRetrieveError(t *testing.T) { if err == nil { t.Fatalf("got no error, expected one") } - re, ok := err.(*RetrieveError) - if !ok { - t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err) + // AuthStyleUnknown probes the server twice; both probes fail here, so + // the returned error is a join of two *RetrieveError values. + var re *RetrieveError + if !errors.As(err, &re) { + t.Fatalf("got %T error, expected to unwrap *RetrieveError; error was: %v", err, err) } - expected := `oauth2: "invalid_grant"` - if errStr := err.Error(); errStr != expected { - t.Fatalf("got %#v, expected %#v", errStr, expected) - } - expected = "invalid_grant" + expected := "invalid_grant" if re.ErrorCode != expected { t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected) } @@ -519,20 +518,53 @@ func TestTokenRetrieveError200(t *testing.T) { if err == nil { t.Fatalf("got no error, expected one") } - re, ok := err.(*RetrieveError) - if !ok { - t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err) - } - expected := `oauth2: "invalid_grant"` - if errStr := err.Error(); errStr != expected { - t.Fatalf("got %#v, expected %#v", errStr, expected) + var re *RetrieveError + if !errors.As(err, &re) { + t.Fatalf("got %T error, expected to unwrap *RetrieveError; error was: %v", err, err) } - expected = "invalid_grant" + expected := "invalid_grant" if re.ErrorCode != expected { t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected) } } +// TestTokenRetrieveError_AuthStyleProbeJoinsErrors verifies that when +// Endpoint.AuthStyle is unset (AuthStyleUnknown) and both probes fail with +// distinct error codes, Exchange returns an error from which errors.As can +// still recover a public *RetrieveError and whose message mentions both +// failures. This prevents the first failure from being silently discarded +// (see golang/oauth2#786). +func TestTokenRetrieveError_AuthStyleProbeJoinsErrors(t *testing.T) { + const headerErrCode = "invalid_client" + const paramsErrCode = "invalid_grant" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if r.Header.Get("Authorization") != "" { + w.Write([]byte(`{"error":"` + headerErrCode + `"}`)) + return + } + w.Write([]byte(`{"error":"` + paramsErrCode + `"}`)) + })) + defer ts.Close() + + conf := newConf(ts.URL) + _, err := conf.Exchange(context.Background(), "exchange-code") + if err == nil { + t.Fatalf("got no error, expected one") + } + + var re *RetrieveError + if !errors.As(err, &re) { + t.Fatalf("got %T error, expected to unwrap *RetrieveError; error was: %v", err, err) + } + + msg := err.Error() + if !strings.Contains(msg, headerErrCode) || !strings.Contains(msg, paramsErrCode) { + t.Errorf("err.Error() = %q; want both %q and %q", msg, headerErrCode, paramsErrCode) + } +} + func TestRefreshToken_RefreshTokenReplacement(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") diff --git a/token.go b/token.go index e995eebb5..1641aec02 100644 --- a/token.go +++ b/token.go @@ -6,6 +6,7 @@ package oauth2 import ( "context" + "errors" "fmt" "net/http" "net/url" @@ -174,14 +175,30 @@ func tokenFromInternal(t *internal.Token) *Token { func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get()) if err != nil { - if rErr, ok := err.(*internal.RetrieveError); ok { - return nil, (*RetrieveError)(rErr) - } - return nil, err + return nil, convertRetrieveError(err) } return tokenFromInternal(tk), nil } +// convertRetrieveError rewrites any *internal.RetrieveError values inside err +// to the public *RetrieveError type, preserving any errors.Join wrapping. +// This keeps errors.As(err, *RetrieveError) working even when +// internal.RetrieveToken joined two failures from its AuthStyle probe. +func convertRetrieveError(err error) error { + if rErr, ok := err.(*internal.RetrieveError); ok { + return (*RetrieveError)(rErr) + } + if u, ok := err.(interface{ Unwrap() []error }); ok { + wrapped := u.Unwrap() + converted := make([]error, len(wrapped)) + for i, e := range wrapped { + converted[i] = convertRetrieveError(e) + } + return errors.Join(converted...) + } + return err +} + // RetrieveError is the error returned when the token endpoint returns a // non-2XX HTTP status code or populates RFC 6749's 'error' parameter. // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2