From 40e879f17be08c5b47ab994fd0ea113335cfb34b Mon Sep 17 00:00:00 2001 From: Kevin McDonald Date: Thu, 11 Jun 2026 14:33:04 +0200 Subject: [PATCH 1/3] fix: remove memory leak potential by not calling op.cancel() Signed-off-by: Kevin McDonald --- transcoder.go | 1 + transcoder_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/transcoder.go b/transcoder.go index a0c1edd..2068df9 100644 --- a/transcoder.go +++ b/transcoder.go @@ -59,6 +59,7 @@ type Transcoder struct { // services and transcoding protocols and message encoding as needed. func (t *Transcoder) ServeHTTP(writer http.ResponseWriter, request *http.Request) { op := t.newOperation(writer, request) + defer op.cancel() err := op.validate(t) if t.unknownHandler != nil && errors.Is(err, errNotFound) { diff --git a/transcoder_test.go b/transcoder_test.go index 346a439..95246d8 100644 --- a/transcoder_test.go +++ b/transcoder_test.go @@ -1740,3 +1740,48 @@ func rot13(data []byte) { data[index] = char } } + +func TestTranscoder_ContextLeak(t *testing.T) { + t.Parallel() + + var handlerCtx context.Context + rpcHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCtx = r.Context() + w.WriteHeader(http.StatusOK) + }) + + var unknownCtx context.Context + unknownHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + unknownCtx = r.Context() + w.WriteHeader(http.StatusNotFound) + }) + + services := []*Service{ + NewService(testv1connect.LibraryServiceName, rpcHandler), + } + handler, err := NewTranscoder(services, WithUnknownHandler(unknownHandler)) + require.NoError(t, err) + + t.Run("success_pass_through", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/vanguard.test.v1.LibraryService/GetBook", nil) + req.ProtoMajor = 2 + req.ProtoMinor = 0 + req.Header.Set("Content-Type", "application/grpc") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + require.NotNil(t, handlerCtx) + assert.ErrorIs(t, handlerCtx.Err(), context.Canceled) + }) + + t.Run("not_found_unknown_handler", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/unknown/path", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + require.NotNil(t, unknownCtx) + assert.ErrorIs(t, unknownCtx.Err(), context.Canceled) + }) +} From c62ab50dd48c2de6bf097b94e1734a85b01b4d25 Mon Sep 17 00:00:00 2001 From: Kevin McDonald Date: Sun, 28 Jun 2026 07:26:54 +0200 Subject: [PATCH 2/3] chore: use op.request for downstream handlers Signed-off-by: Kevin McDonald --- transcoder.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transcoder.go b/transcoder.go index 2068df9..66c6476 100644 --- a/transcoder.go +++ b/transcoder.go @@ -63,8 +63,8 @@ func (t *Transcoder) ServeHTTP(writer http.ResponseWriter, request *http.Request err := op.validate(t) if t.unknownHandler != nil && errors.Is(err, errNotFound) { - request.Header = op.originalHeaders // restore headers, just in case initialization removed keys - t.unknownHandler.ServeHTTP(writer, request) + op.request.Header = op.originalHeaders // restore headers, just in case initialization removed keys + t.unknownHandler.ServeHTTP(writer, op.request) return } @@ -78,8 +78,8 @@ func (t *Transcoder) ServeHTTP(writer http.ResponseWriter, request *http.Request op.client.reqCompression.Name() == op.server.reqCompression.Name() { // No transformation needed. But we do need to restore the original headers first // since extracting request metadata may have removed keys. - request.Header = op.originalHeaders - op.methodConf.handler.ServeHTTP(writer, request) + op.request.Header = op.originalHeaders + op.methodConf.handler.ServeHTTP(writer, op.request) return } From 25b2e35523ffb307278977ef91df84d7c54d4762 Mon Sep 17 00:00:00 2001 From: Kevin McDonald Date: Sun, 28 Jun 2026 07:33:01 +0200 Subject: [PATCH 3/3] chore: fix linter Signed-off-by: Kevin McDonald --- transcoder_test.go | 49 +++++++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/transcoder_test.go b/transcoder_test.go index 95246d8..b58d680 100644 --- a/transcoder_test.go +++ b/transcoder_test.go @@ -1744,25 +1744,19 @@ func rot13(data []byte) { func TestTranscoder_ContextLeak(t *testing.T) { t.Parallel() - var handlerCtx context.Context - rpcHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handlerCtx = r.Context() - w.WriteHeader(http.StatusOK) - }) - - var unknownCtx context.Context - unknownHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - unknownCtx = r.Context() - w.WriteHeader(http.StatusNotFound) - }) - - services := []*Service{ - NewService(testv1connect.LibraryServiceName, rpcHandler), - } - handler, err := NewTranscoder(services, WithUnknownHandler(unknownHandler)) - require.NoError(t, err) - t.Run("success_pass_through", func(t *testing.T) { + t.Parallel() + ctxChan := make(chan context.Context, 1) + rpcHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctxChan <- r.Context() + w.WriteHeader(http.StatusOK) + }) + services := []*Service{ + NewService(testv1connect.LibraryServiceName, rpcHandler), + } + handler, err := NewTranscoder(services) + require.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/vanguard.test.v1.LibraryService/GetBook", nil) req.ProtoMajor = 2 req.ProtoMinor = 0 @@ -1771,16 +1765,35 @@ func TestTranscoder_ContextLeak(t *testing.T) { handler.ServeHTTP(rec, req) + var handlerCtx context.Context + select { + case handlerCtx = <-ctxChan: + default: + } require.NotNil(t, handlerCtx) assert.ErrorIs(t, handlerCtx.Err(), context.Canceled) }) t.Run("not_found_unknown_handler", func(t *testing.T) { + t.Parallel() + ctxChan := make(chan context.Context, 1) + unknownHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctxChan <- r.Context() + w.WriteHeader(http.StatusNotFound) + }) + handler, err := NewTranscoder(nil, WithUnknownHandler(unknownHandler)) + require.NoError(t, err) + req := httptest.NewRequest(http.MethodGet, "/unknown/path", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) + var unknownCtx context.Context + select { + case unknownCtx = <-ctxChan: + default: + } require.NotNil(t, unknownCtx) assert.ErrorIs(t, unknownCtx.Err(), context.Canceled) })