Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions transcoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ type Transcoder struct {
// services and transcoding protocols and message encoding as needed.
func (t *Transcoder) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
op := t.newOperation(writer, request)
defer op.cancel()
err := op.validate(t)

if t.unknownHandler != nil && errors.Is(err, errNotFound) {
request.Header = op.originalHeaders // restore headers, just in case initialization removed keys
t.unknownHandler.ServeHTTP(writer, request)
op.request.Header = op.originalHeaders // restore headers, just in case initialization removed keys
t.unknownHandler.ServeHTTP(writer, op.request)
return
}

Expand All @@ -77,8 +78,8 @@ func (t *Transcoder) ServeHTTP(writer http.ResponseWriter, request *http.Request
op.client.reqCompression.Name() == op.server.reqCompression.Name() {
// No transformation needed. But we do need to restore the original headers first
// since extracting request metadata may have removed keys.
request.Header = op.originalHeaders
op.methodConf.handler.ServeHTTP(writer, request)
op.request.Header = op.originalHeaders
op.methodConf.handler.ServeHTTP(writer, op.request)
return
}

Expand Down
58 changes: 58 additions & 0 deletions transcoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1740,3 +1740,61 @@ func rot13(data []byte) {
data[index] = char
}
}

func TestTranscoder_ContextLeak(t *testing.T) {
t.Parallel()

t.Run("success_pass_through", func(t *testing.T) {
t.Parallel()
ctxChan := make(chan context.Context, 1)
rpcHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxChan <- r.Context()
w.WriteHeader(http.StatusOK)
})
services := []*Service{
NewService(testv1connect.LibraryServiceName, rpcHandler),
}
handler, err := NewTranscoder(services)
require.NoError(t, err)

req := httptest.NewRequest(http.MethodPost, "/vanguard.test.v1.LibraryService/GetBook", nil)
req.ProtoMajor = 2
req.ProtoMinor = 0
req.Header.Set("Content-Type", "application/grpc")
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

var handlerCtx context.Context
select {
case handlerCtx = <-ctxChan:
default:
}
require.NotNil(t, handlerCtx)
assert.ErrorIs(t, handlerCtx.Err(), context.Canceled)
})

t.Run("not_found_unknown_handler", func(t *testing.T) {
t.Parallel()
ctxChan := make(chan context.Context, 1)
unknownHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxChan <- r.Context()
w.WriteHeader(http.StatusNotFound)
})
handler, err := NewTranscoder(nil, WithUnknownHandler(unknownHandler))
require.NoError(t, err)

req := httptest.NewRequest(http.MethodGet, "/unknown/path", nil)
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

var unknownCtx context.Context
select {
case unknownCtx = <-ctxChan:
default:
}
require.NotNil(t, unknownCtx)
assert.ErrorIs(t, unknownCtx.Err(), context.Canceled)
})
}
Loading