Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,17 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
// We used to maintain a big table in this code of all the sites and which way
// they went, but maintaining it didn't scale & got annoying.
// So just try both ways.
headerErr := err
authStyle = AuthStyleInParams // the second way we'll try
req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle)
token, err = doTokenRoundTrip(ctx, req)
if err != nil {
// Return both errors so callers can see the original
// failure from the header probe, which may be the real
// cause (e.g. misconfiguration) while the server already
// consumed the authorization code on the first attempt.
err = errors.Join(headerErr, err)
}
}
if needsAuthStyleProbe && err == nil {
styleCache.setAuthStyle(tokenURL, clientID, authStyle)
Expand Down
73 changes: 73 additions & 0 deletions internal/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ package internal

import (
"context"
"errors"
"fmt"
"io"
"math"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)

Expand Down Expand Up @@ -76,6 +78,77 @@ func TestExpiresInUpperBound(t *testing.T) {
}
}

// TestRetrieveToken_AuthStyleUnknown_BothProbesFail verifies that when
// the auth style is unknown and both the header-probe and the params-probe
// fail, RetrieveToken returns both errors joined together rather than only
// the second one. This matters because the first request may have already
// consumed the authorization code at the provider, so the second error is
// often a misleading consequence (e.g. "code already redeemed") of an
// unrelated first failure (e.g. misconfiguration).
func TestRetrieveToken_AuthStyleUnknown_BothProbesFail(t *testing.T) {
styleCache := new(AuthStyleCache)
const clientID = "client-id"

const headerErrCode = "invalid_client"
const paramsErrCode = "invalid_grant"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
if r.Header.Get("Authorization") != "" {
io.WriteString(w, `{"error":"`+headerErrCode+`"}`)
return
}
io.WriteString(w, `{"error":"`+paramsErrCode+`"}`)
}))
defer ts.Close()

_, err := RetrieveToken(context.Background(), clientID, "secret", ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
if err == nil {
t.Fatalf("RetrieveToken returned nil error; want joined error")
}

msg := err.Error()
if !strings.Contains(msg, headerErrCode) || !strings.Contains(msg, paramsErrCode) {
t.Errorf("error = %q; want it to mention both %q and %q", msg, headerErrCode, paramsErrCode)
}

var re *RetrieveError
if !errors.As(err, &re) {
t.Fatalf("errors.As(err, *RetrieveError) = false; want true")
}
}

// TestRetrieveToken_AuthStyleUnknown_HeaderFailsParamsSucceeds verifies that
// the probe still succeeds transparently when the header attempt fails but
// the params attempt succeeds. This is the long-standing behavior we must
// not break with the error-joining change.
func TestRetrieveToken_AuthStyleUnknown_HeaderFailsParamsSucceeds(t *testing.T) {
styleCache := new(AuthStyleCache)
const clientID = "client-id"

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Header.Get("Authorization") != "" {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, `{"error":"invalid_client"}`)
return
}
if got, want := r.FormValue("client_id"), clientID; got != want {
t.Errorf("client_id = %q; want %q", got, want)
}
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
}))
defer ts.Close()

tok, err := RetrieveToken(context.Background(), clientID, "secret", ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
if err != nil {
t.Fatalf("RetrieveToken = %v; want no error", err)
}
if tok.AccessToken != "ACCESS_TOKEN" {
t.Errorf("AccessToken = %q; want %q", tok.AccessToken, "ACCESS_TOKEN")
}
}

func TestAuthStyleCache(t *testing.T) {
var c LazyAuthStyleCache

Expand Down
64 changes: 48 additions & 16 deletions oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -490,15 +491,13 @@ func TestTokenRetrieveError(t *testing.T) {
if err == nil {
t.Fatalf("got no error, expected one")
}
re, ok := err.(*RetrieveError)
if !ok {
t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
// AuthStyleUnknown probes the server twice; both probes fail here, so
// the returned error is a join of two *RetrieveError values.
var re *RetrieveError
if !errors.As(err, &re) {
t.Fatalf("got %T error, expected to unwrap *RetrieveError; error was: %v", err, err)
}
expected := `oauth2: "invalid_grant"`
if errStr := err.Error(); errStr != expected {
t.Fatalf("got %#v, expected %#v", errStr, expected)
}
expected = "invalid_grant"
expected := "invalid_grant"
if re.ErrorCode != expected {
t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected)
}
Expand All @@ -519,20 +518,53 @@ func TestTokenRetrieveError200(t *testing.T) {
if err == nil {
t.Fatalf("got no error, expected one")
}
re, ok := err.(*RetrieveError)
if !ok {
t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
}
expected := `oauth2: "invalid_grant"`
if errStr := err.Error(); errStr != expected {
t.Fatalf("got %#v, expected %#v", errStr, expected)
var re *RetrieveError
if !errors.As(err, &re) {
t.Fatalf("got %T error, expected to unwrap *RetrieveError; error was: %v", err, err)
}
expected = "invalid_grant"
expected := "invalid_grant"
if re.ErrorCode != expected {
t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected)
}
}

// TestTokenRetrieveError_AuthStyleProbeJoinsErrors verifies that when
// Endpoint.AuthStyle is unset (AuthStyleUnknown) and both probes fail with
// distinct error codes, Exchange returns an error from which errors.As can
// still recover a public *RetrieveError and whose message mentions both
// failures. This prevents the first failure from being silently discarded
// (see golang/oauth2#786).
func TestTokenRetrieveError_AuthStyleProbeJoinsErrors(t *testing.T) {
const headerErrCode = "invalid_client"
const paramsErrCode = "invalid_grant"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
if r.Header.Get("Authorization") != "" {
w.Write([]byte(`{"error":"` + headerErrCode + `"}`))
return
}
w.Write([]byte(`{"error":"` + paramsErrCode + `"}`))
}))
defer ts.Close()

conf := newConf(ts.URL)
_, err := conf.Exchange(context.Background(), "exchange-code")
if err == nil {
t.Fatalf("got no error, expected one")
}

var re *RetrieveError
if !errors.As(err, &re) {
t.Fatalf("got %T error, expected to unwrap *RetrieveError; error was: %v", err, err)
}

msg := err.Error()
if !strings.Contains(msg, headerErrCode) || !strings.Contains(msg, paramsErrCode) {
t.Errorf("err.Error() = %q; want both %q and %q", msg, headerErrCode, paramsErrCode)
}
}

func TestRefreshToken_RefreshTokenReplacement(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
Expand Down
25 changes: 21 additions & 4 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package oauth2

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -174,14 +175,30 @@ func tokenFromInternal(t *internal.Token) *Token {
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*RetrieveError)(rErr)
}
return nil, err
return nil, convertRetrieveError(err)
}
return tokenFromInternal(tk), nil
}

// convertRetrieveError rewrites any *internal.RetrieveError values inside err
// to the public *RetrieveError type, preserving any errors.Join wrapping.
// This keeps errors.As(err, *RetrieveError) working even when
// internal.RetrieveToken joined two failures from its AuthStyle probe.
func convertRetrieveError(err error) error {
if rErr, ok := err.(*internal.RetrieveError); ok {
return (*RetrieveError)(rErr)
}
if u, ok := err.(interface{ Unwrap() []error }); ok {
wrapped := u.Unwrap()
converted := make([]error, len(wrapped))
for i, e := range wrapped {
converted[i] = convertRetrieveError(e)
}
return errors.Join(converted...)
}
return err
}

// RetrieveError is the error returned when the token endpoint returns a
// non-2XX HTTP status code or populates RFC 6749's 'error' parameter.
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
Expand Down