Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,31 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client {
return internal.ContextClient(ctx)
}
cc := internal.ContextClient(ctx)
// Wrap the caller's CheckRedirect (or nil for the default) to stop
// cross-host redirects. The oauth2.Transport RoundTripper re-injects
// the Authorization header on every RoundTrip call, including redirect
// hops; Go's http.Client strips sensitive headers for cross-host
// redirects, but the Transport immediately adds them back. Stopping
// at the first cross-host redirect (returning the 3xx response to the
// caller) prevents token leakage to unintended hosts.
callerRedirect := cc.CheckRedirect
checkRedirect := func(req *http.Request, via []*http.Request) error {
if callerRedirect != nil {
if err := callerRedirect(req, via); err != nil {
return err
}
}
if len(via) > 0 && req.URL.Host != via[0].URL.Host {
return http.ErrUseLastResponse
}
return nil
}
return &http.Client{
Transport: &Transport{
Base: cc.Transport,
Source: ReuseTokenSource(nil, src),
},
CheckRedirect: cc.CheckRedirect,
CheckRedirect: checkRedirect,
Jar: cc.Jar,
Timeout: cc.Timeout,
}
Expand Down
70 changes: 70 additions & 0 deletions transport_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package oauth2

import (
"context"
"errors"
"io"
"net/http"
Expand Down Expand Up @@ -154,3 +155,72 @@ func TestExpiredWithExpiry(t *testing.T) {
func newMockServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(handler))
}

// TestNewClientCrossHostRedirectDoesNotLeakToken verifies that a client
// created with NewClient stops following redirects when the redirect target
// is on a different host, preventing the oauth2 transport from forwarding
// the bearer token to an unintended destination.
func TestNewClientCrossHostRedirectDoesNotLeakToken(t *testing.T) {
tokenReceived := false
victim := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "" {
tokenReceived = true
}
w.WriteHeader(http.StatusOK)
}))
defer victim.Close()

// redirector lives on a different address (different host:port) from victim.
redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, victim.URL+"/stolen", http.StatusFound)
}))
defer redirector.Close()

ts := StaticTokenSource(&Token{AccessToken: "SECRET_TOKEN"})
client := NewClient(context.Background(), ts)

resp, err := client.Get(redirector.URL + "/api")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()

// The client must NOT have followed the cross-host redirect.
if resp.StatusCode != http.StatusFound {
t.Errorf("got status %d; want %d (redirect should not be followed)", resp.StatusCode, http.StatusFound)
}
if tokenReceived {
t.Error("token was forwarded to the redirect target; cross-host redirect should have been stopped")
}
}

// TestNewClientSameHostRedirectIsFollowed verifies that same-host redirects
// are still followed (backward compatibility).
func TestNewClientSameHostRedirectIsFollowed(t *testing.T) {
called := false
mux := http.NewServeMux()
mux.HandleFunc("/redirectme", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/final", http.StatusFound)
})
mux.HandleFunc("/final", func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
server := httptest.NewServer(mux)
defer server.Close()

ts := StaticTokenSource(&Token{AccessToken: "MY_TOKEN"})
client := NewClient(context.Background(), ts)

resp, err := client.Get(server.URL + "/redirectme")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if !called {
t.Error("same-host redirect was not followed")
}
if resp.StatusCode != http.StatusOK {
t.Errorf("got status %d; want 200", resp.StatusCode)
}
}