From 90277549d4f946577e0ce1563b8976e667b0be27 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 13:54:45 -0400 Subject: [PATCH 01/57] Redo Signed-off-by: Steve Ayers --- client.go | 46 ++- client_ext_test.go | 10 +- connect_ext_test.go | 460 +++++++++++++++++++++-------- context.go | 133 ++++++--- error_example_test.go | 11 +- error_not_modified_example_test.go | 26 +- example_init_test.go | 2 +- handler.go | 58 ++-- interceptor_ext_test.go | 6 +- 9 files changed, 532 insertions(+), 220 deletions(-) diff --git a/client.go b/client.go index ffb9336f..792430bb 100644 --- a/client.go +++ b/client.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "io" + "maps" "net/http" "net/url" "strings" @@ -127,16 +128,31 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) if c.err != nil { return nil, c.err } - return c.callUnary(ctx, request) + ctx, ci := NewOutgoingContext(ctx) + call, ok := ci.(*callInfo) + if ok { + call.requestHeader = request.Header() + } + + resp, err := c.callUnary(ctx, request) + if err != nil { + return nil, err + } + + if ok { + call.peer = request.Peer() + call.spec = request.Spec() + call.method = request.HTTPMethod() + maps.Copy(call.ResponseHeader(), resp.Header()) + maps.Copy(call.ResponseTrailer(), resp.Trailer()) + } + + return resp, nil } -// CallUnarySimple calls a request-response procedure using the function signature -// associated with the "simple" generation option. -// -// This option eliminates the [Request] and [Response] wrappers, and instead uses the -// context.Context to propagate information such as headers. -func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, requestMsg *Req) (*Res, error) { - response, err := c.CallUnary(ctx, requestFromContext(ctx, requestMsg)) +// CallUnary calls a request-response procedure. +func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, request *Req) (*Res, error) { + response, err := c.CallUnary(ctx, requestFromContext(ctx, request)) if response != nil { return response.Msg, err } @@ -159,12 +175,21 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if c.err != nil { return nil, c.err } + ctx, ctxCallInfo := NewOutgoingContext(ctx) + // Note we don't need to check ok here because it should always be in context + // because of the above call to NewOutgoingContext + info, _ := ctxCallInfo.(*callInfo) conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method + info.method = r.Method }) request.spec = conn.Spec() request.peer = conn.Peer() mergeHeaders(conn.RequestHeader(), request.header) + + info.peer = conn.Peer() + info.spec = conn.Spec() + mergeHeaders(conn.RequestHeader(), info.requestHeader) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. @@ -182,11 +207,6 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques }, nil } -// CallServerStreamSimple calls a server streaming procedure using the function signature -// associated with the "simple" generation option. -// -// This option eliminates the [Request] wrapper, and instead uses the context.Context to -// propagate information such as headers. func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { return c.CallServerStream(ctx, requestFromContext(ctx, requestMsg)) } diff --git a/client_ext_test.go b/client_ext_test.go index 4e5b8351..14ae2773 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -89,7 +89,7 @@ func TestNewClient_InitFailure(t *testing.T) { func TestClientPeer(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{})) server := memhttptest.NewServer(t, mux) run := func(t *testing.T, unaryHTTPMethod string, opts ...connect.ClientOption) { @@ -205,7 +205,7 @@ func TestGetNoContentHeaders(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(&pingServer{})) + mux.Handle(pingv1connect.NewPingServiceHandler(&pingServerGenerics{})) server := memhttptest.NewServer(t, http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) { if len(req.Header.Values("content-type")) > 0 || len(req.Header.Values("content-encoding")) > 0 || @@ -283,7 +283,7 @@ func TestSpecSchema(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler( - pingServer{}, + pingServerGenerics{}, connect.WithInterceptors(&assertSchemaInterceptor{t}), )) server := memhttptest.NewServer(t, mux) @@ -320,7 +320,7 @@ func TestSpecSchema(t *testing.T) { func TestDynamicClient(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{})) server := memhttptest.NewServer(t, mux) ctx := context.Background() initializer := func(spec connect.Spec, msg any) error { @@ -494,7 +494,7 @@ func TestClientDeadlineHandling(t *testing.T) { // detector enabled. That's partly why the makefile only runs "slow" // tests with the race detector disabled. - _, handler := pingv1connect.NewPingServiceHandler(pingServer{}) + _, handler := pingv1connect.NewPingServiceHandler(pingServerGenerics{}) svr := httptest.NewUnstartedServer(http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) { if req.Context().Err() != nil { return diff --git a/connect_ext_test.go b/connect_ext_test.go index 22e392e3..61908a13 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -29,6 +29,7 @@ import ( rand "math/rand/v2" "net" "net/http" + "net/http/httptest" "runtime" "strings" "sync" @@ -38,8 +39,9 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" - "connectrpc.com/connect/internal/gen/generics/connect/import/v1/importv1connect" - "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + pingv1connectgenerics "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/import/v1/importv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/proto" @@ -61,9 +63,132 @@ const ( clientMiddlewareErrorHeader = "Connect-Trigger-HTTP-Error" ) +func TestCallInfo(t *testing.T) { + t.Parallel() + t.Run("simple_api", func(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{checkMetadata: true}, + )) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + t.Run("unary", func(t *testing.T) { + num := int64(42) + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + callInfo.RequestHeader().Set(clientHeader, headerValue) + expect := &pingv1.PingResponse{Number: num} + + response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) + assert.Equal(t, response, expect) + assert.Nil(t, err) + + // Assert call info values are correctly populated + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + }) + t.Run("server_stream", func(t *testing.T) { + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + callInfo.RequestHeader().Set(clientHeader, headerValue) + stream, err := client.CountUp(ctx, &pingv1.CountUpRequest{ + Number: 1, + }) + assert.Nil(t, err) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), 1) + assert.Nil(t, stream.Close()) + + // Assert call info values are correctly populated + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + }) + }) + t.Run("generics_api", func(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connectgenerics.NewPingServiceHandler( + pingServerGenerics{checkMetadata: true}, + )) + server := memhttptest.NewServer(t, mux) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + t.Run("unary", func(t *testing.T) { + num := int64(42) + request := connect.NewRequest(&pingv1.PingRequest{Number: num}) + request.Header().Set(clientHeader, headerValue) + expect := &pingv1.PingResponse{Number: num} + + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + response, err := client.Ping(ctx, request) + assert.Nil(t, err) + assert.Equal(t, response.Msg, expect) + + assert.Equal(t, request.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, request.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, request.Spec().IsClient) + assert.Equal(t, request.Peer().Addr, httptest.DefaultRemoteAddr) + + // Verify that spec and peer on the callInfo are the same as the request wrapper + assert.Equal(t, callInfo.Spec().StreamType, request.Spec().StreamType) + assert.Equal(t, callInfo.Spec().Procedure, request.Spec().Procedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, request.Peer().Addr) + + // Verify that the response headers and trailers are the same on callInfo and the response + assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + + }) + t.Run("server_stream", func(t *testing.T) { + req := connect.NewRequest(&pingv1.CountUpRequest{ + Number: 1, + }) + req.Header().Set(clientHeader, headerValue) + stream, err := client.CountUp(context.Background(), req) + assert.Nil(t, err) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), 1) + assert.Nil(t, stream.Close()) + assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + // num := int64(42) + // ctx, callInfo := connect.NewOutgoingContext(context.Background()) + // callInfo.RequestHeader().Set(clientHeader, headerValue) + // expect := &pingv1.PingResponse{Number: num} + + // response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) + // assert.Equal(t, response, expect) + // assert.Nil(t, err) + + // // Assert call info values are correctly populated + // assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) + // assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) + // assert.True(t, callInfo.Spec().IsClient) + // assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + // assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + }) + }) +} + func TestServer(t *testing.T) { t.Parallel() - testPing := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + testPing := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper t.Run("ping", func(t *testing.T) { num := int64(42) request := connect.NewRequest(&pingv1.PingRequest{Number: num}) @@ -117,7 +242,7 @@ func TestServer(t *testing.T) { assert.Equal(t, connect.CodeOf(err), connect.CodeDeadlineExceeded) }) } - testSum := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + testSum := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper t.Run("sum", func(t *testing.T) { const ( upTo = 10 @@ -153,7 +278,7 @@ func TestServer(t *testing.T) { assert.Equal(t, got.Header().Values(handlerHeader), []string{headerValue}) }) } - testCountUp := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + testCountUp := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper t.Run("count_up", func(t *testing.T) { const upTo = 5 got := make([]int64, 0, upTo) @@ -211,7 +336,7 @@ func TestServer(t *testing.T) { assert.Nil(t, stream.Close()) }) } - testCumSum := func(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper + testCumSum := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, expectSuccess bool) { //nolint:thelper t.Run("cumsum", func(t *testing.T) { send := []int64{3, 5, 1} expect := []int64{3, 8, 9} @@ -326,7 +451,7 @@ func TestServer(t *testing.T) { assert.Nil(t, stream.CloseResponse()) }) } - testErrors := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + testErrors := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper assertIsHTTPMiddlewareError := func(tb testing.TB, err error) { tb.Helper() assert.NotNil(tb, err) @@ -377,7 +502,7 @@ func TestServer(t *testing.T) { testMatrix := func(t *testing.T, client *http.Client, url string, bidi bool) { //nolint:thelper run := func(t *testing.T, opts ...connect.ClientOption) { t.Helper() - client := pingv1connect.NewPingServiceClient(client, url, opts...) + client := pingv1connectgenerics.NewPingServiceClient(client, url, opts...) testPing(t, client) testSum(t, client) testCountUp(t, client) @@ -442,8 +567,8 @@ func TestServer(t *testing.T) { } mux := http.NewServeMux() - pingRoute, pingHandler := pingv1connect.NewPingServiceHandler( - pingServer{checkMetadata: true}, + pingRoute, pingHandler := pingv1connectgenerics.NewPingServiceHandler( + pingServerGenerics{checkMetadata: true}, ) errorWriter := connect.NewErrorWriter() // Add net/http middleware to the ping service to evaluate HTTP state. @@ -464,15 +589,15 @@ func TestServer(t *testing.T) { } // Check Content-Length is set correctly. switch request.URL.Path { - case pingv1connect.PingServicePingProcedure, - pingv1connect.PingServiceFailProcedure, - pingv1connect.PingServiceCountUpProcedure: + case pingv1connectgenerics.PingServicePingProcedure, + pingv1connectgenerics.PingServiceFailProcedure, + pingv1connectgenerics.PingServiceCountUpProcedure: // Unary requests set Content-Length to the length of the request body. if request.ContentLength < 0 { t.Errorf("%s: expected Content-Length >= 0, got %d", request.URL.Path, request.ContentLength) } - case pingv1connect.PingServiceSumProcedure, - pingv1connect.PingServiceCumSumProcedure: + case pingv1connectgenerics.PingServiceSumProcedure, + pingv1connectgenerics.PingServiceCumSumProcedure: // Streaming requests set Content-Length to -1 or 0 on empty requests. if request.ContentLength > 0 { t.Errorf("%s: expected Content-Length -1 or 0, got %d", request.URL.Path, request.ContentLength) @@ -503,7 +628,7 @@ func TestConcurrentStreams(t *testing.T) { } t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{})) server := memhttptest.NewServer(t, mux) var done, start sync.WaitGroup start.Add(1) @@ -511,7 +636,7 @@ func TestConcurrentStreams(t *testing.T) { done.Add(1) go func() { defer done.Done() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) var total int64 sum := client.CumSum(context.Background()) start.Wait() @@ -575,7 +700,7 @@ func TestErrorHeaderPropagation(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) assertError := func(t *testing.T, err error, allowCustomHeaders bool) { @@ -612,7 +737,7 @@ func TestErrorHeaderPropagation(t *testing.T) { assert.Equal(t, meta.Values("X-Test"), []string(nil)) } } - testServices := func(t *testing.T, client pingv1connect.PingServiceClient) { + testServices := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { t.Helper() t.Run("unary", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) @@ -660,17 +785,17 @@ func TestErrorHeaderPropagation(t *testing.T) { } t.Run("connect", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) testServices(t, client) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) testServices(t, client) }) t.Run("grpc-web", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) testServices(t, client) }) } @@ -692,10 +817,10 @@ func TestHeaderBasic(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) request := connect.NewRequest(&pingv1.PingRequest{}) request.Header().Set(key, cval) response, err := client.Ping(context.Background(), request) @@ -721,12 +846,12 @@ func TestHeaderHost(t *testing.T) { newHTTP2Server := func(t *testing.T) *memhttp.Server { t.Helper() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) return server } - callWithHost := func(t *testing.T, client pingv1connect.PingServiceClient) { + callWithHost := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { t.Helper() request := connect.NewRequest(&pingv1.PingRequest{}) @@ -739,21 +864,21 @@ func TestHeaderHost(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) callWithHost(t, client) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) callWithHost(t, client) }) t.Run("grpc-web", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) callWithHost(t, client) }) } @@ -772,12 +897,12 @@ func TestTimeoutParsing(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) } @@ -786,7 +911,7 @@ func TestFailCodec(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) server := memhttptest.NewServer(t, handler) - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), connect.WithCodec(failCodec{}), @@ -803,7 +928,7 @@ func TestContextError(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) server := memhttptest.NewServer(t, handler) - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), ) @@ -822,8 +947,8 @@ func TestGRPCMarshalStatusError(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler( - pingServer{ + mux.Handle(pingv1connectgenerics.NewPingServiceHandler( + pingServerGenerics{ // Include error details in the response, so that the Status protobuf will be marshaled. includeErrorDetails: true, }, @@ -834,7 +959,7 @@ func TestGRPCMarshalStatusError(t *testing.T) { assertInternalError := func(tb testing.TB, opts ...connect.ClientOption) { tb.Helper() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), opts...) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), opts...) request := connect.NewRequest(&pingv1.FailRequest{Code: int32(connect.CodeResourceExhausted)}) _, err := client.Fail(context.Background(), request) tb.Log(err) @@ -871,7 +996,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { pingServer{checkMetadata: true}, )) server := memhttptest.NewServer(t, trimTrailers(mux)) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) assertErrorNoTrailers := func(t *testing.T, err error) { t.Helper() @@ -935,7 +1060,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { func TestUnavailableIfHostInvalid(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( http.DefaultClient, "https://api.invalid/", ) @@ -955,7 +1080,7 @@ func TestBidiRequiresHTTP2(t *testing.T) { assert.Nil(t, err) }) server := memhttptest.NewServer(t, handler) - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( &http.Client{Transport: server.TransportHTTP1()}, server.URL(), ) @@ -986,7 +1111,7 @@ func TestCompressMinBytesClient(t *testing.T) { assert.Equal(tb, request.Header.Get("Content-Encoding"), expect) })) server := memhttptest.NewServer(t, mux) - _, err := pingv1connect.NewPingServiceClient( + _, err := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), connect.WithSendGzip(), @@ -1075,7 +1200,7 @@ func TestCustomCompression(t *testing.T) { connect.WithCompression(compressionName, decompressor, compressor), )) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithAcceptCompression(compressionName, decompressor, compressor), connect.WithSendCompression(compressionName), @@ -1094,7 +1219,7 @@ func TestClientWithoutGzipSupport(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithAcceptCompression("gzip", nil, nil), connect.WithSendGzip(), @@ -1144,7 +1269,7 @@ func TestInterceptorReturnsWrongType(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { if _, err := next(ctx, request); err != nil { return nil, err @@ -1176,7 +1301,7 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { return options }), )) - readMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { + readMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_read_max", func(t *testing.T) { t.Parallel() @@ -1227,37 +1352,37 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) } @@ -1268,9 +1393,9 @@ func TestHandlerWithHTTPMaxBytes(t *testing.T) { t.Parallel() const readMaxBytes = 128 mux := http.NewServeMux() - pingRoute, pingHandler := pingv1connect.NewPingServiceHandler(pingServer{}) + pingRoute, pingHandler := pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}) mux.Handle(pingRoute, http.MaxBytesHandler(pingHandler, readMaxBytes)) - run := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { + run := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { t.Helper() t.Run("below_read_max", func(t *testing.T) { t.Parallel() @@ -1308,37 +1433,37 @@ func TestHandlerWithHTTPMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) run(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) run(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) run(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) run(t, client, true) }) } @@ -1354,14 +1479,14 @@ func TestClientWithReadMaxBytes(t *testing.T) { } else { compressionOption = connect.WithCompressMinBytes(math.MaxInt) } - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, compressionOption)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}, compressionOption)) server := memhttptest.NewServer(t, mux) return server } serverUncompressed := createServer(t, false) serverCompressed := createServer(t, true) readMaxBytes := 1024 - readMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { + readMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_read_max", func(t *testing.T) { t.Parallel() @@ -1403,32 +1528,32 @@ func TestClientWithReadMaxBytes(t *testing.T) { } t.Run("connect", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, true) }) } @@ -1436,7 +1561,7 @@ func TestClientWithReadMaxBytes(t *testing.T) { func TestHandlerWithSendMaxBytes(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - sendMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { + sendMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_send_max", func(t *testing.T) { t.Parallel() @@ -1506,37 +1631,37 @@ func TestHandlerWithSendMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, true) }) } @@ -1546,7 +1671,7 @@ func TestClientWithSendMaxBytes(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - sendMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, sendMaxBytes int, compressed bool) { + sendMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, sendMaxBytes int, compressed bool) { t.Helper() t.Run("equal_send_max", func(t *testing.T) { t.Parallel() @@ -1597,37 +1722,37 @@ func TestClientWithSendMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes)) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes)) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) } @@ -1644,9 +1769,9 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), connect.WithClientOptions(opts...), @@ -1680,12 +1805,12 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { func TestStreamForServer(t *testing.T) { t.Parallel() - newPingClient := func(t *testing.T, pingServer pingv1connect.PingServiceHandler) pingv1connect.PingServiceClient { + newPingClient := func(t *testing.T, pingServer pingv1connectgenerics.PingServiceHandler) pingv1connectgenerics.PingServiceClient { t.Helper() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), ) @@ -1851,7 +1976,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { return nil, connect.NewError(connectCode, errors.New("error")) }, } - mux.Handle(pingv1connect.NewPingServiceHandler(pluggableServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pluggableServer)) server := memhttptest.NewServer(t, mux) req, err := http.NewRequestWithContext( context.Background(), @@ -1865,7 +1990,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { assert.Nil(t, err) defer resp.Body.Close() assert.Equal(t, wantHttpStatus, resp.StatusCode) - connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + connectClient := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) connectResp, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) assert.Nil(t, connectResp) @@ -1957,7 +2082,7 @@ func TestFailCompression(t *testing.T) { ), ) server := memhttptest.NewServer(t, mux) - pingclient := pingv1connect.NewPingServiceClient( + pingclient := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), connect.WithAcceptCompression(compressorName, decompressor, compressor), @@ -2006,7 +2131,7 @@ func TestUnflushableResponseWriter(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - pingclient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), tt.options...) + pingclient := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), tt.options...) stream, err := pingclient.CountUp( context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 5}), @@ -2062,10 +2187,10 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { func TestConnectProtocolHeaderSentByDefault(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithRequireConnectProtocolHeader())) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}, connect.WithRequireConnectProtocolHeader())) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) @@ -2114,7 +2239,7 @@ func TestAllowCustomUserAgent(t *testing.T) { const customAgent = "custom" mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ ping: func(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { agent := req.Header().Get("User-Agent") assert.Equal(t, agent, customAgent) @@ -2133,7 +2258,7 @@ func TestAllowCustomUserAgent(t *testing.T) { {"grpcweb", []connect.ClientOption{connect.WithGRPCWeb()}}, } for _, testCase := range tests { - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), testCase.opts...) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), testCase.opts...) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) req.Header().Set("User-Agent", customAgent) _, err := client.Ping(context.Background(), req) @@ -2145,7 +2270,7 @@ func TestWebXUserAgent(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ ping: func(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { agent := req.Header().Get("User-Agent") assert.NotZero(t, agent) @@ -2159,7 +2284,7 @@ func TestWebXUserAgent(t *testing.T) { })) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) _, err := client.Ping(context.Background(), req) assert.Nil(t, err) @@ -2174,7 +2299,7 @@ func TestBidiOverHTTP1(t *testing.T) { // Clients expecting a full-duplex connection that end up with a simplex // HTTP/1.1 connection shouldn't hang. Instead, the server should close the // TCP connection. - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( &http.Client{Transport: server.TransportHTTP1()}, server.URL(), ) @@ -2210,7 +2335,7 @@ func TestHandlerReturnsNilResponse(t *testing.T) { } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ ping: func(ctx context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { return nil, nil //nolint: nilnil }, @@ -2219,7 +2344,7 @@ func TestHandlerReturnsNilResponse(t *testing.T) { }, }, connect.WithRecover(recoverPanic))) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) @@ -2465,7 +2590,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { for _, testcase := range testcases { t.Run(testcase.name, func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), testcase.options..., @@ -2539,12 +2664,12 @@ func TestClientDisconnect(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) var clientConn net.Conn transport := captureTransport(server, &clientConn, gotRequest) serverClient := &http.Client{Transport: transport} - client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(serverClient, server.URL()) stream := client.Sum(context.Background()) // Send header. assert.Nil(t, stream.Send(nil)) @@ -2582,12 +2707,12 @@ func TestClientDisconnect(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) var clientConn net.Conn transport := captureTransport(server, &clientConn, gotRequest) serverClient := &http.Client{Transport: transport} - client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(serverClient, server.URL()) stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) if !assert.Nil(t, err) { return @@ -2652,9 +2777,9 @@ func TestSetProtocolHeaders(t *testing.T) { testcase := tt t.Run(testcase.name, func(t *testing.T) { t.Parallel() - pingServer := &pingServer{} + pingServer := &pingServerGenerics{} mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) clientOpts := []connect.ClientOption{} @@ -2662,7 +2787,7 @@ func TestSetProtocolHeaders(t *testing.T) { // Use a different protocol to test the override. clientOpts = append(clientOpts, connect.WithGRPC()) } - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), clientOpts...) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), clientOpts...) pingProxyServer := &pluggablePingServer{ ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { @@ -2670,14 +2795,14 @@ func TestSetProtocolHeaders(t *testing.T) { }, } proxyMux := http.NewServeMux() - proxyMux.Handle(pingv1connect.NewPingServiceHandler(pingProxyServer)) + proxyMux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingProxyServer)) proxyServer := memhttptest.NewServer(t, proxyMux) proxyClientOpts := []connect.ClientOption{} if testcase.clientOption != nil { proxyClientOpts = append(proxyClientOpts, testcase.clientOption) } - proxyClient := pingv1connect.NewPingServiceClient(proxyServer.Client(), proxyServer.URL(), proxyClientOpts...) + proxyClient := pingv1connectgenerics.NewPingServiceClient(proxyServer.Client(), proxyServer.URL(), proxyClientOpts...) request := connect.NewRequest(&pingv1.PingRequest{Number: 42}) request.Header().Set("X-Test", t.Name()) @@ -2731,7 +2856,7 @@ func (c failCodec) Unmarshal(data []byte, message any) error { } type pluggablePingServer struct { - pingv1connect.UnimplementedPingServiceHandler + pingv1connectgenerics.UnimplementedPingServiceHandler ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) sum func(context.Context, *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) @@ -2792,7 +2917,7 @@ func expectClientHeader(check bool, req connect.AnyRequest) error { return expectMetadata(req.Header(), "header", clientHeader, headerValue) } -func expectMetadata(meta http.Header, metaType, key, value string) error { +func expectMetadata(meta http.Header, metaType, key, value string) error { //nolint:unparam if got := meta.Get(key); got != value { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( "%s %q: got %q, expected %q", @@ -2805,14 +2930,14 @@ func expectMetadata(meta http.Header, metaType, key, value string) error { return nil } -type pingServer struct { - pingv1connect.UnimplementedPingServiceHandler +type pingServerGenerics struct { + pingv1connectgenerics.UnimplementedPingServiceHandler checkMetadata bool includeErrorDetails bool } -func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { +func (p pingServerGenerics) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2833,7 +2958,7 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi return response, nil } -func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { +func (p pingServerGenerics) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2859,7 +2984,7 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa return nil, err } -func (p pingServer) Sum( +func (p pingServerGenerics) Sum( ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest], ) (*connect.Response[pingv1.SumResponse], error) { @@ -2887,7 +3012,7 @@ func (p pingServer) Sum( return response, nil } -func (p pingServer) CountUp( +func (p pingServerGenerics) CountUp( ctx context.Context, request *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse], @@ -2917,7 +3042,7 @@ func (p pingServer) CountUp( return nil } -func (p pingServer) CumSum( +func (p pingServerGenerics) CumSum( ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], ) error { @@ -2949,6 +3074,107 @@ func (p pingServer) CumSum( } } +func expectClientHeaderInCallInfo(check bool, callInfo connect.CallInfo) error { + if !check { + return nil + } + return expectMetadata(callInfo.RequestHeader(), "header", clientHeader, headerValue) +} + +type pingServer struct { + pingv1connect.UnimplementedPingServiceHandler + + checkMetadata bool + includeErrorDetails bool +} + +func (p pingServer) Ping(ctx context.Context, request *pingv1.PingRequest) (*pingv1.PingResponse, error) { + callInfo, ok := connect.CallInfoFromContext(ctx) + if !ok { + return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) + } + if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { + return nil, err + } + if callInfo.Peer().Addr == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } + if callInfo.Peer().Protocol == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } + response := &pingv1.PingResponse{ + Number: request.GetNumber(), + Text: request.GetText(), + } + callInfo.ResponseHeader().Set(handlerHeader, headerValue) + callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) + return response, nil +} + +func (p pingServer) CountUp( + ctx context.Context, + request *pingv1.CountUpRequest, + stream *connect.ServerStream[pingv1.CountUpResponse], +) error { + callInfo, ok := connect.CallInfoFromContext(ctx) + if !ok { + return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) + } + if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { + return err + } + if callInfo.Peer().Addr == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } + if callInfo.Peer().Protocol == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } + if request.GetNumber() <= 0 { + return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( + "number must be positive: got %v", + request.GetNumber(), + )) + } + callInfo.ResponseHeader().Set(handlerHeader, headerValue) + callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) + for i := range request.GetNumber() { + if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { + return err + } + } + return nil +} + +func (p pingServer) Fail(ctx context.Context, request *pingv1.FailRequest) (*pingv1.FailResponse, error) { + callInfo, ok := connect.CallInfoFromContext(ctx) + if !ok { + return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) + } + if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { + return nil, err + } + if callInfo.Peer().Addr == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } + if callInfo.Peer().Protocol == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } + err := connect.NewError( + connect.Code(request.GetCode()), + errors.New(errorMessage), + ) + err.Meta().Set(handlerHeader, headerValue) + err.Meta().Set(handlerTrailer, trailerValue) + if p.includeErrorDetails { + detail, derr := connect.NewErrorDetail(&pingv1.FailRequest{Code: request.GetCode()}) + if derr != nil { + return nil, derr + } + err.AddDetail(detail) + } + return nil, err +} + type deflateReader struct { r io.ReadCloser } diff --git a/context.go b/context.go index 7a7bdbd5..b99e7430 100644 --- a/context.go +++ b/context.go @@ -19,68 +19,119 @@ import ( "net/http" ) -type requestIncomingHeaderContextKey struct{} -type requestOutgoingHeaderContextKey struct{} -type responseHeaderAddressContextKey struct{} -type responseTrailerAddressContextKey struct{} - -// HeaderFromIncomingContext gets the header from a request sent to a handler. -func HeaderFromIncomingContext(ctx context.Context) (http.Header, bool) { - value, ok := ctx.Value(requestIncomingHeaderContextKey{}).(http.Header) - return value, ok +type CallInfo interface { + // Spec returns a description of this call. + Spec() Spec + // Peer describes the other party for this call. + Peer() Peer + // HTTPMethod returns the HTTP method for this request. This is nearly always + // POST, but side-effect-free unary RPCs could be made via a GET. + // + // On a newly created request, via NewRequest, this will return the empty + // string until the actual request is actually sent and the HTTP method + // determined. This means that client interceptor functions will see the + // empty string until *after* they delegate to the handler they wrapped. It + // is even possible for this to return the empty string after such delegation, + // if the request was never actually sent to the server (and thus no + // determination ever made about the HTTP method). + HTTPMethod() string + // RequestHeader returns the HTTP headers for this request. Headers beginning with + // "Connect-" and "Grpc-" are reserved for use by the Connect and gRPC + // protocols: applications may read them but shouldn't write them. + RequestHeader() http.Header + // ResponseHeader returns the HTTP headers for this response. Headers beginning with + // "Connect-" and "Grpc-" are reserved for use by the Connect and gRPC + // protocols: applications may read them but shouldn't write them. + ResponseHeader() http.Header + // ResponseTrailer returns the trailers for this response. Depending on the underlying + // RPC protocol, trailers may be sent as HTTP trailers or a protocol-specific + // block of in-body metadata. + // + // Trailers beginning with "Connect-" and "Grpc-" are reserved for use by the + // Connect and gRPC protocols: applications may read them but shouldn't write + // them. + ResponseTrailer() http.Header + + internalOnly() } -// HeaderFromOutgoingContext gets the header from a request sent by a client. -func HeaderFromOutgoingContext(ctx context.Context) (http.Header, bool) { - value, ok := ctx.Value(requestOutgoingHeaderContextKey{}).(http.Header) - return value, ok +type callInfo struct { + spec Spec + peer Peer + method string + requestHeader http.Header + responseHeader http.Header + responseTrailer http.Header } -// WithIncomingHeader adds the header to the context from a request sent to a handler. -func WithIncomingHeader(ctx context.Context, header http.Header) context.Context { - return context.WithValue(ctx, requestIncomingHeaderContextKey{}, header) +func (c *callInfo) Spec() Spec { + return c.spec } -// WithOutgoingHeader adds the header to the context from a request sent by a client. -func WithOutgoingHeader(ctx context.Context, header http.Header) context.Context { - return context.WithValue(ctx, requestOutgoingHeaderContextKey{}, header) +func (c *callInfo) Peer() Peer { + return c.peer } -// WithStoreResponseHeader returns a new context to be given to a client when making a request -// that will result in the header pointer being set to the response header. -func WithStoreResponseHeader(ctx context.Context, header *http.Header) context.Context { - return context.WithValue(ctx, responseHeaderAddressContextKey{}, header) +func (c *callInfo) RequestHeader() http.Header { + if c.requestHeader == nil { + c.requestHeader = make(http.Header) + } + return c.requestHeader } -// WithStoreResponseTrailer returns a new context to be given to a client when making a request -// that will result in the trailer pointer being set to the response trailer. -func WithStoreResponseTrailer(ctx context.Context, trailer *http.Header) context.Context { - return context.WithValue(ctx, responseTrailerAddressContextKey{}, trailer) +func (c *callInfo) ResponseHeader() http.Header { + if c.responseHeader == nil { + c.responseHeader = make(http.Header) + } + return c.responseHeader } -// SetResponseHeader sets the response header within a simple handler implementation. -func SetResponseHeader(ctx context.Context, header http.Header) { - responseHeaderAddress, ok := ctx.Value(responseHeaderAddressContextKey{}).(*http.Header) - if !ok { - return +func (c *callInfo) ResponseTrailer() http.Header { + if c.responseTrailer == nil { + c.responseTrailer = make(http.Header) } - *responseHeaderAddress = header + return c.responseTrailer +} + +func (c *callInfo) HTTPMethod() string { + return c.method } -// SetResponseTrailer sets the response trailer within a simple handler implementation. -func SetResponseTrailer(ctx context.Context, trailer http.Header) { - responseTrailerAddress, ok := ctx.Value(responseTrailerAddressContextKey{}).(*http.Header) +// internalOnly implements CallInfo. +func (c *callInfo) internalOnly() {} + +type callInfoContextKey struct{} + +// Create a new request context for use from a client. When the returned +// context is passed to RPCs, the returned call info can be used to set +// request metadata before the RPC is invoked and to inspect response +// metadata after the RPC completes. +// +// The returned context may be re-used across RPCs as long as they are +// not concurrent. Results of all CallInfo methods other than +// RequestHeader() are undefined if the context is used with concurrent RPCs. +// If the given context is already associated with an outgoing CallInfo, then +// ctx and the existing CallInfo are returned. +func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { + info, ok := ctx.Value(callInfoContextKey{}).(CallInfo) if !ok { - return + info = &callInfo{} + return context.WithValue(ctx, callInfoContextKey{}, info), info } - *responseTrailerAddress = trailer + return ctx, info +} + +// CallInfoFromContext returns the CallInfo for the given context, if there is one. +func CallInfoFromContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(callInfoContextKey{}).(CallInfo) + return value, ok } func requestFromContext[T any](ctx context.Context, message *T) *Request[T] { - request := NewRequest[T](message) - header, ok := HeaderFromOutgoingContext(ctx) + request := NewRequest(message) + callInfo, ok := CallInfoFromContext(ctx) if ok { - request.setHeader(header) + request.setHeader(callInfo.RequestHeader()) } return request } diff --git a/error_example_test.go b/error_example_test.go index d8155f75..30930a97 100644 --- a/error_example_test.go +++ b/error_example_test.go @@ -22,7 +22,7 @@ import ( connect "connectrpc.com/connect" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" - "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" ) func ExampleError_Message() { @@ -47,14 +47,15 @@ func ExampleIsNotModifiedError() { // Enable client-side support for HTTP GETs. connect.WithHTTPGet(), ) - req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) - first, err := client.Ping(context.Background(), req) + req := &pingv1.PingRequest{Number: 42} + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + _, err := client.Ping(ctx, req) if err != nil { fmt.Println(err) return } // If the server set an Etag, we can use it to cache the response. - etag := first.Header().Get("Etag") + etag := callInfo.ResponseHeader().Get("Etag") if etag == "" { fmt.Println("no Etag in response headers") return @@ -62,7 +63,7 @@ func ExampleIsNotModifiedError() { fmt.Println("cached response with Etag", etag) // Now we'd like to make the same request again, but avoid re-fetching the // response if possible. - req.Header().Set("If-None-Match", etag) + callInfo.RequestHeader().Set("If-None-Match", etag) _, err = client.Ping(context.Background(), req) if connect.IsNotModifiedError(err) { fmt.Println("can reuse cached response") diff --git a/error_not_modified_example_test.go b/error_not_modified_example_test.go index fe520548..fc9c9925 100644 --- a/error_not_modified_example_test.go +++ b/error_not_modified_example_test.go @@ -16,12 +16,13 @@ package connect_test import ( "context" + "errors" "net/http" "strconv" connect "connectrpc.com/connect" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" - "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" ) // ExampleCachingServer is an example of how servers can take advantage the @@ -35,22 +36,27 @@ type ExampleCachingPingServer struct { // indicates this), so clients using the Connect protocol may call it with HTTP // GET requests. This implementation uses Etags to manage client-side caching. func (*ExampleCachingPingServer) Ping( - _ context.Context, - req *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { - resp := connect.NewResponse(&pingv1.PingResponse{ - Number: req.Msg.GetNumber(), - }) + ctx context.Context, + req *pingv1.PingRequest, +) (*pingv1.PingResponse, error) { + resp := &pingv1.PingResponse{ + Number: req.GetNumber(), + } + callInfo, ok := connect.CallInfoFromContext(ctx) + if !ok { + return nil, errors.New("not call info found in context") + } + // Our hashing logic is simple: we use the number in the PingResponse. - hash := strconv.FormatInt(resp.Msg.GetNumber(), 10) + hash := strconv.FormatInt(resp.GetNumber(), 10) // If the request was an HTTP GET, we'll need to check if the client already // has the response cached. - if req.HTTPMethod() == http.MethodGet && req.Header().Get("If-None-Match") == hash { + if callInfo.HTTPMethod() == http.MethodGet && callInfo.RequestHeader().Get("If-None-Match") == hash { return nil, connect.NewNotModifiedError(http.Header{ "Etag": []string{hash}, }) } - resp.Header().Set("Etag", hash) + callInfo.ResponseHeader().Set("Etag", hash) return resp, nil } diff --git a/example_init_test.go b/example_init_test.go index e7abee52..d14275b6 100644 --- a/example_init_test.go +++ b/example_init_test.go @@ -17,7 +17,7 @@ package connect_test import ( "net/http" - "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp" ) diff --git a/handler.go b/handler.go index 9fc9bb6e..9bbd8d8e 100644 --- a/handler.go +++ b/handler.go @@ -16,6 +16,7 @@ package connect import ( "context" + "maps" "net/http" ) @@ -66,12 +67,28 @@ func NewUnaryHandler[Req, Res any]( if err != nil { return err } + // Add the request header to the context, and store the response header + // and trailer to propagate back to the caller. + ctx, ci := NewOutgoingContext(ctx) + call, ok := ci.(*callInfo) + if ok { + call.peer = request.Peer() + call.spec = request.Spec() + call.method = request.HTTPMethod() + call.requestHeader = request.Header() + } response, err := untyped(ctx, request) if err != nil { return err } + // Add response headers/trailers into the conn mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) + + // Add response headers/trailers into the context callinfo also + mergeNonProtocolHeaders(call.ResponseHeader(), response.Header()) + mergeNonProtocolHeaders(call.ResponseTrailer(), response.Trailer()) + return conn.Send(response.Any()) } @@ -98,28 +115,17 @@ func NewUnaryHandlerSimple[Req, Res any]( return NewUnaryHandler( procedure, func(ctx context.Context, request *Request[Req]) (*Response[Res], error) { - var responseHeader http.Header - var responseTrailer http.Header - // Add the request header to the context, and store the response header - // and trailer to propagate back to the caller. - ctx = WithIncomingHeader( - WithStoreResponseHeader( - WithStoreResponseTrailer( - ctx, - &responseTrailer, - ), - &responseHeader, - ), - request.Header(), - ) responseMsg, err := unary(ctx, request.Msg) - if responseMsg != nil { - response := NewResponse(responseMsg) - response.setHeader(responseHeader) - response.setTrailer(responseHeader) - return response, err + if err != nil { + return nil, err } - return nil, err + response := NewResponse(responseMsg) + callInfo, ok := CallInfoFromContext(ctx) + if ok { + response.setHeader(callInfo.ResponseHeader()) + response.setTrailer(callInfo.ResponseTrailer()) + } + return response, err }, options..., ) @@ -169,6 +175,13 @@ func NewServerStreamHandler[Req, Res any]( if err != nil { return err } + ctx, ci := NewOutgoingContext(ctx) + callInfo, _ := ci.(*callInfo) + callInfo.peer = req.Peer() + callInfo.spec = req.Spec() + callInfo.method = req.HTTPMethod() + maps.Copy(callInfo.RequestHeader(), req.Header()) + return implementation(ctx, req, &ServerStream[Res]{conn: conn}) }, ) @@ -187,11 +200,6 @@ func NewServerStreamHandlerSimple[Req, Res any]( return NewServerStreamHandler( procedure, func(ctx context.Context, request *Request[Req], serverStream *ServerStream[Res]) error { - // Add the request header to the context. - ctx = WithIncomingHeader( - ctx, - request.Header(), - ) return implementation(ctx, request.Msg, serverStream) }, options..., diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index b5892fde..a630cff9 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -123,7 +123,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( - pingServer{}, + pingServerGenerics{}, handlerOnion, ), ) @@ -171,7 +171,7 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { return next(ctx, request) } }) - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithInterceptors(interceptor))) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{}, connect.WithInterceptors(interceptor))) server := memhttptest.NewServer(t, mux) connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(interceptor)) _, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) @@ -197,7 +197,7 @@ func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( - pingServer{}, + pingServerGenerics{}, connect.WithInterceptors(handlerChecker), ), ) From bf18abf845f9cbebf19bf6976671d14c978ab4dd Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 16:19:35 -0400 Subject: [PATCH 02/57] Tests Signed-off-by: Steve Ayers --- client.go | 2 +- connect.go | 2 ++ connect_ext_test.go | 20 ++------------------ handler.go | 1 + 4 files changed, 6 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 792430bb..a2d0240e 100644 --- a/client.go +++ b/client.go @@ -189,7 +189,7 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques info.peer = conn.Peer() info.spec = conn.Spec() - mergeHeaders(conn.RequestHeader(), info.requestHeader) + mergeHeaders(info.RequestHeader(), request.header) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. diff --git a/connect.go b/connect.go index 596bc7fd..5843f62a 100644 --- a/connect.go +++ b/connect.go @@ -392,6 +392,8 @@ func receiveUnaryResponse[T any](conn StreamingClientConn, initializer maybeInit if err != nil { return nil, err } + fmt.Printf("Header %+v", conn.ResponseHeader()) + fmt.Printf("trailer %+v", conn.ResponseTrailer()) return &Response[T]{ Msg: msg, header: conn.ResponseHeader(), diff --git a/connect_ext_test.go b/connect_ext_test.go index 61908a13..2a738536 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -166,22 +166,6 @@ func TestCallInfo(t *testing.T) { assert.Nil(t, stream.Close()) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) - // num := int64(42) - // ctx, callInfo := connect.NewOutgoingContext(context.Background()) - // callInfo.RequestHeader().Set(clientHeader, headerValue) - // expect := &pingv1.PingResponse{Number: num} - - // response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) - // assert.Equal(t, response, expect) - // assert.Nil(t, err) - - // // Assert call info values are correctly populated - // assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) - // assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) - // assert.True(t, callInfo.Spec().IsClient) - // assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) - // assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) } @@ -3135,8 +3119,8 @@ func (p pingServer) CountUp( request.GetNumber(), )) } - callInfo.ResponseHeader().Set(handlerHeader, headerValue) - callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) + stream.Conn().ResponseHeader().Set(handlerHeader, headerValue) + stream.Conn().ResponseTrailer().Set(handlerTrailer, trailerValue) for i := range request.GetNumber() { if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { return err diff --git a/handler.go b/handler.go index 9bbd8d8e..3612600e 100644 --- a/handler.go +++ b/handler.go @@ -81,6 +81,7 @@ func NewUnaryHandler[Req, Res any]( if err != nil { return err } + // Add response headers/trailers into the conn mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) From a40af0cbd98da2b23a5fb6e134b19cfbf9088c66 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 16:22:13 -0400 Subject: [PATCH 03/57] Remove print Signed-off-by: Steve Ayers --- connect.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/connect.go b/connect.go index 5843f62a..596bc7fd 100644 --- a/connect.go +++ b/connect.go @@ -392,8 +392,6 @@ func receiveUnaryResponse[T any](conn StreamingClientConn, initializer maybeInit if err != nil { return nil, err } - fmt.Printf("Header %+v", conn.ResponseHeader()) - fmt.Printf("trailer %+v", conn.ResponseTrailer()) return &Response[T]{ Msg: msg, header: conn.ResponseHeader(), From 38f5be6421f9e05ba16003b6283c188f996f723a Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 16:36:13 -0400 Subject: [PATCH 04/57] Simplify Signed-off-by: Steve Ayers --- client_ext_test.go | 10 +- connect_ext_test.go | 256 ++++++++++++++++++++-------------------- example_init_test.go | 2 +- interceptor_ext_test.go | 6 +- 4 files changed, 137 insertions(+), 137 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index 14ae2773..4e5b8351 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -89,7 +89,7 @@ func TestNewClient_InitFailure(t *testing.T) { func TestClientPeer(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) run := func(t *testing.T, unaryHTTPMethod string, opts ...connect.ClientOption) { @@ -205,7 +205,7 @@ func TestGetNoContentHeaders(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(&pingServerGenerics{})) + mux.Handle(pingv1connect.NewPingServiceHandler(&pingServer{})) server := memhttptest.NewServer(t, http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) { if len(req.Header.Values("content-type")) > 0 || len(req.Header.Values("content-encoding")) > 0 || @@ -283,7 +283,7 @@ func TestSpecSchema(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler( - pingServerGenerics{}, + pingServer{}, connect.WithInterceptors(&assertSchemaInterceptor{t}), )) server := memhttptest.NewServer(t, mux) @@ -320,7 +320,7 @@ func TestSpecSchema(t *testing.T) { func TestDynamicClient(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) ctx := context.Background() initializer := func(spec connect.Spec, msg any) error { @@ -494,7 +494,7 @@ func TestClientDeadlineHandling(t *testing.T) { // detector enabled. That's partly why the makefile only runs "slow" // tests with the race detector disabled. - _, handler := pingv1connect.NewPingServiceHandler(pingServerGenerics{}) + _, handler := pingv1connect.NewPingServiceHandler(pingServer{}) svr := httptest.NewUnstartedServer(http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) { if req.Context().Err() != nil { return diff --git a/connect_ext_test.go b/connect_ext_test.go index 2a738536..6327d5de 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -39,9 +39,9 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" - pingv1connectgenerics "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" - "connectrpc.com/connect/internal/gen/simple/connect/import/v1/importv1connect" - "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/generics/connect/import/v1/importv1connect" + "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + pingv1connectsimple "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/proto" @@ -68,11 +68,11 @@ func TestCallInfo(t *testing.T) { t.Run("simple_api", func(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler( - pingServer{checkMetadata: true}, + mux.Handle(pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{checkMetadata: true}, )) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) t.Run("unary", func(t *testing.T) { num := int64(42) ctx, callInfo := connect.NewOutgoingContext(context.Background()) @@ -117,11 +117,11 @@ func TestCallInfo(t *testing.T) { t.Run("generics_api", func(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler( - pingServerGenerics{checkMetadata: true}, + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{checkMetadata: true}, )) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) t.Run("unary", func(t *testing.T) { num := int64(42) request := connect.NewRequest(&pingv1.PingRequest{Number: num}) @@ -172,7 +172,7 @@ func TestCallInfo(t *testing.T) { func TestServer(t *testing.T) { t.Parallel() - testPing := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper + testPing := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("ping", func(t *testing.T) { num := int64(42) request := connect.NewRequest(&pingv1.PingRequest{Number: num}) @@ -226,7 +226,7 @@ func TestServer(t *testing.T) { assert.Equal(t, connect.CodeOf(err), connect.CodeDeadlineExceeded) }) } - testSum := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper + testSum := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("sum", func(t *testing.T) { const ( upTo = 10 @@ -262,7 +262,7 @@ func TestServer(t *testing.T) { assert.Equal(t, got.Header().Values(handlerHeader), []string{headerValue}) }) } - testCountUp := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper + testCountUp := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("count_up", func(t *testing.T) { const upTo = 5 got := make([]int64, 0, upTo) @@ -320,7 +320,7 @@ func TestServer(t *testing.T) { assert.Nil(t, stream.Close()) }) } - testCumSum := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, expectSuccess bool) { //nolint:thelper + testCumSum := func(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper t.Run("cumsum", func(t *testing.T) { send := []int64{3, 5, 1} expect := []int64{3, 8, 9} @@ -435,7 +435,7 @@ func TestServer(t *testing.T) { assert.Nil(t, stream.CloseResponse()) }) } - testErrors := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper + testErrors := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper assertIsHTTPMiddlewareError := func(tb testing.TB, err error) { tb.Helper() assert.NotNil(tb, err) @@ -486,7 +486,7 @@ func TestServer(t *testing.T) { testMatrix := func(t *testing.T, client *http.Client, url string, bidi bool) { //nolint:thelper run := func(t *testing.T, opts ...connect.ClientOption) { t.Helper() - client := pingv1connectgenerics.NewPingServiceClient(client, url, opts...) + client := pingv1connect.NewPingServiceClient(client, url, opts...) testPing(t, client) testSum(t, client) testCountUp(t, client) @@ -551,8 +551,8 @@ func TestServer(t *testing.T) { } mux := http.NewServeMux() - pingRoute, pingHandler := pingv1connectgenerics.NewPingServiceHandler( - pingServerGenerics{checkMetadata: true}, + pingRoute, pingHandler := pingv1connect.NewPingServiceHandler( + pingServer{checkMetadata: true}, ) errorWriter := connect.NewErrorWriter() // Add net/http middleware to the ping service to evaluate HTTP state. @@ -573,15 +573,15 @@ func TestServer(t *testing.T) { } // Check Content-Length is set correctly. switch request.URL.Path { - case pingv1connectgenerics.PingServicePingProcedure, - pingv1connectgenerics.PingServiceFailProcedure, - pingv1connectgenerics.PingServiceCountUpProcedure: + case pingv1connect.PingServicePingProcedure, + pingv1connect.PingServiceFailProcedure, + pingv1connect.PingServiceCountUpProcedure: // Unary requests set Content-Length to the length of the request body. if request.ContentLength < 0 { t.Errorf("%s: expected Content-Length >= 0, got %d", request.URL.Path, request.ContentLength) } - case pingv1connectgenerics.PingServiceSumProcedure, - pingv1connectgenerics.PingServiceCumSumProcedure: + case pingv1connect.PingServiceSumProcedure, + pingv1connect.PingServiceCumSumProcedure: // Streaming requests set Content-Length to -1 or 0 on empty requests. if request.ContentLength > 0 { t.Errorf("%s: expected Content-Length -1 or 0, got %d", request.URL.Path, request.ContentLength) @@ -612,7 +612,7 @@ func TestConcurrentStreams(t *testing.T) { } t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) var done, start sync.WaitGroup start.Add(1) @@ -620,7 +620,7 @@ func TestConcurrentStreams(t *testing.T) { done.Add(1) go func() { defer done.Done() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) var total int64 sum := client.CumSum(context.Background()) start.Wait() @@ -684,7 +684,7 @@ func TestErrorHeaderPropagation(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) assertError := func(t *testing.T, err error, allowCustomHeaders bool) { @@ -721,7 +721,7 @@ func TestErrorHeaderPropagation(t *testing.T) { assert.Equal(t, meta.Values("X-Test"), []string(nil)) } } - testServices := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { + testServices := func(t *testing.T, client pingv1connect.PingServiceClient) { t.Helper() t.Run("unary", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) @@ -769,17 +769,17 @@ func TestErrorHeaderPropagation(t *testing.T) { } t.Run("connect", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) testServices(t, client) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) testServices(t, client) }) t.Run("grpc-web", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) testServices(t, client) }) } @@ -801,10 +801,10 @@ func TestHeaderBasic(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) request := connect.NewRequest(&pingv1.PingRequest{}) request.Header().Set(key, cval) response, err := client.Ping(context.Background(), request) @@ -830,12 +830,12 @@ func TestHeaderHost(t *testing.T) { newHTTP2Server := func(t *testing.T) *memhttp.Server { t.Helper() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) return server } - callWithHost := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { + callWithHost := func(t *testing.T, client pingv1connect.PingServiceClient) { t.Helper() request := connect.NewRequest(&pingv1.PingRequest{}) @@ -848,21 +848,21 @@ func TestHeaderHost(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) callWithHost(t, client) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) callWithHost(t, client) }) t.Run("grpc-web", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) callWithHost(t, client) }) } @@ -881,12 +881,12 @@ func TestTimeoutParsing(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) } @@ -895,7 +895,7 @@ func TestFailCodec(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) server := memhttptest.NewServer(t, handler) - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), connect.WithCodec(failCodec{}), @@ -912,7 +912,7 @@ func TestContextError(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) server := memhttptest.NewServer(t, handler) - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), ) @@ -931,8 +931,8 @@ func TestGRPCMarshalStatusError(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler( - pingServerGenerics{ + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{ // Include error details in the response, so that the Status protobuf will be marshaled. includeErrorDetails: true, }, @@ -943,7 +943,7 @@ func TestGRPCMarshalStatusError(t *testing.T) { assertInternalError := func(tb testing.TB, opts ...connect.ClientOption) { tb.Helper() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), opts...) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), opts...) request := connect.NewRequest(&pingv1.FailRequest{Code: int32(connect.CodeResourceExhausted)}) _, err := client.Fail(context.Background(), request) tb.Log(err) @@ -980,7 +980,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { pingServer{checkMetadata: true}, )) server := memhttptest.NewServer(t, trimTrailers(mux)) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) assertErrorNoTrailers := func(t *testing.T, err error) { t.Helper() @@ -1044,7 +1044,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { func TestUnavailableIfHostInvalid(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( http.DefaultClient, "https://api.invalid/", ) @@ -1064,7 +1064,7 @@ func TestBidiRequiresHTTP2(t *testing.T) { assert.Nil(t, err) }) server := memhttptest.NewServer(t, handler) - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( &http.Client{Transport: server.TransportHTTP1()}, server.URL(), ) @@ -1095,7 +1095,7 @@ func TestCompressMinBytesClient(t *testing.T) { assert.Equal(tb, request.Header.Get("Content-Encoding"), expect) })) server := memhttptest.NewServer(t, mux) - _, err := pingv1connectgenerics.NewPingServiceClient( + _, err := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), connect.WithSendGzip(), @@ -1184,7 +1184,7 @@ func TestCustomCompression(t *testing.T) { connect.WithCompression(compressionName, decompressor, compressor), )) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithAcceptCompression(compressionName, decompressor, compressor), connect.WithSendCompression(compressionName), @@ -1203,7 +1203,7 @@ func TestClientWithoutGzipSupport(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithAcceptCompression("gzip", nil, nil), connect.WithSendGzip(), @@ -1253,7 +1253,7 @@ func TestInterceptorReturnsWrongType(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { if _, err := next(ctx, request); err != nil { return nil, err @@ -1285,7 +1285,7 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { return options }), )) - readMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { + readMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_read_max", func(t *testing.T) { t.Parallel() @@ -1336,37 +1336,37 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) } @@ -1377,9 +1377,9 @@ func TestHandlerWithHTTPMaxBytes(t *testing.T) { t.Parallel() const readMaxBytes = 128 mux := http.NewServeMux() - pingRoute, pingHandler := pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}) + pingRoute, pingHandler := pingv1connect.NewPingServiceHandler(pingServer{}) mux.Handle(pingRoute, http.MaxBytesHandler(pingHandler, readMaxBytes)) - run := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { + run := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { t.Helper() t.Run("below_read_max", func(t *testing.T) { t.Parallel() @@ -1417,37 +1417,37 @@ func TestHandlerWithHTTPMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) run(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) run(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) run(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) run(t, client, true) }) } @@ -1463,14 +1463,14 @@ func TestClientWithReadMaxBytes(t *testing.T) { } else { compressionOption = connect.WithCompressMinBytes(math.MaxInt) } - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}, compressionOption)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, compressionOption)) server := memhttptest.NewServer(t, mux) return server } serverUncompressed := createServer(t, false) serverCompressed := createServer(t, true) readMaxBytes := 1024 - readMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { + readMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_read_max", func(t *testing.T) { t.Parallel() @@ -1512,32 +1512,32 @@ func TestClientWithReadMaxBytes(t *testing.T) { } t.Run("connect", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, true) }) } @@ -1545,7 +1545,7 @@ func TestClientWithReadMaxBytes(t *testing.T) { func TestHandlerWithSendMaxBytes(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - sendMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { + sendMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_send_max", func(t *testing.T) { t.Parallel() @@ -1615,37 +1615,37 @@ func TestHandlerWithSendMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, true) }) } @@ -1655,7 +1655,7 @@ func TestClientWithSendMaxBytes(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - sendMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, sendMaxBytes int, compressed bool) { + sendMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, sendMaxBytes int, compressed bool) { t.Helper() t.Run("equal_send_max", func(t *testing.T) { t.Parallel() @@ -1706,37 +1706,37 @@ func TestClientWithSendMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes)) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes)) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) } @@ -1753,9 +1753,9 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), connect.WithClientOptions(opts...), @@ -1789,12 +1789,12 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { func TestStreamForServer(t *testing.T) { t.Parallel() - newPingClient := func(t *testing.T, pingServer pingv1connectgenerics.PingServiceHandler) pingv1connectgenerics.PingServiceClient { + newPingClient := func(t *testing.T, pingServer pingv1connect.PingServiceHandler) pingv1connect.PingServiceClient { t.Helper() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), ) @@ -1960,7 +1960,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { return nil, connect.NewError(connectCode, errors.New("error")) }, } - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pluggableServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pluggableServer)) server := memhttptest.NewServer(t, mux) req, err := http.NewRequestWithContext( context.Background(), @@ -1974,7 +1974,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { assert.Nil(t, err) defer resp.Body.Close() assert.Equal(t, wantHttpStatus, resp.StatusCode) - connectClient := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) connectResp, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) assert.Nil(t, connectResp) @@ -2066,7 +2066,7 @@ func TestFailCompression(t *testing.T) { ), ) server := memhttptest.NewServer(t, mux) - pingclient := pingv1connectgenerics.NewPingServiceClient( + pingclient := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), connect.WithAcceptCompression(compressorName, decompressor, compressor), @@ -2115,7 +2115,7 @@ func TestUnflushableResponseWriter(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - pingclient := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), tt.options...) + pingclient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), tt.options...) stream, err := pingclient.CountUp( context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 5}), @@ -2171,10 +2171,10 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { func TestConnectProtocolHeaderSentByDefault(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}, connect.WithRequireConnectProtocolHeader())) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithRequireConnectProtocolHeader())) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) @@ -2223,7 +2223,7 @@ func TestAllowCustomUserAgent(t *testing.T) { const customAgent = "custom" mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ ping: func(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { agent := req.Header().Get("User-Agent") assert.Equal(t, agent, customAgent) @@ -2242,7 +2242,7 @@ func TestAllowCustomUserAgent(t *testing.T) { {"grpcweb", []connect.ClientOption{connect.WithGRPCWeb()}}, } for _, testCase := range tests { - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), testCase.opts...) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), testCase.opts...) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) req.Header().Set("User-Agent", customAgent) _, err := client.Ping(context.Background(), req) @@ -2254,7 +2254,7 @@ func TestWebXUserAgent(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ ping: func(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { agent := req.Header().Get("User-Agent") assert.NotZero(t, agent) @@ -2268,7 +2268,7 @@ func TestWebXUserAgent(t *testing.T) { })) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) _, err := client.Ping(context.Background(), req) assert.Nil(t, err) @@ -2283,7 +2283,7 @@ func TestBidiOverHTTP1(t *testing.T) { // Clients expecting a full-duplex connection that end up with a simplex // HTTP/1.1 connection shouldn't hang. Instead, the server should close the // TCP connection. - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( &http.Client{Transport: server.TransportHTTP1()}, server.URL(), ) @@ -2319,7 +2319,7 @@ func TestHandlerReturnsNilResponse(t *testing.T) { } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ ping: func(ctx context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { return nil, nil //nolint: nilnil }, @@ -2328,7 +2328,7 @@ func TestHandlerReturnsNilResponse(t *testing.T) { }, }, connect.WithRecover(recoverPanic))) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) @@ -2574,7 +2574,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { for _, testcase := range testcases { t.Run(testcase.name, func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), testcase.options..., @@ -2648,12 +2648,12 @@ func TestClientDisconnect(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) var clientConn net.Conn transport := captureTransport(server, &clientConn, gotRequest) serverClient := &http.Client{Transport: transport} - client := pingv1connectgenerics.NewPingServiceClient(serverClient, server.URL()) + client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) stream := client.Sum(context.Background()) // Send header. assert.Nil(t, stream.Send(nil)) @@ -2691,12 +2691,12 @@ func TestClientDisconnect(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) var clientConn net.Conn transport := captureTransport(server, &clientConn, gotRequest) serverClient := &http.Client{Transport: transport} - client := pingv1connectgenerics.NewPingServiceClient(serverClient, server.URL()) + client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) if !assert.Nil(t, err) { return @@ -2761,9 +2761,9 @@ func TestSetProtocolHeaders(t *testing.T) { testcase := tt t.Run(testcase.name, func(t *testing.T) { t.Parallel() - pingServer := &pingServerGenerics{} + pingServer := &pingServer{} mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) clientOpts := []connect.ClientOption{} @@ -2771,7 +2771,7 @@ func TestSetProtocolHeaders(t *testing.T) { // Use a different protocol to test the override. clientOpts = append(clientOpts, connect.WithGRPC()) } - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), clientOpts...) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), clientOpts...) pingProxyServer := &pluggablePingServer{ ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { @@ -2779,14 +2779,14 @@ func TestSetProtocolHeaders(t *testing.T) { }, } proxyMux := http.NewServeMux() - proxyMux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingProxyServer)) + proxyMux.Handle(pingv1connect.NewPingServiceHandler(pingProxyServer)) proxyServer := memhttptest.NewServer(t, proxyMux) proxyClientOpts := []connect.ClientOption{} if testcase.clientOption != nil { proxyClientOpts = append(proxyClientOpts, testcase.clientOption) } - proxyClient := pingv1connectgenerics.NewPingServiceClient(proxyServer.Client(), proxyServer.URL(), proxyClientOpts...) + proxyClient := pingv1connect.NewPingServiceClient(proxyServer.Client(), proxyServer.URL(), proxyClientOpts...) request := connect.NewRequest(&pingv1.PingRequest{Number: 42}) request.Header().Set("X-Test", t.Name()) @@ -2840,7 +2840,7 @@ func (c failCodec) Unmarshal(data []byte, message any) error { } type pluggablePingServer struct { - pingv1connectgenerics.UnimplementedPingServiceHandler + pingv1connect.UnimplementedPingServiceHandler ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) sum func(context.Context, *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) @@ -2914,14 +2914,14 @@ func expectMetadata(meta http.Header, metaType, key, value string) error { //nol return nil } -type pingServerGenerics struct { - pingv1connectgenerics.UnimplementedPingServiceHandler +type pingServer struct { + pingv1connect.UnimplementedPingServiceHandler checkMetadata bool includeErrorDetails bool } -func (p pingServerGenerics) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { +func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2942,7 +2942,7 @@ func (p pingServerGenerics) Ping(ctx context.Context, request *connect.Request[p return response, nil } -func (p pingServerGenerics) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { +func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2968,7 +2968,7 @@ func (p pingServerGenerics) Fail(ctx context.Context, request *connect.Request[p return nil, err } -func (p pingServerGenerics) Sum( +func (p pingServer) Sum( ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest], ) (*connect.Response[pingv1.SumResponse], error) { @@ -2996,7 +2996,7 @@ func (p pingServerGenerics) Sum( return response, nil } -func (p pingServerGenerics) CountUp( +func (p pingServer) CountUp( ctx context.Context, request *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse], @@ -3026,7 +3026,7 @@ func (p pingServerGenerics) CountUp( return nil } -func (p pingServerGenerics) CumSum( +func (p pingServer) CumSum( ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], ) error { @@ -3065,14 +3065,14 @@ func expectClientHeaderInCallInfo(check bool, callInfo connect.CallInfo) error { return expectMetadata(callInfo.RequestHeader(), "header", clientHeader, headerValue) } -type pingServer struct { - pingv1connect.UnimplementedPingServiceHandler +type pingServerSimple struct { + pingv1connectsimple.UnimplementedPingServiceHandler checkMetadata bool includeErrorDetails bool } -func (p pingServer) Ping(ctx context.Context, request *pingv1.PingRequest) (*pingv1.PingResponse, error) { +func (p pingServerSimple) Ping(ctx context.Context, request *pingv1.PingRequest) (*pingv1.PingResponse, error) { callInfo, ok := connect.CallInfoFromContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) @@ -3095,7 +3095,7 @@ func (p pingServer) Ping(ctx context.Context, request *pingv1.PingRequest) (*pin return response, nil } -func (p pingServer) CountUp( +func (p pingServerSimple) CountUp( ctx context.Context, request *pingv1.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], @@ -3129,7 +3129,7 @@ func (p pingServer) CountUp( return nil } -func (p pingServer) Fail(ctx context.Context, request *pingv1.FailRequest) (*pingv1.FailResponse, error) { +func (p pingServerSimple) Fail(ctx context.Context, request *pingv1.FailRequest) (*pingv1.FailResponse, error) { callInfo, ok := connect.CallInfoFromContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) diff --git a/example_init_test.go b/example_init_test.go index d14275b6..e79d95b8 100644 --- a/example_init_test.go +++ b/example_init_test.go @@ -32,6 +32,6 @@ func init() { // deadlock, see: // (https://github.com/golang/go/issues/48394) mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServerSimple{})) examplePingServer = memhttp.NewServer(mux) } diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index a630cff9..b5892fde 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -123,7 +123,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( - pingServerGenerics{}, + pingServer{}, handlerOnion, ), ) @@ -171,7 +171,7 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { return next(ctx, request) } }) - mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{}, connect.WithInterceptors(interceptor))) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithInterceptors(interceptor))) server := memhttptest.NewServer(t, mux) connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(interceptor)) _, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) @@ -197,7 +197,7 @@ func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( - pingServerGenerics{}, + pingServer{}, connect.WithInterceptors(handlerChecker), ), ) From 4763170549339373ee314ada68ba81d558593549 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 20:37:30 -0400 Subject: [PATCH 05/57] Feedback Signed-off-by: Steve Ayers --- client.go | 44 ++++++++++------- connect_ext_test.go | 21 +++++--- context.go | 78 +++++++++++++++++++++++++----- error_not_modified_example_test.go | 4 +- handler.go | 31 +++++------- 5 files changed, 122 insertions(+), 56 deletions(-) diff --git a/client.go b/client.go index a2d0240e..2f542557 100644 --- a/client.go +++ b/client.go @@ -19,7 +19,6 @@ import ( "errors" "fmt" "io" - "maps" "net/http" "net/url" "strings" @@ -128,23 +127,26 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) if c.err != nil { return nil, c.err } - ctx, ci := NewOutgoingContext(ctx) - call, ok := ci.(*callInfo) - if ok { - call.requestHeader = request.Header() - } + ctx, callInfo := newOutgoingContext(ctx) + callInfo.requestHeader = request.Header() resp, err := c.callUnary(ctx, request) if err != nil { return nil, err } - if ok { - call.peer = request.Peer() - call.spec = request.Spec() - call.method = request.HTTPMethod() - maps.Copy(call.ResponseHeader(), resp.Header()) - maps.Copy(call.ResponseTrailer(), resp.Trailer()) + callInfo.peer = request.Peer() + callInfo.spec = request.Spec() + callInfo.method = request.HTTPMethod() + if callInfo.responseHeader == nil { + callInfo.responseHeader = resp.Header() + } else { + mergeHeaders(callInfo.ResponseHeader(), resp.Header()) + } + if callInfo.responseTrailer == nil { + callInfo.responseTrailer = resp.Trailer() + } else { + mergeHeaders(callInfo.ResponseTrailer(), resp.Trailer()) } return resp, nil @@ -175,21 +177,21 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if c.err != nil { return nil, c.err } - ctx, ctxCallInfo := NewOutgoingContext(ctx) - // Note we don't need to check ok here because it should always be in context - // because of the above call to NewOutgoingContext - info, _ := ctxCallInfo.(*callInfo) conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method - info.method = r.Method }) request.spec = conn.Spec() request.peer = conn.Peer() mergeHeaders(conn.RequestHeader(), request.header) + ctx, ctxCallInfo := NewOutgoingContext(ctx) + // Note we don't need to check ok here because it should always be in context + // because of the above call to NewOutgoingContext + info, _ := ctxCallInfo.(*callInfo) info.peer = conn.Peer() info.spec = conn.Spec() mergeHeaders(info.RequestHeader(), request.header) + // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. @@ -198,15 +200,23 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques _ = conn.CloseResponse() return nil, err } + info.responseHeader = conn.ResponseHeader() + info.responseTrailer = conn.ResponseTrailer() if err := conn.CloseRequest(); err != nil { return nil, err } + return &ServerStreamForClient[Res]{ conn: conn, initializer: c.config.Initializer, }, nil } +// CallServerStreamSimple calls a server streaming procedure using the function signature +// associated with the "simple" generation option. +// +// This option eliminates the [Request] wrapper, and instead uses the context.Context to +// propagate information such as headers. func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { return c.CallServerStream(ctx, requestFromContext(ctx, requestMsg)) } diff --git a/connect_ext_test.go b/connect_ext_test.go index 6327d5de..dbb1e392 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -111,7 +111,7 @@ func TestCallInfo(t *testing.T) { assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + assert.Equal(t, callInfo.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) }) }) t.Run("generics_api", func(t *testing.T) { @@ -156,7 +156,8 @@ func TestCallInfo(t *testing.T) { Number: 1, }) req.Header().Set(clientHeader, headerValue) - stream, err := client.CountUp(context.Background(), req) + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + stream, err := client.CountUp(ctx, req) assert.Nil(t, err) assert.True(t, stream.Receive()) assert.Nil(t, stream.Err()) @@ -165,7 +166,10 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, msg.GetNumber(), 1) assert.Nil(t, stream.Close()) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + // assert.Equal(t, stream.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) + // assert.Equal(t, stream.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) + assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + // assert.Equal(t, callInfo.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) }) }) } @@ -3073,7 +3077,7 @@ type pingServerSimple struct { } func (p pingServerSimple) Ping(ctx context.Context, request *pingv1.PingRequest) (*pingv1.PingResponse, error) { - callInfo, ok := connect.CallInfoFromContext(ctx) + callInfo, ok := connect.CallInfoFromIncomingContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } @@ -3100,7 +3104,8 @@ func (p pingServerSimple) CountUp( request *pingv1.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { - callInfo, ok := connect.CallInfoFromContext(ctx) + fmt.Println("Count Up server") + callInfo, ok := connect.CallInfoFromIncomingContext(ctx) if !ok { return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } @@ -3119,8 +3124,8 @@ func (p pingServerSimple) CountUp( request.GetNumber(), )) } - stream.Conn().ResponseHeader().Set(handlerHeader, headerValue) - stream.Conn().ResponseTrailer().Set(handlerTrailer, trailerValue) + callInfo.ResponseHeader().Set(handlerHeader, headerValue) + callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) for i := range request.GetNumber() { if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { return err @@ -3130,7 +3135,7 @@ func (p pingServerSimple) CountUp( } func (p pingServerSimple) Fail(ctx context.Context, request *pingv1.FailRequest) (*pingv1.FailResponse, error) { - callInfo, ok := connect.CallInfoFromContext(ctx) + callInfo, ok := connect.CallInfoFromIncomingContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } diff --git a/context.go b/context.go index b99e7430..72ef37a2 100644 --- a/context.go +++ b/context.go @@ -20,10 +20,7 @@ import ( ) type CallInfo interface { - // Spec returns a description of this call. - Spec() Spec - // Peer describes the other party for this call. - Peer() Peer + StreamCallInfo // HTTPMethod returns the HTTP method for this request. This is nearly always // POST, but side-effect-free unary RPCs could be made via a GET. // @@ -35,6 +32,13 @@ type CallInfo interface { // if the request was never actually sent to the server (and thus no // determination ever made about the HTTP method). HTTPMethod() string +} + +type StreamCallInfo interface { + // Spec returns a description of this call. + Spec() Spec + // Peer describes the other party for this call. + Peer() Peer // RequestHeader returns the HTTP headers for this request. Headers beginning with // "Connect-" and "Grpc-" are reserved for use by the Connect and gRPC // protocols: applications may read them but shouldn't write them. @@ -100,7 +104,40 @@ func (c *callInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *callInfo) internalOnly() {} -type callInfoContextKey struct{} +type streamCallInfo struct { + conn StreamingHandlerConn +} + +func (c *streamCallInfo) Spec() Spec { + return c.conn.Spec() +} + +func (c *streamCallInfo) Peer() Peer { + return c.conn.Peer() +} + +func (c *streamCallInfo) RequestHeader() http.Header { + return c.conn.RequestHeader() +} + +func (c *streamCallInfo) ResponseHeader() http.Header { + return c.conn.ResponseHeader() +} + +func (c *streamCallInfo) ResponseTrailer() http.Header { + return c.conn.ResponseHeader() +} + +func (c *streamCallInfo) HTTPMethod() string { + // All stream calls are POSTs + return http.MethodPost +} + +// internalOnly implements CallInfo. +func (c *streamCallInfo) internalOnly() {} + +type outgoingCallInfoContextKey struct{} +type incomingCallInfoContextKey struct{} // Create a new request context for use from a client. When the returned // context is passed to RPCs, the returned call info can be used to set @@ -113,23 +150,42 @@ type callInfoContextKey struct{} // If the given context is already associated with an outgoing CallInfo, then // ctx and the existing CallInfo are returned. func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { - info, ok := ctx.Value(callInfoContextKey{}).(CallInfo) + info, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) + if !ok { + info = &callInfo{} + return context.WithValue(ctx, outgoingCallInfoContextKey{}, info), info + } + return ctx, info +} + +func newOutgoingContext(ctx context.Context) (context.Context, *callInfo) { + info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*callInfo) if !ok { info = &callInfo{} - return context.WithValue(ctx, callInfoContextKey{}, info), info + return context.WithValue(ctx, outgoingCallInfoContextKey{}, info), info } return ctx, info } -// CallInfoFromContext returns the CallInfo for the given context, if there is one. -func CallInfoFromContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(callInfoContextKey{}).(CallInfo) +func newIncomingContext(ctx context.Context, info CallInfo) context.Context { + return context.WithValue(ctx, incomingCallInfoContextKey{}, info) +} + +// CallInfoFromOutgoingContext returns the CallInfo for the given context, if there is one. +func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) + return value, ok +} + +// CallInfoFromIncomingContext returns the CallInfo for the given context, if there is one. +func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) return value, ok } func requestFromContext[T any](ctx context.Context, message *T) *Request[T] { request := NewRequest(message) - callInfo, ok := CallInfoFromContext(ctx) + callInfo, ok := CallInfoFromOutgoingContext(ctx) if ok { request.setHeader(callInfo.RequestHeader()) } diff --git a/error_not_modified_example_test.go b/error_not_modified_example_test.go index fc9c9925..3daf8223 100644 --- a/error_not_modified_example_test.go +++ b/error_not_modified_example_test.go @@ -42,9 +42,9 @@ func (*ExampleCachingPingServer) Ping( resp := &pingv1.PingResponse{ Number: req.GetNumber(), } - callInfo, ok := connect.CallInfoFromContext(ctx) + callInfo, ok := connect.CallInfoFromIncomingContext(ctx) if !ok { - return nil, errors.New("not call info found in context") + return nil, errors.New("no call info found in context") } // Our hashing logic is simple: we use the number in the PingResponse. diff --git a/handler.go b/handler.go index 3612600e..dc4ac3d6 100644 --- a/handler.go +++ b/handler.go @@ -16,7 +16,6 @@ package connect import ( "context" - "maps" "net/http" ) @@ -69,14 +68,13 @@ func NewUnaryHandler[Req, Res any]( } // Add the request header to the context, and store the response header // and trailer to propagate back to the caller. - ctx, ci := NewOutgoingContext(ctx) - call, ok := ci.(*callInfo) - if ok { - call.peer = request.Peer() - call.spec = request.Spec() - call.method = request.HTTPMethod() - call.requestHeader = request.Header() + info := &callInfo{ + peer: request.Peer(), + spec: request.Spec(), + method: request.HTTPMethod(), + requestHeader: request.Header(), } + ctx = newIncomingContext(ctx, info) response, err := untyped(ctx, request) if err != nil { return err @@ -87,8 +85,8 @@ func NewUnaryHandler[Req, Res any]( mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) // Add response headers/trailers into the context callinfo also - mergeNonProtocolHeaders(call.ResponseHeader(), response.Header()) - mergeNonProtocolHeaders(call.ResponseTrailer(), response.Trailer()) + mergeNonProtocolHeaders(info.ResponseHeader(), response.Header()) + mergeNonProtocolHeaders(info.ResponseTrailer(), response.Trailer()) return conn.Send(response.Any()) } @@ -121,7 +119,7 @@ func NewUnaryHandlerSimple[Req, Res any]( return nil, err } response := NewResponse(responseMsg) - callInfo, ok := CallInfoFromContext(ctx) + callInfo, ok := CallInfoFromIncomingContext(ctx) if ok { response.setHeader(callInfo.ResponseHeader()) response.setTrailer(callInfo.ResponseTrailer()) @@ -176,13 +174,10 @@ func NewServerStreamHandler[Req, Res any]( if err != nil { return err } - ctx, ci := NewOutgoingContext(ctx) - callInfo, _ := ci.(*callInfo) - callInfo.peer = req.Peer() - callInfo.spec = req.Spec() - callInfo.method = req.HTTPMethod() - maps.Copy(callInfo.RequestHeader(), req.Header()) - + info := &streamCallInfo{ + conn: conn, + } + ctx = newIncomingContext(ctx, info) return implementation(ctx, req, &ServerStream[Res]{conn: conn}) }, ) From 6a95dc5a9ff8d18da704ed3a7a538a4e56bbee81 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 20:56:39 -0400 Subject: [PATCH 06/57] Feedback Signed-off-by: Steve Ayers --- client.go | 27 +++++++++++++-------------- connect_ext_test.go | 14 -------------- context.go | 6 +++--- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/client.go b/client.go index 2f542557..bdc471e7 100644 --- a/client.go +++ b/client.go @@ -154,7 +154,7 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) // CallUnary calls a request-response procedure. func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, request *Req) (*Res, error) { - response, err := c.CallUnary(ctx, requestFromContext(ctx, request)) + response, err := c.CallUnary(ctx, requestFromOutgoingContext(ctx, request)) if response != nil { return response.Msg, err } @@ -180,17 +180,16 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method }) - request.spec = conn.Spec() + _, callInfo := newOutgoingContext(ctx) + callInfo.peer = conn.Peer() + callInfo.spec = conn.Spec() request.peer = conn.Peer() - mergeHeaders(conn.RequestHeader(), request.header) + request.spec = conn.Spec() - ctx, ctxCallInfo := NewOutgoingContext(ctx) - // Note we don't need to check ok here because it should always be in context - // because of the above call to NewOutgoingContext - info, _ := ctxCallInfo.(*callInfo) - info.peer = conn.Peer() - info.spec = conn.Spec() - mergeHeaders(info.RequestHeader(), request.header) + // Merge any callInfo request headers first, then do the request. + // so that context headers show first in the list of headers + mergeHeaders(conn.RequestHeader(), callInfo.RequestHeader()) + mergeHeaders(conn.RequestHeader(), request.header) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the @@ -200,12 +199,12 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques _ = conn.CloseResponse() return nil, err } - info.responseHeader = conn.ResponseHeader() - info.responseTrailer = conn.ResponseTrailer() + callInfo.responseHeader = conn.ResponseHeader() + callInfo.responseTrailer = conn.ResponseTrailer() + if err := conn.CloseRequest(); err != nil { return nil, err } - return &ServerStreamForClient[Res]{ conn: conn, initializer: c.config.Initializer, @@ -218,7 +217,7 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques // This option eliminates the [Request] wrapper, and instead uses the context.Context to // propagate information such as headers. func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { - return c.CallServerStream(ctx, requestFromContext(ctx, requestMsg)) + return c.CallServerStream(ctx, requestFromOutgoingContext(ctx, requestMsg)) } // CallBidiStream calls a bidirectional streaming procedure. diff --git a/connect_ext_test.go b/connect_ext_test.go index dbb1e392..f38dbe3a 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -82,8 +82,6 @@ func TestCallInfo(t *testing.T) { response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) assert.Equal(t, response, expect) assert.Nil(t, err) - - // Assert call info values are correctly populated assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) assert.True(t, callInfo.Spec().IsClient) @@ -104,8 +102,6 @@ func TestCallInfo(t *testing.T) { assert.NotNil(t, msg) assert.Equal(t, msg.GetNumber(), 1) assert.Nil(t, stream.Close()) - - // Assert call info values are correctly populated assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeServer) assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) assert.True(t, callInfo.Spec().IsClient) @@ -132,24 +128,18 @@ func TestCallInfo(t *testing.T) { response, err := client.Ping(ctx, request) assert.Nil(t, err) assert.Equal(t, response.Msg, expect) - assert.Equal(t, request.Spec().StreamType, connect.StreamTypeUnary) assert.Equal(t, request.Spec().Procedure, pingv1connect.PingServicePingProcedure) assert.True(t, request.Spec().IsClient) assert.Equal(t, request.Peer().Addr, httptest.DefaultRemoteAddr) - - // Verify that spec and peer on the callInfo are the same as the request wrapper assert.Equal(t, callInfo.Spec().StreamType, request.Spec().StreamType) assert.Equal(t, callInfo.Spec().Procedure, request.Spec().Procedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, request.Peer().Addr) - - // Verify that the response headers and trailers are the same on callInfo and the response assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) - }) t.Run("server_stream", func(t *testing.T) { req := connect.NewRequest(&pingv1.CountUpRequest{ @@ -166,10 +156,7 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, msg.GetNumber(), 1) assert.Nil(t, stream.Close()) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) - // assert.Equal(t, stream.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) - // assert.Equal(t, stream.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - // assert.Equal(t, callInfo.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) }) }) } @@ -3104,7 +3091,6 @@ func (p pingServerSimple) CountUp( request *pingv1.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { - fmt.Println("Count Up server") callInfo, ok := connect.CallInfoFromIncomingContext(ctx) if !ok { return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) diff --git a/context.go b/context.go index 72ef37a2..ff5895f5 100644 --- a/context.go +++ b/context.go @@ -171,19 +171,19 @@ func newIncomingContext(ctx context.Context, info CallInfo) context.Context { return context.WithValue(ctx, incomingCallInfoContextKey{}, info) } -// CallInfoFromOutgoingContext returns the CallInfo for the given context, if there is one. +// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) return value, ok } -// CallInfoFromIncomingContext returns the CallInfo for the given context, if there is one. +// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) return value, ok } -func requestFromContext[T any](ctx context.Context, message *T) *Request[T] { +func requestFromOutgoingContext[T any](ctx context.Context, message *T) *Request[T] { request := NewRequest(message) callInfo, ok := CallInfoFromOutgoingContext(ctx) if ok { From ee8fafd6af2674d2bc3c339fb3d88809a3c46df9 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 21:20:07 -0400 Subject: [PATCH 07/57] Feedback Signed-off-by: Steve Ayers --- connect.go | 10 ---------- connect_ext_test.go | 4 +++- context.go | 3 ++- handler.go | 21 +++++++-------------- 4 files changed, 12 insertions(+), 26 deletions(-) diff --git a/connect.go b/connect.go index 596bc7fd..caaf838b 100644 --- a/connect.go +++ b/connect.go @@ -287,16 +287,6 @@ func (r *Response[_]) Trailer() http.Header { return r.trailer } -// setHeader sets the response header. -func (r *Response[_]) setHeader(header http.Header) { - r.header = header -} - -// setTrailer sets the response trailer. -func (r *Response[_]) setTrailer(trailer http.Header) { - r.trailer = trailer -} - // internalOnly implements AnyResponse. func (r *Response[_]) internalOnly() {} diff --git a/connect_ext_test.go b/connect_ext_test.go index f38dbe3a..07190db2 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -107,7 +107,7 @@ func TestCallInfo(t *testing.T) { assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, callInfo.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) + // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) t.Run("generics_api", func(t *testing.T) { @@ -157,6 +157,8 @@ func TestCallInfo(t *testing.T) { assert.Nil(t, stream.Close()) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + // assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) } diff --git a/context.go b/context.go index ff5895f5..38087a07 100644 --- a/context.go +++ b/context.go @@ -104,6 +104,7 @@ func (c *callInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *callInfo) internalOnly() {} +// streamCallInfo is a CallInfo implementation used for streaming RPCs. type streamCallInfo struct { conn StreamingHandlerConn } @@ -125,7 +126,7 @@ func (c *streamCallInfo) ResponseHeader() http.Header { } func (c *streamCallInfo) ResponseTrailer() http.Header { - return c.conn.ResponseHeader() + return c.conn.ResponseTrailer() } func (c *streamCallInfo) HTTPMethod() string { diff --git a/handler.go b/handler.go index dc4ac3d6..5a7884a0 100644 --- a/handler.go +++ b/handler.go @@ -80,14 +80,14 @@ func NewUnaryHandler[Req, Res any]( return err } + // Add response headers/trailers into the context callinfo + mergeNonProtocolHeaders(conn.ResponseHeader(), info.ResponseHeader()) + mergeNonProtocolHeaders(conn.ResponseTrailer(), info.ResponseTrailer()) + // Add response headers/trailers into the conn mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) - // Add response headers/trailers into the context callinfo also - mergeNonProtocolHeaders(info.ResponseHeader(), response.Header()) - mergeNonProtocolHeaders(info.ResponseTrailer(), response.Trailer()) - return conn.Send(response.Any()) } @@ -118,13 +118,7 @@ func NewUnaryHandlerSimple[Req, Res any]( if err != nil { return nil, err } - response := NewResponse(responseMsg) - callInfo, ok := CallInfoFromIncomingContext(ctx) - if ok { - response.setHeader(callInfo.ResponseHeader()) - response.setTrailer(callInfo.ResponseTrailer()) - } - return response, err + return NewResponse(responseMsg), nil }, options..., ) @@ -174,10 +168,9 @@ func NewServerStreamHandler[Req, Res any]( if err != nil { return err } - info := &streamCallInfo{ + ctx = newIncomingContext(ctx, &streamCallInfo{ conn: conn, - } - ctx = newIncomingContext(ctx, info) + }) return implementation(ctx, req, &ServerStream[Res]{conn: conn}) }, ) From 9af9940ce4f05a6496cb52721d688a47f4cca114 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 21:23:16 -0400 Subject: [PATCH 08/57] Cleanup Signed-off-by: Steve Ayers --- client.go | 6 +++++- connect_ext_test.go | 3 --- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index bdc471e7..5b0a1e84 100644 --- a/client.go +++ b/client.go @@ -152,7 +152,11 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) return resp, nil } -// CallUnary calls a request-response procedure. +// CallUnarySimple calls a request-response procedure using the function signature +// associated with the "simple" generation option. +// +// This option eliminates the [Request] and [Response] wrappers, and instead uses the +// context.Context to propagate information such as headers. func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, request *Req) (*Res, error) { response, err := c.CallUnary(ctx, requestFromOutgoingContext(ctx, request)) if response != nil { diff --git a/connect_ext_test.go b/connect_ext_test.go index 07190db2..918a3c33 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -107,7 +107,6 @@ func TestCallInfo(t *testing.T) { assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) t.Run("generics_api", func(t *testing.T) { @@ -157,8 +156,6 @@ func TestCallInfo(t *testing.T) { assert.Nil(t, stream.Close()) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - // assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) - // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) } From 1cd1311ea87ba41fdbbcad1b80e9e16c5ba7e2f5 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 11:20:12 -0400 Subject: [PATCH 09/57] Update context.go Co-authored-by: Joshua Humphries <2035234+jhump@users.noreply.github.com> Signed-off-by: Steve Ayers Signed-off-by: Steve Ayers --- context.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/context.go b/context.go index 38087a07..4e0ecf1c 100644 --- a/context.go +++ b/context.go @@ -151,12 +151,7 @@ type incomingCallInfoContextKey struct{} // If the given context is already associated with an outgoing CallInfo, then // ctx and the existing CallInfo are returned. func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { - info, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) - if !ok { - info = &callInfo{} - return context.WithValue(ctx, outgoingCallInfoContextKey{}, info), info - } - return ctx, info + return newOutgoingContext(ctx) } func newOutgoingContext(ctx context.Context) (context.Context, *callInfo) { From 93cacf839990ca987983ab1878b84b3bbb0c9cb8 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 11:20:44 -0400 Subject: [PATCH 10/57] Feedback Signed-off-by: Steve Ayers --- context.go | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/context.go b/context.go index 4e0ecf1c..71dfdf40 100644 --- a/context.go +++ b/context.go @@ -20,21 +20,6 @@ import ( ) type CallInfo interface { - StreamCallInfo - // HTTPMethod returns the HTTP method for this request. This is nearly always - // POST, but side-effect-free unary RPCs could be made via a GET. - // - // On a newly created request, via NewRequest, this will return the empty - // string until the actual request is actually sent and the HTTP method - // determined. This means that client interceptor functions will see the - // empty string until *after* they delegate to the handler they wrapped. It - // is even possible for this to return the empty string after such delegation, - // if the request was never actually sent to the server (and thus no - // determination ever made about the HTTP method). - HTTPMethod() string -} - -type StreamCallInfo interface { // Spec returns a description of this call. Spec() Spec // Peer describes the other party for this call. @@ -57,6 +42,17 @@ type StreamCallInfo interface { ResponseTrailer() http.Header internalOnly() + // HTTPMethod returns the HTTP method for this request. This is nearly always + // POST, but side-effect-free unary RPCs could be made via a GET. + // + // On a newly created request, via NewRequest, this will return the empty + // string until the actual request is actually sent and the HTTP method + // determined. This means that client interceptor functions will see the + // empty string until *after* they delegate to the handler they wrapped. It + // is even possible for this to return the empty string after such delegation, + // if the request was never actually sent to the server (and thus no + // determination ever made about the HTTP method). + HTTPMethod() string } type callInfo struct { From adc81b17852fc3e859cdbe4fc5daf8f4bff39f18 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 13:53:15 -0400 Subject: [PATCH 11/57] Interceptors Signed-off-by: Steve Ayers --- client.go | 40 ++++++++++++------- context.go | 88 +++++++++++++++++++++++++++++++++++------ handler.go | 14 ++++--- interceptor_ext_test.go | 38 ++++++++++++++---- 4 files changed, 140 insertions(+), 40 deletions(-) diff --git a/client.go b/client.go index 5b0a1e84..51065c35 100644 --- a/client.go +++ b/client.go @@ -76,6 +76,8 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien // once at client creation. unarySpec := config.newSpec(StreamTypeUnary) unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { + ctx, callInfo := newOutgoingContext(ctx) + fmt.Printf("unary func call info: %+v\n\n", callInfo) conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) conn.onRequestSend(func(r *http.Request) { request.setRequestMethod(r.Method) @@ -109,6 +111,14 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien request.spec = unarySpec request.peer = client.protocolClient.Peer() protocolClient.WriteRequestHeader(StreamTypeUnary, request.Header()) + + // Also set them in the context so interceptors can inspect context for this information. + ctx, callInfo := newOutgoingContext(ctx) + callInfo.peer = request.Peer() + callInfo.spec = request.Spec() + + fmt.Printf("call unary call info: %+v\n\n", callInfo) + response, err := unaryFunc(ctx, request) if err != nil { return nil, err @@ -122,6 +132,18 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien return client } +type wrapper[Res any] struct { + response *Response[Res] +} + +func (w *wrapper[Res]) ResponseHeader() http.Header { + return w.response.Header() +} + +func (w *wrapper[Res]) ResponseTrailer() http.Header { + return w.response.Trailer() +} + // CallUnary calls a request-response procedure. func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) (*Response[Res], error) { if c.err != nil { @@ -135,18 +157,9 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) return nil, err } - callInfo.peer = request.Peer() - callInfo.spec = request.Spec() callInfo.method = request.HTTPMethod() - if callInfo.responseHeader == nil { - callInfo.responseHeader = resp.Header() - } else { - mergeHeaders(callInfo.ResponseHeader(), resp.Header()) - } - if callInfo.responseTrailer == nil { - callInfo.responseTrailer = resp.Trailer() - } else { - mergeHeaders(callInfo.ResponseTrailer(), resp.Trailer()) + callInfo.responseSource = &wrapper[Res]{ + response: resp, } return resp, nil @@ -187,6 +200,8 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques _, callInfo := newOutgoingContext(ctx) callInfo.peer = conn.Peer() callInfo.spec = conn.Spec() + callInfo.responseSource = conn + request.peer = conn.Peer() request.spec = conn.Spec() @@ -203,9 +218,6 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques _ = conn.CloseResponse() return nil, err } - callInfo.responseHeader = conn.ResponseHeader() - callInfo.responseTrailer = conn.ResponseTrailer() - if err := conn.CloseRequest(); err != nil { return nil, err } diff --git a/context.go b/context.go index 71dfdf40..0db47eae 100644 --- a/context.go +++ b/context.go @@ -19,6 +19,10 @@ import ( "net/http" ) +// CallInfo represents information relevant to an RPC call. +// Values returned by these methods are not thread-safe. Users should expect +// data races if they create an outgoing CallInfo in context and then pass that +// CallInfo to another goroutine and try to call methods on it concurrent with the RPC. type CallInfo interface { // Spec returns a description of this call. Spec() Spec @@ -31,6 +35,9 @@ type CallInfo interface { // ResponseHeader returns the HTTP headers for this response. Headers beginning with // "Connect-" and "Grpc-" are reserved for use by the Connect and gRPC // protocols: applications may read them but shouldn't write them. + // On the client side, this method returns nil before + // the call is actually made. After the call is made, for streaming operations, + // this method will block for the server to actually return response headers. ResponseHeader() http.Header // ResponseTrailer returns the trailers for this response. Depending on the underlying // RPC protocol, trailers may be sent as HTTP trailers or a protocol-specific @@ -39,9 +46,11 @@ type CallInfo interface { // Trailers beginning with "Connect-" and "Grpc-" are reserved for use by the // Connect and gRPC protocols: applications may read them but shouldn't write // them. + // + // On the client side, this method returns nil before + // the call is actually made. After the call is made, for streaming operations, + // this method will block for the server to actually return response trailers. ResponseTrailer() http.Header - - internalOnly() // HTTPMethod returns the HTTP method for this request. This is nearly always // POST, but side-effect-free unary RPCs could be made via a GET. // @@ -53,9 +62,12 @@ type CallInfo interface { // if the request was never actually sent to the server (and thus no // determination ever made about the HTTP method). HTTPMethod() string + + internalOnly() } -type callInfo struct { +// handlerCallInfo is a CallInfo implementation used for handlers. +type handlerCallInfo struct { spec Spec peer Peer method string @@ -64,41 +76,41 @@ type callInfo struct { responseTrailer http.Header } -func (c *callInfo) Spec() Spec { +func (c *handlerCallInfo) Spec() Spec { return c.spec } -func (c *callInfo) Peer() Peer { +func (c *handlerCallInfo) Peer() Peer { return c.peer } -func (c *callInfo) RequestHeader() http.Header { +func (c *handlerCallInfo) RequestHeader() http.Header { if c.requestHeader == nil { c.requestHeader = make(http.Header) } return c.requestHeader } -func (c *callInfo) ResponseHeader() http.Header { +func (c *handlerCallInfo) ResponseHeader() http.Header { if c.responseHeader == nil { c.responseHeader = make(http.Header) } return c.responseHeader } -func (c *callInfo) ResponseTrailer() http.Header { +func (c *handlerCallInfo) ResponseTrailer() http.Header { if c.responseTrailer == nil { c.responseTrailer = make(http.Header) } return c.responseTrailer } -func (c *callInfo) HTTPMethod() string { +func (c *handlerCallInfo) HTTPMethod() string { return c.method } // internalOnly implements CallInfo. -func (c *callInfo) internalOnly() {} +func (c *handlerCallInfo) internalOnly() {} // streamCallInfo is a CallInfo implementation used for streaming RPCs. type streamCallInfo struct { @@ -133,6 +145,56 @@ func (c *streamCallInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *streamCallInfo) internalOnly() {} +type responseSource interface { + ResponseHeader() http.Header + ResponseTrailer() http.Header +} + +// clientCallInfo is a CallInfo implementation used for clients. +type clientCallInfo struct { + responseSource + spec Spec + peer Peer + method string + requestHeader http.Header +} + +func (c *clientCallInfo) Spec() Spec { + return c.spec +} + +func (c *clientCallInfo) Peer() Peer { + return c.peer +} + +func (c *clientCallInfo) RequestHeader() http.Header { + if c.requestHeader == nil { + c.requestHeader = make(http.Header) + } + return c.requestHeader +} + +func (c *clientCallInfo) ResponseHeader() http.Header { + if c.responseSource == nil { + return nil + } + return c.responseSource.ResponseHeader() +} + +func (c *clientCallInfo) ResponseTrailer() http.Header { + if c.responseSource == nil { + return nil + } + return c.responseSource.ResponseTrailer() +} + +func (c *clientCallInfo) HTTPMethod() string { + return c.method +} + +// internalOnly implements CallInfo. +func (c *clientCallInfo) internalOnly() {} + type outgoingCallInfoContextKey struct{} type incomingCallInfoContextKey struct{} @@ -150,10 +212,10 @@ func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { return newOutgoingContext(ctx) } -func newOutgoingContext(ctx context.Context) (context.Context, *callInfo) { - info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*callInfo) +func newOutgoingContext(ctx context.Context) (context.Context, *clientCallInfo) { + info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*clientCallInfo) if !ok { - info = &callInfo{} + info = &clientCallInfo{} return context.WithValue(ctx, outgoingCallInfoContextKey{}, info), info } return ctx, info diff --git a/handler.go b/handler.go index 5a7884a0..fe11bc45 100644 --- a/handler.go +++ b/handler.go @@ -68,7 +68,7 @@ func NewUnaryHandler[Req, Res any]( } // Add the request header to the context, and store the response header // and trailer to propagate back to the caller. - info := &callInfo{ + info := &handlerCallInfo{ peer: request.Peer(), spec: request.Spec(), method: request.HTTPMethod(), @@ -80,11 +80,15 @@ func NewUnaryHandler[Req, Res any]( return err } - // Add response headers/trailers into the context callinfo - mergeNonProtocolHeaders(conn.ResponseHeader(), info.ResponseHeader()) - mergeNonProtocolHeaders(conn.ResponseTrailer(), info.ResponseTrailer()) + // Add response headers/trailers from the context callinfo into the conn if they exist + if info.responseHeader != nil { + mergeNonProtocolHeaders(conn.ResponseHeader(), info.responseHeader) + } + if info.responseTrailer != nil { + mergeNonProtocolHeaders(conn.ResponseTrailer(), info.responseTrailer) + } - // Add response headers/trailers into the conn + // Add response headers/trailers from the response into the conn mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index b5892fde..46e6f173 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -191,8 +191,8 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { t.Parallel() - clientChecker := &httpMethodChecker{client: true} - handlerChecker := &httpMethodChecker{} + clientChecker := &callInfoChecker{client: true} + handlerChecker := &callInfoChecker{} mux := http.NewServeMux() mux.Handle( @@ -344,25 +344,39 @@ func (cc *headerInspectingClientConn) Receive(msg any) error { return err } -type httpMethodChecker struct { +type callInfoChecker struct { client bool count atomic.Int32 } -func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { +func (h *callInfoChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { h.count.Add(1) if h.client { + outgoingCallInfo, ok := connect.CallInfoFromOutgoingContext(ctx) + if !ok { + return nil, fmt.Errorf("no call info found in outgoing context") + } // should be blank until after we make request + if outgoingCallInfo.HTTPMethod() != "" { + return nil, fmt.Errorf("expected blank HTTP method in outgoing context but instead got %q", outgoingCallInfo.HTTPMethod()) + } if req.HTTPMethod() != "" { - return nil, fmt.Errorf("expected blank HTTP method but instead got %q", req.HTTPMethod()) + return nil, fmt.Errorf("expected blank HTTP method in request but instead got %q", req.HTTPMethod()) } } else { + incomingCallInfo, ok := connect.CallInfoFromIncomingContext(ctx) + if !ok { + return nil, fmt.Errorf("no call info found in incoming context") + } // server interceptors see method from the start // NB: In theory, the method could also be GET, not just POST. But for the // configuration under test, it will always be POST. + if incomingCallInfo.HTTPMethod() != http.MethodPost { + return nil, fmt.Errorf("expected HTTP method %s in incoming context but instead got %q", http.MethodPost, incomingCallInfo.HTTPMethod()) + } if req.HTTPMethod() != http.MethodPost { - return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) + return nil, fmt.Errorf("expected HTTP method %s in request but instead got %q", http.MethodPost, req.HTTPMethod()) } } resp, err := unaryFunc(ctx, req) @@ -371,11 +385,19 @@ func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.Unary if req.HTTPMethod() != http.MethodPost { return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) } + // Method should now be set on the outgoing context + // callInfo, ok := connect.CallInfoFromOutgoingContext(ctx) + // if !ok { + // return nil, fmt.Errorf("no call info found in outgoing context after request") + // } + // if callInfo.HTTPMethod() != http.MethodPost { + // return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, callInfo.HTTPMethod()) + // } return resp, err } } -func (h *httpMethodChecker) WrapStreamingClient(clientFunc connect.StreamingClientFunc) connect.StreamingClientFunc { +func (h *callInfoChecker) WrapStreamingClient(clientFunc connect.StreamingClientFunc) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { // method not exposed to streaming interceptor, but that's okay because it's always POST for streams h.count.Add(1) @@ -383,7 +405,7 @@ func (h *httpMethodChecker) WrapStreamingClient(clientFunc connect.StreamingClie } } -func (h *httpMethodChecker) WrapStreamingHandler(handlerFunc connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (h *callInfoChecker) WrapStreamingHandler(handlerFunc connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { // method not exposed to streaming interceptor, but that's okay because it's always POST for streams h.count.Add(1) From 83978658adb4c534a5cfc8dca76c99cb18bb734c Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 14:43:44 -0400 Subject: [PATCH 12/57] Interceptor tests Signed-off-by: Steve Ayers --- client.go | 12 ++--- interceptor_ext_test.go | 104 +++++++++++++++++++++++++++------------- 2 files changed, 75 insertions(+), 41 deletions(-) diff --git a/client.go b/client.go index 51065c35..1b8bd665 100644 --- a/client.go +++ b/client.go @@ -77,10 +77,10 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien unarySpec := config.newSpec(StreamTypeUnary) unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { ctx, callInfo := newOutgoingContext(ctx) - fmt.Printf("unary func call info: %+v\n\n", callInfo) conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) conn.onRequestSend(func(r *http.Request) { request.setRequestMethod(r.Method) + callInfo.method = r.Method }) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the @@ -117,8 +117,6 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien callInfo.peer = request.Peer() callInfo.spec = request.Spec() - fmt.Printf("call unary call info: %+v\n\n", callInfo) - response, err := unaryFunc(ctx, request) if err != nil { return nil, err @@ -127,6 +125,9 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien if !ok { return nil, errorf(CodeInternal, "unexpected client response type %T", response) } + callInfo.responseSource = &wrapper[Res]{ + response: typed, + } return typed, nil } return client @@ -157,11 +158,6 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) return nil, err } - callInfo.method = request.HTTPMethod() - callInfo.responseSource = &wrapper[Res]{ - response: resp, - } - return resp, nil } diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 46e6f173..9201dcc6 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -16,6 +16,7 @@ package connect_test import ( "context" + "errors" "fmt" "net/http" "sync/atomic" @@ -349,50 +350,87 @@ type callInfoChecker struct { count atomic.Int32 } +func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connect.AnyRequest, prerequest bool) error { + // method should be blank until after we make request + if prerequest { //nolint:nestif + if callInfo.HTTPMethod() != "" { + return fmt.Errorf("expected blank HTTP method in outgoing context but instead got %q", callInfo.HTTPMethod()) + } + if req.HTTPMethod() != "" { + return fmt.Errorf("expected blank HTTP method in request but instead got %q", req.HTTPMethod()) + } + } else { + // server interceptors see method from the start + // NB: In theory, the method could also be GET, not just POST. But for the + // configuration under test, it will always be POST. + if callInfo.HTTPMethod() != http.MethodPost { + return fmt.Errorf("expected HTTP method %s in outgoing context but instead got %q", http.MethodPost, callInfo.HTTPMethod()) + } + if req.HTTPMethod() != http.MethodPost { + return fmt.Errorf("expected HTTP method %s in request but instead got %q", http.MethodPost, req.HTTPMethod()) + } + } + if callInfo.Peer().Addr == "" { + return errors.New("no peer set on call info") + } + if callInfo.Spec().Procedure != pingv1connect.PingServicePingProcedure { + return fmt.Errorf("expected spec procedure %s but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) + } + return nil +} + +func (h *callInfoChecker) getCallInfo(ctx context.Context) (connect.CallInfo, error) { + var callInfo connect.CallInfo + if h.client { + info, ok := connect.CallInfoFromOutgoingContext(ctx) + if !ok { + return nil, errors.New("no call info found in outgoing context") + } + callInfo = info + } else { + info, ok := connect.CallInfoFromIncomingContext(ctx) + if !ok { + return nil, errors.New("no call info found in incoming context") + } + callInfo = info + } + return callInfo, nil +} + func (h *callInfoChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { h.count.Add(1) + + callInfo, err := h.getCallInfo(ctx) + if err != nil { + return nil, err + } + if h.client { - outgoingCallInfo, ok := connect.CallInfoFromOutgoingContext(ctx) - if !ok { - return nil, fmt.Errorf("no call info found in outgoing context") - } - // should be blank until after we make request - if outgoingCallInfo.HTTPMethod() != "" { - return nil, fmt.Errorf("expected blank HTTP method in outgoing context but instead got %q", outgoingCallInfo.HTTPMethod()) - } - if req.HTTPMethod() != "" { - return nil, fmt.Errorf("expected blank HTTP method in request but instead got %q", req.HTTPMethod()) + if err := h.validateCallInfo(callInfo, req, true); err != nil { + return nil, err } } else { - incomingCallInfo, ok := connect.CallInfoFromIncomingContext(ctx) - if !ok { - return nil, fmt.Errorf("no call info found in incoming context") - } - // server interceptors see method from the start - // NB: In theory, the method could also be GET, not just POST. But for the - // configuration under test, it will always be POST. - if incomingCallInfo.HTTPMethod() != http.MethodPost { - return nil, fmt.Errorf("expected HTTP method %s in incoming context but instead got %q", http.MethodPost, incomingCallInfo.HTTPMethod()) - } - if req.HTTPMethod() != http.MethodPost { - return nil, fmt.Errorf("expected HTTP method %s in request but instead got %q", http.MethodPost, req.HTTPMethod()) + if err := h.validateCallInfo(callInfo, req, false); err != nil { + return nil, err } } + resp, err := unaryFunc(ctx, req) - // NB: In theory, the method could also be GET, not just POST. But for the - // configuration under test, it will always be POST. - if req.HTTPMethod() != http.MethodPost { - return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) + if err != nil { + return nil, err } + // Method should now be set on the outgoing context - // callInfo, ok := connect.CallInfoFromOutgoingContext(ctx) - // if !ok { - // return nil, fmt.Errorf("no call info found in outgoing context after request") - // } - // if callInfo.HTTPMethod() != http.MethodPost { - // return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, callInfo.HTTPMethod()) - // } + callInfo, err = h.getCallInfo(ctx) + if err != nil { + return nil, err + } + + if err := h.validateCallInfo(callInfo, req, false); err != nil { + return nil, err + } + return resp, err } } From bc250b00b4c769243c0f97e5a752a4e8c2eb9bb3 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:01:44 -0400 Subject: [PATCH 13/57] Feedback Signed-off-by: Steve Ayers --- interceptor_ext_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 9201dcc6..b27943c7 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -190,7 +190,7 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { assert.Nil(t, countUpStream.Close()) } -func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { +func TestInterceptorFuncAccessingCallInfo(t *testing.T) { t.Parallel() clientChecker := &callInfoChecker{client: true} handlerChecker := &callInfoChecker{} @@ -350,9 +350,9 @@ type callInfoChecker struct { count atomic.Int32 } -func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connect.AnyRequest, prerequest bool) error { +func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connect.AnyRequest, expectMethod bool) error { // method should be blank until after we make request - if prerequest { //nolint:nestif + if !expectMethod { //nolint:nestif if callInfo.HTTPMethod() != "" { return fmt.Errorf("expected blank HTTP method in outgoing context but instead got %q", callInfo.HTTPMethod()) } @@ -407,11 +407,11 @@ func (h *callInfoChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFu } if h.client { - if err := h.validateCallInfo(callInfo, req, true); err != nil { + if err := h.validateCallInfo(callInfo, req, false); err != nil { return nil, err } } else { - if err := h.validateCallInfo(callInfo, req, false); err != nil { + if err := h.validateCallInfo(callInfo, req, true); err != nil { return nil, err } } @@ -427,7 +427,7 @@ func (h *callInfoChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFu return nil, err } - if err := h.validateCallInfo(callInfo, req, false); err != nil { + if err := h.validateCallInfo(callInfo, req, true); err != nil { return nil, err } From 0a44db9505651764dee1d6c0f86d6ec5be453fe7 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:04:13 -0400 Subject: [PATCH 14/57] Feedback Signed-off-by: Steve Ayers --- client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client.go b/client.go index 1b8bd665..3c342970 100644 --- a/client.go +++ b/client.go @@ -116,6 +116,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien ctx, callInfo := newOutgoingContext(ctx) callInfo.peer = request.Peer() callInfo.spec = request.Spec() + callInfo.requestHeader = request.Header() response, err := unaryFunc(ctx, request) if err != nil { From 94dbb48fa6079e7d264a0471d46b11c7c8e249b2 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:13:41 -0400 Subject: [PATCH 15/57] Update header setting Signed-off-by: Steve Ayers --- handler.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/handler.go b/handler.go index fe11bc45..9975712d 100644 --- a/handler.go +++ b/handler.go @@ -88,9 +88,13 @@ func NewUnaryHandler[Req, Res any]( mergeNonProtocolHeaders(conn.ResponseTrailer(), info.responseTrailer) } - // Add response headers/trailers from the response into the conn - mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) - mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) + // Add response headers/trailers from the response into the conn if they exist + if len(response.Header()) != 0 { + mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) + } + if len(response.Trailer()) != 0 { + mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) + } return conn.Send(response.Any()) } From 57e869899a862d613b01610f67260bae9aecb92c Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:19:40 -0400 Subject: [PATCH 16/57] Fix responseWrapper docs Signed-off-by: Steve Ayers --- client.go | 14 +------------- context.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 3c342970..cdd676c4 100644 --- a/client.go +++ b/client.go @@ -126,7 +126,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien if !ok { return nil, errorf(CodeInternal, "unexpected client response type %T", response) } - callInfo.responseSource = &wrapper[Res]{ + callInfo.responseSource = &responseWrapper[Res]{ response: typed, } return typed, nil @@ -134,18 +134,6 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien return client } -type wrapper[Res any] struct { - response *Response[Res] -} - -func (w *wrapper[Res]) ResponseHeader() http.Header { - return w.response.Header() -} - -func (w *wrapper[Res]) ResponseTrailer() http.Header { - return w.response.Trailer() -} - // CallUnary calls a request-response procedure. func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) (*Response[Res], error) { if c.err != nil { diff --git a/context.go b/context.go index 0db47eae..00716158 100644 --- a/context.go +++ b/context.go @@ -150,6 +150,19 @@ type responseSource interface { ResponseTrailer() http.Header } +// responseWrapper wraps a Response object so that it can implement the responseSource interface. +type responseWrapper[Res any] struct { + response *Response[Res] +} + +func (w *responseWrapper[Res]) ResponseHeader() http.Header { + return w.response.Header() +} + +func (w *responseWrapper[Res]) ResponseTrailer() http.Header { + return w.response.Trailer() +} + // clientCallInfo is a CallInfo implementation used for clients. type clientCallInfo struct { responseSource From e422ba2707fdebb437fadc27da76e87b2a3a26f6 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:21:48 -0400 Subject: [PATCH 17/57] Fix again Signed-off-by: Steve Ayers --- client.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/client.go b/client.go index cdd676c4..4dd629d9 100644 --- a/client.go +++ b/client.go @@ -139,9 +139,6 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) if c.err != nil { return nil, c.err } - ctx, callInfo := newOutgoingContext(ctx) - callInfo.requestHeader = request.Header() - resp, err := c.callUnary(ctx, request) if err != nil { return nil, err From 3cdb5e187b0cd9e6580e4747257464033e9e4856 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:25:10 -0400 Subject: [PATCH 18/57] Update tests Signed-off-by: Steve Ayers --- interceptor_ext_test.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index b27943c7..949a6afc 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -354,7 +354,7 @@ func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connec // method should be blank until after we make request if !expectMethod { //nolint:nestif if callInfo.HTTPMethod() != "" { - return fmt.Errorf("expected blank HTTP method in outgoing context but instead got %q", callInfo.HTTPMethod()) + return fmt.Errorf("expected blank HTTP method in context but instead got %q", callInfo.HTTPMethod()) } if req.HTTPMethod() != "" { return fmt.Errorf("expected blank HTTP method in request but instead got %q", req.HTTPMethod()) @@ -364,7 +364,7 @@ func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connec // NB: In theory, the method could also be GET, not just POST. But for the // configuration under test, it will always be POST. if callInfo.HTTPMethod() != http.MethodPost { - return fmt.Errorf("expected HTTP method %s in outgoing context but instead got %q", http.MethodPost, callInfo.HTTPMethod()) + return fmt.Errorf("expected HTTP method %s in context but instead got %q", http.MethodPost, callInfo.HTTPMethod()) } if req.HTTPMethod() != http.MethodPost { return fmt.Errorf("expected HTTP method %s in request but instead got %q", http.MethodPost, req.HTTPMethod()) @@ -373,8 +373,14 @@ func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connec if callInfo.Peer().Addr == "" { return errors.New("no peer set on call info") } + if req.Peer().Addr == "" { + return errors.New("no peer set on request") + } if callInfo.Spec().Procedure != pingv1connect.PingServicePingProcedure { - return fmt.Errorf("expected spec procedure %s but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) + return fmt.Errorf("expected spec procedure %s in context but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) + } + if req.Spec().Procedure != pingv1connect.PingServicePingProcedure { + return fmt.Errorf("expected spec procedure %s on request but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) } return nil } From 6a3ed80a578d79ee0463ca0fc2e3cee8d83deb99 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 19:13:15 -0400 Subject: [PATCH 19/57] Style Signed-off-by: Steve Ayers --- client.go | 8 +--- connect_ext_test.go | 114 ++++++++++++++++++++++---------------------- context.go | 68 +++++++++++++------------- 3 files changed, 94 insertions(+), 96 deletions(-) diff --git a/client.go b/client.go index 4dd629d9..78365943 100644 --- a/client.go +++ b/client.go @@ -126,6 +126,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien if !ok { return nil, errorf(CodeInternal, "unexpected client response type %T", response) } + // Wrap the response and set it into the context callinfo callInfo.responseSource = &responseWrapper[Res]{ response: typed, } @@ -139,12 +140,7 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) if c.err != nil { return nil, c.err } - resp, err := c.callUnary(ctx, request) - if err != nil { - return nil, err - } - - return resp, nil + return c.callUnary(ctx, request) } // CallUnarySimple calls a request-response procedure using the function signature diff --git a/connect_ext_test.go b/connect_ext_test.go index 918a3c33..6cbfa9fa 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2867,43 +2867,6 @@ func (p *pluggablePingServer) CumSum( return p.cumSum(ctx, stream) } -func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { - tb.Helper() - if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { - assert.ErrorIs(tb, err, io.EOF) - assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown) - } - assert.Nil(tb, stream.CloseRequest()) - _, err := stream.Receive() - assert.NotNil(tb, err) // should be 505 - assert.True( - tb, - strings.Contains(err.Error(), "HTTP status 505"), - assert.Sprintf("expected 505, got %v", err), - ) - assert.Nil(tb, stream.CloseResponse()) -} - -func expectClientHeader(check bool, req connect.AnyRequest) error { - if !check { - return nil - } - return expectMetadata(req.Header(), "header", clientHeader, headerValue) -} - -func expectMetadata(meta http.Header, metaType, key, value string) error { //nolint:unparam - if got := meta.Get(key); got != value { - return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( - "%s %q: got %q, expected %q", - metaType, - key, - got, - value, - )) - } - return nil -} - type pingServer struct { pingv1connect.UnimplementedPingServiceHandler @@ -2912,8 +2875,10 @@ type pingServer struct { } func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { - if err := expectClientHeader(p.checkMetadata, request); err != nil { - return nil, err + if p.checkMetadata { + if err := expectMetadata(request.Header()); err != nil { + return nil, err + } } if request.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -2933,8 +2898,10 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi } func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { - if err := expectClientHeader(p.checkMetadata, request); err != nil { - return nil, err + if p.checkMetadata { + if err := expectMetadata(request.Header()); err != nil { + return nil, err + } } if request.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -2963,7 +2930,7 @@ func (p pingServer) Sum( stream *connect.ClientStream[pingv1.SumRequest], ) (*connect.Response[pingv1.SumResponse], error) { if p.checkMetadata { - if err := expectMetadata(stream.RequestHeader(), "header", clientHeader, headerValue); err != nil { + if err := expectMetadata(stream.RequestHeader()); err != nil { return nil, err } } @@ -2991,8 +2958,10 @@ func (p pingServer) CountUp( request *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse], ) error { - if err := expectClientHeader(p.checkMetadata, request); err != nil { - return err + if p.checkMetadata { + if err := expectMetadata(request.Header()); err != nil { + return err + } } if request.Peer().Addr == "" { return connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -3022,7 +2991,7 @@ func (p pingServer) CumSum( ) error { var sum int64 if p.checkMetadata { - if err := expectMetadata(stream.RequestHeader(), "header", clientHeader, headerValue); err != nil { + if err := expectMetadata(stream.RequestHeader()); err != nil { return err } } @@ -3048,13 +3017,6 @@ func (p pingServer) CumSum( } } -func expectClientHeaderInCallInfo(check bool, callInfo connect.CallInfo) error { - if !check { - return nil - } - return expectMetadata(callInfo.RequestHeader(), "header", clientHeader, headerValue) -} - type pingServerSimple struct { pingv1connectsimple.UnimplementedPingServiceHandler @@ -3067,8 +3029,10 @@ func (p pingServerSimple) Ping(ctx context.Context, request *pingv1.PingRequest) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } - if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { - return nil, err + if p.checkMetadata { + if err := expectMetadata(callInfo.RequestHeader()); err != nil { + return nil, err + } } if callInfo.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -3094,8 +3058,10 @@ func (p pingServerSimple) CountUp( if !ok { return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } - if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { - return err + if p.checkMetadata { + if err := expectMetadata(callInfo.RequestHeader()); err != nil { + return err + } } if callInfo.Peer().Addr == "" { return connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -3124,8 +3090,10 @@ func (p pingServerSimple) Fail(ctx context.Context, request *pingv1.FailRequest) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } - if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { - return nil, err + if p.checkMetadata { + if err := expectMetadata(callInfo.RequestHeader()); err != nil { + return nil, err + } } if callInfo.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -3235,3 +3203,33 @@ func (failCompressor) Close() error { } func (failCompressor) Reset(io.Writer) {} + +func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { + tb.Helper() + if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { + assert.ErrorIs(tb, err, io.EOF) + assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown) + } + assert.Nil(tb, stream.CloseRequest()) + _, err := stream.Receive() + assert.NotNil(tb, err) // should be 505 + assert.True( + tb, + strings.Contains(err.Error(), "HTTP status 505"), + assert.Sprintf("expected 505, got %v", err), + ) + assert.Nil(tb, stream.CloseResponse()) +} + +func expectMetadata(meta http.Header) error { + if got := meta.Get(clientHeader); got != headerValue { + return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( + "%s %q: got %q, expected %q", + "header", + clientHeader, + got, + headerValue, + )) + } + return nil +} diff --git a/context.go b/context.go index 00716158..7a1eb4ad 100644 --- a/context.go +++ b/context.go @@ -112,7 +112,7 @@ func (c *handlerCallInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *handlerCallInfo) internalOnly() {} -// streamCallInfo is a CallInfo implementation used for streaming RPCs. +// streamCallInfo is a CallInfo implementation used for streaming RPC handlers. type streamCallInfo struct { conn StreamingHandlerConn } @@ -145,24 +145,6 @@ func (c *streamCallInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *streamCallInfo) internalOnly() {} -type responseSource interface { - ResponseHeader() http.Header - ResponseTrailer() http.Header -} - -// responseWrapper wraps a Response object so that it can implement the responseSource interface. -type responseWrapper[Res any] struct { - response *Response[Res] -} - -func (w *responseWrapper[Res]) ResponseHeader() http.Header { - return w.response.Header() -} - -func (w *responseWrapper[Res]) ResponseTrailer() http.Header { - return w.response.Trailer() -} - // clientCallInfo is a CallInfo implementation used for clients. type clientCallInfo struct { responseSource @@ -211,7 +193,26 @@ func (c *clientCallInfo) internalOnly() {} type outgoingCallInfoContextKey struct{} type incomingCallInfoContextKey struct{} -// Create a new request context for use from a client. When the returned +// responseSource indicates a type that manage response headers and trailers. +type responseSource interface { + ResponseHeader() http.Header + ResponseTrailer() http.Header +} + +// responseWrapper wraps a Response object so that it can implement the responseSource interface. +type responseWrapper[Res any] struct { + response *Response[Res] +} + +func (w *responseWrapper[Res]) ResponseHeader() http.Header { + return w.response.Header() +} + +func (w *responseWrapper[Res]) ResponseTrailer() http.Header { + return w.response.Trailer() +} + +// Create a new outgoing context for use from a client. When the returned // context is passed to RPCs, the returned call info can be used to set // request metadata before the RPC is invoked and to inspect response // metadata after the RPC completes. @@ -225,6 +226,19 @@ func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { return newOutgoingContext(ctx) } +// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. +func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) + return value, ok +} + +// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. +func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) + return value, ok +} + +// Creates a new outgoing context or returns the existing one in context. func newOutgoingContext(ctx context.Context) (context.Context, *clientCallInfo) { info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*clientCallInfo) if !ok { @@ -234,22 +248,12 @@ func newOutgoingContext(ctx context.Context) (context.Context, *clientCallInfo) return ctx, info } +// newOutgoingContext creates a new incoming context. func newIncomingContext(ctx context.Context, info CallInfo) context.Context { return context.WithValue(ctx, incomingCallInfoContextKey{}, info) } -// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. -func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) - return value, ok -} - -// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. -func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) - return value, ok -} - +// requestFromOutgoingContext creates a new Request using the given context and message. func requestFromOutgoingContext[T any](ctx context.Context, message *T) *Request[T] { request := NewRequest(message) callInfo, ok := CallInfoFromOutgoingContext(ctx) From e14c0d70f15ab1b478cd906587fe3d02ad0ece4d Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Wed, 2 Jul 2025 12:46:31 -0400 Subject: [PATCH 20/57] Move func Signed-off-by: Steve Ayers --- context.go | 52 ++++++++++++++++++++++++++-------------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/context.go b/context.go index 7a1eb4ad..41ea749d 100644 --- a/context.go +++ b/context.go @@ -66,6 +66,32 @@ type CallInfo interface { internalOnly() } +// Create a new outgoing context for use from a client. When the returned +// context is passed to RPCs, the returned call info can be used to set +// request metadata before the RPC is invoked and to inspect response +// metadata after the RPC completes. +// +// The returned context may be re-used across RPCs as long as they are +// not concurrent. Results of all CallInfo methods other than +// RequestHeader() are undefined if the context is used with concurrent RPCs. +// If the given context is already associated with an outgoing CallInfo, then +// ctx and the existing CallInfo are returned. +func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { + return newOutgoingContext(ctx) +} + +// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. +func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) + return value, ok +} + +// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. +func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) + return value, ok +} + // handlerCallInfo is a CallInfo implementation used for handlers. type handlerCallInfo struct { spec Spec @@ -212,32 +238,6 @@ func (w *responseWrapper[Res]) ResponseTrailer() http.Header { return w.response.Trailer() } -// Create a new outgoing context for use from a client. When the returned -// context is passed to RPCs, the returned call info can be used to set -// request metadata before the RPC is invoked and to inspect response -// metadata after the RPC completes. -// -// The returned context may be re-used across RPCs as long as they are -// not concurrent. Results of all CallInfo methods other than -// RequestHeader() are undefined if the context is used with concurrent RPCs. -// If the given context is already associated with an outgoing CallInfo, then -// ctx and the existing CallInfo are returned. -func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { - return newOutgoingContext(ctx) -} - -// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. -func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) - return value, ok -} - -// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. -func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) - return value, ok -} - // Creates a new outgoing context or returns the existing one in context. func newOutgoingContext(ctx context.Context) (context.Context, *clientCallInfo) { info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*clientCallInfo) From a2e3f4e7dd5c7bf4b6994ef4c5d2e0111c939a14 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Thu, 3 Jul 2025 15:29:51 -0400 Subject: [PATCH 21/57] Fix server stream tests Signed-off-by: Steve Ayers --- connect_ext_test.go | 60 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index 6cbfa9fa..3e52b5ec 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -92,21 +92,32 @@ func TestCallInfo(t *testing.T) { t.Run("server_stream", func(t *testing.T) { ctx, callInfo := connect.NewOutgoingContext(context.Background()) callInfo.RequestHeader().Set(clientHeader, headerValue) + + val := 3 stream, err := client.CountUp(ctx, &pingv1.CountUpRequest{ - Number: 1, + Number: int64(val), }) assert.Nil(t, err) - assert.True(t, stream.Receive()) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) assert.Nil(t, stream.Err()) - msg := stream.Msg() - assert.NotNil(t, msg) - assert.Equal(t, msg.GetNumber(), 1) assert.Nil(t, stream.Close()) assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeServer) assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) - assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) t.Run("generics_api", func(t *testing.T) { @@ -141,21 +152,44 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) t.Run("server_stream", func(t *testing.T) { + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + callInfo.RequestHeader().Set(clientHeader, headerValue) + + val := 3 req := connect.NewRequest(&pingv1.CountUpRequest{ - Number: 1, + Number: int64(val), }) - req.Header().Set(clientHeader, headerValue) - ctx, callInfo := connect.NewOutgoingContext(context.Background()) stream, err := client.CountUp(ctx, req) assert.Nil(t, err) - assert.True(t, stream.Receive()) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) assert.Nil(t, stream.Err()) - msg := stream.Msg() - assert.NotNil(t, msg) - assert.Equal(t, msg.GetNumber(), 1) assert.Nil(t, stream.Close()) + // Assert values on request and stream + assert.Equal(t, req.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, req.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, req.Spec().IsClient) + assert.Equal(t, req.Peer().Addr, httptest.DefaultRemoteAddr) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, req.Spec().StreamType) + assert.Equal(t, callInfo.Spec().Procedure, req.Spec().Procedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, req.Peer().Addr) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) } From b83ca031a21b90a84899f09d7074431c9ba70f65 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Wed, 16 Jul 2025 14:30:24 -0400 Subject: [PATCH 22/57] Rename context methods and always create a new call info when using exported NewClientContextAPI Signed-off-by: Steve Ayers --- client.go | 50 ++++++++------ connect_ext_test.go | 14 ++-- context.go | 53 +++++++-------- error_example_test.go | 2 +- error_not_modified_example_test.go | 2 +- handler.go | 4 +- interceptor_ext_test.go | 104 ++++++----------------------- 7 files changed, 84 insertions(+), 145 deletions(-) diff --git a/client.go b/client.go index 78365943..9e45deaa 100644 --- a/client.go +++ b/client.go @@ -76,11 +76,13 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien // once at client creation. unarySpec := config.newSpec(StreamTypeUnary) unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { - ctx, callInfo := newOutgoingContext(ctx) conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) conn.onRequestSend(func(r *http.Request) { request.setRequestMethod(r.Method) - callInfo.method = r.Method + callInfo, ok := getClientCallInfoFromContext(ctx) + if ok { + callInfo.method = r.Method + } }) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the @@ -112,11 +114,13 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien request.peer = client.protocolClient.Peer() protocolClient.WriteRequestHeader(StreamTypeUnary, request.Header()) - // Also set them in the context so interceptors can inspect context for this information. - ctx, callInfo := newOutgoingContext(ctx) - callInfo.peer = request.Peer() - callInfo.spec = request.Spec() - callInfo.requestHeader = request.Header() + // Also set them in the context if there's a call info present + callInfo, callInfoOk := getClientCallInfoFromContext(ctx) + if callInfoOk { + callInfo.peer = request.Peer() + callInfo.spec = request.Spec() + callInfo.requestHeader = request.Header() + } response, err := unaryFunc(ctx, request) if err != nil { @@ -126,9 +130,11 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien if !ok { return nil, errorf(CodeInternal, "unexpected client response type %T", response) } - // Wrap the response and set it into the context callinfo - callInfo.responseSource = &responseWrapper[Res]{ - response: typed, + if callInfoOk { + // Wrap the response and set it into the context callinfo + callInfo.responseSource = &responseWrapper[Res]{ + response: typed, + } } return typed, nil } @@ -149,7 +155,7 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) // This option eliminates the [Request] and [Response] wrappers, and instead uses the // context.Context to propagate information such as headers. func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, request *Req) (*Res, error) { - response, err := c.CallUnary(ctx, requestFromOutgoingContext(ctx, request)) + response, err := c.CallUnary(ctx, requestFromClientContext(ctx, request)) if response != nil { return response.Msg, err } @@ -175,17 +181,21 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method }) - _, callInfo := newOutgoingContext(ctx) - callInfo.peer = conn.Peer() - callInfo.spec = conn.Spec() - callInfo.responseSource = conn - request.peer = conn.Peer() request.spec = conn.Spec() - // Merge any callInfo request headers first, then do the request. - // so that context headers show first in the list of headers - mergeHeaders(conn.RequestHeader(), callInfo.RequestHeader()) + callInfo, ok := getClientCallInfoFromContext(ctx) + // Set values in the context if there's a call info present + if ok { + callInfo.peer = conn.Peer() + callInfo.spec = conn.Spec() + callInfo.responseSource = conn + + // Merge any callInfo request headers first, then do the request. + // so that context headers show first in the list of headers + mergeHeaders(conn.RequestHeader(), callInfo.RequestHeader()) + } + mergeHeaders(conn.RequestHeader(), request.header) // Send always returns an io.EOF unless the error is from the client-side. @@ -211,7 +221,7 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques // This option eliminates the [Request] wrapper, and instead uses the context.Context to // propagate information such as headers. func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { - return c.CallServerStream(ctx, requestFromOutgoingContext(ctx, requestMsg)) + return c.CallServerStream(ctx, requestFromClientContext(ctx, requestMsg)) } // CallBidiStream calls a bidirectional streaming procedure. diff --git a/connect_ext_test.go b/connect_ext_test.go index 3e52b5ec..9492b119 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -75,7 +75,7 @@ func TestCallInfo(t *testing.T) { client := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) t.Run("unary", func(t *testing.T) { num := int64(42) - ctx, callInfo := connect.NewOutgoingContext(context.Background()) + ctx, callInfo := connect.NewClientContext(context.Background()) callInfo.RequestHeader().Set(clientHeader, headerValue) expect := &pingv1.PingResponse{Number: num} @@ -90,7 +90,7 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) t.Run("server_stream", func(t *testing.T) { - ctx, callInfo := connect.NewOutgoingContext(context.Background()) + ctx, callInfo := connect.NewClientContext(context.Background()) callInfo.RequestHeader().Set(clientHeader, headerValue) val := 3 @@ -134,7 +134,7 @@ func TestCallInfo(t *testing.T) { request.Header().Set(clientHeader, headerValue) expect := &pingv1.PingResponse{Number: num} - ctx, callInfo := connect.NewOutgoingContext(context.Background()) + ctx, callInfo := connect.NewClientContext(context.Background()) response, err := client.Ping(ctx, request) assert.Nil(t, err) assert.Equal(t, response.Msg, expect) @@ -152,7 +152,7 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) t.Run("server_stream", func(t *testing.T) { - ctx, callInfo := connect.NewOutgoingContext(context.Background()) + ctx, callInfo := connect.NewClientContext(context.Background()) callInfo.RequestHeader().Set(clientHeader, headerValue) val := 3 @@ -3059,7 +3059,7 @@ type pingServerSimple struct { } func (p pingServerSimple) Ping(ctx context.Context, request *pingv1.PingRequest) (*pingv1.PingResponse, error) { - callInfo, ok := connect.CallInfoFromIncomingContext(ctx) + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } @@ -3088,7 +3088,7 @@ func (p pingServerSimple) CountUp( request *pingv1.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { - callInfo, ok := connect.CallInfoFromIncomingContext(ctx) + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) if !ok { return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } @@ -3120,7 +3120,7 @@ func (p pingServerSimple) CountUp( } func (p pingServerSimple) Fail(ctx context.Context, request *pingv1.FailRequest) (*pingv1.FailResponse, error) { - callInfo, ok := connect.CallInfoFromIncomingContext(ctx) + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } diff --git a/context.go b/context.go index 41ea749d..b34a6255 100644 --- a/context.go +++ b/context.go @@ -21,7 +21,7 @@ import ( // CallInfo represents information relevant to an RPC call. // Values returned by these methods are not thread-safe. Users should expect -// data races if they create an outgoing CallInfo in context and then pass that +// data races if they create an outgoing client CallInfo in context and then pass that // CallInfo to another goroutine and try to call methods on it concurrent with the RPC. type CallInfo interface { // Spec returns a description of this call. @@ -66,29 +66,28 @@ type CallInfo interface { internalOnly() } -// Create a new outgoing context for use from a client. When the returned -// context is passed to RPCs, the returned call info can be used to set +// Create a new client (i.e. outgoing) context for use from a client. When the +// returned context is passed to RPCs, the returned call info can be used to set // request metadata before the RPC is invoked and to inspect response // metadata after the RPC completes. // // The returned context may be re-used across RPCs as long as they are // not concurrent. Results of all CallInfo methods other than // RequestHeader() are undefined if the context is used with concurrent RPCs. -// If the given context is already associated with an outgoing CallInfo, then -// ctx and the existing CallInfo are returned. -func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { - return newOutgoingContext(ctx) +func NewClientContext(ctx context.Context) (context.Context, CallInfo) { + info := &clientCallInfo{} + return context.WithValue(ctx, clientCallInfoContextKey{}, info), info } -// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. -func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) +// CallInfoFromClientContext returns the CallInfo for the given client context, if there is one. +func CallInfoFromClientContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(clientCallInfoContextKey{}).(CallInfo) return value, ok } -// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. -func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) +// CallInfoFromHandlerContext returns the CallInfo for the given handler (i.e. incoming) context, if there is one. +func CallInfoFromHandlerContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(handlerCallInfoContextKey{}).(CallInfo) return value, ok } @@ -216,8 +215,8 @@ func (c *clientCallInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *clientCallInfo) internalOnly() {} -type outgoingCallInfoContextKey struct{} -type incomingCallInfoContextKey struct{} +type clientCallInfoContextKey struct{} +type handlerCallInfoContextKey struct{} // responseSource indicates a type that manage response headers and trailers. type responseSource interface { @@ -238,25 +237,21 @@ func (w *responseWrapper[Res]) ResponseTrailer() http.Header { return w.response.Trailer() } -// Creates a new outgoing context or returns the existing one in context. -func newOutgoingContext(ctx context.Context) (context.Context, *clientCallInfo) { - info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*clientCallInfo) - if !ok { - info = &clientCallInfo{} - return context.WithValue(ctx, outgoingCallInfoContextKey{}, info), info - } - return ctx, info +// Gets a client (i.e. outgoing) call info from context. +func getClientCallInfoFromContext(ctx context.Context) (*clientCallInfo, bool) { + info, ok := ctx.Value(clientCallInfoContextKey{}).(*clientCallInfo) + return info, ok } -// newOutgoingContext creates a new incoming context. -func newIncomingContext(ctx context.Context, info CallInfo) context.Context { - return context.WithValue(ctx, incomingCallInfoContextKey{}, info) +// newHandlerContext creates a new handler (i.e. incoming) context. +func newHandlerContext(ctx context.Context, info CallInfo) context.Context { + return context.WithValue(ctx, handlerCallInfoContextKey{}, info) } -// requestFromOutgoingContext creates a new Request using the given context and message. -func requestFromOutgoingContext[T any](ctx context.Context, message *T) *Request[T] { +// requestFromClientContext creates a new Request using the given context and message. +func requestFromClientContext[T any](ctx context.Context, message *T) *Request[T] { request := NewRequest(message) - callInfo, ok := CallInfoFromOutgoingContext(ctx) + callInfo, ok := CallInfoFromClientContext(ctx) if ok { request.setHeader(callInfo.RequestHeader()) } diff --git a/error_example_test.go b/error_example_test.go index 30930a97..1bcc68c3 100644 --- a/error_example_test.go +++ b/error_example_test.go @@ -48,7 +48,7 @@ func ExampleIsNotModifiedError() { connect.WithHTTPGet(), ) req := &pingv1.PingRequest{Number: 42} - ctx, callInfo := connect.NewOutgoingContext(context.Background()) + ctx, callInfo := connect.NewClientContext(context.Background()) _, err := client.Ping(ctx, req) if err != nil { fmt.Println(err) diff --git a/error_not_modified_example_test.go b/error_not_modified_example_test.go index 3daf8223..e87cd07a 100644 --- a/error_not_modified_example_test.go +++ b/error_not_modified_example_test.go @@ -42,7 +42,7 @@ func (*ExampleCachingPingServer) Ping( resp := &pingv1.PingResponse{ Number: req.GetNumber(), } - callInfo, ok := connect.CallInfoFromIncomingContext(ctx) + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) if !ok { return nil, errors.New("no call info found in context") } diff --git a/handler.go b/handler.go index 9975712d..5eb5e1e0 100644 --- a/handler.go +++ b/handler.go @@ -74,7 +74,7 @@ func NewUnaryHandler[Req, Res any]( method: request.HTTPMethod(), requestHeader: request.Header(), } - ctx = newIncomingContext(ctx, info) + ctx = newHandlerContext(ctx, info) response, err := untyped(ctx, request) if err != nil { return err @@ -176,7 +176,7 @@ func NewServerStreamHandler[Req, Res any]( if err != nil { return err } - ctx = newIncomingContext(ctx, &streamCallInfo{ + ctx = newHandlerContext(ctx, &streamCallInfo{ conn: conn, }) return implementation(ctx, req, &ServerStream[Res]{conn: conn}) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 949a6afc..b5892fde 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -16,7 +16,6 @@ package connect_test import ( "context" - "errors" "fmt" "net/http" "sync/atomic" @@ -190,10 +189,10 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { assert.Nil(t, countUpStream.Close()) } -func TestInterceptorFuncAccessingCallInfo(t *testing.T) { +func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { t.Parallel() - clientChecker := &callInfoChecker{client: true} - handlerChecker := &callInfoChecker{} + clientChecker := &httpMethodChecker{client: true} + handlerChecker := &httpMethodChecker{} mux := http.NewServeMux() mux.Handle( @@ -345,103 +344,38 @@ func (cc *headerInspectingClientConn) Receive(msg any) error { return err } -type callInfoChecker struct { +type httpMethodChecker struct { client bool count atomic.Int32 } -func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connect.AnyRequest, expectMethod bool) error { - // method should be blank until after we make request - if !expectMethod { //nolint:nestif - if callInfo.HTTPMethod() != "" { - return fmt.Errorf("expected blank HTTP method in context but instead got %q", callInfo.HTTPMethod()) - } - if req.HTTPMethod() != "" { - return fmt.Errorf("expected blank HTTP method in request but instead got %q", req.HTTPMethod()) - } - } else { - // server interceptors see method from the start - // NB: In theory, the method could also be GET, not just POST. But for the - // configuration under test, it will always be POST. - if callInfo.HTTPMethod() != http.MethodPost { - return fmt.Errorf("expected HTTP method %s in context but instead got %q", http.MethodPost, callInfo.HTTPMethod()) - } - if req.HTTPMethod() != http.MethodPost { - return fmt.Errorf("expected HTTP method %s in request but instead got %q", http.MethodPost, req.HTTPMethod()) - } - } - if callInfo.Peer().Addr == "" { - return errors.New("no peer set on call info") - } - if req.Peer().Addr == "" { - return errors.New("no peer set on request") - } - if callInfo.Spec().Procedure != pingv1connect.PingServicePingProcedure { - return fmt.Errorf("expected spec procedure %s in context but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) - } - if req.Spec().Procedure != pingv1connect.PingServicePingProcedure { - return fmt.Errorf("expected spec procedure %s on request but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) - } - return nil -} - -func (h *callInfoChecker) getCallInfo(ctx context.Context) (connect.CallInfo, error) { - var callInfo connect.CallInfo - if h.client { - info, ok := connect.CallInfoFromOutgoingContext(ctx) - if !ok { - return nil, errors.New("no call info found in outgoing context") - } - callInfo = info - } else { - info, ok := connect.CallInfoFromIncomingContext(ctx) - if !ok { - return nil, errors.New("no call info found in incoming context") - } - callInfo = info - } - return callInfo, nil -} - -func (h *callInfoChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { +func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { h.count.Add(1) - - callInfo, err := h.getCallInfo(ctx) - if err != nil { - return nil, err - } - if h.client { - if err := h.validateCallInfo(callInfo, req, false); err != nil { - return nil, err + // should be blank until after we make request + if req.HTTPMethod() != "" { + return nil, fmt.Errorf("expected blank HTTP method but instead got %q", req.HTTPMethod()) } } else { - if err := h.validateCallInfo(callInfo, req, true); err != nil { - return nil, err + // server interceptors see method from the start + // NB: In theory, the method could also be GET, not just POST. But for the + // configuration under test, it will always be POST. + if req.HTTPMethod() != http.MethodPost { + return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) } } - resp, err := unaryFunc(ctx, req) - if err != nil { - return nil, err - } - - // Method should now be set on the outgoing context - callInfo, err = h.getCallInfo(ctx) - if err != nil { - return nil, err - } - - if err := h.validateCallInfo(callInfo, req, true); err != nil { - return nil, err + // NB: In theory, the method could also be GET, not just POST. But for the + // configuration under test, it will always be POST. + if req.HTTPMethod() != http.MethodPost { + return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) } - return resp, err } } -func (h *callInfoChecker) WrapStreamingClient(clientFunc connect.StreamingClientFunc) connect.StreamingClientFunc { +func (h *httpMethodChecker) WrapStreamingClient(clientFunc connect.StreamingClientFunc) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { // method not exposed to streaming interceptor, but that's okay because it's always POST for streams h.count.Add(1) @@ -449,7 +383,7 @@ func (h *callInfoChecker) WrapStreamingClient(clientFunc connect.StreamingClient } } -func (h *callInfoChecker) WrapStreamingHandler(handlerFunc connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (h *httpMethodChecker) WrapStreamingHandler(handlerFunc connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { // method not exposed to streaming interceptor, but that's okay because it's always POST for streams h.count.Add(1) From 7d37a59ace20f724834f126815145215ca1b1dd2 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 21 Jul 2025 10:11:25 -0400 Subject: [PATCH 23/57] Interceptor tests Signed-off-by: Steve Ayers --- client.go | 24 +++++- connect.go | 41 ++++++++++ connect_ext_test.go | 1 - context.go | 9 +-- interceptor.go | 39 ++++++++- interceptor_ext_test.go | 173 +++++++++++++++++++++++++++++++++++++++- 6 files changed, 272 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index 9e45deaa..2a7af459 100644 --- a/client.go +++ b/client.go @@ -104,6 +104,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien return response, conn.CloseResponse() }) if interceptor := config.Interceptor; interceptor != nil { + // interceptor here is the chain unaryFunc = interceptor.WrapUnary(unaryFunc) } client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) { @@ -119,7 +120,19 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien if callInfoOk { callInfo.peer = request.Peer() callInfo.spec = request.Spec() + // A client could have set request headers in the call info OR the request wrapper + // So if a callInfo exists in context, merge any headers from there into the request wrapper + // so that all headers are sent in the request + mergeHeaders(request.Header(), callInfo.requestHeader) + // Then, set the full list of merged headers into the call info so users can query the context + // for this information + // TODO - Does this necessarily need done? callInfo.requestHeader = request.Header() + + // Copy the call info into a sentinel value. This is so we can compare + // the sentinel value against the call info in context. If they're different, + // we can stop the request. This protects against changing the context in interceptors. + ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo) } response, err := unaryFunc(ctx, request) @@ -178,15 +191,22 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if c.err != nil { return nil, c.err } + callInfo, callInfoOk := getClientCallInfoFromContext(ctx) + // Set values in the context if there's a call info present + if callInfoOk { + // Copy the call info into a sentinel value. This is so we can compare + // the sentinel value against the call info in context. If they're different, + // we can stop the request. This protects against changing the context in interceptors. + ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo) + } conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method }) request.peer = conn.Peer() request.spec = conn.Spec() - callInfo, ok := getClientCallInfoFromContext(ctx) // Set values in the context if there's a call info present - if ok { + if callInfoOk { callInfo.peer = conn.Peer() callInfo.spec = conn.Spec() callInfo.responseSource = conn diff --git a/connect.go b/connect.go index caaf838b..963c7998 100644 --- a/connect.go +++ b/connect.go @@ -373,6 +373,47 @@ type hasHTTPMethod interface { getHTTPMethod() string } +type errStreamingClientConn struct { + StreamingClientConn + err error +} + +func (c *errStreamingClientConn) Receive(msg any) error { + return c.err +} + +func (c *errStreamingClientConn) Spec() Spec { + return Spec{} +} + +func (c *errStreamingClientConn) Peer() Peer { + return Peer{} +} + +func (c *errStreamingClientConn) Send(msg any) error { + return c.err +} + +func (c *errStreamingClientConn) CloseRequest() error { + return c.err +} + +func (c *errStreamingClientConn) CloseResponse() error { + return c.err +} + +func (c *errStreamingClientConn) RequestHeader() http.Header { + return make(http.Header) +} + +func (c *errStreamingClientConn) ResponseHeader() http.Header { + return make(http.Header) +} + +func (c *errStreamingClientConn) ResponseTrailer() http.Header { + return make(http.Header) +} + // receiveUnaryResponse unmarshals a message from a StreamingClientConn, then // envelopes the message and attaches headers and trailers. It attempts to // consume the response stream and isn't appropriate when receiving multiple diff --git a/connect_ext_test.go b/connect_ext_test.go index 9492b119..eddec15a 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -78,7 +78,6 @@ func TestCallInfo(t *testing.T) { ctx, callInfo := connect.NewClientContext(context.Background()) callInfo.RequestHeader().Set(clientHeader, headerValue) expect := &pingv1.PingResponse{Number: num} - response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) assert.Equal(t, response, expect) assert.Nil(t, err) diff --git a/context.go b/context.go index b34a6255..004be9ea 100644 --- a/context.go +++ b/context.go @@ -79,12 +79,6 @@ func NewClientContext(ctx context.Context) (context.Context, CallInfo) { return context.WithValue(ctx, clientCallInfoContextKey{}, info), info } -// CallInfoFromClientContext returns the CallInfo for the given client context, if there is one. -func CallInfoFromClientContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(clientCallInfoContextKey{}).(CallInfo) - return value, ok -} - // CallInfoFromHandlerContext returns the CallInfo for the given handler (i.e. incoming) context, if there is one. func CallInfoFromHandlerContext(ctx context.Context) (CallInfo, bool) { value, ok := ctx.Value(handlerCallInfoContextKey{}).(CallInfo) @@ -216,6 +210,7 @@ func (c *clientCallInfo) HTTPMethod() string { func (c *clientCallInfo) internalOnly() {} type clientCallInfoContextKey struct{} +type sentinelContextKey struct{} type handlerCallInfoContextKey struct{} // responseSource indicates a type that manage response headers and trailers. @@ -251,7 +246,7 @@ func newHandlerContext(ctx context.Context, info CallInfo) context.Context { // requestFromClientContext creates a new Request using the given context and message. func requestFromClientContext[T any](ctx context.Context, message *T) *Request[T] { request := NewRequest(message) - callInfo, ok := CallInfoFromClientContext(ctx) + callInfo, ok := getClientCallInfoFromContext(ctx) if ok { request.setHeader(callInfo.RequestHeader()) } diff --git a/interceptor.go b/interceptor.go index f0c3620a..f5aac3e1 100644 --- a/interceptor.go +++ b/interceptor.go @@ -16,6 +16,13 @@ package connect import ( "context" + "errors" +) + +var ( + // errNewClientContextProhibited signals that a new client context was created + // in an interceptor, which is prohibited. + errNewClientContextProhibited = errors.New("creating a new context in an interceptor is prohibited") ) // UnaryFunc is the generic signature of a unary RPC. Interceptors may wrap @@ -36,9 +43,8 @@ type StreamingHandlerFunc func(context.Context, StreamingHandlerConn) error // An Interceptor adds logic to a generated handler or client, like the // decorators or middleware you may have seen in other libraries. Interceptors -// may replace the context, mutate requests and responses, handle errors, -// retry, recover from panics, emit logs and metrics, or do nearly anything -// else. +// may mutate requests and responses, handle errors, retry, recover from panics, +// emit logs and metrics, or do nearly anything else. // // The returned functions must be safe to call concurrently. type Interceptor interface { @@ -85,6 +91,7 @@ func newChain(interceptors []Interceptor) *chain { func (c *chain) WrapUnary(next UnaryFunc) UnaryFunc { for _, interceptor := range c.interceptors { + next = unaryThunk(next) next = interceptor.WrapUnary(next) } return next @@ -92,6 +99,7 @@ func (c *chain) WrapUnary(next UnaryFunc) UnaryFunc { func (c *chain) WrapStreamingClient(next StreamingClientFunc) StreamingClientFunc { for _, interceptor := range c.interceptors { + next = streamingClientThunk(next) next = interceptor.WrapStreamingClient(next) } return next @@ -103,3 +111,28 @@ func (c *chain) WrapStreamingHandler(next StreamingHandlerFunc) StreamingHandler } return next } + +func unaryThunk(next UnaryFunc) UnaryFunc { + return func(ctx context.Context, req AnyRequest) (AnyResponse, error) { + if !checkSentinel(ctx) { + return nil, errNewClientContextProhibited + } + return next(ctx, req) + } +} + +func streamingClientThunk(next StreamingClientFunc) StreamingClientFunc { + return func(ctx context.Context, spec Spec) StreamingClientConn { + if !checkSentinel(ctx) { + return &errStreamingClientConn{err: errNewClientContextProhibited} + } + return next(ctx, spec) + } +} + +func checkSentinel(ctx context.Context) bool { + callInfo, _ := ctx.Value(clientCallInfoContextKey{}).(*clientCallInfo) + sentinel, _ := ctx.Value(sentinelContextKey{}).(*clientCallInfo) + // Only verify if there's a sentinel call info to compare it to + return callInfo == sentinel +} diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index b5892fde..bc022c7c 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -28,6 +28,139 @@ import ( "connectrpc.com/connect/internal/memhttp/memhttptest" ) +func TestNewClientContextFails(t *testing.T) { + // Verifies that calling NewClientContext in an interceptor fails when sending the new context downstream + t.Parallel() + t.Run("unary", func(t *testing.T) { + t.Parallel() + t.Run("first_interceptor", func(t *testing.T) { + t.Parallel() + createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + return connect.WithInterceptors( + &contextInterceptor{client: true, count: client1, createNewContext: true}, + &contextInterceptor{client: true, count: client2}, + ) + } + var client1, client2 atomic.Int32 + mux := http.NewServeMux() + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + createInterceptors(&client1, &client2), + ) + _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + + // Since we are creating a new client context, an error will be returned from the invocation + assert.NotNil(t, err) + assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") + // And because we're creating it in the first interceptor, only the first interceptor fires + assert.Equal(t, int32(1), client1.Load()) + assert.Equal(t, int32(0), client2.Load()) + }) + t.Run("subsequent_interceptor", func(t *testing.T) { + t.Parallel() + createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + return connect.WithInterceptors( + &contextInterceptor{client: true, count: client1}, + &contextInterceptor{client: true, count: client2, createNewContext: true}, + ) + } + var client1, client2 atomic.Int32 + mux := http.NewServeMux() + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + createInterceptors(&client1, &client2), + ) + _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + + // Since we are creating a new client context, an error will be returned from the invocation + assert.NotNil(t, err) + assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") + // And because we're creating it in the second interceptor, they both fire + assert.Equal(t, int32(1), client1.Load()) + assert.Equal(t, int32(1), client2.Load()) + }) + }) + t.Run("server_streaming", func(t *testing.T) { + t.Parallel() + t.Run("first_interceptor", func(t *testing.T) { + t.Parallel() + createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + return connect.WithInterceptors( + &contextInterceptor{client: true, count: client1, createNewContext: true}, + &contextInterceptor{client: true, count: client2}, + ) + } + var client1, client2 atomic.Int32 + mux := http.NewServeMux() + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + createInterceptors(&client1, &client2), + ) + responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) + + // Since we are creating a new client context, an error will be returned from the invocation + assert.Nil(t, responses) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") + // And because we're creating it in the first interceptor, only the first interceptor fires + assert.Equal(t, int32(1), client1.Load()) + assert.Equal(t, int32(0), client2.Load()) + }) + t.Run("subsequent_interceptor", func(t *testing.T) { + t.Parallel() + createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + return connect.WithInterceptors( + &contextInterceptor{client: true, count: client1}, + &contextInterceptor{client: true, count: client2, createNewContext: true}, + ) + } + var client1, client2 atomic.Int32 + mux := http.NewServeMux() + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + createInterceptors(&client1, &client2), + ) + responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) + + // Since we are creating a new client context, an error will be returned from the invocation + assert.Nil(t, responses) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") + // And because we're creating it in the second interceptor, all interceptors fire + assert.Equal(t, int32(1), client1.Load()) + assert.Equal(t, int32(1), client2.Load()) + }) + }) +} + func TestOnionOrderingEndToEnd(t *testing.T) { t.Parallel() // Helper function: returns a function that asserts that there's some value @@ -349,7 +482,7 @@ type httpMethodChecker struct { count atomic.Int32 } -func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { +func (h *httpMethodChecker) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { h.count.Add(1) if h.client { @@ -365,7 +498,7 @@ func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.Unary return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) } } - resp, err := unaryFunc(ctx, req) + resp, err := next(ctx, req) // NB: In theory, the method could also be GET, not just POST. But for the // configuration under test, it will always be POST. if req.HTTPMethod() != http.MethodPost { @@ -390,3 +523,39 @@ func (h *httpMethodChecker) WrapStreamingHandler(handlerFunc connect.StreamingHa return handlerFunc(ctx, conn) } } + +type contextInterceptor struct { + client bool + count *atomic.Int32 + // Whether the interceptor should attempt to create a new context (which will cause next() to return an error) + createNewContext bool +} + +func (h *contextInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + h.count.Add(1) + if h.createNewContext { + // This will cause next to return an error + ctx, _ = connect.NewClientContext(ctx) + } + return next(ctx, req) + } +} + +func (h *contextInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + h.count.Add(1) + if h.createNewContext { + // This will cause next to return an error + ctx, _ = connect.NewClientContext(ctx) + } + return next(ctx, spec) + } +} + +func (h *contextInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + h.count.Add(1) + return next(ctx, conn) + } +} From 3cadcc1e828d0d09a5794381f52d03e5939d22d2 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 21 Jul 2025 11:10:02 -0400 Subject: [PATCH 24/57] Side quest tests Signed-off-by: Steve Ayers Fix names Signed-off-by: Steve Ayers --- interceptor_ext_test.go | 186 ++++++++++++++++++++++++++++++---------- 1 file changed, 139 insertions(+), 47 deletions(-) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index bc022c7c..9668aff8 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -25,9 +25,44 @@ import ( "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" ) +func TestSideQuestInInterceptor(t *testing.T) { + t.Parallel() + t.Run("unary", func(t *testing.T) { + t.Parallel() + t.Run("sidequest_succeeds", func(t *testing.T) { + t.Parallel() + createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32, server *memhttp.Server) connect.Option { + return connect.WithInterceptors( + newSideQuestInterceptor(t, clientCounter1, server), + newSideQuestInterceptor(t, clientCounter2, server), + ) + } + var clientCounter1, clientCounter2 atomic.Int32 + mux := http.NewServeMux() + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + createInterceptors(&clientCounter1, &clientCounter2, server), + ) + _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + }) +} + func TestNewClientContextFails(t *testing.T) { // Verifies that calling NewClientContext in an interceptor fails when sending the new context downstream t.Parallel() @@ -35,13 +70,13 @@ func TestNewClientContextFails(t *testing.T) { t.Parallel() t.Run("first_interceptor", func(t *testing.T) { t.Parallel() - createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { return connect.WithInterceptors( - &contextInterceptor{client: true, count: client1, createNewContext: true}, - &contextInterceptor{client: true, count: client2}, + &contextInterceptor{client: true, count: clientCounter1, createNewContext: true}, + &contextInterceptor{client: true, count: clientCounter2}, ) } - var client1, client2 atomic.Int32 + var clientCounter1, clientCounter2 atomic.Int32 mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( @@ -52,7 +87,7 @@ func TestNewClientContextFails(t *testing.T) { client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), - createInterceptors(&client1, &client2), + createInterceptors(&clientCounter1, &clientCounter2), ) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) @@ -60,18 +95,18 @@ func TestNewClientContextFails(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") // And because we're creating it in the first interceptor, only the first interceptor fires - assert.Equal(t, int32(1), client1.Load()) - assert.Equal(t, int32(0), client2.Load()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) }) t.Run("subsequent_interceptor", func(t *testing.T) { t.Parallel() - createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { return connect.WithInterceptors( - &contextInterceptor{client: true, count: client1}, - &contextInterceptor{client: true, count: client2, createNewContext: true}, + &contextInterceptor{client: true, count: clientCounter1}, + &contextInterceptor{client: true, count: clientCounter2, createNewContext: true}, ) } - var client1, client2 atomic.Int32 + var clientCounter1, clientCounter2 atomic.Int32 mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( @@ -82,7 +117,7 @@ func TestNewClientContextFails(t *testing.T) { client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), - createInterceptors(&client1, &client2), + createInterceptors(&clientCounter1, &clientCounter2), ) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) @@ -90,21 +125,21 @@ func TestNewClientContextFails(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") // And because we're creating it in the second interceptor, they both fire - assert.Equal(t, int32(1), client1.Load()) - assert.Equal(t, int32(1), client2.Load()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) }) }) t.Run("server_streaming", func(t *testing.T) { t.Parallel() t.Run("first_interceptor", func(t *testing.T) { t.Parallel() - createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { return connect.WithInterceptors( - &contextInterceptor{client: true, count: client1, createNewContext: true}, - &contextInterceptor{client: true, count: client2}, + &contextInterceptor{client: true, count: clientCounter1, createNewContext: true}, + &contextInterceptor{client: true, count: clientCounter2}, ) } - var client1, client2 atomic.Int32 + var clientCounter1, clientCounter2 atomic.Int32 mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( @@ -115,7 +150,7 @@ func TestNewClientContextFails(t *testing.T) { client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), - createInterceptors(&client1, &client2), + createInterceptors(&clientCounter1, &clientCounter2), ) responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) @@ -124,18 +159,18 @@ func TestNewClientContextFails(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") // And because we're creating it in the first interceptor, only the first interceptor fires - assert.Equal(t, int32(1), client1.Load()) - assert.Equal(t, int32(0), client2.Load()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) }) t.Run("subsequent_interceptor", func(t *testing.T) { t.Parallel() - createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { return connect.WithInterceptors( - &contextInterceptor{client: true, count: client1}, - &contextInterceptor{client: true, count: client2, createNewContext: true}, + &contextInterceptor{client: true, count: clientCounter1}, + &contextInterceptor{client: true, count: clientCounter2, createNewContext: true}, ) } - var client1, client2 atomic.Int32 + var clientCounter1, clientCounter2 atomic.Int32 mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( @@ -146,7 +181,7 @@ func TestNewClientContextFails(t *testing.T) { client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), - createInterceptors(&client1, &client2), + createInterceptors(&clientCounter1, &clientCounter2), ) responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) @@ -155,8 +190,8 @@ func TestNewClientContextFails(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") // And because we're creating it in the second interceptor, all interceptors fire - assert.Equal(t, int32(1), client1.Load()) - assert.Equal(t, int32(1), client2.Load()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) }) }) } @@ -201,7 +236,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { } } - var client1, client2, client3, handler1, handler2, handler3 atomic.Int32 + var clientCounter1, clientCounter2, clientCounter3, handlerCounter1, handlerCounter2, handlerCounter3 atomic.Int32 // The client and handler interceptor onions are the meat of the test. The // order of interceptor execution must be the same for unary and streaming @@ -216,7 +251,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { // intended order clear. clientOnion := connect.WithInterceptors( newHeaderInterceptor( - &client1, + &clientCounter1, // 1 (start). request: should see protocol-related headers func(_ connect.Spec, h http.Header) { assert.NotZero(t, h.Get("Content-Type")) @@ -225,29 +260,29 @@ func TestOnionOrderingEndToEnd(t *testing.T) { assertAllPresent, ), newHeaderInterceptor( - &client2, + &clientCounter2, newInspector("", "one"), // 2. request: add header "one" newInspector("three", "four"), // 11. response: check "three", add "four" ), newHeaderInterceptor( - &client3, + &clientCounter3, newInspector("one", "two"), // 3. request: check "one", add "two" newInspector("two", "three"), // 10. response: check "two", add "three" ), ) handlerOnion := connect.WithInterceptors( newHeaderInterceptor( - &handler1, + &handlerCounter1, newInspector("two", "three"), // 4. request: check "two", add "three" newInspector("one", "two"), // 9. response: check "one", add "two" ), newHeaderInterceptor( - &handler2, + &handlerCounter2, newInspector("three", "four"), // 5. request: check "three", add "four" newInspector("", "one"), // 8. response: add "one" ), newHeaderInterceptor( - &handler3, + &handlerCounter3, assertAllPresent, // 6. request: check "one"-"four" nil, // 7. response: no-op ), @@ -271,12 +306,12 @@ func TestOnionOrderingEndToEnd(t *testing.T) { assert.Nil(t, err) // make sure the interceptors were actually invoked - assert.Equal(t, int32(1), client1.Load()) - assert.Equal(t, int32(1), client2.Load()) - assert.Equal(t, int32(1), client3.Load()) - assert.Equal(t, int32(1), handler1.Load()) - assert.Equal(t, int32(1), handler2.Load()) - assert.Equal(t, int32(1), handler3.Load()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + assert.Equal(t, int32(1), clientCounter3.Load()) + assert.Equal(t, int32(1), handlerCounter1.Load()) + assert.Equal(t, int32(1), handlerCounter2.Load()) + assert.Equal(t, int32(1), handlerCounter3.Load()) responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) assert.Nil(t, err) @@ -288,12 +323,12 @@ func TestOnionOrderingEndToEnd(t *testing.T) { assert.Nil(t, responses.Close()) // make sure the interceptors were invoked again - assert.Equal(t, int32(2), client1.Load()) - assert.Equal(t, int32(2), client2.Load()) - assert.Equal(t, int32(2), client3.Load()) - assert.Equal(t, int32(2), handler1.Load()) - assert.Equal(t, int32(2), handler2.Load()) - assert.Equal(t, int32(2), handler3.Load()) + assert.Equal(t, int32(2), clientCounter1.Load()) + assert.Equal(t, int32(2), clientCounter2.Load()) + assert.Equal(t, int32(2), clientCounter3.Load()) + assert.Equal(t, int32(2), handlerCounter1.Load()) + assert.Equal(t, int32(2), handlerCounter2.Load()) + assert.Equal(t, int32(2), handlerCounter3.Load()) } func TestEmptyUnaryInterceptorFunc(t *testing.T) { @@ -559,3 +594,60 @@ func (h *contextInterceptor) WrapStreamingHandler(next connect.StreamingHandlerF return next(ctx, conn) } } + +type sideQuestInterceptor struct { + count *atomic.Int32 + client pingv1connect.PingServiceClient + t *testing.T +} + +func newSideQuestInterceptor( //nolint:thelper + t *testing.T, + counter *atomic.Int32, + server *memhttp.Server, +) *sideQuestInterceptor { + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + ) + return &sideQuestInterceptor{t: t, client: client, count: counter} +} + +func (h *sideQuestInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + h.count.Add(1) + num := int64(42) + // Create a new client context for the side quest Ping. This should succeed because we aren't + // sending this on through the interceptor chain and reusing this context + newCtx, _ := connect.NewClientContext(ctx) + resp, err := h.client.Ping(newCtx, connect.NewRequest(&pingv1.PingRequest{Number: num})) + assert.Nil(h.t, err) + assert.Equal(h.t, resp.Msg.Number, num) + + return next(ctx, req) + } +} + +func (h *sideQuestInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + h.count.Add(1) + // Create a new context for the side quest CountUp. This should succeed because we aren't + // sending this on through the interceptor chain and reusing this context + newCtx, _ := connect.NewClientContext(ctx) + responses, err := h.client.CountUp(newCtx, connect.NewRequest(&pingv1.CountUpRequest{Number: 3})) + assert.Nil(h.t, err) + var sum int64 + for responses.Receive() { + sum += responses.Msg().GetNumber() + } + assert.Equal(h.t, sum, 6) + assert.Nil(h.t, responses.Close()) + return next(ctx, spec) + } +} + +func (h *sideQuestInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + return next(ctx, conn) + } +} From b2d9bce7c8817377ace4f70c746e7c4fd39c9004 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 21 Jul 2025 17:20:03 -0400 Subject: [PATCH 25/57] Extensive testing for simple and generic APIs using callinfo Signed-off-by: Steve Ayers --- client.go | 8 +- connect.go | 5 - connect_ext_test.go | 453 +++++++++++++++++++++++++++++++++----------- context.go | 10 - 4 files changed, 343 insertions(+), 133 deletions(-) diff --git a/client.go b/client.go index 2a7af459..d2be7070 100644 --- a/client.go +++ b/client.go @@ -124,10 +124,6 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien // So if a callInfo exists in context, merge any headers from there into the request wrapper // so that all headers are sent in the request mergeHeaders(request.Header(), callInfo.requestHeader) - // Then, set the full list of merged headers into the call info so users can query the context - // for this information - // TODO - Does this necessarily need done? - callInfo.requestHeader = request.Header() // Copy the call info into a sentinel value. This is so we can compare // the sentinel value against the call info in context. If they're different, @@ -168,7 +164,7 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) // This option eliminates the [Request] and [Response] wrappers, and instead uses the // context.Context to propagate information such as headers. func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, request *Req) (*Res, error) { - response, err := c.CallUnary(ctx, requestFromClientContext(ctx, request)) + response, err := c.CallUnary(ctx, NewRequest(request)) if response != nil { return response.Msg, err } @@ -241,7 +237,7 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques // This option eliminates the [Request] wrapper, and instead uses the context.Context to // propagate information such as headers. func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { - return c.CallServerStream(ctx, requestFromClientContext(ctx, requestMsg)) + return c.CallServerStream(ctx, NewRequest(requestMsg)) } // CallBidiStream calls a bidirectional streaming procedure. diff --git a/connect.go b/connect.go index 963c7998..8f68c46a 100644 --- a/connect.go +++ b/connect.go @@ -211,11 +211,6 @@ func (r *Request[_]) setRequestMethod(method string) { r.method = method } -// setHeader sets the request header to the given value. -func (r *Request[_]) setHeader(header http.Header) { - r.header = header -} - // AnyRequest is the common method set of every [Request], regardless of type // parameter. It's used in unary interceptors. // diff --git a/connect_ext_test.go b/connect_ext_test.go index eddec15a..fd3d06d1 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -31,6 +31,7 @@ import ( "net/http" "net/http/httptest" "runtime" + "sort" "strings" "sync" "testing" @@ -63,20 +64,26 @@ const ( clientMiddlewareErrorHeader = "Connect-Trigger-HTTP-Error" ) +var ( + expectedHeaderValues = []string{"foo", "bar"} //nolint:gochecknoglobals +) + func TestCallInfo(t *testing.T) { t.Parallel() t.Run("simple_api", func(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connectsimple.NewPingServiceHandler( - pingServerSimple{checkMetadata: true}, + pingServerSimple{}, )) server := memhttptest.NewServer(t, mux) client := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) t.Run("unary", func(t *testing.T) { num := int64(42) ctx, callInfo := connect.NewClientContext(context.Background()) - callInfo.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + callInfo.RequestHeader().Add(clientHeader, el) + } expect := &pingv1.PingResponse{Number: num} response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) assert.Equal(t, response, expect) @@ -85,13 +92,23 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) - assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + // When using the simple API for unary calls, users can only access response headers and trailers + // from the call info in context. + assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + }) + t.Run("unary_no_callinfo", func(t *testing.T) { + num := int64(42) + expect := &pingv1.PingResponse{Number: num} + response, err := client.Ping(context.Background(), &pingv1.PingRequest{Number: num}) + assert.Equal(t, response, expect) + assert.Nil(t, err) }) t.Run("server_stream", func(t *testing.T) { ctx, callInfo := connect.NewClientContext(context.Background()) - callInfo.RequestHeader().Set(clientHeader, headerValue) - + for _, el := range expectedHeaderValues { + callInfo.RequestHeader().Add(clientHeader, el) + } val := 3 stream, err := client.CountUp(ctx, &pingv1.CountUpRequest{ Number: int64(val), @@ -115,8 +132,35 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) - assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + + // On server-streaming calls, users can access response headers and trailers + // either from the call info in context or from the stream itself. + // This verifies that the both the stream and the call info have the same information + assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + }) + t.Run("server_stream_no_callinfo", func(t *testing.T) { + val := 3 + stream, err := client.CountUp(context.Background(), &pingv1.CountUpRequest{ + Number: int64(val), + }) + assert.Nil(t, err) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) + assert.Nil(t, stream.Err()) + assert.Nil(t, stream.Close()) }) }) t.Run("generics_api", func(t *testing.T) { @@ -130,10 +174,14 @@ func TestCallInfo(t *testing.T) { t.Run("unary", func(t *testing.T) { num := int64(42) request := connect.NewRequest(&pingv1.PingRequest{Number: num}) - request.Header().Set(clientHeader, headerValue) - expect := &pingv1.PingResponse{Number: num} ctx, callInfo := connect.NewClientContext(context.Background()) + // With the generics API, a user can use the call info or request wrapper or both to set request headers. + // The resulting headers should be combined and sent in the request. + request.Header().Add(clientHeader, "foo") + callInfo.RequestHeader().Add(clientHeader, "bar") + expect := &pingv1.PingResponse{Number: num} + response, err := client.Ping(ctx, request) assert.Nil(t, err) assert.Equal(t, response.Msg, expect) @@ -145,19 +193,42 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.Spec().Procedure, request.Spec().Procedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, request.Peer().Addr) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) - assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + + // When using the generics API for unary calls, users can access response headers and trailers + // either from the call info in context or the response wrapper. This verifies both have the same information. + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) }) - t.Run("server_stream", func(t *testing.T) { - ctx, callInfo := connect.NewClientContext(context.Background()) - callInfo.RequestHeader().Set(clientHeader, headerValue) + t.Run("unary_no_callinfo", func(t *testing.T) { + num := int64(42) + request := connect.NewRequest(&pingv1.PingRequest{Number: num}) + request.Header().Add(clientHeader, "foo") + request.Header().Add(clientHeader, "bar") + expect := &pingv1.PingResponse{Number: num} + response, err := client.Ping(context.Background(), request) + assert.Nil(t, err) + assert.Equal(t, response.Msg, expect) + assert.Equal(t, request.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, request.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, request.Spec().IsClient) + assert.Equal(t, request.Peer().Addr, httptest.DefaultRemoteAddr) + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) + }) + t.Run("server_stream", func(t *testing.T) { val := 3 req := connect.NewRequest(&pingv1.CountUpRequest{ Number: int64(val), }) + ctx, callInfo := connect.NewClientContext(context.Background()) + // With the generics API, A user can use the call info or request wrapper or both to set request headers. + // The resulting headers should be combined and sent in the request. + callInfo.RequestHeader().Set(clientHeader, "foo") + req.Header().Add(clientHeader, "bar") + stream, err := client.CountUp(ctx, req) assert.Nil(t, err) // Receive expected messages @@ -174,48 +245,88 @@ func TestCallInfo(t *testing.T) { assert.False(t, stream.Receive()) assert.Nil(t, stream.Err()) assert.Nil(t, stream.Close()) - // Assert values on request and stream + // Assert values on request assert.Equal(t, req.Spec().StreamType, connect.StreamTypeServer) assert.Equal(t, req.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) assert.True(t, req.Spec().IsClient) assert.Equal(t, req.Peer().Addr, httptest.DefaultRemoteAddr) - assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) // Assert the same values are in the call info assert.Equal(t, callInfo.Spec().StreamType, req.Spec().StreamType) assert.Equal(t, callInfo.Spec().Procedure, req.Spec().Procedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, req.Peer().Addr) - assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + + // On server-streaming calls, users can access response headers and trailers + // either from the call info in context or from the stream itself. + // This verifies that the both the stream and the call info have the same information + assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + }) + t.Run("server_stream_no_callinfo", func(t *testing.T) { + val := 3 + req := connect.NewRequest(&pingv1.CountUpRequest{ + Number: int64(val), + }) + req.Header().Set(clientHeader, "foo") + req.Header().Add(clientHeader, "bar") + + stream, err := client.CountUp(context.Background(), req) + assert.Nil(t, err) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) + assert.Nil(t, stream.Err()) + assert.Nil(t, stream.Close()) + // Assert values on request + assert.Equal(t, req.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, req.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, req.Spec().IsClient) + assert.Equal(t, req.Peer().Addr, httptest.DefaultRemoteAddr) + assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) }) }) } -func TestServer(t *testing.T) { +func TestServer(t *testing.T) { //nolint:gocyclo t.Parallel() testPing := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("ping", func(t *testing.T) { num := int64(42) request := connect.NewRequest(&pingv1.PingRequest{Number: num}) - request.Header().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Add(clientHeader, el) + } expect := &pingv1.PingResponse{Number: num} response, err := client.Ping(context.Background(), request) assert.Nil(t, err) assert.Equal(t, response.Msg, expect) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("zero_ping", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) - request.Header().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Add(clientHeader, el) + } response, err := client.Ping(context.Background(), request) assert.Nil(t, err) var expect pingv1.PingResponse assert.Equal(t, response.Msg, &expect) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("large_ping", func(t *testing.T) { // Using a large payload splits the request and response over multiple @@ -226,12 +337,14 @@ func TestServer(t *testing.T) { } hellos := strings.Repeat("hello", 1024*1024) // ~5mb request := connect.NewRequest(&pingv1.PingRequest{Text: hellos}) - request.Header().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Add(clientHeader, el) + } response, err := client.Ping(context.Background(), request) assert.Nil(t, err) assert.Equal(t, response.Msg.GetText(), hellos) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("ping_error", func(t *testing.T) { _, err := client.Ping( @@ -244,7 +357,7 @@ func TestServer(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) defer cancel() request := connect.NewRequest(&pingv1.PingRequest{}) - request.Header().Set(clientHeader, headerValue) + request.Header().Set(clientHeader, "foo") _, err := client.Ping(ctx, request) assert.Equal(t, connect.CodeOf(err), connect.CodeDeadlineExceeded) }) @@ -256,7 +369,9 @@ func TestServer(t *testing.T) { expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 ) stream := client.Sum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } for i := range upTo { err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) assert.Nil(t, err, assert.Sprintf("send %d", i)) @@ -264,8 +379,8 @@ func TestServer(t *testing.T) { response, err := stream.CloseAndReceive() assert.Nil(t, err) assert.Equal(t, response.Msg.GetSum(), expect) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("sum_error", func(t *testing.T) { stream := client.Sum(context.Background()) @@ -278,11 +393,14 @@ func TestServer(t *testing.T) { }) t.Run("sum_close_and_receive_without_send", func(t *testing.T) { stream := client.Sum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } got, err := stream.CloseAndReceive() assert.Nil(t, err) assert.Equal(t, got.Msg, &pingv1.SumResponse{}) // receive header only stream - assert.Equal(t, got.Header().Values(handlerHeader), []string{headerValue}) + assert.True(t, compareValues(got.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(got.Trailer().Values(handlerTrailer), expectedHeaderValues)) }) } testCountUp := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper @@ -294,7 +412,8 @@ func TestServer(t *testing.T) { expect = append(expect, int64(i+1)) } request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo}) - request.Header().Set(clientHeader, headerValue) + request.Header().Add(clientHeader, "foo") + request.Header().Add(clientHeader, "bar") stream, err := client.CountUp(context.Background(), request) assert.Nil(t, err) for stream.Receive() { @@ -332,7 +451,8 @@ func TestServer(t *testing.T) { t.Run("count_up_cancel_after_first_response", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) request := connect.NewRequest(&pingv1.CountUpRequest{Number: 5}) - request.Header().Set(clientHeader, headerValue) + request.Header().Add(clientHeader, "foo") + request.Header().Add(clientHeader, "bar") stream, err := client.CountUp(ctx, request) assert.Nil(t, err) assert.True(t, stream.Receive()) @@ -349,7 +469,9 @@ func TestServer(t *testing.T) { expect := []int64{3, 8, 9} var got []int64 stream := client.CumSum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } if !expectSuccess { // server doesn't support HTTP/2 failNoHTTP2(t, stream) return @@ -378,8 +500,8 @@ func TestServer(t *testing.T) { }() wg.Wait() assert.Equal(t, got, expect) - assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("cumsum_error", func(t *testing.T) { stream := client.CumSum(context.Background()) @@ -399,7 +521,9 @@ func TestServer(t *testing.T) { }) t.Run("cumsum_empty_stream", func(t *testing.T) { stream := client.CumSum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } if !expectSuccess { // server doesn't support HTTP/2 failNoHTTP2(t, stream) return @@ -416,7 +540,9 @@ func TestServer(t *testing.T) { t.Run("cumsum_cancel_after_first_response", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) stream := client.CumSum(ctx) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } if !expectSuccess { // server doesn't support HTTP/2 failNoHTTP2(t, stream) cancel() @@ -446,7 +572,9 @@ func TestServer(t *testing.T) { cancel() return } - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 8})) cancel() // On a subsequent send, ensure that we are still catching context @@ -476,7 +604,9 @@ func TestServer(t *testing.T) { request := connect.NewRequest(&pingv1.FailRequest{ Code: int32(connect.CodeResourceExhausted), }) - request.Header().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Add(clientHeader, el) + } response, err := client.Fail(context.Background(), request) assert.Nil(t, response) @@ -488,8 +618,8 @@ func TestServer(t *testing.T) { assert.Equal(t, connectErr.Code(), connect.CodeResourceExhausted) assert.Equal(t, connectErr.Error(), "resource_exhausted: "+errorMessage) assert.Zero(t, connectErr.Details()) - assert.Equal(t, connectErr.Meta().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, connectErr.Meta().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(connectErr.Meta().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(connectErr.Meta().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("middleware_errors_unary", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) @@ -2176,6 +2306,9 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { ) assert.Nil(t, err) req.Header.Set("Content-Type", "application/grpc") + for _, el := range expectedHeaderValues { + req.Header.Add(clientHeader, el) + } res, err := server.Client().Do(req) assert.Nil(t, err) assert.Equal(t, res.StatusCode, http.StatusOK) @@ -2903,51 +3036,57 @@ func (p *pluggablePingServer) CumSum( type pingServer struct { pingv1connect.UnimplementedPingServiceHandler + // Whether to verify metadata sent to the server. Can be used to force an error returned from the server + // by intentionally sending no metadata. checkMetadata bool includeErrorDetails bool } func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + if err := validateRequestInfo(request); err != nil { + return nil, err + } + if err := compareContextAndRequest(ctx, request, request.Header()); err != nil { + return nil, err + } if p.checkMetadata { if err := expectMetadata(request.Header()); err != nil { return nil, err } } - if request.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if request.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } response := connect.NewResponse( &pingv1.PingResponse{ Number: request.Msg.GetNumber(), Text: request.Msg.GetText(), }, ) - response.Header().Set(handlerHeader, headerValue) - response.Trailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := request.Header().Values(clientHeader) + for _, el := range reqHeader { + response.Header().Add(handlerHeader, el) + response.Trailer().Add(handlerTrailer, el) + } + return response, nil } func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { - if p.checkMetadata { - if err := expectMetadata(request.Header()); err != nil { - return nil, err - } + if err := validateRequestInfo(request); err != nil { + return nil, err } - if request.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if request.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + if err := compareContextAndRequest(ctx, request, request.Header()); err != nil { + return nil, err } err := connect.NewError( connect.Code(request.Msg.GetCode()), errors.New(errorMessage), ) - err.Meta().Set(handlerHeader, headerValue) - err.Meta().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the error metadata headers and trailers + reqHeader := request.Header().Values(clientHeader) + for _, el := range reqHeader { + err.Meta().Add(handlerHeader, el) + err.Meta().Add(handlerTrailer, el) + } if p.includeErrorDetails { detail, derr := connect.NewErrorDetail(&pingv1.FailRequest{Code: request.Msg.GetCode()}) if derr != nil { @@ -2962,17 +3101,14 @@ func (p pingServer) Sum( ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest], ) (*connect.Response[pingv1.SumResponse], error) { + if err := validateRequestInfo(stream); err != nil { + return nil, err + } if p.checkMetadata { if err := expectMetadata(stream.RequestHeader()); err != nil { return nil, err } } - if stream.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if stream.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } var sum int64 for stream.Receive() { sum += stream.Msg().GetNumber() @@ -2981,8 +3117,12 @@ func (p pingServer) Sum( return nil, stream.Err() } response := connect.NewResponse(&pingv1.SumResponse{Sum: sum}) - response.Header().Set(handlerHeader, headerValue) - response.Trailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := stream.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + response.Header().Add(handlerHeader, el) + response.Trailer().Add(handlerTrailer, el) + } return response, nil } @@ -2991,25 +3131,29 @@ func (p pingServer) CountUp( request *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse], ) error { + if err := validateRequestInfo(stream.Conn()); err != nil { + return err + } + if err := compareContextAndRequest(ctx, request, request.Header()); err != nil { + return err + } if p.checkMetadata { if err := expectMetadata(request.Header()); err != nil { return err } } - if request.Peer().Addr == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if request.Peer().Protocol == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } if request.Msg.GetNumber() <= 0 { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( "number must be positive: got %v", request.Msg.GetNumber(), )) } - stream.ResponseHeader().Set(handlerHeader, headerValue) - stream.ResponseTrailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := request.Header().Values(clientHeader) + for _, el := range reqHeader { + stream.ResponseHeader().Add(handlerHeader, el) + stream.ResponseTrailer().Add(handlerTrailer, el) + } for i := range request.Msg.GetNumber() { if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { return err @@ -3028,14 +3172,11 @@ func (p pingServer) CumSum( return err } } - if stream.Peer().Addr == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if stream.Peer().Protocol == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + reqHeader := stream.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + stream.ResponseHeader().Add(handlerHeader, el) + stream.ResponseTrailer().Add(handlerTrailer, el) } - stream.ResponseHeader().Set(handlerHeader, headerValue) - stream.ResponseTrailer().Set(handlerTrailer, trailerValue) for { msg, err := stream.Receive() if errors.Is(err, io.EOF) { @@ -3062,23 +3203,24 @@ func (p pingServerSimple) Ping(ctx context.Context, request *pingv1.PingRequest) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } + if err := validateRequestInfo(callInfo); err != nil { + return nil, err + } if p.checkMetadata { if err := expectMetadata(callInfo.RequestHeader()); err != nil { return nil, err } } - if callInfo.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if callInfo.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } response := &pingv1.PingResponse{ Number: request.GetNumber(), Text: request.GetText(), } - callInfo.ResponseHeader().Set(handlerHeader, headerValue) - callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := callInfo.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + callInfo.ResponseHeader().Add(handlerHeader, el) + callInfo.ResponseTrailer().Add(handlerTrailer, el) + } return response, nil } @@ -3091,25 +3233,26 @@ func (p pingServerSimple) CountUp( if !ok { return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } + if err := validateRequestInfo(callInfo); err != nil { + return err + } if p.checkMetadata { if err := expectMetadata(callInfo.RequestHeader()); err != nil { return err } } - if callInfo.Peer().Addr == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if callInfo.Peer().Protocol == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } if request.GetNumber() <= 0 { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( "number must be positive: got %v", request.GetNumber(), )) } - callInfo.ResponseHeader().Set(handlerHeader, headerValue) - callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := callInfo.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + callInfo.ResponseHeader().Add(handlerHeader, el) + callInfo.ResponseTrailer().Add(handlerTrailer, el) + } for i := range request.GetNumber() { if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { return err @@ -3254,15 +3397,101 @@ func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSu assert.Nil(tb, stream.CloseResponse()) } +type requestInfo interface { + Peer() connect.Peer + Spec() connect.Spec +} + +// Validates that the peer and spec information is set correctly in a request. +func validateRequestInfo(request requestInfo) error { + if request.Peer().Addr == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } + if request.Peer().Protocol == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } + if request.Spec().Procedure == "" { + return connect.NewError(connect.CodeInternal, errors.New("no procedure name")) + } + return nil +} + +// Compares the information in the call info in context with the given request information to verify they match. +func compareContextAndRequest(ctx context.Context, request requestInfo, requestHeaders http.Header) error { + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) + if !ok { + return connect.NewError(connect.CodeInternal, errors.New("no call info in handler context")) + } + if request.Peer().Addr != callInfo.Peer().Addr { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched peer address. found %s in request and %s in call info", request.Peer().Addr, callInfo.Peer().Addr)) + } + if request.Peer().Protocol != callInfo.Peer().Protocol { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched peer protocol. found %s in request and %s in call info", request.Peer().Addr, callInfo.Peer().Addr)) + } + if request.Spec().Procedure != callInfo.Spec().Procedure { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched procedure name. found %s in request and %s in call info", request.Spec().Procedure, request.Spec().Procedure)) + } + if valid := compareHeaders(callInfo.RequestHeader(), requestHeaders); !valid { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched request headers. found %+v in request and %+v in call info", callInfo.RequestHeader(), requestHeaders)) + } + return nil +} + +// Returns an error if the given http headers don't contain the expected header values. +// Typically, most methods in the pingServer implementations just read the request headers +// and copy those to the response headers and trailers and let the client verify that way. +// However, this method can be used in conjunction with the server's verifyMetadata setting +// to force an error to be returned if metadata isn't set. For example, see +// TestGRPCMissingTrailersError tests. func expectMetadata(meta http.Header) error { - if got := meta.Get(clientHeader); got != headerValue { + vals := meta.Values(clientHeader) + if ok := compareValues(vals, expectedHeaderValues); !ok { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( - "%s %q: got %q, expected %q", - "header", + "header %q: got %q, expected %q", clientHeader, - got, - headerValue, + vals, + expectedHeaderValues, )) } return nil } + +// Compares two http Header objects to verify they contain the exact same information. +func compareHeaders(hdr1 http.Header, hdr2 http.Header) bool { + if len(hdr1) != len(hdr2) { + return false + } + for key, hdr1Val := range hdr1 { + hdr2Val, ok := hdr2[key] + if !ok || len(hdr1Val) != len(hdr2Val) { + return false + } + + if equal := compareValues(hdr1Val, hdr2Val); !equal { + return false + } + } + return true +} + +// Compares two string slices of header values to verify they are the same, ignoring order. +func compareValues(hdr1 []string, hdr2 []string) bool { + if len(hdr1) != len(hdr2) { + return false + } + // Copy slices to avoid race conditions with other tests trying to read the headers + sorted1 := make([]string, len(hdr1)) + copy(sorted1, hdr1) + sorted2 := make([]string, len(hdr2)) + copy(sorted2, hdr2) + + sort.Strings(sorted1) + sort.Strings(sorted2) + + for idx, el := range sorted1 { + if el != sorted2[idx] { + return false + } + } + return true +} diff --git a/context.go b/context.go index 004be9ea..10aa8d0e 100644 --- a/context.go +++ b/context.go @@ -242,13 +242,3 @@ func getClientCallInfoFromContext(ctx context.Context) (*clientCallInfo, bool) { func newHandlerContext(ctx context.Context, info CallInfo) context.Context { return context.WithValue(ctx, handlerCallInfoContextKey{}, info) } - -// requestFromClientContext creates a new Request using the given context and message. -func requestFromClientContext[T any](ctx context.Context, message *T) *Request[T] { - request := NewRequest(message) - callInfo, ok := getClientCallInfoFromContext(ctx) - if ok { - request.setHeader(callInfo.RequestHeader()) - } - return request -} From 153acb7bca47dc8b3dc92b28d6551da6555e8ba2 Mon Sep 17 00:00:00 2001 From: John Chadwick Date: Wed, 16 Jul 2025 22:02:03 -0400 Subject: [PATCH 26/57] Implement simple for client streaming on handler There's not much to do here, the only thing we want to do is get rid of the wrapper. --- .../simple/gen/genconnect/simple.connect.go | 6 +-- cmd/protoc-gen-connect-go/main.go | 12 +++++- handler.go | 23 +++++++++++ handler_ext_test.go | 41 +++++++++++++++++++ 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go index 68149cc7..cc79b938 100644 --- a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go +++ b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go @@ -137,7 +137,7 @@ func (c *testServiceClient) MethodBidiStream(ctx context.Context, req *gen.Reque // TestServiceHandler is an implementation of the connect.test.simple.TestService service. type TestServiceHandler interface { Method(context.Context, *gen.Request) (*gen.Response, error) - MethodClientStream(context.Context, *connect.ClientStream[gen.Request]) (*connect.Response[gen.Response], error) + MethodClientStream(context.Context, *connect.ClientStream[gen.Request]) (*gen.Response, error) MethodServerStream(context.Context, *gen.Request, *connect.ServerStream[gen.Response]) error MethodBidiStream(context.Context, *gen.Request, *connect.ServerStream[gen.Response]) error } @@ -155,7 +155,7 @@ func NewTestServiceHandler(svc TestServiceHandler, opts ...connect.HandlerOption connect.WithSchema(testServiceMethods.ByName("Method")), connect.WithHandlerOptions(opts...), ) - testServiceMethodClientStreamHandler := connect.NewClientStreamHandler( + testServiceMethodClientStreamHandler := connect.NewClientStreamHandlerSimple( TestServiceMethodClientStreamProcedure, svc.MethodClientStream, connect.WithSchema(testServiceMethods.ByName("MethodClientStream")), @@ -196,7 +196,7 @@ func (UnimplementedTestServiceHandler) Method(context.Context, *gen.Request) (*g return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.test.simple.TestService.Method is not implemented")) } -func (UnimplementedTestServiceHandler) MethodClientStream(context.Context, *connect.ClientStream[gen.Request]) (*connect.Response[gen.Response], error) { +func (UnimplementedTestServiceHandler) MethodClientStream(context.Context, *connect.ClientStream[gen.Request]) (*gen.Response, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.test.simple.TestService.MethodClientStream is not implemented")) } diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index b9018ce1..1572da4c 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -480,7 +480,11 @@ func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, s idempotency := methodIdempotency(method) switch { case isStreamingClient && !isStreamingServer: - g.P(procedureHandlerName(method), ` := `, connectPackage.Ident("NewClientStreamHandler"), "(") + if simple { + g.P(procedureHandlerName(method), ` := `, connectPackage.Ident("NewClientStreamHandlerSimple"), "(") + } else { + g.P(procedureHandlerName(method), ` := `, connectPackage.Ident("NewClientStreamHandler"), "(") + } case !isStreamingClient && isStreamingServer: if simple { g.P(procedureHandlerName(method), ` := `, connectPackage.Ident("NewServerStreamHandlerSimple"), "(") @@ -564,6 +568,12 @@ func serverSignatureParams(g *protogen.GeneratedFile, method *protogen.Method, n } if method.Desc.IsStreamingClient() { // client streaming + if simple { + return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " + + streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStream")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + + ") (*" + g.QualifiedGoIdent(method.Output.GoIdent) + " ,error)" + } return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " + streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStream")) + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + diff --git a/handler.go b/handler.go index 5eb5e1e0..ba25eb64 100644 --- a/handler.go +++ b/handler.go @@ -162,6 +162,29 @@ func NewClientStreamHandler[Req, Res any]( ) } +// NewClientStreamHandlerSimple constructs a [Handler] for a request-streaming procedure +// using the function signature associated with the "simple" generation option. +// +// This option eliminates the [Response] wrapper, and instead uses the context.Context +// to propagate information such as headers. +func NewClientStreamHandlerSimple[Req, Res any]( + procedure string, + implementation func(context.Context, *ClientStream[Req]) (*Res, error), + options ...HandlerOption, +) *Handler { + return NewClientStreamHandler( + procedure, + func(ctx context.Context, stream *ClientStream[Req]) (*Response[Res], error) { + responseMsg, err := implementation(ctx, stream) + if err != nil { + return nil, err + } + return NewResponse(responseMsg), nil + }, + options..., + ) +} + // NewServerStreamHandler constructs a [Handler] for a server streaming procedure. func NewServerStreamHandler[Req, Res any]( procedure string, diff --git a/handler_ext_test.go b/handler_ext_test.go index c5541222..0ac2deb9 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -380,6 +380,47 @@ func TestDynamicHandler(t *testing.T) { } assert.Equal(t, rsp.Msg.Sum, 42*2) }) + t.Run("clientStreamSimple", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Sum") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + dynamicSum := func(_ context.Context, stream *connect.ClientStream[dynamicpb.Message]) (*dynamicpb.Message, error) { + var sum int64 + for stream.Receive() { + got := stream.Msg().Get( + methodDesc.Input().Fields().ByName("number"), + ).Int() + sum += got + } + msg := dynamicpb.NewMessage(methodDesc.Output()) + msg.Set( + methodDesc.Output().Fields().ByName("sum"), + protoreflect.ValueOfInt64(sum), + ) + return msg, nil + } + mux := http.NewServeMux() + mux.Handle("/connect.ping.v1.PingService/Sum", + connect.NewClientStreamHandlerSimple( + "/connect.ping.v1.PingService/Sum", + dynamicSum, + connect.WithSchema(methodDesc), + connect.WithRequestInitializer(initializer), + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + stream := client.Sum(context.Background()) + assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 42})) + assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 42})) + rsp, err := stream.CloseAndReceive() + if !assert.Nil(t, err) { + return + } + assert.Equal(t, rsp.Msg.Sum, 42*2) + }) t.Run("serverStream", func(t *testing.T) { t.Parallel() desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CountUp") From 021a0cb64d1f0a7918cd963db5bd303f01620804 Mon Sep 17 00:00:00 2001 From: John Chadwick Date: Wed, 16 Jul 2025 22:03:56 -0400 Subject: [PATCH 27/57] Implement simple for client streaming on client All we want to do here is make sure headers get sent immediately. We no longer need to wait until first send, because now it is possible to set request headers on the context. Otherwise, the interface is exactly the same. --- client.go | 7 +++ client_ext_test.go | 44 +++++++++++++++++++ .../simple/gen/genconnect/simple.connect.go | 2 +- cmd/protoc-gen-connect-go/main.go | 6 ++- 4 files changed, 57 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index d2be7070..05aa3124 100644 --- a/client.go +++ b/client.go @@ -182,6 +182,13 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo } } +// CallClientStream calls a client streaming procedure in simple mode. +func (c *Client[Req, Res]) CallClientStreamSimple(ctx context.Context) *ClientStreamForClient[Req, Res] { + stream := c.CallClientStream(ctx) + stream.Send(nil) + return stream +} + // CallServerStream calls a server streaming procedure. func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Request[Req]) (*ServerStreamForClient[Res], error) { if c.err != nil { diff --git a/client_ext_test.go b/client_ext_test.go index 4e5b8351..e03eb34d 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -18,10 +18,12 @@ import ( "bytes" "context" "crypto/rand" + "crypto/tls" "errors" "fmt" "io" "log" + "net" "net/http" "net/http/httptest" "runtime" @@ -36,6 +38,7 @@ import ( pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp/memhttptest" + "golang.org/x/net/http2" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" @@ -389,6 +392,47 @@ func TestDynamicClient(t *testing.T) { got := rsp.Msg.Get(methodDesc.Output().Fields().ByName("sum")).Int() assert.Equal(t, got, 42*2) }) + t.Run("clientStreamSimple", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Sum") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + connected := make(chan struct{}) + transport := &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + close(connected) + return server.Transport().DialTLSContext(ctx, network, addr, cfg) + }, + AllowHTTP: true, + } + client := connect.NewClient[dynamicpb.Message, dynamicpb.Message]( + &http.Client{Transport: transport}, + server.URL()+"/connect.ping.v1.PingService/Sum", + connect.WithSchema(methodDesc), + connect.WithResponseInitializer(initializer), + ) + stream := client.CallClientStreamSimple(ctx) + select { + case <-connected: + break + case <-time.After(time.Second): + t.Error("CallClientStreamSimple did not eagerly send headers") + } + msg := dynamicpb.NewMessage(methodDesc.Input()) + msg.Set( + methodDesc.Input().Fields().ByName("number"), + protoreflect.ValueOfInt64(42), + ) + assert.Nil(t, stream.Send(msg)) + assert.Nil(t, stream.Send(msg)) + rsp, err := stream.CloseAndReceive() + if !assert.Nil(t, err) { + return + } + got := rsp.Msg.Get(methodDesc.Output().Fields().ByName("sum")).Int() + assert.Equal(t, got, 42*2) + }) t.Run("serverStream", func(t *testing.T) { t.Parallel() desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CountUp") diff --git a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go index cc79b938..ae227b27 100644 --- a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go +++ b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go @@ -121,7 +121,7 @@ func (c *testServiceClient) Method(ctx context.Context, req *gen.Request) (*gen. // MethodClientStream calls connect.test.simple.TestService.MethodClientStream. func (c *testServiceClient) MethodClientStream(ctx context.Context) *connect.ClientStreamForClient[gen.Request, gen.Response] { - return c.methodClientStream.CallClientStream(ctx) + return c.methodClientStream.CallClientStreamSimple(ctx) } // MethodServerStream calls connect.test.simple.TestService.MethodServerStream. diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index 1572da4c..327c92fe 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -381,7 +381,11 @@ func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, na switch { case isStreamingClient && !isStreamingServer: - g.P("return c.", unexport(method.GoName), ".CallClientStream(ctx)") + if simple { + g.P("return c.", unexport(method.GoName), ".CallClientStreamSimple(ctx)") + } else { + g.P("return c.", unexport(method.GoName), ".CallClientStream(ctx)") + } case !isStreamingClient && isStreamingServer: if simple { g.P("return c.", unexport(method.GoName), ".CallServerStreamSimple(ctx, req)") From 7995c00c72835129b2a252481835f710e1c4f4b5 Mon Sep 17 00:00:00 2001 From: John Chadwick Date: Wed, 16 Jul 2025 22:04:31 -0400 Subject: [PATCH 28/57] Implement simple for bidi streaming on client Like client streaming on client, all we want to do for bidi streaming is ensure that headers are eagerly sent. --- client.go | 7 ++++++ client_ext_test.go | 41 +++++++++++++++++++++++++++++++ cmd/protoc-gen-connect-go/main.go | 6 ++++- 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 05aa3124..aa372965 100644 --- a/client.go +++ b/client.go @@ -258,6 +258,13 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli } } +// CallBidiStreamSimple calls a bidirectional streaming procedure in simple mode. +func (c *Client[Req, Res]) CallBidiStreamSimple(ctx context.Context) *BidiStreamForClient[Req, Res] { + stream := c.CallBidiStream(ctx) + stream.Send(nil) + return stream +} + func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn { newConn := func(ctx context.Context, spec Spec) StreamingClientConn { header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing diff --git a/client_ext_test.go b/client_ext_test.go index e03eb34d..2ae2200f 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -489,6 +489,47 @@ func TestDynamicClient(t *testing.T) { got := out.Get(methodDesc.Output().Fields().ByName("number")).Int() assert.Equal(t, got, 42) }) + t.Run("bidiSimple", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CumSum") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + connected := make(chan struct{}) + transport := &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + close(connected) + return server.Transport().DialTLSContext(ctx, network, addr, cfg) + }, + AllowHTTP: true, + } + client := connect.NewClient[dynamicpb.Message, dynamicpb.Message]( + &http.Client{Transport: transport}, + server.URL()+"/connect.ping.v1.PingService/CumSum", + connect.WithSchema(methodDesc), + connect.WithResponseInitializer(initializer), + ) + stream := client.CallBidiStreamSimple(ctx) + select { + case <-connected: + break + case <-time.After(time.Second): + t.Error("CallBidiStreamSimple did not eagerly send headers") + } + msg := dynamicpb.NewMessage(methodDesc.Input()) + msg.Set( + methodDesc.Input().Fields().ByName("number"), + protoreflect.ValueOfInt64(42), + ) + assert.Nil(t, stream.Send(msg)) + assert.Nil(t, stream.CloseRequest()) + out, err := stream.Receive() + if assert.Nil(t, err) { + return + } + got := out.Get(methodDesc.Output().Fields().ByName("number")).Int() + assert.Equal(t, got, 42) + }) t.Run("option", func(t *testing.T) { t.Parallel() desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping") diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index 327c92fe..fd756287 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -393,7 +393,11 @@ func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, na g.P("return c.", unexport(method.GoName), ".CallServerStream(ctx, req)") } case isStreamingClient && isStreamingServer: - g.P("return c.", unexport(method.GoName), ".CallBidiStream(ctx)") + if simple { + g.P("return c.", unexport(method.GoName), ".CallBidiStreamSimple(ctx)") + } else { + g.P("return c.", unexport(method.GoName), ".CallBidiStream(ctx)") + } default: if simple { g.P("return c.", unexport(method.GoName), ".CallUnarySimple(ctx, req)") From d12c6c940cb71b67e0738cf9d0ed0646765fc236 Mon Sep 17 00:00:00 2001 From: John Chadwick Date: Thu, 17 Jul 2025 12:51:39 -0400 Subject: [PATCH 29/57] Make client/bidi stream fallible for simple Simple client sends requests immediately, but oops, it accidentally swallows errors as it was written. Instead of swallowing errors _or_ returning an error stream, it would be better to just immediately return the error, like the non-client-streaming endpoints do. --- client.go | 22 ++++++++++++++----- client_ext_test.go | 6 +++-- .../simple/gen/genconnect/simple.connect.go | 4 ++-- cmd/protoc-gen-connect-go/main.go | 12 ++++++++++ .../ping/v1/pingv1connect/ping.connect.go | 18 +++++++-------- 5 files changed, 43 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index aa372965..5bc030fa 100644 --- a/client.go +++ b/client.go @@ -183,10 +183,15 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo } // CallClientStream calls a client streaming procedure in simple mode. -func (c *Client[Req, Res]) CallClientStreamSimple(ctx context.Context) *ClientStreamForClient[Req, Res] { +func (c *Client[Req, Res]) CallClientStreamSimple(ctx context.Context) (*ClientStreamForClient[Req, Res], error) { stream := c.CallClientStream(ctx) - stream.Send(nil) - return stream + if stream.err != nil { + return nil, stream.err + } + if err := stream.Send(nil); err != nil { + return nil, err + } + return stream, nil } // CallServerStream calls a server streaming procedure. @@ -259,10 +264,15 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli } // CallBidiStreamSimple calls a bidirectional streaming procedure in simple mode. -func (c *Client[Req, Res]) CallBidiStreamSimple(ctx context.Context) *BidiStreamForClient[Req, Res] { +func (c *Client[Req, Res]) CallBidiStreamSimple(ctx context.Context) (*BidiStreamForClient[Req, Res], error) { stream := c.CallBidiStream(ctx) - stream.Send(nil) - return stream + if stream.err != nil { + return nil, stream.err + } + if err := stream.Send(nil); err != nil { + return nil, err + } + return stream, nil } func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn { diff --git a/client_ext_test.go b/client_ext_test.go index 2ae2200f..f2825588 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -412,7 +412,8 @@ func TestDynamicClient(t *testing.T) { connect.WithSchema(methodDesc), connect.WithResponseInitializer(initializer), ) - stream := client.CallClientStreamSimple(ctx) + stream, err := client.CallClientStreamSimple(ctx) + assert.Nil(t, err) select { case <-connected: break @@ -509,7 +510,8 @@ func TestDynamicClient(t *testing.T) { connect.WithSchema(methodDesc), connect.WithResponseInitializer(initializer), ) - stream := client.CallBidiStreamSimple(ctx) + stream, err := client.CallBidiStreamSimple(ctx) + assert.Nil(t, err) select { case <-connected: break diff --git a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go index ae227b27..a387cbde 100644 --- a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go +++ b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go @@ -63,7 +63,7 @@ const ( // TestServiceClient is a client for the connect.test.simple.TestService service. type TestServiceClient interface { Method(context.Context, *gen.Request) (*gen.Response, error) - MethodClientStream(context.Context) *connect.ClientStreamForClient[gen.Request, gen.Response] + MethodClientStream(context.Context) (*connect.ClientStreamForClient[gen.Request, gen.Response], error) MethodServerStream(context.Context, *gen.Request) (*connect.ServerStreamForClient[gen.Response], error) MethodBidiStream(context.Context, *gen.Request) (*connect.ServerStreamForClient[gen.Response], error) } @@ -120,7 +120,7 @@ func (c *testServiceClient) Method(ctx context.Context, req *gen.Request) (*gen. } // MethodClientStream calls connect.test.simple.TestService.MethodClientStream. -func (c *testServiceClient) MethodClientStream(ctx context.Context) *connect.ClientStreamForClient[gen.Request, gen.Response] { +func (c *testServiceClient) MethodClientStream(ctx context.Context) (*connect.ClientStreamForClient[gen.Request, gen.Response], error) { return c.methodClientStream.CallClientStreamSimple(ctx) } diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index fd756287..fe36bf46 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -417,12 +417,24 @@ func clientSignature(g *protogen.GeneratedFile, method *protogen.Method, named b } if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() { // bidi streaming + if simple { + return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + + "(*" + g.QualifiedGoIdent(connectPackage.Ident("BidiStreamForClient")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + + ", error)" + } return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + "*" + g.QualifiedGoIdent(connectPackage.Ident("BidiStreamForClient")) + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" } if method.Desc.IsStreamingClient() { // client streaming + if simple { + return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + + "(*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStreamForClient")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + + ", error)" + } return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStreamForClient")) + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" diff --git a/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go index 9fe1d615..482020b0 100644 --- a/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go @@ -71,11 +71,11 @@ type PingServiceClient interface { // Fail always fails. Fail(context.Context, *v1.FailRequest) (*v1.FailResponse, error) // Sum calculates the sum of the numbers sent on the stream. - Sum(context.Context) *connect.ClientStreamForClient[v1.SumRequest, v1.SumResponse] + Sum(context.Context) (*connect.ClientStreamForClient[v1.SumRequest, v1.SumResponse], error) // CountUp returns a stream of the numbers up to the given request. CountUp(context.Context, *v1.CountUpRequest) (*connect.ServerStreamForClient[v1.CountUpResponse], error) // CumSum determines the cumulative sum of all the numbers sent on the stream. - CumSum(context.Context) *connect.BidiStreamForClient[v1.CumSumRequest, v1.CumSumResponse] + CumSum(context.Context) (*connect.BidiStreamForClient[v1.CumSumRequest, v1.CumSumResponse], error) } // NewPingServiceClient constructs a client for the connect.ping.v1.PingService service. By default, @@ -143,8 +143,8 @@ func (c *pingServiceClient) Fail(ctx context.Context, req *v1.FailRequest) (*v1. } // Sum calls connect.ping.v1.PingService.Sum. -func (c *pingServiceClient) Sum(ctx context.Context) *connect.ClientStreamForClient[v1.SumRequest, v1.SumResponse] { - return c.sum.CallClientStream(ctx) +func (c *pingServiceClient) Sum(ctx context.Context) (*connect.ClientStreamForClient[v1.SumRequest, v1.SumResponse], error) { + return c.sum.CallClientStreamSimple(ctx) } // CountUp calls connect.ping.v1.PingService.CountUp. @@ -153,8 +153,8 @@ func (c *pingServiceClient) CountUp(ctx context.Context, req *v1.CountUpRequest) } // CumSum calls connect.ping.v1.PingService.CumSum. -func (c *pingServiceClient) CumSum(ctx context.Context) *connect.BidiStreamForClient[v1.CumSumRequest, v1.CumSumResponse] { - return c.cumSum.CallBidiStream(ctx) +func (c *pingServiceClient) CumSum(ctx context.Context) (*connect.BidiStreamForClient[v1.CumSumRequest, v1.CumSumResponse], error) { + return c.cumSum.CallBidiStreamSimple(ctx) } // PingServiceHandler is an implementation of the connect.ping.v1.PingService service. @@ -164,7 +164,7 @@ type PingServiceHandler interface { // Fail always fails. Fail(context.Context, *v1.FailRequest) (*v1.FailResponse, error) // Sum calculates the sum of the numbers sent on the stream. - Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*connect.Response[v1.SumResponse], error) + Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*v1.SumResponse, error) // CountUp returns a stream of the numbers up to the given request. CountUp(context.Context, *v1.CountUpRequest, *connect.ServerStream[v1.CountUpResponse]) error // CumSum determines the cumulative sum of all the numbers sent on the stream. @@ -191,7 +191,7 @@ func NewPingServiceHandler(svc PingServiceHandler, opts ...connect.HandlerOption connect.WithSchema(pingServiceMethods.ByName("Fail")), connect.WithHandlerOptions(opts...), ) - pingServiceSumHandler := connect.NewClientStreamHandler( + pingServiceSumHandler := connect.NewClientStreamHandlerSimple( PingServiceSumProcedure, svc.Sum, connect.WithSchema(pingServiceMethods.ByName("Sum")), @@ -238,7 +238,7 @@ func (UnimplementedPingServiceHandler) Fail(context.Context, *v1.FailRequest) (* return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Fail is not implemented")) } -func (UnimplementedPingServiceHandler) Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*connect.Response[v1.SumResponse], error) { +func (UnimplementedPingServiceHandler) Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*v1.SumResponse, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Sum is not implemented")) } From e4026f2a1902bb4080b1315d4d54e016b33fdc53 Mon Sep 17 00:00:00 2001 From: John Chadwick Date: Thu, 17 Jul 2025 12:52:10 -0400 Subject: [PATCH 30/57] Fix benchmark/example test --- bench_test.go | 10 ++++++++-- handler_example_test.go | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/bench_test.go b/bench_test.go index 333b5bbb..d5ed9416 100644 --- a/bench_test.go +++ b/bench_test.go @@ -112,7 +112,10 @@ func BenchmarkConnect(b *testing.B) { upTo = 1 expect = 1 ) - stream := client.Sum(ctx) + stream, err := client.Sum(ctx) + if err != nil { + b.Error(err) + } for number := int64(1); number <= upTo; number++ { if err := stream.Send(&pingv1.SumRequest{Number: number}); err != nil { b.Error(err) @@ -159,7 +162,10 @@ func BenchmarkConnect(b *testing.B) { const ( upTo = 1 ) - stream := client.CumSum(ctx) + stream, err := client.CumSum(ctx) + if err != nil { + b.Error(err) + } number := int64(1) for ; number <= upTo; number++ { if err := stream.Send(&pingv1.CumSumRequest{Number: number}); err != nil { diff --git a/handler_example_test.go b/handler_example_test.go index 9fe849dd..36843dff 100644 --- a/handler_example_test.go +++ b/handler_example_test.go @@ -43,7 +43,7 @@ func (*ExamplePingServer) Ping( } // Sum implements pingv1connect.PingServiceHandler. -func (p *ExamplePingServer) Sum(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { +func (p *ExamplePingServer) Sum(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*pingv1.SumResponse, error) { var sum int64 for stream.Receive() { sum += stream.Msg().GetNumber() @@ -51,7 +51,7 @@ func (p *ExamplePingServer) Sum(ctx context.Context, stream *connect.ClientStrea if stream.Err() != nil { return nil, stream.Err() } - return connect.NewResponse(&pingv1.SumResponse{Sum: sum}), nil + return &pingv1.SumResponse{Sum: sum}, nil } // CountUp implements pingv1connect.PingServiceHandler. From e34c8c826e90eb3cf4f93170ea8c7cb907b9d979 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 13:54:45 -0400 Subject: [PATCH 31/57] Redo Signed-off-by: Steve Ayers --- client.go | 46 ++- client_ext_test.go | 10 +- connect_ext_test.go | 460 +++++++++++++++++++++-------- context.go | 133 ++++++--- error_example_test.go | 11 +- error_not_modified_example_test.go | 26 +- example_init_test.go | 2 +- handler.go | 58 ++-- interceptor_ext_test.go | 6 +- 9 files changed, 532 insertions(+), 220 deletions(-) diff --git a/client.go b/client.go index ffb9336f..792430bb 100644 --- a/client.go +++ b/client.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "io" + "maps" "net/http" "net/url" "strings" @@ -127,16 +128,31 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) if c.err != nil { return nil, c.err } - return c.callUnary(ctx, request) + ctx, ci := NewOutgoingContext(ctx) + call, ok := ci.(*callInfo) + if ok { + call.requestHeader = request.Header() + } + + resp, err := c.callUnary(ctx, request) + if err != nil { + return nil, err + } + + if ok { + call.peer = request.Peer() + call.spec = request.Spec() + call.method = request.HTTPMethod() + maps.Copy(call.ResponseHeader(), resp.Header()) + maps.Copy(call.ResponseTrailer(), resp.Trailer()) + } + + return resp, nil } -// CallUnarySimple calls a request-response procedure using the function signature -// associated with the "simple" generation option. -// -// This option eliminates the [Request] and [Response] wrappers, and instead uses the -// context.Context to propagate information such as headers. -func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, requestMsg *Req) (*Res, error) { - response, err := c.CallUnary(ctx, requestFromContext(ctx, requestMsg)) +// CallUnary calls a request-response procedure. +func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, request *Req) (*Res, error) { + response, err := c.CallUnary(ctx, requestFromContext(ctx, request)) if response != nil { return response.Msg, err } @@ -159,12 +175,21 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if c.err != nil { return nil, c.err } + ctx, ctxCallInfo := NewOutgoingContext(ctx) + // Note we don't need to check ok here because it should always be in context + // because of the above call to NewOutgoingContext + info, _ := ctxCallInfo.(*callInfo) conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method + info.method = r.Method }) request.spec = conn.Spec() request.peer = conn.Peer() mergeHeaders(conn.RequestHeader(), request.header) + + info.peer = conn.Peer() + info.spec = conn.Spec() + mergeHeaders(conn.RequestHeader(), info.requestHeader) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. @@ -182,11 +207,6 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques }, nil } -// CallServerStreamSimple calls a server streaming procedure using the function signature -// associated with the "simple" generation option. -// -// This option eliminates the [Request] wrapper, and instead uses the context.Context to -// propagate information such as headers. func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { return c.CallServerStream(ctx, requestFromContext(ctx, requestMsg)) } diff --git a/client_ext_test.go b/client_ext_test.go index 4e5b8351..14ae2773 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -89,7 +89,7 @@ func TestNewClient_InitFailure(t *testing.T) { func TestClientPeer(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{})) server := memhttptest.NewServer(t, mux) run := func(t *testing.T, unaryHTTPMethod string, opts ...connect.ClientOption) { @@ -205,7 +205,7 @@ func TestGetNoContentHeaders(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(&pingServer{})) + mux.Handle(pingv1connect.NewPingServiceHandler(&pingServerGenerics{})) server := memhttptest.NewServer(t, http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) { if len(req.Header.Values("content-type")) > 0 || len(req.Header.Values("content-encoding")) > 0 || @@ -283,7 +283,7 @@ func TestSpecSchema(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler( - pingServer{}, + pingServerGenerics{}, connect.WithInterceptors(&assertSchemaInterceptor{t}), )) server := memhttptest.NewServer(t, mux) @@ -320,7 +320,7 @@ func TestSpecSchema(t *testing.T) { func TestDynamicClient(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{})) server := memhttptest.NewServer(t, mux) ctx := context.Background() initializer := func(spec connect.Spec, msg any) error { @@ -494,7 +494,7 @@ func TestClientDeadlineHandling(t *testing.T) { // detector enabled. That's partly why the makefile only runs "slow" // tests with the race detector disabled. - _, handler := pingv1connect.NewPingServiceHandler(pingServer{}) + _, handler := pingv1connect.NewPingServiceHandler(pingServerGenerics{}) svr := httptest.NewUnstartedServer(http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) { if req.Context().Err() != nil { return diff --git a/connect_ext_test.go b/connect_ext_test.go index 22e392e3..61908a13 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -29,6 +29,7 @@ import ( rand "math/rand/v2" "net" "net/http" + "net/http/httptest" "runtime" "strings" "sync" @@ -38,8 +39,9 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" - "connectrpc.com/connect/internal/gen/generics/connect/import/v1/importv1connect" - "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + pingv1connectgenerics "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/import/v1/importv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/proto" @@ -61,9 +63,132 @@ const ( clientMiddlewareErrorHeader = "Connect-Trigger-HTTP-Error" ) +func TestCallInfo(t *testing.T) { + t.Parallel() + t.Run("simple_api", func(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{checkMetadata: true}, + )) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + t.Run("unary", func(t *testing.T) { + num := int64(42) + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + callInfo.RequestHeader().Set(clientHeader, headerValue) + expect := &pingv1.PingResponse{Number: num} + + response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) + assert.Equal(t, response, expect) + assert.Nil(t, err) + + // Assert call info values are correctly populated + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + }) + t.Run("server_stream", func(t *testing.T) { + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + callInfo.RequestHeader().Set(clientHeader, headerValue) + stream, err := client.CountUp(ctx, &pingv1.CountUpRequest{ + Number: 1, + }) + assert.Nil(t, err) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), 1) + assert.Nil(t, stream.Close()) + + // Assert call info values are correctly populated + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + }) + }) + t.Run("generics_api", func(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connectgenerics.NewPingServiceHandler( + pingServerGenerics{checkMetadata: true}, + )) + server := memhttptest.NewServer(t, mux) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + t.Run("unary", func(t *testing.T) { + num := int64(42) + request := connect.NewRequest(&pingv1.PingRequest{Number: num}) + request.Header().Set(clientHeader, headerValue) + expect := &pingv1.PingResponse{Number: num} + + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + response, err := client.Ping(ctx, request) + assert.Nil(t, err) + assert.Equal(t, response.Msg, expect) + + assert.Equal(t, request.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, request.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, request.Spec().IsClient) + assert.Equal(t, request.Peer().Addr, httptest.DefaultRemoteAddr) + + // Verify that spec and peer on the callInfo are the same as the request wrapper + assert.Equal(t, callInfo.Spec().StreamType, request.Spec().StreamType) + assert.Equal(t, callInfo.Spec().Procedure, request.Spec().Procedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, request.Peer().Addr) + + // Verify that the response headers and trailers are the same on callInfo and the response + assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + + }) + t.Run("server_stream", func(t *testing.T) { + req := connect.NewRequest(&pingv1.CountUpRequest{ + Number: 1, + }) + req.Header().Set(clientHeader, headerValue) + stream, err := client.CountUp(context.Background(), req) + assert.Nil(t, err) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), 1) + assert.Nil(t, stream.Close()) + assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + // num := int64(42) + // ctx, callInfo := connect.NewOutgoingContext(context.Background()) + // callInfo.RequestHeader().Set(clientHeader, headerValue) + // expect := &pingv1.PingResponse{Number: num} + + // response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) + // assert.Equal(t, response, expect) + // assert.Nil(t, err) + + // // Assert call info values are correctly populated + // assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) + // assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) + // assert.True(t, callInfo.Spec().IsClient) + // assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + // assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + }) + }) +} + func TestServer(t *testing.T) { t.Parallel() - testPing := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + testPing := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper t.Run("ping", func(t *testing.T) { num := int64(42) request := connect.NewRequest(&pingv1.PingRequest{Number: num}) @@ -117,7 +242,7 @@ func TestServer(t *testing.T) { assert.Equal(t, connect.CodeOf(err), connect.CodeDeadlineExceeded) }) } - testSum := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + testSum := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper t.Run("sum", func(t *testing.T) { const ( upTo = 10 @@ -153,7 +278,7 @@ func TestServer(t *testing.T) { assert.Equal(t, got.Header().Values(handlerHeader), []string{headerValue}) }) } - testCountUp := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + testCountUp := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper t.Run("count_up", func(t *testing.T) { const upTo = 5 got := make([]int64, 0, upTo) @@ -211,7 +336,7 @@ func TestServer(t *testing.T) { assert.Nil(t, stream.Close()) }) } - testCumSum := func(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper + testCumSum := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, expectSuccess bool) { //nolint:thelper t.Run("cumsum", func(t *testing.T) { send := []int64{3, 5, 1} expect := []int64{3, 8, 9} @@ -326,7 +451,7 @@ func TestServer(t *testing.T) { assert.Nil(t, stream.CloseResponse()) }) } - testErrors := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + testErrors := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper assertIsHTTPMiddlewareError := func(tb testing.TB, err error) { tb.Helper() assert.NotNil(tb, err) @@ -377,7 +502,7 @@ func TestServer(t *testing.T) { testMatrix := func(t *testing.T, client *http.Client, url string, bidi bool) { //nolint:thelper run := func(t *testing.T, opts ...connect.ClientOption) { t.Helper() - client := pingv1connect.NewPingServiceClient(client, url, opts...) + client := pingv1connectgenerics.NewPingServiceClient(client, url, opts...) testPing(t, client) testSum(t, client) testCountUp(t, client) @@ -442,8 +567,8 @@ func TestServer(t *testing.T) { } mux := http.NewServeMux() - pingRoute, pingHandler := pingv1connect.NewPingServiceHandler( - pingServer{checkMetadata: true}, + pingRoute, pingHandler := pingv1connectgenerics.NewPingServiceHandler( + pingServerGenerics{checkMetadata: true}, ) errorWriter := connect.NewErrorWriter() // Add net/http middleware to the ping service to evaluate HTTP state. @@ -464,15 +589,15 @@ func TestServer(t *testing.T) { } // Check Content-Length is set correctly. switch request.URL.Path { - case pingv1connect.PingServicePingProcedure, - pingv1connect.PingServiceFailProcedure, - pingv1connect.PingServiceCountUpProcedure: + case pingv1connectgenerics.PingServicePingProcedure, + pingv1connectgenerics.PingServiceFailProcedure, + pingv1connectgenerics.PingServiceCountUpProcedure: // Unary requests set Content-Length to the length of the request body. if request.ContentLength < 0 { t.Errorf("%s: expected Content-Length >= 0, got %d", request.URL.Path, request.ContentLength) } - case pingv1connect.PingServiceSumProcedure, - pingv1connect.PingServiceCumSumProcedure: + case pingv1connectgenerics.PingServiceSumProcedure, + pingv1connectgenerics.PingServiceCumSumProcedure: // Streaming requests set Content-Length to -1 or 0 on empty requests. if request.ContentLength > 0 { t.Errorf("%s: expected Content-Length -1 or 0, got %d", request.URL.Path, request.ContentLength) @@ -503,7 +628,7 @@ func TestConcurrentStreams(t *testing.T) { } t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{})) server := memhttptest.NewServer(t, mux) var done, start sync.WaitGroup start.Add(1) @@ -511,7 +636,7 @@ func TestConcurrentStreams(t *testing.T) { done.Add(1) go func() { defer done.Done() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) var total int64 sum := client.CumSum(context.Background()) start.Wait() @@ -575,7 +700,7 @@ func TestErrorHeaderPropagation(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) assertError := func(t *testing.T, err error, allowCustomHeaders bool) { @@ -612,7 +737,7 @@ func TestErrorHeaderPropagation(t *testing.T) { assert.Equal(t, meta.Values("X-Test"), []string(nil)) } } - testServices := func(t *testing.T, client pingv1connect.PingServiceClient) { + testServices := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { t.Helper() t.Run("unary", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) @@ -660,17 +785,17 @@ func TestErrorHeaderPropagation(t *testing.T) { } t.Run("connect", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) testServices(t, client) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) testServices(t, client) }) t.Run("grpc-web", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) testServices(t, client) }) } @@ -692,10 +817,10 @@ func TestHeaderBasic(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) request := connect.NewRequest(&pingv1.PingRequest{}) request.Header().Set(key, cval) response, err := client.Ping(context.Background(), request) @@ -721,12 +846,12 @@ func TestHeaderHost(t *testing.T) { newHTTP2Server := func(t *testing.T) *memhttp.Server { t.Helper() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) return server } - callWithHost := func(t *testing.T, client pingv1connect.PingServiceClient) { + callWithHost := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { t.Helper() request := connect.NewRequest(&pingv1.PingRequest{}) @@ -739,21 +864,21 @@ func TestHeaderHost(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) callWithHost(t, client) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) callWithHost(t, client) }) t.Run("grpc-web", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) callWithHost(t, client) }) } @@ -772,12 +897,12 @@ func TestTimeoutParsing(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) } @@ -786,7 +911,7 @@ func TestFailCodec(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) server := memhttptest.NewServer(t, handler) - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), connect.WithCodec(failCodec{}), @@ -803,7 +928,7 @@ func TestContextError(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) server := memhttptest.NewServer(t, handler) - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), ) @@ -822,8 +947,8 @@ func TestGRPCMarshalStatusError(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler( - pingServer{ + mux.Handle(pingv1connectgenerics.NewPingServiceHandler( + pingServerGenerics{ // Include error details in the response, so that the Status protobuf will be marshaled. includeErrorDetails: true, }, @@ -834,7 +959,7 @@ func TestGRPCMarshalStatusError(t *testing.T) { assertInternalError := func(tb testing.TB, opts ...connect.ClientOption) { tb.Helper() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), opts...) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), opts...) request := connect.NewRequest(&pingv1.FailRequest{Code: int32(connect.CodeResourceExhausted)}) _, err := client.Fail(context.Background(), request) tb.Log(err) @@ -871,7 +996,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { pingServer{checkMetadata: true}, )) server := memhttptest.NewServer(t, trimTrailers(mux)) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) assertErrorNoTrailers := func(t *testing.T, err error) { t.Helper() @@ -935,7 +1060,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { func TestUnavailableIfHostInvalid(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( http.DefaultClient, "https://api.invalid/", ) @@ -955,7 +1080,7 @@ func TestBidiRequiresHTTP2(t *testing.T) { assert.Nil(t, err) }) server := memhttptest.NewServer(t, handler) - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( &http.Client{Transport: server.TransportHTTP1()}, server.URL(), ) @@ -986,7 +1111,7 @@ func TestCompressMinBytesClient(t *testing.T) { assert.Equal(tb, request.Header.Get("Content-Encoding"), expect) })) server := memhttptest.NewServer(t, mux) - _, err := pingv1connect.NewPingServiceClient( + _, err := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), connect.WithSendGzip(), @@ -1075,7 +1200,7 @@ func TestCustomCompression(t *testing.T) { connect.WithCompression(compressionName, decompressor, compressor), )) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithAcceptCompression(compressionName, decompressor, compressor), connect.WithSendCompression(compressionName), @@ -1094,7 +1219,7 @@ func TestClientWithoutGzipSupport(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithAcceptCompression("gzip", nil, nil), connect.WithSendGzip(), @@ -1144,7 +1269,7 @@ func TestInterceptorReturnsWrongType(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { if _, err := next(ctx, request); err != nil { return nil, err @@ -1176,7 +1301,7 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { return options }), )) - readMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { + readMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_read_max", func(t *testing.T) { t.Parallel() @@ -1227,37 +1352,37 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) } @@ -1268,9 +1393,9 @@ func TestHandlerWithHTTPMaxBytes(t *testing.T) { t.Parallel() const readMaxBytes = 128 mux := http.NewServeMux() - pingRoute, pingHandler := pingv1connect.NewPingServiceHandler(pingServer{}) + pingRoute, pingHandler := pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}) mux.Handle(pingRoute, http.MaxBytesHandler(pingHandler, readMaxBytes)) - run := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { + run := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { t.Helper() t.Run("below_read_max", func(t *testing.T) { t.Parallel() @@ -1308,37 +1433,37 @@ func TestHandlerWithHTTPMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) run(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) run(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) run(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) run(t, client, true) }) } @@ -1354,14 +1479,14 @@ func TestClientWithReadMaxBytes(t *testing.T) { } else { compressionOption = connect.WithCompressMinBytes(math.MaxInt) } - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, compressionOption)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}, compressionOption)) server := memhttptest.NewServer(t, mux) return server } serverUncompressed := createServer(t, false) serverCompressed := createServer(t, true) readMaxBytes := 1024 - readMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { + readMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_read_max", func(t *testing.T) { t.Parallel() @@ -1403,32 +1528,32 @@ func TestClientWithReadMaxBytes(t *testing.T) { } t.Run("connect", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, true) }) } @@ -1436,7 +1561,7 @@ func TestClientWithReadMaxBytes(t *testing.T) { func TestHandlerWithSendMaxBytes(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - sendMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { + sendMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_send_max", func(t *testing.T) { t.Parallel() @@ -1506,37 +1631,37 @@ func TestHandlerWithSendMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, true) }) } @@ -1546,7 +1671,7 @@ func TestClientWithSendMaxBytes(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - sendMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, sendMaxBytes int, compressed bool) { + sendMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, sendMaxBytes int, compressed bool) { t.Helper() t.Run("equal_send_max", func(t *testing.T) { t.Parallel() @@ -1597,37 +1722,37 @@ func TestClientWithSendMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes)) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes)) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) } @@ -1644,9 +1769,9 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), connect.WithClientOptions(opts...), @@ -1680,12 +1805,12 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { func TestStreamForServer(t *testing.T) { t.Parallel() - newPingClient := func(t *testing.T, pingServer pingv1connect.PingServiceHandler) pingv1connect.PingServiceClient { + newPingClient := func(t *testing.T, pingServer pingv1connectgenerics.PingServiceHandler) pingv1connectgenerics.PingServiceClient { t.Helper() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), ) @@ -1851,7 +1976,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { return nil, connect.NewError(connectCode, errors.New("error")) }, } - mux.Handle(pingv1connect.NewPingServiceHandler(pluggableServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pluggableServer)) server := memhttptest.NewServer(t, mux) req, err := http.NewRequestWithContext( context.Background(), @@ -1865,7 +1990,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { assert.Nil(t, err) defer resp.Body.Close() assert.Equal(t, wantHttpStatus, resp.StatusCode) - connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + connectClient := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) connectResp, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) assert.Nil(t, connectResp) @@ -1957,7 +2082,7 @@ func TestFailCompression(t *testing.T) { ), ) server := memhttptest.NewServer(t, mux) - pingclient := pingv1connect.NewPingServiceClient( + pingclient := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), connect.WithAcceptCompression(compressorName, decompressor, compressor), @@ -2006,7 +2131,7 @@ func TestUnflushableResponseWriter(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - pingclient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), tt.options...) + pingclient := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), tt.options...) stream, err := pingclient.CountUp( context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 5}), @@ -2062,10 +2187,10 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { func TestConnectProtocolHeaderSentByDefault(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithRequireConnectProtocolHeader())) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}, connect.WithRequireConnectProtocolHeader())) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) @@ -2114,7 +2239,7 @@ func TestAllowCustomUserAgent(t *testing.T) { const customAgent = "custom" mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ ping: func(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { agent := req.Header().Get("User-Agent") assert.Equal(t, agent, customAgent) @@ -2133,7 +2258,7 @@ func TestAllowCustomUserAgent(t *testing.T) { {"grpcweb", []connect.ClientOption{connect.WithGRPCWeb()}}, } for _, testCase := range tests { - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), testCase.opts...) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), testCase.opts...) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) req.Header().Set("User-Agent", customAgent) _, err := client.Ping(context.Background(), req) @@ -2145,7 +2270,7 @@ func TestWebXUserAgent(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ ping: func(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { agent := req.Header().Get("User-Agent") assert.NotZero(t, agent) @@ -2159,7 +2284,7 @@ func TestWebXUserAgent(t *testing.T) { })) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) _, err := client.Ping(context.Background(), req) assert.Nil(t, err) @@ -2174,7 +2299,7 @@ func TestBidiOverHTTP1(t *testing.T) { // Clients expecting a full-duplex connection that end up with a simplex // HTTP/1.1 connection shouldn't hang. Instead, the server should close the // TCP connection. - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( &http.Client{Transport: server.TransportHTTP1()}, server.URL(), ) @@ -2210,7 +2335,7 @@ func TestHandlerReturnsNilResponse(t *testing.T) { } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ ping: func(ctx context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { return nil, nil //nolint: nilnil }, @@ -2219,7 +2344,7 @@ func TestHandlerReturnsNilResponse(t *testing.T) { }, }, connect.WithRecover(recoverPanic))) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) @@ -2465,7 +2590,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { for _, testcase := range testcases { t.Run(testcase.name, func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient( + client := pingv1connectgenerics.NewPingServiceClient( server.Client(), server.URL(), testcase.options..., @@ -2539,12 +2664,12 @@ func TestClientDisconnect(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) var clientConn net.Conn transport := captureTransport(server, &clientConn, gotRequest) serverClient := &http.Client{Transport: transport} - client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(serverClient, server.URL()) stream := client.Sum(context.Background()) // Send header. assert.Nil(t, stream.Send(nil)) @@ -2582,12 +2707,12 @@ func TestClientDisconnect(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) var clientConn net.Conn transport := captureTransport(server, &clientConn, gotRequest) serverClient := &http.Client{Transport: transport} - client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) + client := pingv1connectgenerics.NewPingServiceClient(serverClient, server.URL()) stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) if !assert.Nil(t, err) { return @@ -2652,9 +2777,9 @@ func TestSetProtocolHeaders(t *testing.T) { testcase := tt t.Run(testcase.name, func(t *testing.T) { t.Parallel() - pingServer := &pingServer{} + pingServer := &pingServerGenerics{} mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) clientOpts := []connect.ClientOption{} @@ -2662,7 +2787,7 @@ func TestSetProtocolHeaders(t *testing.T) { // Use a different protocol to test the override. clientOpts = append(clientOpts, connect.WithGRPC()) } - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), clientOpts...) + client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), clientOpts...) pingProxyServer := &pluggablePingServer{ ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { @@ -2670,14 +2795,14 @@ func TestSetProtocolHeaders(t *testing.T) { }, } proxyMux := http.NewServeMux() - proxyMux.Handle(pingv1connect.NewPingServiceHandler(pingProxyServer)) + proxyMux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingProxyServer)) proxyServer := memhttptest.NewServer(t, proxyMux) proxyClientOpts := []connect.ClientOption{} if testcase.clientOption != nil { proxyClientOpts = append(proxyClientOpts, testcase.clientOption) } - proxyClient := pingv1connect.NewPingServiceClient(proxyServer.Client(), proxyServer.URL(), proxyClientOpts...) + proxyClient := pingv1connectgenerics.NewPingServiceClient(proxyServer.Client(), proxyServer.URL(), proxyClientOpts...) request := connect.NewRequest(&pingv1.PingRequest{Number: 42}) request.Header().Set("X-Test", t.Name()) @@ -2731,7 +2856,7 @@ func (c failCodec) Unmarshal(data []byte, message any) error { } type pluggablePingServer struct { - pingv1connect.UnimplementedPingServiceHandler + pingv1connectgenerics.UnimplementedPingServiceHandler ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) sum func(context.Context, *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) @@ -2792,7 +2917,7 @@ func expectClientHeader(check bool, req connect.AnyRequest) error { return expectMetadata(req.Header(), "header", clientHeader, headerValue) } -func expectMetadata(meta http.Header, metaType, key, value string) error { +func expectMetadata(meta http.Header, metaType, key, value string) error { //nolint:unparam if got := meta.Get(key); got != value { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( "%s %q: got %q, expected %q", @@ -2805,14 +2930,14 @@ func expectMetadata(meta http.Header, metaType, key, value string) error { return nil } -type pingServer struct { - pingv1connect.UnimplementedPingServiceHandler +type pingServerGenerics struct { + pingv1connectgenerics.UnimplementedPingServiceHandler checkMetadata bool includeErrorDetails bool } -func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { +func (p pingServerGenerics) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2833,7 +2958,7 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi return response, nil } -func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { +func (p pingServerGenerics) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2859,7 +2984,7 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa return nil, err } -func (p pingServer) Sum( +func (p pingServerGenerics) Sum( ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest], ) (*connect.Response[pingv1.SumResponse], error) { @@ -2887,7 +3012,7 @@ func (p pingServer) Sum( return response, nil } -func (p pingServer) CountUp( +func (p pingServerGenerics) CountUp( ctx context.Context, request *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse], @@ -2917,7 +3042,7 @@ func (p pingServer) CountUp( return nil } -func (p pingServer) CumSum( +func (p pingServerGenerics) CumSum( ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], ) error { @@ -2949,6 +3074,107 @@ func (p pingServer) CumSum( } } +func expectClientHeaderInCallInfo(check bool, callInfo connect.CallInfo) error { + if !check { + return nil + } + return expectMetadata(callInfo.RequestHeader(), "header", clientHeader, headerValue) +} + +type pingServer struct { + pingv1connect.UnimplementedPingServiceHandler + + checkMetadata bool + includeErrorDetails bool +} + +func (p pingServer) Ping(ctx context.Context, request *pingv1.PingRequest) (*pingv1.PingResponse, error) { + callInfo, ok := connect.CallInfoFromContext(ctx) + if !ok { + return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) + } + if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { + return nil, err + } + if callInfo.Peer().Addr == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } + if callInfo.Peer().Protocol == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } + response := &pingv1.PingResponse{ + Number: request.GetNumber(), + Text: request.GetText(), + } + callInfo.ResponseHeader().Set(handlerHeader, headerValue) + callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) + return response, nil +} + +func (p pingServer) CountUp( + ctx context.Context, + request *pingv1.CountUpRequest, + stream *connect.ServerStream[pingv1.CountUpResponse], +) error { + callInfo, ok := connect.CallInfoFromContext(ctx) + if !ok { + return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) + } + if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { + return err + } + if callInfo.Peer().Addr == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } + if callInfo.Peer().Protocol == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } + if request.GetNumber() <= 0 { + return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( + "number must be positive: got %v", + request.GetNumber(), + )) + } + callInfo.ResponseHeader().Set(handlerHeader, headerValue) + callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) + for i := range request.GetNumber() { + if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { + return err + } + } + return nil +} + +func (p pingServer) Fail(ctx context.Context, request *pingv1.FailRequest) (*pingv1.FailResponse, error) { + callInfo, ok := connect.CallInfoFromContext(ctx) + if !ok { + return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) + } + if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { + return nil, err + } + if callInfo.Peer().Addr == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } + if callInfo.Peer().Protocol == "" { + return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } + err := connect.NewError( + connect.Code(request.GetCode()), + errors.New(errorMessage), + ) + err.Meta().Set(handlerHeader, headerValue) + err.Meta().Set(handlerTrailer, trailerValue) + if p.includeErrorDetails { + detail, derr := connect.NewErrorDetail(&pingv1.FailRequest{Code: request.GetCode()}) + if derr != nil { + return nil, derr + } + err.AddDetail(detail) + } + return nil, err +} + type deflateReader struct { r io.ReadCloser } diff --git a/context.go b/context.go index 7a7bdbd5..b99e7430 100644 --- a/context.go +++ b/context.go @@ -19,68 +19,119 @@ import ( "net/http" ) -type requestIncomingHeaderContextKey struct{} -type requestOutgoingHeaderContextKey struct{} -type responseHeaderAddressContextKey struct{} -type responseTrailerAddressContextKey struct{} - -// HeaderFromIncomingContext gets the header from a request sent to a handler. -func HeaderFromIncomingContext(ctx context.Context) (http.Header, bool) { - value, ok := ctx.Value(requestIncomingHeaderContextKey{}).(http.Header) - return value, ok +type CallInfo interface { + // Spec returns a description of this call. + Spec() Spec + // Peer describes the other party for this call. + Peer() Peer + // HTTPMethod returns the HTTP method for this request. This is nearly always + // POST, but side-effect-free unary RPCs could be made via a GET. + // + // On a newly created request, via NewRequest, this will return the empty + // string until the actual request is actually sent and the HTTP method + // determined. This means that client interceptor functions will see the + // empty string until *after* they delegate to the handler they wrapped. It + // is even possible for this to return the empty string after such delegation, + // if the request was never actually sent to the server (and thus no + // determination ever made about the HTTP method). + HTTPMethod() string + // RequestHeader returns the HTTP headers for this request. Headers beginning with + // "Connect-" and "Grpc-" are reserved for use by the Connect and gRPC + // protocols: applications may read them but shouldn't write them. + RequestHeader() http.Header + // ResponseHeader returns the HTTP headers for this response. Headers beginning with + // "Connect-" and "Grpc-" are reserved for use by the Connect and gRPC + // protocols: applications may read them but shouldn't write them. + ResponseHeader() http.Header + // ResponseTrailer returns the trailers for this response. Depending on the underlying + // RPC protocol, trailers may be sent as HTTP trailers or a protocol-specific + // block of in-body metadata. + // + // Trailers beginning with "Connect-" and "Grpc-" are reserved for use by the + // Connect and gRPC protocols: applications may read them but shouldn't write + // them. + ResponseTrailer() http.Header + + internalOnly() } -// HeaderFromOutgoingContext gets the header from a request sent by a client. -func HeaderFromOutgoingContext(ctx context.Context) (http.Header, bool) { - value, ok := ctx.Value(requestOutgoingHeaderContextKey{}).(http.Header) - return value, ok +type callInfo struct { + spec Spec + peer Peer + method string + requestHeader http.Header + responseHeader http.Header + responseTrailer http.Header } -// WithIncomingHeader adds the header to the context from a request sent to a handler. -func WithIncomingHeader(ctx context.Context, header http.Header) context.Context { - return context.WithValue(ctx, requestIncomingHeaderContextKey{}, header) +func (c *callInfo) Spec() Spec { + return c.spec } -// WithOutgoingHeader adds the header to the context from a request sent by a client. -func WithOutgoingHeader(ctx context.Context, header http.Header) context.Context { - return context.WithValue(ctx, requestOutgoingHeaderContextKey{}, header) +func (c *callInfo) Peer() Peer { + return c.peer } -// WithStoreResponseHeader returns a new context to be given to a client when making a request -// that will result in the header pointer being set to the response header. -func WithStoreResponseHeader(ctx context.Context, header *http.Header) context.Context { - return context.WithValue(ctx, responseHeaderAddressContextKey{}, header) +func (c *callInfo) RequestHeader() http.Header { + if c.requestHeader == nil { + c.requestHeader = make(http.Header) + } + return c.requestHeader } -// WithStoreResponseTrailer returns a new context to be given to a client when making a request -// that will result in the trailer pointer being set to the response trailer. -func WithStoreResponseTrailer(ctx context.Context, trailer *http.Header) context.Context { - return context.WithValue(ctx, responseTrailerAddressContextKey{}, trailer) +func (c *callInfo) ResponseHeader() http.Header { + if c.responseHeader == nil { + c.responseHeader = make(http.Header) + } + return c.responseHeader } -// SetResponseHeader sets the response header within a simple handler implementation. -func SetResponseHeader(ctx context.Context, header http.Header) { - responseHeaderAddress, ok := ctx.Value(responseHeaderAddressContextKey{}).(*http.Header) - if !ok { - return +func (c *callInfo) ResponseTrailer() http.Header { + if c.responseTrailer == nil { + c.responseTrailer = make(http.Header) } - *responseHeaderAddress = header + return c.responseTrailer +} + +func (c *callInfo) HTTPMethod() string { + return c.method } -// SetResponseTrailer sets the response trailer within a simple handler implementation. -func SetResponseTrailer(ctx context.Context, trailer http.Header) { - responseTrailerAddress, ok := ctx.Value(responseTrailerAddressContextKey{}).(*http.Header) +// internalOnly implements CallInfo. +func (c *callInfo) internalOnly() {} + +type callInfoContextKey struct{} + +// Create a new request context for use from a client. When the returned +// context is passed to RPCs, the returned call info can be used to set +// request metadata before the RPC is invoked and to inspect response +// metadata after the RPC completes. +// +// The returned context may be re-used across RPCs as long as they are +// not concurrent. Results of all CallInfo methods other than +// RequestHeader() are undefined if the context is used with concurrent RPCs. +// If the given context is already associated with an outgoing CallInfo, then +// ctx and the existing CallInfo are returned. +func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { + info, ok := ctx.Value(callInfoContextKey{}).(CallInfo) if !ok { - return + info = &callInfo{} + return context.WithValue(ctx, callInfoContextKey{}, info), info } - *responseTrailerAddress = trailer + return ctx, info +} + +// CallInfoFromContext returns the CallInfo for the given context, if there is one. +func CallInfoFromContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(callInfoContextKey{}).(CallInfo) + return value, ok } func requestFromContext[T any](ctx context.Context, message *T) *Request[T] { - request := NewRequest[T](message) - header, ok := HeaderFromOutgoingContext(ctx) + request := NewRequest(message) + callInfo, ok := CallInfoFromContext(ctx) if ok { - request.setHeader(header) + request.setHeader(callInfo.RequestHeader()) } return request } diff --git a/error_example_test.go b/error_example_test.go index d8155f75..30930a97 100644 --- a/error_example_test.go +++ b/error_example_test.go @@ -22,7 +22,7 @@ import ( connect "connectrpc.com/connect" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" - "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" ) func ExampleError_Message() { @@ -47,14 +47,15 @@ func ExampleIsNotModifiedError() { // Enable client-side support for HTTP GETs. connect.WithHTTPGet(), ) - req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) - first, err := client.Ping(context.Background(), req) + req := &pingv1.PingRequest{Number: 42} + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + _, err := client.Ping(ctx, req) if err != nil { fmt.Println(err) return } // If the server set an Etag, we can use it to cache the response. - etag := first.Header().Get("Etag") + etag := callInfo.ResponseHeader().Get("Etag") if etag == "" { fmt.Println("no Etag in response headers") return @@ -62,7 +63,7 @@ func ExampleIsNotModifiedError() { fmt.Println("cached response with Etag", etag) // Now we'd like to make the same request again, but avoid re-fetching the // response if possible. - req.Header().Set("If-None-Match", etag) + callInfo.RequestHeader().Set("If-None-Match", etag) _, err = client.Ping(context.Background(), req) if connect.IsNotModifiedError(err) { fmt.Println("can reuse cached response") diff --git a/error_not_modified_example_test.go b/error_not_modified_example_test.go index fe520548..fc9c9925 100644 --- a/error_not_modified_example_test.go +++ b/error_not_modified_example_test.go @@ -16,12 +16,13 @@ package connect_test import ( "context" + "errors" "net/http" "strconv" connect "connectrpc.com/connect" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" - "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" ) // ExampleCachingServer is an example of how servers can take advantage the @@ -35,22 +36,27 @@ type ExampleCachingPingServer struct { // indicates this), so clients using the Connect protocol may call it with HTTP // GET requests. This implementation uses Etags to manage client-side caching. func (*ExampleCachingPingServer) Ping( - _ context.Context, - req *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { - resp := connect.NewResponse(&pingv1.PingResponse{ - Number: req.Msg.GetNumber(), - }) + ctx context.Context, + req *pingv1.PingRequest, +) (*pingv1.PingResponse, error) { + resp := &pingv1.PingResponse{ + Number: req.GetNumber(), + } + callInfo, ok := connect.CallInfoFromContext(ctx) + if !ok { + return nil, errors.New("not call info found in context") + } + // Our hashing logic is simple: we use the number in the PingResponse. - hash := strconv.FormatInt(resp.Msg.GetNumber(), 10) + hash := strconv.FormatInt(resp.GetNumber(), 10) // If the request was an HTTP GET, we'll need to check if the client already // has the response cached. - if req.HTTPMethod() == http.MethodGet && req.Header().Get("If-None-Match") == hash { + if callInfo.HTTPMethod() == http.MethodGet && callInfo.RequestHeader().Get("If-None-Match") == hash { return nil, connect.NewNotModifiedError(http.Header{ "Etag": []string{hash}, }) } - resp.Header().Set("Etag", hash) + callInfo.ResponseHeader().Set("Etag", hash) return resp, nil } diff --git a/example_init_test.go b/example_init_test.go index e7abee52..d14275b6 100644 --- a/example_init_test.go +++ b/example_init_test.go @@ -17,7 +17,7 @@ package connect_test import ( "net/http" - "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp" ) diff --git a/handler.go b/handler.go index 9fc9bb6e..9bbd8d8e 100644 --- a/handler.go +++ b/handler.go @@ -16,6 +16,7 @@ package connect import ( "context" + "maps" "net/http" ) @@ -66,12 +67,28 @@ func NewUnaryHandler[Req, Res any]( if err != nil { return err } + // Add the request header to the context, and store the response header + // and trailer to propagate back to the caller. + ctx, ci := NewOutgoingContext(ctx) + call, ok := ci.(*callInfo) + if ok { + call.peer = request.Peer() + call.spec = request.Spec() + call.method = request.HTTPMethod() + call.requestHeader = request.Header() + } response, err := untyped(ctx, request) if err != nil { return err } + // Add response headers/trailers into the conn mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) + + // Add response headers/trailers into the context callinfo also + mergeNonProtocolHeaders(call.ResponseHeader(), response.Header()) + mergeNonProtocolHeaders(call.ResponseTrailer(), response.Trailer()) + return conn.Send(response.Any()) } @@ -98,28 +115,17 @@ func NewUnaryHandlerSimple[Req, Res any]( return NewUnaryHandler( procedure, func(ctx context.Context, request *Request[Req]) (*Response[Res], error) { - var responseHeader http.Header - var responseTrailer http.Header - // Add the request header to the context, and store the response header - // and trailer to propagate back to the caller. - ctx = WithIncomingHeader( - WithStoreResponseHeader( - WithStoreResponseTrailer( - ctx, - &responseTrailer, - ), - &responseHeader, - ), - request.Header(), - ) responseMsg, err := unary(ctx, request.Msg) - if responseMsg != nil { - response := NewResponse(responseMsg) - response.setHeader(responseHeader) - response.setTrailer(responseHeader) - return response, err + if err != nil { + return nil, err } - return nil, err + response := NewResponse(responseMsg) + callInfo, ok := CallInfoFromContext(ctx) + if ok { + response.setHeader(callInfo.ResponseHeader()) + response.setTrailer(callInfo.ResponseTrailer()) + } + return response, err }, options..., ) @@ -169,6 +175,13 @@ func NewServerStreamHandler[Req, Res any]( if err != nil { return err } + ctx, ci := NewOutgoingContext(ctx) + callInfo, _ := ci.(*callInfo) + callInfo.peer = req.Peer() + callInfo.spec = req.Spec() + callInfo.method = req.HTTPMethod() + maps.Copy(callInfo.RequestHeader(), req.Header()) + return implementation(ctx, req, &ServerStream[Res]{conn: conn}) }, ) @@ -187,11 +200,6 @@ func NewServerStreamHandlerSimple[Req, Res any]( return NewServerStreamHandler( procedure, func(ctx context.Context, request *Request[Req], serverStream *ServerStream[Res]) error { - // Add the request header to the context. - ctx = WithIncomingHeader( - ctx, - request.Header(), - ) return implementation(ctx, request.Msg, serverStream) }, options..., diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index b5892fde..a630cff9 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -123,7 +123,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( - pingServer{}, + pingServerGenerics{}, handlerOnion, ), ) @@ -171,7 +171,7 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { return next(ctx, request) } }) - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithInterceptors(interceptor))) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{}, connect.WithInterceptors(interceptor))) server := memhttptest.NewServer(t, mux) connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(interceptor)) _, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) @@ -197,7 +197,7 @@ func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( - pingServer{}, + pingServerGenerics{}, connect.WithInterceptors(handlerChecker), ), ) From 194fb35ac989316283496bbe3a88ec986c943976 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 16:19:35 -0400 Subject: [PATCH 32/57] Tests Signed-off-by: Steve Ayers --- client.go | 2 +- connect.go | 2 ++ connect_ext_test.go | 20 ++------------------ handler.go | 1 + 4 files changed, 6 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 792430bb..a2d0240e 100644 --- a/client.go +++ b/client.go @@ -189,7 +189,7 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques info.peer = conn.Peer() info.spec = conn.Spec() - mergeHeaders(conn.RequestHeader(), info.requestHeader) + mergeHeaders(info.RequestHeader(), request.header) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. diff --git a/connect.go b/connect.go index 596bc7fd..5843f62a 100644 --- a/connect.go +++ b/connect.go @@ -392,6 +392,8 @@ func receiveUnaryResponse[T any](conn StreamingClientConn, initializer maybeInit if err != nil { return nil, err } + fmt.Printf("Header %+v", conn.ResponseHeader()) + fmt.Printf("trailer %+v", conn.ResponseTrailer()) return &Response[T]{ Msg: msg, header: conn.ResponseHeader(), diff --git a/connect_ext_test.go b/connect_ext_test.go index 61908a13..2a738536 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -166,22 +166,6 @@ func TestCallInfo(t *testing.T) { assert.Nil(t, stream.Close()) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) - // num := int64(42) - // ctx, callInfo := connect.NewOutgoingContext(context.Background()) - // callInfo.RequestHeader().Set(clientHeader, headerValue) - // expect := &pingv1.PingResponse{Number: num} - - // response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) - // assert.Equal(t, response, expect) - // assert.Nil(t, err) - - // // Assert call info values are correctly populated - // assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) - // assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) - // assert.True(t, callInfo.Spec().IsClient) - // assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) - // assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) } @@ -3135,8 +3119,8 @@ func (p pingServer) CountUp( request.GetNumber(), )) } - callInfo.ResponseHeader().Set(handlerHeader, headerValue) - callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) + stream.Conn().ResponseHeader().Set(handlerHeader, headerValue) + stream.Conn().ResponseTrailer().Set(handlerTrailer, trailerValue) for i := range request.GetNumber() { if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { return err diff --git a/handler.go b/handler.go index 9bbd8d8e..3612600e 100644 --- a/handler.go +++ b/handler.go @@ -81,6 +81,7 @@ func NewUnaryHandler[Req, Res any]( if err != nil { return err } + // Add response headers/trailers into the conn mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) From 39380da6f601b9d5c092e26dd521b6b1e4e9800a Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 16:22:13 -0400 Subject: [PATCH 33/57] Remove print Signed-off-by: Steve Ayers --- connect.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/connect.go b/connect.go index 5843f62a..596bc7fd 100644 --- a/connect.go +++ b/connect.go @@ -392,8 +392,6 @@ func receiveUnaryResponse[T any](conn StreamingClientConn, initializer maybeInit if err != nil { return nil, err } - fmt.Printf("Header %+v", conn.ResponseHeader()) - fmt.Printf("trailer %+v", conn.ResponseTrailer()) return &Response[T]{ Msg: msg, header: conn.ResponseHeader(), From 36bc4bc3048c2b368f43e63b10f2cd322f67c3c6 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 16:36:13 -0400 Subject: [PATCH 34/57] Simplify Signed-off-by: Steve Ayers --- client_ext_test.go | 10 +- connect_ext_test.go | 256 ++++++++++++++++++++-------------------- example_init_test.go | 2 +- interceptor_ext_test.go | 6 +- 4 files changed, 137 insertions(+), 137 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index 14ae2773..4e5b8351 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -89,7 +89,7 @@ func TestNewClient_InitFailure(t *testing.T) { func TestClientPeer(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) run := func(t *testing.T, unaryHTTPMethod string, opts ...connect.ClientOption) { @@ -205,7 +205,7 @@ func TestGetNoContentHeaders(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(&pingServerGenerics{})) + mux.Handle(pingv1connect.NewPingServiceHandler(&pingServer{})) server := memhttptest.NewServer(t, http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) { if len(req.Header.Values("content-type")) > 0 || len(req.Header.Values("content-encoding")) > 0 || @@ -283,7 +283,7 @@ func TestSpecSchema(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler( - pingServerGenerics{}, + pingServer{}, connect.WithInterceptors(&assertSchemaInterceptor{t}), )) server := memhttptest.NewServer(t, mux) @@ -320,7 +320,7 @@ func TestSpecSchema(t *testing.T) { func TestDynamicClient(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) ctx := context.Background() initializer := func(spec connect.Spec, msg any) error { @@ -494,7 +494,7 @@ func TestClientDeadlineHandling(t *testing.T) { // detector enabled. That's partly why the makefile only runs "slow" // tests with the race detector disabled. - _, handler := pingv1connect.NewPingServiceHandler(pingServerGenerics{}) + _, handler := pingv1connect.NewPingServiceHandler(pingServer{}) svr := httptest.NewUnstartedServer(http.HandlerFunc(func(respWriter http.ResponseWriter, req *http.Request) { if req.Context().Err() != nil { return diff --git a/connect_ext_test.go b/connect_ext_test.go index 2a738536..6327d5de 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -39,9 +39,9 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" - pingv1connectgenerics "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" - "connectrpc.com/connect/internal/gen/simple/connect/import/v1/importv1connect" - "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/generics/connect/import/v1/importv1connect" + "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + pingv1connectsimple "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/proto" @@ -68,11 +68,11 @@ func TestCallInfo(t *testing.T) { t.Run("simple_api", func(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler( - pingServer{checkMetadata: true}, + mux.Handle(pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{checkMetadata: true}, )) server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) t.Run("unary", func(t *testing.T) { num := int64(42) ctx, callInfo := connect.NewOutgoingContext(context.Background()) @@ -117,11 +117,11 @@ func TestCallInfo(t *testing.T) { t.Run("generics_api", func(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler( - pingServerGenerics{checkMetadata: true}, + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{checkMetadata: true}, )) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) t.Run("unary", func(t *testing.T) { num := int64(42) request := connect.NewRequest(&pingv1.PingRequest{Number: num}) @@ -172,7 +172,7 @@ func TestCallInfo(t *testing.T) { func TestServer(t *testing.T) { t.Parallel() - testPing := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper + testPing := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("ping", func(t *testing.T) { num := int64(42) request := connect.NewRequest(&pingv1.PingRequest{Number: num}) @@ -226,7 +226,7 @@ func TestServer(t *testing.T) { assert.Equal(t, connect.CodeOf(err), connect.CodeDeadlineExceeded) }) } - testSum := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper + testSum := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("sum", func(t *testing.T) { const ( upTo = 10 @@ -262,7 +262,7 @@ func TestServer(t *testing.T) { assert.Equal(t, got.Header().Values(handlerHeader), []string{headerValue}) }) } - testCountUp := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper + testCountUp := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("count_up", func(t *testing.T) { const upTo = 5 got := make([]int64, 0, upTo) @@ -320,7 +320,7 @@ func TestServer(t *testing.T) { assert.Nil(t, stream.Close()) }) } - testCumSum := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, expectSuccess bool) { //nolint:thelper + testCumSum := func(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper t.Run("cumsum", func(t *testing.T) { send := []int64{3, 5, 1} expect := []int64{3, 8, 9} @@ -435,7 +435,7 @@ func TestServer(t *testing.T) { assert.Nil(t, stream.CloseResponse()) }) } - testErrors := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { //nolint:thelper + testErrors := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper assertIsHTTPMiddlewareError := func(tb testing.TB, err error) { tb.Helper() assert.NotNil(tb, err) @@ -486,7 +486,7 @@ func TestServer(t *testing.T) { testMatrix := func(t *testing.T, client *http.Client, url string, bidi bool) { //nolint:thelper run := func(t *testing.T, opts ...connect.ClientOption) { t.Helper() - client := pingv1connectgenerics.NewPingServiceClient(client, url, opts...) + client := pingv1connect.NewPingServiceClient(client, url, opts...) testPing(t, client) testSum(t, client) testCountUp(t, client) @@ -551,8 +551,8 @@ func TestServer(t *testing.T) { } mux := http.NewServeMux() - pingRoute, pingHandler := pingv1connectgenerics.NewPingServiceHandler( - pingServerGenerics{checkMetadata: true}, + pingRoute, pingHandler := pingv1connect.NewPingServiceHandler( + pingServer{checkMetadata: true}, ) errorWriter := connect.NewErrorWriter() // Add net/http middleware to the ping service to evaluate HTTP state. @@ -573,15 +573,15 @@ func TestServer(t *testing.T) { } // Check Content-Length is set correctly. switch request.URL.Path { - case pingv1connectgenerics.PingServicePingProcedure, - pingv1connectgenerics.PingServiceFailProcedure, - pingv1connectgenerics.PingServiceCountUpProcedure: + case pingv1connect.PingServicePingProcedure, + pingv1connect.PingServiceFailProcedure, + pingv1connect.PingServiceCountUpProcedure: // Unary requests set Content-Length to the length of the request body. if request.ContentLength < 0 { t.Errorf("%s: expected Content-Length >= 0, got %d", request.URL.Path, request.ContentLength) } - case pingv1connectgenerics.PingServiceSumProcedure, - pingv1connectgenerics.PingServiceCumSumProcedure: + case pingv1connect.PingServiceSumProcedure, + pingv1connect.PingServiceCumSumProcedure: // Streaming requests set Content-Length to -1 or 0 on empty requests. if request.ContentLength > 0 { t.Errorf("%s: expected Content-Length -1 or 0, got %d", request.URL.Path, request.ContentLength) @@ -612,7 +612,7 @@ func TestConcurrentStreams(t *testing.T) { } t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) var done, start sync.WaitGroup start.Add(1) @@ -620,7 +620,7 @@ func TestConcurrentStreams(t *testing.T) { done.Add(1) go func() { defer done.Done() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) var total int64 sum := client.CumSum(context.Background()) start.Wait() @@ -684,7 +684,7 @@ func TestErrorHeaderPropagation(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) assertError := func(t *testing.T, err error, allowCustomHeaders bool) { @@ -721,7 +721,7 @@ func TestErrorHeaderPropagation(t *testing.T) { assert.Equal(t, meta.Values("X-Test"), []string(nil)) } } - testServices := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { + testServices := func(t *testing.T, client pingv1connect.PingServiceClient) { t.Helper() t.Run("unary", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) @@ -769,17 +769,17 @@ func TestErrorHeaderPropagation(t *testing.T) { } t.Run("connect", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) testServices(t, client) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) testServices(t, client) }) t.Run("grpc-web", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) testServices(t, client) }) } @@ -801,10 +801,10 @@ func TestHeaderBasic(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) request := connect.NewRequest(&pingv1.PingRequest{}) request.Header().Set(key, cval) response, err := client.Ping(context.Background(), request) @@ -830,12 +830,12 @@ func TestHeaderHost(t *testing.T) { newHTTP2Server := func(t *testing.T) *memhttp.Server { t.Helper() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) return server } - callWithHost := func(t *testing.T, client pingv1connectgenerics.PingServiceClient) { + callWithHost := func(t *testing.T, client pingv1connect.PingServiceClient) { t.Helper() request := connect.NewRequest(&pingv1.PingRequest{}) @@ -848,21 +848,21 @@ func TestHeaderHost(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) callWithHost(t, client) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) callWithHost(t, client) }) t.Run("grpc-web", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) callWithHost(t, client) }) } @@ -881,12 +881,12 @@ func TestTimeoutParsing(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) } @@ -895,7 +895,7 @@ func TestFailCodec(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) server := memhttptest.NewServer(t, handler) - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), connect.WithCodec(failCodec{}), @@ -912,7 +912,7 @@ func TestContextError(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) server := memhttptest.NewServer(t, handler) - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), ) @@ -931,8 +931,8 @@ func TestGRPCMarshalStatusError(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler( - pingServerGenerics{ + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{ // Include error details in the response, so that the Status protobuf will be marshaled. includeErrorDetails: true, }, @@ -943,7 +943,7 @@ func TestGRPCMarshalStatusError(t *testing.T) { assertInternalError := func(tb testing.TB, opts ...connect.ClientOption) { tb.Helper() - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), opts...) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), opts...) request := connect.NewRequest(&pingv1.FailRequest{Code: int32(connect.CodeResourceExhausted)}) _, err := client.Fail(context.Background(), request) tb.Log(err) @@ -980,7 +980,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { pingServer{checkMetadata: true}, )) server := memhttptest.NewServer(t, trimTrailers(mux)) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) assertErrorNoTrailers := func(t *testing.T, err error) { t.Helper() @@ -1044,7 +1044,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { func TestUnavailableIfHostInvalid(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( http.DefaultClient, "https://api.invalid/", ) @@ -1064,7 +1064,7 @@ func TestBidiRequiresHTTP2(t *testing.T) { assert.Nil(t, err) }) server := memhttptest.NewServer(t, handler) - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( &http.Client{Transport: server.TransportHTTP1()}, server.URL(), ) @@ -1095,7 +1095,7 @@ func TestCompressMinBytesClient(t *testing.T) { assert.Equal(tb, request.Header.Get("Content-Encoding"), expect) })) server := memhttptest.NewServer(t, mux) - _, err := pingv1connectgenerics.NewPingServiceClient( + _, err := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), connect.WithSendGzip(), @@ -1184,7 +1184,7 @@ func TestCustomCompression(t *testing.T) { connect.WithCompression(compressionName, decompressor, compressor), )) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithAcceptCompression(compressionName, decompressor, compressor), connect.WithSendCompression(compressionName), @@ -1203,7 +1203,7 @@ func TestClientWithoutGzipSupport(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithAcceptCompression("gzip", nil, nil), connect.WithSendGzip(), @@ -1253,7 +1253,7 @@ func TestInterceptorReturnsWrongType(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { if _, err := next(ctx, request); err != nil { return nil, err @@ -1285,7 +1285,7 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { return options }), )) - readMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { + readMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_read_max", func(t *testing.T) { t.Parallel() @@ -1336,37 +1336,37 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) } @@ -1377,9 +1377,9 @@ func TestHandlerWithHTTPMaxBytes(t *testing.T) { t.Parallel() const readMaxBytes = 128 mux := http.NewServeMux() - pingRoute, pingHandler := pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}) + pingRoute, pingHandler := pingv1connect.NewPingServiceHandler(pingServer{}) mux.Handle(pingRoute, http.MaxBytesHandler(pingHandler, readMaxBytes)) - run := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { + run := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { t.Helper() t.Run("below_read_max", func(t *testing.T) { t.Parallel() @@ -1417,37 +1417,37 @@ func TestHandlerWithHTTPMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) run(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) run(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) run(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) run(t, client, true) }) } @@ -1463,14 +1463,14 @@ func TestClientWithReadMaxBytes(t *testing.T) { } else { compressionOption = connect.WithCompressMinBytes(math.MaxInt) } - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}, compressionOption)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, compressionOption)) server := memhttptest.NewServer(t, mux) return server } serverUncompressed := createServer(t, false) serverCompressed := createServer(t, true) readMaxBytes := 1024 - readMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { + readMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_read_max", func(t *testing.T) { t.Parallel() @@ -1512,32 +1512,32 @@ func TestClientWithReadMaxBytes(t *testing.T) { } t.Run("connect", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, true) }) } @@ -1545,7 +1545,7 @@ func TestClientWithReadMaxBytes(t *testing.T) { func TestHandlerWithSendMaxBytes(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - sendMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, compressed bool) { + sendMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) { t.Helper() t.Run("equal_send_max", func(t *testing.T) { t.Parallel() @@ -1615,37 +1615,37 @@ func TestHandlerWithSendMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, true) }) } @@ -1655,7 +1655,7 @@ func TestClientWithSendMaxBytes(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) - sendMaxBytesMatrix := func(t *testing.T, client pingv1connectgenerics.PingServiceClient, sendMaxBytes int, compressed bool) { + sendMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, sendMaxBytes int, compressed bool) { t.Helper() t.Run("equal_send_max", func(t *testing.T) { t.Parallel() @@ -1706,37 +1706,37 @@ func TestClientWithSendMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes)) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes)) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) } @@ -1753,9 +1753,9 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), connect.WithClientOptions(opts...), @@ -1789,12 +1789,12 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { func TestStreamForServer(t *testing.T) { t.Parallel() - newPingClient := func(t *testing.T, pingServer pingv1connectgenerics.PingServiceHandler) pingv1connectgenerics.PingServiceClient { + newPingClient := func(t *testing.T, pingServer pingv1connect.PingServiceHandler) pingv1connect.PingServiceClient { t.Helper() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), ) @@ -1960,7 +1960,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { return nil, connect.NewError(connectCode, errors.New("error")) }, } - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pluggableServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pluggableServer)) server := memhttptest.NewServer(t, mux) req, err := http.NewRequestWithContext( context.Background(), @@ -1974,7 +1974,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { assert.Nil(t, err) defer resp.Body.Close() assert.Equal(t, wantHttpStatus, resp.StatusCode) - connectClient := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) connectResp, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) assert.Nil(t, connectResp) @@ -2066,7 +2066,7 @@ func TestFailCompression(t *testing.T) { ), ) server := memhttptest.NewServer(t, mux) - pingclient := pingv1connectgenerics.NewPingServiceClient( + pingclient := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), connect.WithAcceptCompression(compressorName, decompressor, compressor), @@ -2115,7 +2115,7 @@ func TestUnflushableResponseWriter(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - pingclient := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), tt.options...) + pingclient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), tt.options...) stream, err := pingclient.CountUp( context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 5}), @@ -2171,10 +2171,10 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { func TestConnectProtocolHeaderSentByDefault(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServerGenerics{}, connect.WithRequireConnectProtocolHeader())) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithRequireConnectProtocolHeader())) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) @@ -2223,7 +2223,7 @@ func TestAllowCustomUserAgent(t *testing.T) { const customAgent = "custom" mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ ping: func(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { agent := req.Header().Get("User-Agent") assert.Equal(t, agent, customAgent) @@ -2242,7 +2242,7 @@ func TestAllowCustomUserAgent(t *testing.T) { {"grpcweb", []connect.ClientOption{connect.WithGRPCWeb()}}, } for _, testCase := range tests { - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), testCase.opts...) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), testCase.opts...) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) req.Header().Set("User-Agent", customAgent) _, err := client.Ping(context.Background(), req) @@ -2254,7 +2254,7 @@ func TestWebXUserAgent(t *testing.T) { t.Parallel() mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ ping: func(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { agent := req.Header().Get("User-Agent") assert.NotZero(t, agent) @@ -2268,7 +2268,7 @@ func TestWebXUserAgent(t *testing.T) { })) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) _, err := client.Ping(context.Background(), req) assert.Nil(t, err) @@ -2283,7 +2283,7 @@ func TestBidiOverHTTP1(t *testing.T) { // Clients expecting a full-duplex connection that end up with a simplex // HTTP/1.1 connection shouldn't hang. Instead, the server should close the // TCP connection. - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( &http.Client{Transport: server.TransportHTTP1()}, server.URL(), ) @@ -2319,7 +2319,7 @@ func TestHandlerReturnsNilResponse(t *testing.T) { } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(&pluggablePingServer{ + mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ ping: func(ctx context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { return nil, nil //nolint: nilnil }, @@ -2328,7 +2328,7 @@ func TestHandlerReturnsNilResponse(t *testing.T) { }, }, connect.WithRecover(recoverPanic))) server := memhttptest.NewServer(t, mux) - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) @@ -2574,7 +2574,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { for _, testcase := range testcases { t.Run(testcase.name, func(t *testing.T) { t.Parallel() - client := pingv1connectgenerics.NewPingServiceClient( + client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), testcase.options..., @@ -2648,12 +2648,12 @@ func TestClientDisconnect(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) var clientConn net.Conn transport := captureTransport(server, &clientConn, gotRequest) serverClient := &http.Client{Transport: transport} - client := pingv1connectgenerics.NewPingServiceClient(serverClient, server.URL()) + client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) stream := client.Sum(context.Background()) // Send header. assert.Nil(t, stream.Send(nil)) @@ -2691,12 +2691,12 @@ func TestClientDisconnect(t *testing.T) { }, } mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) var clientConn net.Conn transport := captureTransport(server, &clientConn, gotRequest) serverClient := &http.Client{Transport: transport} - client := pingv1connectgenerics.NewPingServiceClient(serverClient, server.URL()) + client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) if !assert.Nil(t, err) { return @@ -2761,9 +2761,9 @@ func TestSetProtocolHeaders(t *testing.T) { testcase := tt t.Run(testcase.name, func(t *testing.T) { t.Parallel() - pingServer := &pingServerGenerics{} + pingServer := &pingServer{} mux := http.NewServeMux() - mux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingServer)) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) clientOpts := []connect.ClientOption{} @@ -2771,7 +2771,7 @@ func TestSetProtocolHeaders(t *testing.T) { // Use a different protocol to test the override. clientOpts = append(clientOpts, connect.WithGRPC()) } - client := pingv1connectgenerics.NewPingServiceClient(server.Client(), server.URL(), clientOpts...) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), clientOpts...) pingProxyServer := &pluggablePingServer{ ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { @@ -2779,14 +2779,14 @@ func TestSetProtocolHeaders(t *testing.T) { }, } proxyMux := http.NewServeMux() - proxyMux.Handle(pingv1connectgenerics.NewPingServiceHandler(pingProxyServer)) + proxyMux.Handle(pingv1connect.NewPingServiceHandler(pingProxyServer)) proxyServer := memhttptest.NewServer(t, proxyMux) proxyClientOpts := []connect.ClientOption{} if testcase.clientOption != nil { proxyClientOpts = append(proxyClientOpts, testcase.clientOption) } - proxyClient := pingv1connectgenerics.NewPingServiceClient(proxyServer.Client(), proxyServer.URL(), proxyClientOpts...) + proxyClient := pingv1connect.NewPingServiceClient(proxyServer.Client(), proxyServer.URL(), proxyClientOpts...) request := connect.NewRequest(&pingv1.PingRequest{Number: 42}) request.Header().Set("X-Test", t.Name()) @@ -2840,7 +2840,7 @@ func (c failCodec) Unmarshal(data []byte, message any) error { } type pluggablePingServer struct { - pingv1connectgenerics.UnimplementedPingServiceHandler + pingv1connect.UnimplementedPingServiceHandler ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) sum func(context.Context, *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) @@ -2914,14 +2914,14 @@ func expectMetadata(meta http.Header, metaType, key, value string) error { //nol return nil } -type pingServerGenerics struct { - pingv1connectgenerics.UnimplementedPingServiceHandler +type pingServer struct { + pingv1connect.UnimplementedPingServiceHandler checkMetadata bool includeErrorDetails bool } -func (p pingServerGenerics) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { +func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2942,7 +2942,7 @@ func (p pingServerGenerics) Ping(ctx context.Context, request *connect.Request[p return response, nil } -func (p pingServerGenerics) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { +func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2968,7 +2968,7 @@ func (p pingServerGenerics) Fail(ctx context.Context, request *connect.Request[p return nil, err } -func (p pingServerGenerics) Sum( +func (p pingServer) Sum( ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest], ) (*connect.Response[pingv1.SumResponse], error) { @@ -2996,7 +2996,7 @@ func (p pingServerGenerics) Sum( return response, nil } -func (p pingServerGenerics) CountUp( +func (p pingServer) CountUp( ctx context.Context, request *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse], @@ -3026,7 +3026,7 @@ func (p pingServerGenerics) CountUp( return nil } -func (p pingServerGenerics) CumSum( +func (p pingServer) CumSum( ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], ) error { @@ -3065,14 +3065,14 @@ func expectClientHeaderInCallInfo(check bool, callInfo connect.CallInfo) error { return expectMetadata(callInfo.RequestHeader(), "header", clientHeader, headerValue) } -type pingServer struct { - pingv1connect.UnimplementedPingServiceHandler +type pingServerSimple struct { + pingv1connectsimple.UnimplementedPingServiceHandler checkMetadata bool includeErrorDetails bool } -func (p pingServer) Ping(ctx context.Context, request *pingv1.PingRequest) (*pingv1.PingResponse, error) { +func (p pingServerSimple) Ping(ctx context.Context, request *pingv1.PingRequest) (*pingv1.PingResponse, error) { callInfo, ok := connect.CallInfoFromContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) @@ -3095,7 +3095,7 @@ func (p pingServer) Ping(ctx context.Context, request *pingv1.PingRequest) (*pin return response, nil } -func (p pingServer) CountUp( +func (p pingServerSimple) CountUp( ctx context.Context, request *pingv1.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], @@ -3129,7 +3129,7 @@ func (p pingServer) CountUp( return nil } -func (p pingServer) Fail(ctx context.Context, request *pingv1.FailRequest) (*pingv1.FailResponse, error) { +func (p pingServerSimple) Fail(ctx context.Context, request *pingv1.FailRequest) (*pingv1.FailResponse, error) { callInfo, ok := connect.CallInfoFromContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) diff --git a/example_init_test.go b/example_init_test.go index d14275b6..e79d95b8 100644 --- a/example_init_test.go +++ b/example_init_test.go @@ -32,6 +32,6 @@ func init() { // deadlock, see: // (https://github.com/golang/go/issues/48394) mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServerSimple{})) examplePingServer = memhttp.NewServer(mux) } diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index a630cff9..b5892fde 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -123,7 +123,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( - pingServerGenerics{}, + pingServer{}, handlerOnion, ), ) @@ -171,7 +171,7 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { return next(ctx, request) } }) - mux.Handle(pingv1connect.NewPingServiceHandler(pingServerGenerics{}, connect.WithInterceptors(interceptor))) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithInterceptors(interceptor))) server := memhttptest.NewServer(t, mux) connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(interceptor)) _, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) @@ -197,7 +197,7 @@ func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( - pingServerGenerics{}, + pingServer{}, connect.WithInterceptors(handlerChecker), ), ) From c4b878f1dc229825fe131499a829dd950923f48a Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 20:37:30 -0400 Subject: [PATCH 35/57] Feedback Signed-off-by: Steve Ayers --- client.go | 44 ++++++++++------- connect_ext_test.go | 21 +++++--- context.go | 78 +++++++++++++++++++++++++----- error_not_modified_example_test.go | 4 +- handler.go | 31 +++++------- 5 files changed, 122 insertions(+), 56 deletions(-) diff --git a/client.go b/client.go index a2d0240e..2f542557 100644 --- a/client.go +++ b/client.go @@ -19,7 +19,6 @@ import ( "errors" "fmt" "io" - "maps" "net/http" "net/url" "strings" @@ -128,23 +127,26 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) if c.err != nil { return nil, c.err } - ctx, ci := NewOutgoingContext(ctx) - call, ok := ci.(*callInfo) - if ok { - call.requestHeader = request.Header() - } + ctx, callInfo := newOutgoingContext(ctx) + callInfo.requestHeader = request.Header() resp, err := c.callUnary(ctx, request) if err != nil { return nil, err } - if ok { - call.peer = request.Peer() - call.spec = request.Spec() - call.method = request.HTTPMethod() - maps.Copy(call.ResponseHeader(), resp.Header()) - maps.Copy(call.ResponseTrailer(), resp.Trailer()) + callInfo.peer = request.Peer() + callInfo.spec = request.Spec() + callInfo.method = request.HTTPMethod() + if callInfo.responseHeader == nil { + callInfo.responseHeader = resp.Header() + } else { + mergeHeaders(callInfo.ResponseHeader(), resp.Header()) + } + if callInfo.responseTrailer == nil { + callInfo.responseTrailer = resp.Trailer() + } else { + mergeHeaders(callInfo.ResponseTrailer(), resp.Trailer()) } return resp, nil @@ -175,21 +177,21 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if c.err != nil { return nil, c.err } - ctx, ctxCallInfo := NewOutgoingContext(ctx) - // Note we don't need to check ok here because it should always be in context - // because of the above call to NewOutgoingContext - info, _ := ctxCallInfo.(*callInfo) conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method - info.method = r.Method }) request.spec = conn.Spec() request.peer = conn.Peer() mergeHeaders(conn.RequestHeader(), request.header) + ctx, ctxCallInfo := NewOutgoingContext(ctx) + // Note we don't need to check ok here because it should always be in context + // because of the above call to NewOutgoingContext + info, _ := ctxCallInfo.(*callInfo) info.peer = conn.Peer() info.spec = conn.Spec() mergeHeaders(info.RequestHeader(), request.header) + // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. @@ -198,15 +200,23 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques _ = conn.CloseResponse() return nil, err } + info.responseHeader = conn.ResponseHeader() + info.responseTrailer = conn.ResponseTrailer() if err := conn.CloseRequest(); err != nil { return nil, err } + return &ServerStreamForClient[Res]{ conn: conn, initializer: c.config.Initializer, }, nil } +// CallServerStreamSimple calls a server streaming procedure using the function signature +// associated with the "simple" generation option. +// +// This option eliminates the [Request] wrapper, and instead uses the context.Context to +// propagate information such as headers. func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { return c.CallServerStream(ctx, requestFromContext(ctx, requestMsg)) } diff --git a/connect_ext_test.go b/connect_ext_test.go index 6327d5de..dbb1e392 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -111,7 +111,7 @@ func TestCallInfo(t *testing.T) { assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + assert.Equal(t, callInfo.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) }) }) t.Run("generics_api", func(t *testing.T) { @@ -156,7 +156,8 @@ func TestCallInfo(t *testing.T) { Number: 1, }) req.Header().Set(clientHeader, headerValue) - stream, err := client.CountUp(context.Background(), req) + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + stream, err := client.CountUp(ctx, req) assert.Nil(t, err) assert.True(t, stream.Receive()) assert.Nil(t, stream.Err()) @@ -165,7 +166,10 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, msg.GetNumber(), 1) assert.Nil(t, stream.Close()) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + // assert.Equal(t, stream.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) + // assert.Equal(t, stream.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) + assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + // assert.Equal(t, callInfo.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) }) }) } @@ -3073,7 +3077,7 @@ type pingServerSimple struct { } func (p pingServerSimple) Ping(ctx context.Context, request *pingv1.PingRequest) (*pingv1.PingResponse, error) { - callInfo, ok := connect.CallInfoFromContext(ctx) + callInfo, ok := connect.CallInfoFromIncomingContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } @@ -3100,7 +3104,8 @@ func (p pingServerSimple) CountUp( request *pingv1.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { - callInfo, ok := connect.CallInfoFromContext(ctx) + fmt.Println("Count Up server") + callInfo, ok := connect.CallInfoFromIncomingContext(ctx) if !ok { return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } @@ -3119,8 +3124,8 @@ func (p pingServerSimple) CountUp( request.GetNumber(), )) } - stream.Conn().ResponseHeader().Set(handlerHeader, headerValue) - stream.Conn().ResponseTrailer().Set(handlerTrailer, trailerValue) + callInfo.ResponseHeader().Set(handlerHeader, headerValue) + callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) for i := range request.GetNumber() { if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { return err @@ -3130,7 +3135,7 @@ func (p pingServerSimple) CountUp( } func (p pingServerSimple) Fail(ctx context.Context, request *pingv1.FailRequest) (*pingv1.FailResponse, error) { - callInfo, ok := connect.CallInfoFromContext(ctx) + callInfo, ok := connect.CallInfoFromIncomingContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } diff --git a/context.go b/context.go index b99e7430..72ef37a2 100644 --- a/context.go +++ b/context.go @@ -20,10 +20,7 @@ import ( ) type CallInfo interface { - // Spec returns a description of this call. - Spec() Spec - // Peer describes the other party for this call. - Peer() Peer + StreamCallInfo // HTTPMethod returns the HTTP method for this request. This is nearly always // POST, but side-effect-free unary RPCs could be made via a GET. // @@ -35,6 +32,13 @@ type CallInfo interface { // if the request was never actually sent to the server (and thus no // determination ever made about the HTTP method). HTTPMethod() string +} + +type StreamCallInfo interface { + // Spec returns a description of this call. + Spec() Spec + // Peer describes the other party for this call. + Peer() Peer // RequestHeader returns the HTTP headers for this request. Headers beginning with // "Connect-" and "Grpc-" are reserved for use by the Connect and gRPC // protocols: applications may read them but shouldn't write them. @@ -100,7 +104,40 @@ func (c *callInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *callInfo) internalOnly() {} -type callInfoContextKey struct{} +type streamCallInfo struct { + conn StreamingHandlerConn +} + +func (c *streamCallInfo) Spec() Spec { + return c.conn.Spec() +} + +func (c *streamCallInfo) Peer() Peer { + return c.conn.Peer() +} + +func (c *streamCallInfo) RequestHeader() http.Header { + return c.conn.RequestHeader() +} + +func (c *streamCallInfo) ResponseHeader() http.Header { + return c.conn.ResponseHeader() +} + +func (c *streamCallInfo) ResponseTrailer() http.Header { + return c.conn.ResponseHeader() +} + +func (c *streamCallInfo) HTTPMethod() string { + // All stream calls are POSTs + return http.MethodPost +} + +// internalOnly implements CallInfo. +func (c *streamCallInfo) internalOnly() {} + +type outgoingCallInfoContextKey struct{} +type incomingCallInfoContextKey struct{} // Create a new request context for use from a client. When the returned // context is passed to RPCs, the returned call info can be used to set @@ -113,23 +150,42 @@ type callInfoContextKey struct{} // If the given context is already associated with an outgoing CallInfo, then // ctx and the existing CallInfo are returned. func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { - info, ok := ctx.Value(callInfoContextKey{}).(CallInfo) + info, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) + if !ok { + info = &callInfo{} + return context.WithValue(ctx, outgoingCallInfoContextKey{}, info), info + } + return ctx, info +} + +func newOutgoingContext(ctx context.Context) (context.Context, *callInfo) { + info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*callInfo) if !ok { info = &callInfo{} - return context.WithValue(ctx, callInfoContextKey{}, info), info + return context.WithValue(ctx, outgoingCallInfoContextKey{}, info), info } return ctx, info } -// CallInfoFromContext returns the CallInfo for the given context, if there is one. -func CallInfoFromContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(callInfoContextKey{}).(CallInfo) +func newIncomingContext(ctx context.Context, info CallInfo) context.Context { + return context.WithValue(ctx, incomingCallInfoContextKey{}, info) +} + +// CallInfoFromOutgoingContext returns the CallInfo for the given context, if there is one. +func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) + return value, ok +} + +// CallInfoFromIncomingContext returns the CallInfo for the given context, if there is one. +func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) return value, ok } func requestFromContext[T any](ctx context.Context, message *T) *Request[T] { request := NewRequest(message) - callInfo, ok := CallInfoFromContext(ctx) + callInfo, ok := CallInfoFromOutgoingContext(ctx) if ok { request.setHeader(callInfo.RequestHeader()) } diff --git a/error_not_modified_example_test.go b/error_not_modified_example_test.go index fc9c9925..3daf8223 100644 --- a/error_not_modified_example_test.go +++ b/error_not_modified_example_test.go @@ -42,9 +42,9 @@ func (*ExampleCachingPingServer) Ping( resp := &pingv1.PingResponse{ Number: req.GetNumber(), } - callInfo, ok := connect.CallInfoFromContext(ctx) + callInfo, ok := connect.CallInfoFromIncomingContext(ctx) if !ok { - return nil, errors.New("not call info found in context") + return nil, errors.New("no call info found in context") } // Our hashing logic is simple: we use the number in the PingResponse. diff --git a/handler.go b/handler.go index 3612600e..dc4ac3d6 100644 --- a/handler.go +++ b/handler.go @@ -16,7 +16,6 @@ package connect import ( "context" - "maps" "net/http" ) @@ -69,14 +68,13 @@ func NewUnaryHandler[Req, Res any]( } // Add the request header to the context, and store the response header // and trailer to propagate back to the caller. - ctx, ci := NewOutgoingContext(ctx) - call, ok := ci.(*callInfo) - if ok { - call.peer = request.Peer() - call.spec = request.Spec() - call.method = request.HTTPMethod() - call.requestHeader = request.Header() + info := &callInfo{ + peer: request.Peer(), + spec: request.Spec(), + method: request.HTTPMethod(), + requestHeader: request.Header(), } + ctx = newIncomingContext(ctx, info) response, err := untyped(ctx, request) if err != nil { return err @@ -87,8 +85,8 @@ func NewUnaryHandler[Req, Res any]( mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) // Add response headers/trailers into the context callinfo also - mergeNonProtocolHeaders(call.ResponseHeader(), response.Header()) - mergeNonProtocolHeaders(call.ResponseTrailer(), response.Trailer()) + mergeNonProtocolHeaders(info.ResponseHeader(), response.Header()) + mergeNonProtocolHeaders(info.ResponseTrailer(), response.Trailer()) return conn.Send(response.Any()) } @@ -121,7 +119,7 @@ func NewUnaryHandlerSimple[Req, Res any]( return nil, err } response := NewResponse(responseMsg) - callInfo, ok := CallInfoFromContext(ctx) + callInfo, ok := CallInfoFromIncomingContext(ctx) if ok { response.setHeader(callInfo.ResponseHeader()) response.setTrailer(callInfo.ResponseTrailer()) @@ -176,13 +174,10 @@ func NewServerStreamHandler[Req, Res any]( if err != nil { return err } - ctx, ci := NewOutgoingContext(ctx) - callInfo, _ := ci.(*callInfo) - callInfo.peer = req.Peer() - callInfo.spec = req.Spec() - callInfo.method = req.HTTPMethod() - maps.Copy(callInfo.RequestHeader(), req.Header()) - + info := &streamCallInfo{ + conn: conn, + } + ctx = newIncomingContext(ctx, info) return implementation(ctx, req, &ServerStream[Res]{conn: conn}) }, ) From 6c5253b88bed023c0a2178b6075fc8ffddf7e249 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 20:56:39 -0400 Subject: [PATCH 36/57] Feedback Signed-off-by: Steve Ayers --- client.go | 27 +++++++++++++-------------- connect_ext_test.go | 14 -------------- context.go | 6 +++--- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/client.go b/client.go index 2f542557..bdc471e7 100644 --- a/client.go +++ b/client.go @@ -154,7 +154,7 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) // CallUnary calls a request-response procedure. func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, request *Req) (*Res, error) { - response, err := c.CallUnary(ctx, requestFromContext(ctx, request)) + response, err := c.CallUnary(ctx, requestFromOutgoingContext(ctx, request)) if response != nil { return response.Msg, err } @@ -180,17 +180,16 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method }) - request.spec = conn.Spec() + _, callInfo := newOutgoingContext(ctx) + callInfo.peer = conn.Peer() + callInfo.spec = conn.Spec() request.peer = conn.Peer() - mergeHeaders(conn.RequestHeader(), request.header) + request.spec = conn.Spec() - ctx, ctxCallInfo := NewOutgoingContext(ctx) - // Note we don't need to check ok here because it should always be in context - // because of the above call to NewOutgoingContext - info, _ := ctxCallInfo.(*callInfo) - info.peer = conn.Peer() - info.spec = conn.Spec() - mergeHeaders(info.RequestHeader(), request.header) + // Merge any callInfo request headers first, then do the request. + // so that context headers show first in the list of headers + mergeHeaders(conn.RequestHeader(), callInfo.RequestHeader()) + mergeHeaders(conn.RequestHeader(), request.header) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the @@ -200,12 +199,12 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques _ = conn.CloseResponse() return nil, err } - info.responseHeader = conn.ResponseHeader() - info.responseTrailer = conn.ResponseTrailer() + callInfo.responseHeader = conn.ResponseHeader() + callInfo.responseTrailer = conn.ResponseTrailer() + if err := conn.CloseRequest(); err != nil { return nil, err } - return &ServerStreamForClient[Res]{ conn: conn, initializer: c.config.Initializer, @@ -218,7 +217,7 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques // This option eliminates the [Request] wrapper, and instead uses the context.Context to // propagate information such as headers. func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { - return c.CallServerStream(ctx, requestFromContext(ctx, requestMsg)) + return c.CallServerStream(ctx, requestFromOutgoingContext(ctx, requestMsg)) } // CallBidiStream calls a bidirectional streaming procedure. diff --git a/connect_ext_test.go b/connect_ext_test.go index dbb1e392..f38dbe3a 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -82,8 +82,6 @@ func TestCallInfo(t *testing.T) { response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) assert.Equal(t, response, expect) assert.Nil(t, err) - - // Assert call info values are correctly populated assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) assert.True(t, callInfo.Spec().IsClient) @@ -104,8 +102,6 @@ func TestCallInfo(t *testing.T) { assert.NotNil(t, msg) assert.Equal(t, msg.GetNumber(), 1) assert.Nil(t, stream.Close()) - - // Assert call info values are correctly populated assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeServer) assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) assert.True(t, callInfo.Spec().IsClient) @@ -132,24 +128,18 @@ func TestCallInfo(t *testing.T) { response, err := client.Ping(ctx, request) assert.Nil(t, err) assert.Equal(t, response.Msg, expect) - assert.Equal(t, request.Spec().StreamType, connect.StreamTypeUnary) assert.Equal(t, request.Spec().Procedure, pingv1connect.PingServicePingProcedure) assert.True(t, request.Spec().IsClient) assert.Equal(t, request.Peer().Addr, httptest.DefaultRemoteAddr) - - // Verify that spec and peer on the callInfo are the same as the request wrapper assert.Equal(t, callInfo.Spec().StreamType, request.Spec().StreamType) assert.Equal(t, callInfo.Spec().Procedure, request.Spec().Procedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, request.Peer().Addr) - - // Verify that the response headers and trailers are the same on callInfo and the response assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) - }) t.Run("server_stream", func(t *testing.T) { req := connect.NewRequest(&pingv1.CountUpRequest{ @@ -166,10 +156,7 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, msg.GetNumber(), 1) assert.Nil(t, stream.Close()) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) - // assert.Equal(t, stream.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) - // assert.Equal(t, stream.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - // assert.Equal(t, callInfo.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) }) }) } @@ -3104,7 +3091,6 @@ func (p pingServerSimple) CountUp( request *pingv1.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { - fmt.Println("Count Up server") callInfo, ok := connect.CallInfoFromIncomingContext(ctx) if !ok { return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) diff --git a/context.go b/context.go index 72ef37a2..ff5895f5 100644 --- a/context.go +++ b/context.go @@ -171,19 +171,19 @@ func newIncomingContext(ctx context.Context, info CallInfo) context.Context { return context.WithValue(ctx, incomingCallInfoContextKey{}, info) } -// CallInfoFromOutgoingContext returns the CallInfo for the given context, if there is one. +// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) return value, ok } -// CallInfoFromIncomingContext returns the CallInfo for the given context, if there is one. +// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) return value, ok } -func requestFromContext[T any](ctx context.Context, message *T) *Request[T] { +func requestFromOutgoingContext[T any](ctx context.Context, message *T) *Request[T] { request := NewRequest(message) callInfo, ok := CallInfoFromOutgoingContext(ctx) if ok { From 3f509d8c91261c6f68e78a9ce198c12022cdc6b6 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 21:20:07 -0400 Subject: [PATCH 37/57] Feedback Signed-off-by: Steve Ayers --- connect.go | 10 ---------- connect_ext_test.go | 4 +++- context.go | 3 ++- handler.go | 21 +++++++-------------- 4 files changed, 12 insertions(+), 26 deletions(-) diff --git a/connect.go b/connect.go index 596bc7fd..caaf838b 100644 --- a/connect.go +++ b/connect.go @@ -287,16 +287,6 @@ func (r *Response[_]) Trailer() http.Header { return r.trailer } -// setHeader sets the response header. -func (r *Response[_]) setHeader(header http.Header) { - r.header = header -} - -// setTrailer sets the response trailer. -func (r *Response[_]) setTrailer(trailer http.Header) { - r.trailer = trailer -} - // internalOnly implements AnyResponse. func (r *Response[_]) internalOnly() {} diff --git a/connect_ext_test.go b/connect_ext_test.go index f38dbe3a..07190db2 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -107,7 +107,7 @@ func TestCallInfo(t *testing.T) { assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, callInfo.ResponseHeader().Values(handlerTrailer), []string{trailerValue}) + // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) t.Run("generics_api", func(t *testing.T) { @@ -157,6 +157,8 @@ func TestCallInfo(t *testing.T) { assert.Nil(t, stream.Close()) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + // assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) } diff --git a/context.go b/context.go index ff5895f5..38087a07 100644 --- a/context.go +++ b/context.go @@ -104,6 +104,7 @@ func (c *callInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *callInfo) internalOnly() {} +// streamCallInfo is a CallInfo implementation used for streaming RPCs. type streamCallInfo struct { conn StreamingHandlerConn } @@ -125,7 +126,7 @@ func (c *streamCallInfo) ResponseHeader() http.Header { } func (c *streamCallInfo) ResponseTrailer() http.Header { - return c.conn.ResponseHeader() + return c.conn.ResponseTrailer() } func (c *streamCallInfo) HTTPMethod() string { diff --git a/handler.go b/handler.go index dc4ac3d6..5a7884a0 100644 --- a/handler.go +++ b/handler.go @@ -80,14 +80,14 @@ func NewUnaryHandler[Req, Res any]( return err } + // Add response headers/trailers into the context callinfo + mergeNonProtocolHeaders(conn.ResponseHeader(), info.ResponseHeader()) + mergeNonProtocolHeaders(conn.ResponseTrailer(), info.ResponseTrailer()) + // Add response headers/trailers into the conn mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) - // Add response headers/trailers into the context callinfo also - mergeNonProtocolHeaders(info.ResponseHeader(), response.Header()) - mergeNonProtocolHeaders(info.ResponseTrailer(), response.Trailer()) - return conn.Send(response.Any()) } @@ -118,13 +118,7 @@ func NewUnaryHandlerSimple[Req, Res any]( if err != nil { return nil, err } - response := NewResponse(responseMsg) - callInfo, ok := CallInfoFromIncomingContext(ctx) - if ok { - response.setHeader(callInfo.ResponseHeader()) - response.setTrailer(callInfo.ResponseTrailer()) - } - return response, err + return NewResponse(responseMsg), nil }, options..., ) @@ -174,10 +168,9 @@ func NewServerStreamHandler[Req, Res any]( if err != nil { return err } - info := &streamCallInfo{ + ctx = newIncomingContext(ctx, &streamCallInfo{ conn: conn, - } - ctx = newIncomingContext(ctx, info) + }) return implementation(ctx, req, &ServerStream[Res]{conn: conn}) }, ) From c78eb945c9736a481604fb51c41c5c93e23d539e Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 30 Jun 2025 21:23:16 -0400 Subject: [PATCH 38/57] Cleanup Signed-off-by: Steve Ayers --- client.go | 6 +++++- connect_ext_test.go | 3 --- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index bdc471e7..5b0a1e84 100644 --- a/client.go +++ b/client.go @@ -152,7 +152,11 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) return resp, nil } -// CallUnary calls a request-response procedure. +// CallUnarySimple calls a request-response procedure using the function signature +// associated with the "simple" generation option. +// +// This option eliminates the [Request] and [Response] wrappers, and instead uses the +// context.Context to propagate information such as headers. func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, request *Req) (*Res, error) { response, err := c.CallUnary(ctx, requestFromOutgoingContext(ctx, request)) if response != nil { diff --git a/connect_ext_test.go b/connect_ext_test.go index 07190db2..918a3c33 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -107,7 +107,6 @@ func TestCallInfo(t *testing.T) { assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) t.Run("generics_api", func(t *testing.T) { @@ -157,8 +156,6 @@ func TestCallInfo(t *testing.T) { assert.Nil(t, stream.Close()) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - // assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) - // assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) } From 279b452cf4da5cbb9af1cd55d66c2691e60eb236 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 11:20:12 -0400 Subject: [PATCH 39/57] Update context.go Co-authored-by: Joshua Humphries <2035234+jhump@users.noreply.github.com> Signed-off-by: Steve Ayers Signed-off-by: Steve Ayers --- context.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/context.go b/context.go index 38087a07..4e0ecf1c 100644 --- a/context.go +++ b/context.go @@ -151,12 +151,7 @@ type incomingCallInfoContextKey struct{} // If the given context is already associated with an outgoing CallInfo, then // ctx and the existing CallInfo are returned. func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { - info, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) - if !ok { - info = &callInfo{} - return context.WithValue(ctx, outgoingCallInfoContextKey{}, info), info - } - return ctx, info + return newOutgoingContext(ctx) } func newOutgoingContext(ctx context.Context) (context.Context, *callInfo) { From 7baf7be12abe938ba9c37e1dcdef191d1f20f7bd Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 11:20:44 -0400 Subject: [PATCH 40/57] Feedback Signed-off-by: Steve Ayers --- context.go | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/context.go b/context.go index 4e0ecf1c..71dfdf40 100644 --- a/context.go +++ b/context.go @@ -20,21 +20,6 @@ import ( ) type CallInfo interface { - StreamCallInfo - // HTTPMethod returns the HTTP method for this request. This is nearly always - // POST, but side-effect-free unary RPCs could be made via a GET. - // - // On a newly created request, via NewRequest, this will return the empty - // string until the actual request is actually sent and the HTTP method - // determined. This means that client interceptor functions will see the - // empty string until *after* they delegate to the handler they wrapped. It - // is even possible for this to return the empty string after such delegation, - // if the request was never actually sent to the server (and thus no - // determination ever made about the HTTP method). - HTTPMethod() string -} - -type StreamCallInfo interface { // Spec returns a description of this call. Spec() Spec // Peer describes the other party for this call. @@ -57,6 +42,17 @@ type StreamCallInfo interface { ResponseTrailer() http.Header internalOnly() + // HTTPMethod returns the HTTP method for this request. This is nearly always + // POST, but side-effect-free unary RPCs could be made via a GET. + // + // On a newly created request, via NewRequest, this will return the empty + // string until the actual request is actually sent and the HTTP method + // determined. This means that client interceptor functions will see the + // empty string until *after* they delegate to the handler they wrapped. It + // is even possible for this to return the empty string after such delegation, + // if the request was never actually sent to the server (and thus no + // determination ever made about the HTTP method). + HTTPMethod() string } type callInfo struct { From 8d92baef8dd4f627f6678c8ababd236f0a51f5d8 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 13:53:15 -0400 Subject: [PATCH 41/57] Interceptors Signed-off-by: Steve Ayers --- client.go | 40 ++++++++++++------- context.go | 88 +++++++++++++++++++++++++++++++++++------ handler.go | 14 ++++--- interceptor_ext_test.go | 38 ++++++++++++++---- 4 files changed, 140 insertions(+), 40 deletions(-) diff --git a/client.go b/client.go index 5b0a1e84..51065c35 100644 --- a/client.go +++ b/client.go @@ -76,6 +76,8 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien // once at client creation. unarySpec := config.newSpec(StreamTypeUnary) unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { + ctx, callInfo := newOutgoingContext(ctx) + fmt.Printf("unary func call info: %+v\n\n", callInfo) conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) conn.onRequestSend(func(r *http.Request) { request.setRequestMethod(r.Method) @@ -109,6 +111,14 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien request.spec = unarySpec request.peer = client.protocolClient.Peer() protocolClient.WriteRequestHeader(StreamTypeUnary, request.Header()) + + // Also set them in the context so interceptors can inspect context for this information. + ctx, callInfo := newOutgoingContext(ctx) + callInfo.peer = request.Peer() + callInfo.spec = request.Spec() + + fmt.Printf("call unary call info: %+v\n\n", callInfo) + response, err := unaryFunc(ctx, request) if err != nil { return nil, err @@ -122,6 +132,18 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien return client } +type wrapper[Res any] struct { + response *Response[Res] +} + +func (w *wrapper[Res]) ResponseHeader() http.Header { + return w.response.Header() +} + +func (w *wrapper[Res]) ResponseTrailer() http.Header { + return w.response.Trailer() +} + // CallUnary calls a request-response procedure. func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) (*Response[Res], error) { if c.err != nil { @@ -135,18 +157,9 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) return nil, err } - callInfo.peer = request.Peer() - callInfo.spec = request.Spec() callInfo.method = request.HTTPMethod() - if callInfo.responseHeader == nil { - callInfo.responseHeader = resp.Header() - } else { - mergeHeaders(callInfo.ResponseHeader(), resp.Header()) - } - if callInfo.responseTrailer == nil { - callInfo.responseTrailer = resp.Trailer() - } else { - mergeHeaders(callInfo.ResponseTrailer(), resp.Trailer()) + callInfo.responseSource = &wrapper[Res]{ + response: resp, } return resp, nil @@ -187,6 +200,8 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques _, callInfo := newOutgoingContext(ctx) callInfo.peer = conn.Peer() callInfo.spec = conn.Spec() + callInfo.responseSource = conn + request.peer = conn.Peer() request.spec = conn.Spec() @@ -203,9 +218,6 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques _ = conn.CloseResponse() return nil, err } - callInfo.responseHeader = conn.ResponseHeader() - callInfo.responseTrailer = conn.ResponseTrailer() - if err := conn.CloseRequest(); err != nil { return nil, err } diff --git a/context.go b/context.go index 71dfdf40..0db47eae 100644 --- a/context.go +++ b/context.go @@ -19,6 +19,10 @@ import ( "net/http" ) +// CallInfo represents information relevant to an RPC call. +// Values returned by these methods are not thread-safe. Users should expect +// data races if they create an outgoing CallInfo in context and then pass that +// CallInfo to another goroutine and try to call methods on it concurrent with the RPC. type CallInfo interface { // Spec returns a description of this call. Spec() Spec @@ -31,6 +35,9 @@ type CallInfo interface { // ResponseHeader returns the HTTP headers for this response. Headers beginning with // "Connect-" and "Grpc-" are reserved for use by the Connect and gRPC // protocols: applications may read them but shouldn't write them. + // On the client side, this method returns nil before + // the call is actually made. After the call is made, for streaming operations, + // this method will block for the server to actually return response headers. ResponseHeader() http.Header // ResponseTrailer returns the trailers for this response. Depending on the underlying // RPC protocol, trailers may be sent as HTTP trailers or a protocol-specific @@ -39,9 +46,11 @@ type CallInfo interface { // Trailers beginning with "Connect-" and "Grpc-" are reserved for use by the // Connect and gRPC protocols: applications may read them but shouldn't write // them. + // + // On the client side, this method returns nil before + // the call is actually made. After the call is made, for streaming operations, + // this method will block for the server to actually return response trailers. ResponseTrailer() http.Header - - internalOnly() // HTTPMethod returns the HTTP method for this request. This is nearly always // POST, but side-effect-free unary RPCs could be made via a GET. // @@ -53,9 +62,12 @@ type CallInfo interface { // if the request was never actually sent to the server (and thus no // determination ever made about the HTTP method). HTTPMethod() string + + internalOnly() } -type callInfo struct { +// handlerCallInfo is a CallInfo implementation used for handlers. +type handlerCallInfo struct { spec Spec peer Peer method string @@ -64,41 +76,41 @@ type callInfo struct { responseTrailer http.Header } -func (c *callInfo) Spec() Spec { +func (c *handlerCallInfo) Spec() Spec { return c.spec } -func (c *callInfo) Peer() Peer { +func (c *handlerCallInfo) Peer() Peer { return c.peer } -func (c *callInfo) RequestHeader() http.Header { +func (c *handlerCallInfo) RequestHeader() http.Header { if c.requestHeader == nil { c.requestHeader = make(http.Header) } return c.requestHeader } -func (c *callInfo) ResponseHeader() http.Header { +func (c *handlerCallInfo) ResponseHeader() http.Header { if c.responseHeader == nil { c.responseHeader = make(http.Header) } return c.responseHeader } -func (c *callInfo) ResponseTrailer() http.Header { +func (c *handlerCallInfo) ResponseTrailer() http.Header { if c.responseTrailer == nil { c.responseTrailer = make(http.Header) } return c.responseTrailer } -func (c *callInfo) HTTPMethod() string { +func (c *handlerCallInfo) HTTPMethod() string { return c.method } // internalOnly implements CallInfo. -func (c *callInfo) internalOnly() {} +func (c *handlerCallInfo) internalOnly() {} // streamCallInfo is a CallInfo implementation used for streaming RPCs. type streamCallInfo struct { @@ -133,6 +145,56 @@ func (c *streamCallInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *streamCallInfo) internalOnly() {} +type responseSource interface { + ResponseHeader() http.Header + ResponseTrailer() http.Header +} + +// clientCallInfo is a CallInfo implementation used for clients. +type clientCallInfo struct { + responseSource + spec Spec + peer Peer + method string + requestHeader http.Header +} + +func (c *clientCallInfo) Spec() Spec { + return c.spec +} + +func (c *clientCallInfo) Peer() Peer { + return c.peer +} + +func (c *clientCallInfo) RequestHeader() http.Header { + if c.requestHeader == nil { + c.requestHeader = make(http.Header) + } + return c.requestHeader +} + +func (c *clientCallInfo) ResponseHeader() http.Header { + if c.responseSource == nil { + return nil + } + return c.responseSource.ResponseHeader() +} + +func (c *clientCallInfo) ResponseTrailer() http.Header { + if c.responseSource == nil { + return nil + } + return c.responseSource.ResponseTrailer() +} + +func (c *clientCallInfo) HTTPMethod() string { + return c.method +} + +// internalOnly implements CallInfo. +func (c *clientCallInfo) internalOnly() {} + type outgoingCallInfoContextKey struct{} type incomingCallInfoContextKey struct{} @@ -150,10 +212,10 @@ func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { return newOutgoingContext(ctx) } -func newOutgoingContext(ctx context.Context) (context.Context, *callInfo) { - info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*callInfo) +func newOutgoingContext(ctx context.Context) (context.Context, *clientCallInfo) { + info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*clientCallInfo) if !ok { - info = &callInfo{} + info = &clientCallInfo{} return context.WithValue(ctx, outgoingCallInfoContextKey{}, info), info } return ctx, info diff --git a/handler.go b/handler.go index 5a7884a0..fe11bc45 100644 --- a/handler.go +++ b/handler.go @@ -68,7 +68,7 @@ func NewUnaryHandler[Req, Res any]( } // Add the request header to the context, and store the response header // and trailer to propagate back to the caller. - info := &callInfo{ + info := &handlerCallInfo{ peer: request.Peer(), spec: request.Spec(), method: request.HTTPMethod(), @@ -80,11 +80,15 @@ func NewUnaryHandler[Req, Res any]( return err } - // Add response headers/trailers into the context callinfo - mergeNonProtocolHeaders(conn.ResponseHeader(), info.ResponseHeader()) - mergeNonProtocolHeaders(conn.ResponseTrailer(), info.ResponseTrailer()) + // Add response headers/trailers from the context callinfo into the conn if they exist + if info.responseHeader != nil { + mergeNonProtocolHeaders(conn.ResponseHeader(), info.responseHeader) + } + if info.responseTrailer != nil { + mergeNonProtocolHeaders(conn.ResponseTrailer(), info.responseTrailer) + } - // Add response headers/trailers into the conn + // Add response headers/trailers from the response into the conn mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index b5892fde..46e6f173 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -191,8 +191,8 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { t.Parallel() - clientChecker := &httpMethodChecker{client: true} - handlerChecker := &httpMethodChecker{} + clientChecker := &callInfoChecker{client: true} + handlerChecker := &callInfoChecker{} mux := http.NewServeMux() mux.Handle( @@ -344,25 +344,39 @@ func (cc *headerInspectingClientConn) Receive(msg any) error { return err } -type httpMethodChecker struct { +type callInfoChecker struct { client bool count atomic.Int32 } -func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { +func (h *callInfoChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { h.count.Add(1) if h.client { + outgoingCallInfo, ok := connect.CallInfoFromOutgoingContext(ctx) + if !ok { + return nil, fmt.Errorf("no call info found in outgoing context") + } // should be blank until after we make request + if outgoingCallInfo.HTTPMethod() != "" { + return nil, fmt.Errorf("expected blank HTTP method in outgoing context but instead got %q", outgoingCallInfo.HTTPMethod()) + } if req.HTTPMethod() != "" { - return nil, fmt.Errorf("expected blank HTTP method but instead got %q", req.HTTPMethod()) + return nil, fmt.Errorf("expected blank HTTP method in request but instead got %q", req.HTTPMethod()) } } else { + incomingCallInfo, ok := connect.CallInfoFromIncomingContext(ctx) + if !ok { + return nil, fmt.Errorf("no call info found in incoming context") + } // server interceptors see method from the start // NB: In theory, the method could also be GET, not just POST. But for the // configuration under test, it will always be POST. + if incomingCallInfo.HTTPMethod() != http.MethodPost { + return nil, fmt.Errorf("expected HTTP method %s in incoming context but instead got %q", http.MethodPost, incomingCallInfo.HTTPMethod()) + } if req.HTTPMethod() != http.MethodPost { - return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) + return nil, fmt.Errorf("expected HTTP method %s in request but instead got %q", http.MethodPost, req.HTTPMethod()) } } resp, err := unaryFunc(ctx, req) @@ -371,11 +385,19 @@ func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.Unary if req.HTTPMethod() != http.MethodPost { return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) } + // Method should now be set on the outgoing context + // callInfo, ok := connect.CallInfoFromOutgoingContext(ctx) + // if !ok { + // return nil, fmt.Errorf("no call info found in outgoing context after request") + // } + // if callInfo.HTTPMethod() != http.MethodPost { + // return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, callInfo.HTTPMethod()) + // } return resp, err } } -func (h *httpMethodChecker) WrapStreamingClient(clientFunc connect.StreamingClientFunc) connect.StreamingClientFunc { +func (h *callInfoChecker) WrapStreamingClient(clientFunc connect.StreamingClientFunc) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { // method not exposed to streaming interceptor, but that's okay because it's always POST for streams h.count.Add(1) @@ -383,7 +405,7 @@ func (h *httpMethodChecker) WrapStreamingClient(clientFunc connect.StreamingClie } } -func (h *httpMethodChecker) WrapStreamingHandler(handlerFunc connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (h *callInfoChecker) WrapStreamingHandler(handlerFunc connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { // method not exposed to streaming interceptor, but that's okay because it's always POST for streams h.count.Add(1) From 329c9e547d1bcd9288b54a0e34cbfbedeb480732 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 14:43:44 -0400 Subject: [PATCH 42/57] Interceptor tests Signed-off-by: Steve Ayers --- client.go | 12 ++--- interceptor_ext_test.go | 104 +++++++++++++++++++++++++++------------- 2 files changed, 75 insertions(+), 41 deletions(-) diff --git a/client.go b/client.go index 51065c35..1b8bd665 100644 --- a/client.go +++ b/client.go @@ -77,10 +77,10 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien unarySpec := config.newSpec(StreamTypeUnary) unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { ctx, callInfo := newOutgoingContext(ctx) - fmt.Printf("unary func call info: %+v\n\n", callInfo) conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) conn.onRequestSend(func(r *http.Request) { request.setRequestMethod(r.Method) + callInfo.method = r.Method }) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the @@ -117,8 +117,6 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien callInfo.peer = request.Peer() callInfo.spec = request.Spec() - fmt.Printf("call unary call info: %+v\n\n", callInfo) - response, err := unaryFunc(ctx, request) if err != nil { return nil, err @@ -127,6 +125,9 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien if !ok { return nil, errorf(CodeInternal, "unexpected client response type %T", response) } + callInfo.responseSource = &wrapper[Res]{ + response: typed, + } return typed, nil } return client @@ -157,11 +158,6 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) return nil, err } - callInfo.method = request.HTTPMethod() - callInfo.responseSource = &wrapper[Res]{ - response: resp, - } - return resp, nil } diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 46e6f173..9201dcc6 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -16,6 +16,7 @@ package connect_test import ( "context" + "errors" "fmt" "net/http" "sync/atomic" @@ -349,50 +350,87 @@ type callInfoChecker struct { count atomic.Int32 } +func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connect.AnyRequest, prerequest bool) error { + // method should be blank until after we make request + if prerequest { //nolint:nestif + if callInfo.HTTPMethod() != "" { + return fmt.Errorf("expected blank HTTP method in outgoing context but instead got %q", callInfo.HTTPMethod()) + } + if req.HTTPMethod() != "" { + return fmt.Errorf("expected blank HTTP method in request but instead got %q", req.HTTPMethod()) + } + } else { + // server interceptors see method from the start + // NB: In theory, the method could also be GET, not just POST. But for the + // configuration under test, it will always be POST. + if callInfo.HTTPMethod() != http.MethodPost { + return fmt.Errorf("expected HTTP method %s in outgoing context but instead got %q", http.MethodPost, callInfo.HTTPMethod()) + } + if req.HTTPMethod() != http.MethodPost { + return fmt.Errorf("expected HTTP method %s in request but instead got %q", http.MethodPost, req.HTTPMethod()) + } + } + if callInfo.Peer().Addr == "" { + return errors.New("no peer set on call info") + } + if callInfo.Spec().Procedure != pingv1connect.PingServicePingProcedure { + return fmt.Errorf("expected spec procedure %s but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) + } + return nil +} + +func (h *callInfoChecker) getCallInfo(ctx context.Context) (connect.CallInfo, error) { + var callInfo connect.CallInfo + if h.client { + info, ok := connect.CallInfoFromOutgoingContext(ctx) + if !ok { + return nil, errors.New("no call info found in outgoing context") + } + callInfo = info + } else { + info, ok := connect.CallInfoFromIncomingContext(ctx) + if !ok { + return nil, errors.New("no call info found in incoming context") + } + callInfo = info + } + return callInfo, nil +} + func (h *callInfoChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { h.count.Add(1) + + callInfo, err := h.getCallInfo(ctx) + if err != nil { + return nil, err + } + if h.client { - outgoingCallInfo, ok := connect.CallInfoFromOutgoingContext(ctx) - if !ok { - return nil, fmt.Errorf("no call info found in outgoing context") - } - // should be blank until after we make request - if outgoingCallInfo.HTTPMethod() != "" { - return nil, fmt.Errorf("expected blank HTTP method in outgoing context but instead got %q", outgoingCallInfo.HTTPMethod()) - } - if req.HTTPMethod() != "" { - return nil, fmt.Errorf("expected blank HTTP method in request but instead got %q", req.HTTPMethod()) + if err := h.validateCallInfo(callInfo, req, true); err != nil { + return nil, err } } else { - incomingCallInfo, ok := connect.CallInfoFromIncomingContext(ctx) - if !ok { - return nil, fmt.Errorf("no call info found in incoming context") - } - // server interceptors see method from the start - // NB: In theory, the method could also be GET, not just POST. But for the - // configuration under test, it will always be POST. - if incomingCallInfo.HTTPMethod() != http.MethodPost { - return nil, fmt.Errorf("expected HTTP method %s in incoming context but instead got %q", http.MethodPost, incomingCallInfo.HTTPMethod()) - } - if req.HTTPMethod() != http.MethodPost { - return nil, fmt.Errorf("expected HTTP method %s in request but instead got %q", http.MethodPost, req.HTTPMethod()) + if err := h.validateCallInfo(callInfo, req, false); err != nil { + return nil, err } } + resp, err := unaryFunc(ctx, req) - // NB: In theory, the method could also be GET, not just POST. But for the - // configuration under test, it will always be POST. - if req.HTTPMethod() != http.MethodPost { - return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) + if err != nil { + return nil, err } + // Method should now be set on the outgoing context - // callInfo, ok := connect.CallInfoFromOutgoingContext(ctx) - // if !ok { - // return nil, fmt.Errorf("no call info found in outgoing context after request") - // } - // if callInfo.HTTPMethod() != http.MethodPost { - // return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, callInfo.HTTPMethod()) - // } + callInfo, err = h.getCallInfo(ctx) + if err != nil { + return nil, err + } + + if err := h.validateCallInfo(callInfo, req, false); err != nil { + return nil, err + } + return resp, err } } From c6741489b3e26a55094972a5a9f1fcfdfd73ba05 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:01:44 -0400 Subject: [PATCH 43/57] Feedback Signed-off-by: Steve Ayers --- interceptor_ext_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 9201dcc6..b27943c7 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -190,7 +190,7 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { assert.Nil(t, countUpStream.Close()) } -func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { +func TestInterceptorFuncAccessingCallInfo(t *testing.T) { t.Parallel() clientChecker := &callInfoChecker{client: true} handlerChecker := &callInfoChecker{} @@ -350,9 +350,9 @@ type callInfoChecker struct { count atomic.Int32 } -func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connect.AnyRequest, prerequest bool) error { +func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connect.AnyRequest, expectMethod bool) error { // method should be blank until after we make request - if prerequest { //nolint:nestif + if !expectMethod { //nolint:nestif if callInfo.HTTPMethod() != "" { return fmt.Errorf("expected blank HTTP method in outgoing context but instead got %q", callInfo.HTTPMethod()) } @@ -407,11 +407,11 @@ func (h *callInfoChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFu } if h.client { - if err := h.validateCallInfo(callInfo, req, true); err != nil { + if err := h.validateCallInfo(callInfo, req, false); err != nil { return nil, err } } else { - if err := h.validateCallInfo(callInfo, req, false); err != nil { + if err := h.validateCallInfo(callInfo, req, true); err != nil { return nil, err } } @@ -427,7 +427,7 @@ func (h *callInfoChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFu return nil, err } - if err := h.validateCallInfo(callInfo, req, false); err != nil { + if err := h.validateCallInfo(callInfo, req, true); err != nil { return nil, err } From 3935be9d1ce2f6cabe090fc25f45632b76dac7ce Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:04:13 -0400 Subject: [PATCH 44/57] Feedback Signed-off-by: Steve Ayers --- client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client.go b/client.go index 1b8bd665..3c342970 100644 --- a/client.go +++ b/client.go @@ -116,6 +116,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien ctx, callInfo := newOutgoingContext(ctx) callInfo.peer = request.Peer() callInfo.spec = request.Spec() + callInfo.requestHeader = request.Header() response, err := unaryFunc(ctx, request) if err != nil { From a295d105111c101de28cd6b237d4a8f430bd7204 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:13:41 -0400 Subject: [PATCH 45/57] Update header setting Signed-off-by: Steve Ayers --- handler.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/handler.go b/handler.go index fe11bc45..9975712d 100644 --- a/handler.go +++ b/handler.go @@ -88,9 +88,13 @@ func NewUnaryHandler[Req, Res any]( mergeNonProtocolHeaders(conn.ResponseTrailer(), info.responseTrailer) } - // Add response headers/trailers from the response into the conn - mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) - mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) + // Add response headers/trailers from the response into the conn if they exist + if len(response.Header()) != 0 { + mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) + } + if len(response.Trailer()) != 0 { + mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) + } return conn.Send(response.Any()) } From 1d30ca6b91d5c0f9e26e2af43de08943bb88dea1 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:19:40 -0400 Subject: [PATCH 46/57] Fix responseWrapper docs Signed-off-by: Steve Ayers --- client.go | 14 +------------- context.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 3c342970..cdd676c4 100644 --- a/client.go +++ b/client.go @@ -126,7 +126,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien if !ok { return nil, errorf(CodeInternal, "unexpected client response type %T", response) } - callInfo.responseSource = &wrapper[Res]{ + callInfo.responseSource = &responseWrapper[Res]{ response: typed, } return typed, nil @@ -134,18 +134,6 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien return client } -type wrapper[Res any] struct { - response *Response[Res] -} - -func (w *wrapper[Res]) ResponseHeader() http.Header { - return w.response.Header() -} - -func (w *wrapper[Res]) ResponseTrailer() http.Header { - return w.response.Trailer() -} - // CallUnary calls a request-response procedure. func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) (*Response[Res], error) { if c.err != nil { diff --git a/context.go b/context.go index 0db47eae..00716158 100644 --- a/context.go +++ b/context.go @@ -150,6 +150,19 @@ type responseSource interface { ResponseTrailer() http.Header } +// responseWrapper wraps a Response object so that it can implement the responseSource interface. +type responseWrapper[Res any] struct { + response *Response[Res] +} + +func (w *responseWrapper[Res]) ResponseHeader() http.Header { + return w.response.Header() +} + +func (w *responseWrapper[Res]) ResponseTrailer() http.Header { + return w.response.Trailer() +} + // clientCallInfo is a CallInfo implementation used for clients. type clientCallInfo struct { responseSource From a2e8814e3e333dfb154182e854e3e9a50cf942e6 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:21:48 -0400 Subject: [PATCH 47/57] Fix again Signed-off-by: Steve Ayers --- client.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/client.go b/client.go index cdd676c4..4dd629d9 100644 --- a/client.go +++ b/client.go @@ -139,9 +139,6 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) if c.err != nil { return nil, c.err } - ctx, callInfo := newOutgoingContext(ctx) - callInfo.requestHeader = request.Header() - resp, err := c.callUnary(ctx, request) if err != nil { return nil, err From d8102c09fc81471dd542298c0f042861d966d4e9 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 15:25:10 -0400 Subject: [PATCH 48/57] Update tests Signed-off-by: Steve Ayers --- interceptor_ext_test.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index b27943c7..949a6afc 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -354,7 +354,7 @@ func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connec // method should be blank until after we make request if !expectMethod { //nolint:nestif if callInfo.HTTPMethod() != "" { - return fmt.Errorf("expected blank HTTP method in outgoing context but instead got %q", callInfo.HTTPMethod()) + return fmt.Errorf("expected blank HTTP method in context but instead got %q", callInfo.HTTPMethod()) } if req.HTTPMethod() != "" { return fmt.Errorf("expected blank HTTP method in request but instead got %q", req.HTTPMethod()) @@ -364,7 +364,7 @@ func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connec // NB: In theory, the method could also be GET, not just POST. But for the // configuration under test, it will always be POST. if callInfo.HTTPMethod() != http.MethodPost { - return fmt.Errorf("expected HTTP method %s in outgoing context but instead got %q", http.MethodPost, callInfo.HTTPMethod()) + return fmt.Errorf("expected HTTP method %s in context but instead got %q", http.MethodPost, callInfo.HTTPMethod()) } if req.HTTPMethod() != http.MethodPost { return fmt.Errorf("expected HTTP method %s in request but instead got %q", http.MethodPost, req.HTTPMethod()) @@ -373,8 +373,14 @@ func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connec if callInfo.Peer().Addr == "" { return errors.New("no peer set on call info") } + if req.Peer().Addr == "" { + return errors.New("no peer set on request") + } if callInfo.Spec().Procedure != pingv1connect.PingServicePingProcedure { - return fmt.Errorf("expected spec procedure %s but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) + return fmt.Errorf("expected spec procedure %s in context but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) + } + if req.Spec().Procedure != pingv1connect.PingServicePingProcedure { + return fmt.Errorf("expected spec procedure %s on request but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) } return nil } From 8de83900228eefe0865314d3ac499ffe296c46e1 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Tue, 1 Jul 2025 19:13:15 -0400 Subject: [PATCH 49/57] Style Signed-off-by: Steve Ayers --- client.go | 8 +--- connect_ext_test.go | 114 ++++++++++++++++++++++---------------------- context.go | 68 +++++++++++++------------- 3 files changed, 94 insertions(+), 96 deletions(-) diff --git a/client.go b/client.go index 4dd629d9..78365943 100644 --- a/client.go +++ b/client.go @@ -126,6 +126,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien if !ok { return nil, errorf(CodeInternal, "unexpected client response type %T", response) } + // Wrap the response and set it into the context callinfo callInfo.responseSource = &responseWrapper[Res]{ response: typed, } @@ -139,12 +140,7 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) if c.err != nil { return nil, c.err } - resp, err := c.callUnary(ctx, request) - if err != nil { - return nil, err - } - - return resp, nil + return c.callUnary(ctx, request) } // CallUnarySimple calls a request-response procedure using the function signature diff --git a/connect_ext_test.go b/connect_ext_test.go index 918a3c33..6cbfa9fa 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2867,43 +2867,6 @@ func (p *pluggablePingServer) CumSum( return p.cumSum(ctx, stream) } -func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { - tb.Helper() - if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { - assert.ErrorIs(tb, err, io.EOF) - assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown) - } - assert.Nil(tb, stream.CloseRequest()) - _, err := stream.Receive() - assert.NotNil(tb, err) // should be 505 - assert.True( - tb, - strings.Contains(err.Error(), "HTTP status 505"), - assert.Sprintf("expected 505, got %v", err), - ) - assert.Nil(tb, stream.CloseResponse()) -} - -func expectClientHeader(check bool, req connect.AnyRequest) error { - if !check { - return nil - } - return expectMetadata(req.Header(), "header", clientHeader, headerValue) -} - -func expectMetadata(meta http.Header, metaType, key, value string) error { //nolint:unparam - if got := meta.Get(key); got != value { - return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( - "%s %q: got %q, expected %q", - metaType, - key, - got, - value, - )) - } - return nil -} - type pingServer struct { pingv1connect.UnimplementedPingServiceHandler @@ -2912,8 +2875,10 @@ type pingServer struct { } func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { - if err := expectClientHeader(p.checkMetadata, request); err != nil { - return nil, err + if p.checkMetadata { + if err := expectMetadata(request.Header()); err != nil { + return nil, err + } } if request.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -2933,8 +2898,10 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi } func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { - if err := expectClientHeader(p.checkMetadata, request); err != nil { - return nil, err + if p.checkMetadata { + if err := expectMetadata(request.Header()); err != nil { + return nil, err + } } if request.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -2963,7 +2930,7 @@ func (p pingServer) Sum( stream *connect.ClientStream[pingv1.SumRequest], ) (*connect.Response[pingv1.SumResponse], error) { if p.checkMetadata { - if err := expectMetadata(stream.RequestHeader(), "header", clientHeader, headerValue); err != nil { + if err := expectMetadata(stream.RequestHeader()); err != nil { return nil, err } } @@ -2991,8 +2958,10 @@ func (p pingServer) CountUp( request *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse], ) error { - if err := expectClientHeader(p.checkMetadata, request); err != nil { - return err + if p.checkMetadata { + if err := expectMetadata(request.Header()); err != nil { + return err + } } if request.Peer().Addr == "" { return connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -3022,7 +2991,7 @@ func (p pingServer) CumSum( ) error { var sum int64 if p.checkMetadata { - if err := expectMetadata(stream.RequestHeader(), "header", clientHeader, headerValue); err != nil { + if err := expectMetadata(stream.RequestHeader()); err != nil { return err } } @@ -3048,13 +3017,6 @@ func (p pingServer) CumSum( } } -func expectClientHeaderInCallInfo(check bool, callInfo connect.CallInfo) error { - if !check { - return nil - } - return expectMetadata(callInfo.RequestHeader(), "header", clientHeader, headerValue) -} - type pingServerSimple struct { pingv1connectsimple.UnimplementedPingServiceHandler @@ -3067,8 +3029,10 @@ func (p pingServerSimple) Ping(ctx context.Context, request *pingv1.PingRequest) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } - if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { - return nil, err + if p.checkMetadata { + if err := expectMetadata(callInfo.RequestHeader()); err != nil { + return nil, err + } } if callInfo.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -3094,8 +3058,10 @@ func (p pingServerSimple) CountUp( if !ok { return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } - if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { - return err + if p.checkMetadata { + if err := expectMetadata(callInfo.RequestHeader()); err != nil { + return err + } } if callInfo.Peer().Addr == "" { return connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -3124,8 +3090,10 @@ func (p pingServerSimple) Fail(ctx context.Context, request *pingv1.FailRequest) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } - if err := expectClientHeaderInCallInfo(p.checkMetadata, callInfo); err != nil { - return nil, err + if p.checkMetadata { + if err := expectMetadata(callInfo.RequestHeader()); err != nil { + return nil, err + } } if callInfo.Peer().Addr == "" { return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) @@ -3235,3 +3203,33 @@ func (failCompressor) Close() error { } func (failCompressor) Reset(io.Writer) {} + +func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { + tb.Helper() + if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { + assert.ErrorIs(tb, err, io.EOF) + assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown) + } + assert.Nil(tb, stream.CloseRequest()) + _, err := stream.Receive() + assert.NotNil(tb, err) // should be 505 + assert.True( + tb, + strings.Contains(err.Error(), "HTTP status 505"), + assert.Sprintf("expected 505, got %v", err), + ) + assert.Nil(tb, stream.CloseResponse()) +} + +func expectMetadata(meta http.Header) error { + if got := meta.Get(clientHeader); got != headerValue { + return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( + "%s %q: got %q, expected %q", + "header", + clientHeader, + got, + headerValue, + )) + } + return nil +} diff --git a/context.go b/context.go index 00716158..7a1eb4ad 100644 --- a/context.go +++ b/context.go @@ -112,7 +112,7 @@ func (c *handlerCallInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *handlerCallInfo) internalOnly() {} -// streamCallInfo is a CallInfo implementation used for streaming RPCs. +// streamCallInfo is a CallInfo implementation used for streaming RPC handlers. type streamCallInfo struct { conn StreamingHandlerConn } @@ -145,24 +145,6 @@ func (c *streamCallInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *streamCallInfo) internalOnly() {} -type responseSource interface { - ResponseHeader() http.Header - ResponseTrailer() http.Header -} - -// responseWrapper wraps a Response object so that it can implement the responseSource interface. -type responseWrapper[Res any] struct { - response *Response[Res] -} - -func (w *responseWrapper[Res]) ResponseHeader() http.Header { - return w.response.Header() -} - -func (w *responseWrapper[Res]) ResponseTrailer() http.Header { - return w.response.Trailer() -} - // clientCallInfo is a CallInfo implementation used for clients. type clientCallInfo struct { responseSource @@ -211,7 +193,26 @@ func (c *clientCallInfo) internalOnly() {} type outgoingCallInfoContextKey struct{} type incomingCallInfoContextKey struct{} -// Create a new request context for use from a client. When the returned +// responseSource indicates a type that manage response headers and trailers. +type responseSource interface { + ResponseHeader() http.Header + ResponseTrailer() http.Header +} + +// responseWrapper wraps a Response object so that it can implement the responseSource interface. +type responseWrapper[Res any] struct { + response *Response[Res] +} + +func (w *responseWrapper[Res]) ResponseHeader() http.Header { + return w.response.Header() +} + +func (w *responseWrapper[Res]) ResponseTrailer() http.Header { + return w.response.Trailer() +} + +// Create a new outgoing context for use from a client. When the returned // context is passed to RPCs, the returned call info can be used to set // request metadata before the RPC is invoked and to inspect response // metadata after the RPC completes. @@ -225,6 +226,19 @@ func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { return newOutgoingContext(ctx) } +// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. +func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) + return value, ok +} + +// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. +func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) + return value, ok +} + +// Creates a new outgoing context or returns the existing one in context. func newOutgoingContext(ctx context.Context) (context.Context, *clientCallInfo) { info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*clientCallInfo) if !ok { @@ -234,22 +248,12 @@ func newOutgoingContext(ctx context.Context) (context.Context, *clientCallInfo) return ctx, info } +// newOutgoingContext creates a new incoming context. func newIncomingContext(ctx context.Context, info CallInfo) context.Context { return context.WithValue(ctx, incomingCallInfoContextKey{}, info) } -// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. -func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) - return value, ok -} - -// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. -func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) - return value, ok -} - +// requestFromOutgoingContext creates a new Request using the given context and message. func requestFromOutgoingContext[T any](ctx context.Context, message *T) *Request[T] { request := NewRequest(message) callInfo, ok := CallInfoFromOutgoingContext(ctx) From 32520e7f782d769f7b694c6bfb2e057653cb1589 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Wed, 2 Jul 2025 12:46:31 -0400 Subject: [PATCH 50/57] Move func Signed-off-by: Steve Ayers --- context.go | 52 ++++++++++++++++++++++++++-------------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/context.go b/context.go index 7a1eb4ad..41ea749d 100644 --- a/context.go +++ b/context.go @@ -66,6 +66,32 @@ type CallInfo interface { internalOnly() } +// Create a new outgoing context for use from a client. When the returned +// context is passed to RPCs, the returned call info can be used to set +// request metadata before the RPC is invoked and to inspect response +// metadata after the RPC completes. +// +// The returned context may be re-used across RPCs as long as they are +// not concurrent. Results of all CallInfo methods other than +// RequestHeader() are undefined if the context is used with concurrent RPCs. +// If the given context is already associated with an outgoing CallInfo, then +// ctx and the existing CallInfo are returned. +func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { + return newOutgoingContext(ctx) +} + +// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. +func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) + return value, ok +} + +// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. +func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) + return value, ok +} + // handlerCallInfo is a CallInfo implementation used for handlers. type handlerCallInfo struct { spec Spec @@ -212,32 +238,6 @@ func (w *responseWrapper[Res]) ResponseTrailer() http.Header { return w.response.Trailer() } -// Create a new outgoing context for use from a client. When the returned -// context is passed to RPCs, the returned call info can be used to set -// request metadata before the RPC is invoked and to inspect response -// metadata after the RPC completes. -// -// The returned context may be re-used across RPCs as long as they are -// not concurrent. Results of all CallInfo methods other than -// RequestHeader() are undefined if the context is used with concurrent RPCs. -// If the given context is already associated with an outgoing CallInfo, then -// ctx and the existing CallInfo are returned. -func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { - return newOutgoingContext(ctx) -} - -// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. -func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) - return value, ok -} - -// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. -func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) - return value, ok -} - // Creates a new outgoing context or returns the existing one in context. func newOutgoingContext(ctx context.Context) (context.Context, *clientCallInfo) { info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*clientCallInfo) From d5ccf16b322163d630fea86239739768ab442b56 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Thu, 3 Jul 2025 15:29:51 -0400 Subject: [PATCH 51/57] Fix server stream tests Signed-off-by: Steve Ayers --- connect_ext_test.go | 60 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index 6cbfa9fa..3e52b5ec 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -92,21 +92,32 @@ func TestCallInfo(t *testing.T) { t.Run("server_stream", func(t *testing.T) { ctx, callInfo := connect.NewOutgoingContext(context.Background()) callInfo.RequestHeader().Set(clientHeader, headerValue) + + val := 3 stream, err := client.CountUp(ctx, &pingv1.CountUpRequest{ - Number: 1, + Number: int64(val), }) assert.Nil(t, err) - assert.True(t, stream.Receive()) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) assert.Nil(t, stream.Err()) - msg := stream.Msg() - assert.NotNil(t, msg) - assert.Equal(t, msg.GetNumber(), 1) assert.Nil(t, stream.Close()) assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeServer) assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) - assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) t.Run("generics_api", func(t *testing.T) { @@ -141,21 +152,44 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) t.Run("server_stream", func(t *testing.T) { + ctx, callInfo := connect.NewOutgoingContext(context.Background()) + callInfo.RequestHeader().Set(clientHeader, headerValue) + + val := 3 req := connect.NewRequest(&pingv1.CountUpRequest{ - Number: 1, + Number: int64(val), }) - req.Header().Set(clientHeader, headerValue) - ctx, callInfo := connect.NewOutgoingContext(context.Background()) stream, err := client.CountUp(ctx, req) assert.Nil(t, err) - assert.True(t, stream.Receive()) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) assert.Nil(t, stream.Err()) - msg := stream.Msg() - assert.NotNil(t, msg) - assert.Equal(t, msg.GetNumber(), 1) assert.Nil(t, stream.Close()) + // Assert values on request and stream + assert.Equal(t, req.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, req.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, req.Spec().IsClient) + assert.Equal(t, req.Peer().Addr, httptest.DefaultRemoteAddr) assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, req.Spec().StreamType) + assert.Equal(t, callInfo.Spec().Procedure, req.Spec().Procedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, req.Peer().Addr) assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) + assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) }) } From cd7dc9231b8cf41a5545ea97909396e6e29fae8e Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Wed, 16 Jul 2025 14:30:24 -0400 Subject: [PATCH 52/57] Rename context methods and always create a new call info when using exported NewClientContextAPI Signed-off-by: Steve Ayers --- client.go | 50 ++++++++------ connect_ext_test.go | 14 ++-- context.go | 53 +++++++-------- error_example_test.go | 2 +- error_not_modified_example_test.go | 2 +- handler.go | 4 +- interceptor_ext_test.go | 104 ++++++----------------------- 7 files changed, 84 insertions(+), 145 deletions(-) diff --git a/client.go b/client.go index 78365943..9e45deaa 100644 --- a/client.go +++ b/client.go @@ -76,11 +76,13 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien // once at client creation. unarySpec := config.newSpec(StreamTypeUnary) unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { - ctx, callInfo := newOutgoingContext(ctx) conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) conn.onRequestSend(func(r *http.Request) { request.setRequestMethod(r.Method) - callInfo.method = r.Method + callInfo, ok := getClientCallInfoFromContext(ctx) + if ok { + callInfo.method = r.Method + } }) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the @@ -112,11 +114,13 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien request.peer = client.protocolClient.Peer() protocolClient.WriteRequestHeader(StreamTypeUnary, request.Header()) - // Also set them in the context so interceptors can inspect context for this information. - ctx, callInfo := newOutgoingContext(ctx) - callInfo.peer = request.Peer() - callInfo.spec = request.Spec() - callInfo.requestHeader = request.Header() + // Also set them in the context if there's a call info present + callInfo, callInfoOk := getClientCallInfoFromContext(ctx) + if callInfoOk { + callInfo.peer = request.Peer() + callInfo.spec = request.Spec() + callInfo.requestHeader = request.Header() + } response, err := unaryFunc(ctx, request) if err != nil { @@ -126,9 +130,11 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien if !ok { return nil, errorf(CodeInternal, "unexpected client response type %T", response) } - // Wrap the response and set it into the context callinfo - callInfo.responseSource = &responseWrapper[Res]{ - response: typed, + if callInfoOk { + // Wrap the response and set it into the context callinfo + callInfo.responseSource = &responseWrapper[Res]{ + response: typed, + } } return typed, nil } @@ -149,7 +155,7 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) // This option eliminates the [Request] and [Response] wrappers, and instead uses the // context.Context to propagate information such as headers. func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, request *Req) (*Res, error) { - response, err := c.CallUnary(ctx, requestFromOutgoingContext(ctx, request)) + response, err := c.CallUnary(ctx, requestFromClientContext(ctx, request)) if response != nil { return response.Msg, err } @@ -175,17 +181,21 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method }) - _, callInfo := newOutgoingContext(ctx) - callInfo.peer = conn.Peer() - callInfo.spec = conn.Spec() - callInfo.responseSource = conn - request.peer = conn.Peer() request.spec = conn.Spec() - // Merge any callInfo request headers first, then do the request. - // so that context headers show first in the list of headers - mergeHeaders(conn.RequestHeader(), callInfo.RequestHeader()) + callInfo, ok := getClientCallInfoFromContext(ctx) + // Set values in the context if there's a call info present + if ok { + callInfo.peer = conn.Peer() + callInfo.spec = conn.Spec() + callInfo.responseSource = conn + + // Merge any callInfo request headers first, then do the request. + // so that context headers show first in the list of headers + mergeHeaders(conn.RequestHeader(), callInfo.RequestHeader()) + } + mergeHeaders(conn.RequestHeader(), request.header) // Send always returns an io.EOF unless the error is from the client-side. @@ -211,7 +221,7 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques // This option eliminates the [Request] wrapper, and instead uses the context.Context to // propagate information such as headers. func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { - return c.CallServerStream(ctx, requestFromOutgoingContext(ctx, requestMsg)) + return c.CallServerStream(ctx, requestFromClientContext(ctx, requestMsg)) } // CallBidiStream calls a bidirectional streaming procedure. diff --git a/connect_ext_test.go b/connect_ext_test.go index 3e52b5ec..9492b119 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -75,7 +75,7 @@ func TestCallInfo(t *testing.T) { client := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) t.Run("unary", func(t *testing.T) { num := int64(42) - ctx, callInfo := connect.NewOutgoingContext(context.Background()) + ctx, callInfo := connect.NewClientContext(context.Background()) callInfo.RequestHeader().Set(clientHeader, headerValue) expect := &pingv1.PingResponse{Number: num} @@ -90,7 +90,7 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) t.Run("server_stream", func(t *testing.T) { - ctx, callInfo := connect.NewOutgoingContext(context.Background()) + ctx, callInfo := connect.NewClientContext(context.Background()) callInfo.RequestHeader().Set(clientHeader, headerValue) val := 3 @@ -134,7 +134,7 @@ func TestCallInfo(t *testing.T) { request.Header().Set(clientHeader, headerValue) expect := &pingv1.PingResponse{Number: num} - ctx, callInfo := connect.NewOutgoingContext(context.Background()) + ctx, callInfo := connect.NewClientContext(context.Background()) response, err := client.Ping(ctx, request) assert.Nil(t, err) assert.Equal(t, response.Msg, expect) @@ -152,7 +152,7 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) }) t.Run("server_stream", func(t *testing.T) { - ctx, callInfo := connect.NewOutgoingContext(context.Background()) + ctx, callInfo := connect.NewClientContext(context.Background()) callInfo.RequestHeader().Set(clientHeader, headerValue) val := 3 @@ -3059,7 +3059,7 @@ type pingServerSimple struct { } func (p pingServerSimple) Ping(ctx context.Context, request *pingv1.PingRequest) (*pingv1.PingResponse, error) { - callInfo, ok := connect.CallInfoFromIncomingContext(ctx) + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } @@ -3088,7 +3088,7 @@ func (p pingServerSimple) CountUp( request *pingv1.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { - callInfo, ok := connect.CallInfoFromIncomingContext(ctx) + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) if !ok { return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } @@ -3120,7 +3120,7 @@ func (p pingServerSimple) CountUp( } func (p pingServerSimple) Fail(ctx context.Context, request *pingv1.FailRequest) (*pingv1.FailResponse, error) { - callInfo, ok := connect.CallInfoFromIncomingContext(ctx) + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } diff --git a/context.go b/context.go index 41ea749d..b34a6255 100644 --- a/context.go +++ b/context.go @@ -21,7 +21,7 @@ import ( // CallInfo represents information relevant to an RPC call. // Values returned by these methods are not thread-safe. Users should expect -// data races if they create an outgoing CallInfo in context and then pass that +// data races if they create an outgoing client CallInfo in context and then pass that // CallInfo to another goroutine and try to call methods on it concurrent with the RPC. type CallInfo interface { // Spec returns a description of this call. @@ -66,29 +66,28 @@ type CallInfo interface { internalOnly() } -// Create a new outgoing context for use from a client. When the returned -// context is passed to RPCs, the returned call info can be used to set +// Create a new client (i.e. outgoing) context for use from a client. When the +// returned context is passed to RPCs, the returned call info can be used to set // request metadata before the RPC is invoked and to inspect response // metadata after the RPC completes. // // The returned context may be re-used across RPCs as long as they are // not concurrent. Results of all CallInfo methods other than // RequestHeader() are undefined if the context is used with concurrent RPCs. -// If the given context is already associated with an outgoing CallInfo, then -// ctx and the existing CallInfo are returned. -func NewOutgoingContext(ctx context.Context) (context.Context, CallInfo) { - return newOutgoingContext(ctx) +func NewClientContext(ctx context.Context) (context.Context, CallInfo) { + info := &clientCallInfo{} + return context.WithValue(ctx, clientCallInfoContextKey{}, info), info } -// CallInfoFromOutgoingContext returns the CallInfo for the given outgoing context, if there is one. -func CallInfoFromOutgoingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(outgoingCallInfoContextKey{}).(CallInfo) +// CallInfoFromClientContext returns the CallInfo for the given client context, if there is one. +func CallInfoFromClientContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(clientCallInfoContextKey{}).(CallInfo) return value, ok } -// CallInfoFromIncomingContext returns the CallInfo for the given incoming context, if there is one. -func CallInfoFromIncomingContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(incomingCallInfoContextKey{}).(CallInfo) +// CallInfoFromHandlerContext returns the CallInfo for the given handler (i.e. incoming) context, if there is one. +func CallInfoFromHandlerContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(handlerCallInfoContextKey{}).(CallInfo) return value, ok } @@ -216,8 +215,8 @@ func (c *clientCallInfo) HTTPMethod() string { // internalOnly implements CallInfo. func (c *clientCallInfo) internalOnly() {} -type outgoingCallInfoContextKey struct{} -type incomingCallInfoContextKey struct{} +type clientCallInfoContextKey struct{} +type handlerCallInfoContextKey struct{} // responseSource indicates a type that manage response headers and trailers. type responseSource interface { @@ -238,25 +237,21 @@ func (w *responseWrapper[Res]) ResponseTrailer() http.Header { return w.response.Trailer() } -// Creates a new outgoing context or returns the existing one in context. -func newOutgoingContext(ctx context.Context) (context.Context, *clientCallInfo) { - info, ok := ctx.Value(outgoingCallInfoContextKey{}).(*clientCallInfo) - if !ok { - info = &clientCallInfo{} - return context.WithValue(ctx, outgoingCallInfoContextKey{}, info), info - } - return ctx, info +// Gets a client (i.e. outgoing) call info from context. +func getClientCallInfoFromContext(ctx context.Context) (*clientCallInfo, bool) { + info, ok := ctx.Value(clientCallInfoContextKey{}).(*clientCallInfo) + return info, ok } -// newOutgoingContext creates a new incoming context. -func newIncomingContext(ctx context.Context, info CallInfo) context.Context { - return context.WithValue(ctx, incomingCallInfoContextKey{}, info) +// newHandlerContext creates a new handler (i.e. incoming) context. +func newHandlerContext(ctx context.Context, info CallInfo) context.Context { + return context.WithValue(ctx, handlerCallInfoContextKey{}, info) } -// requestFromOutgoingContext creates a new Request using the given context and message. -func requestFromOutgoingContext[T any](ctx context.Context, message *T) *Request[T] { +// requestFromClientContext creates a new Request using the given context and message. +func requestFromClientContext[T any](ctx context.Context, message *T) *Request[T] { request := NewRequest(message) - callInfo, ok := CallInfoFromOutgoingContext(ctx) + callInfo, ok := CallInfoFromClientContext(ctx) if ok { request.setHeader(callInfo.RequestHeader()) } diff --git a/error_example_test.go b/error_example_test.go index 30930a97..1bcc68c3 100644 --- a/error_example_test.go +++ b/error_example_test.go @@ -48,7 +48,7 @@ func ExampleIsNotModifiedError() { connect.WithHTTPGet(), ) req := &pingv1.PingRequest{Number: 42} - ctx, callInfo := connect.NewOutgoingContext(context.Background()) + ctx, callInfo := connect.NewClientContext(context.Background()) _, err := client.Ping(ctx, req) if err != nil { fmt.Println(err) diff --git a/error_not_modified_example_test.go b/error_not_modified_example_test.go index 3daf8223..e87cd07a 100644 --- a/error_not_modified_example_test.go +++ b/error_not_modified_example_test.go @@ -42,7 +42,7 @@ func (*ExampleCachingPingServer) Ping( resp := &pingv1.PingResponse{ Number: req.GetNumber(), } - callInfo, ok := connect.CallInfoFromIncomingContext(ctx) + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) if !ok { return nil, errors.New("no call info found in context") } diff --git a/handler.go b/handler.go index 9975712d..5eb5e1e0 100644 --- a/handler.go +++ b/handler.go @@ -74,7 +74,7 @@ func NewUnaryHandler[Req, Res any]( method: request.HTTPMethod(), requestHeader: request.Header(), } - ctx = newIncomingContext(ctx, info) + ctx = newHandlerContext(ctx, info) response, err := untyped(ctx, request) if err != nil { return err @@ -176,7 +176,7 @@ func NewServerStreamHandler[Req, Res any]( if err != nil { return err } - ctx = newIncomingContext(ctx, &streamCallInfo{ + ctx = newHandlerContext(ctx, &streamCallInfo{ conn: conn, }) return implementation(ctx, req, &ServerStream[Res]{conn: conn}) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 949a6afc..b5892fde 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -16,7 +16,6 @@ package connect_test import ( "context" - "errors" "fmt" "net/http" "sync/atomic" @@ -190,10 +189,10 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { assert.Nil(t, countUpStream.Close()) } -func TestInterceptorFuncAccessingCallInfo(t *testing.T) { +func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { t.Parallel() - clientChecker := &callInfoChecker{client: true} - handlerChecker := &callInfoChecker{} + clientChecker := &httpMethodChecker{client: true} + handlerChecker := &httpMethodChecker{} mux := http.NewServeMux() mux.Handle( @@ -345,103 +344,38 @@ func (cc *headerInspectingClientConn) Receive(msg any) error { return err } -type callInfoChecker struct { +type httpMethodChecker struct { client bool count atomic.Int32 } -func (h *callInfoChecker) validateCallInfo(callInfo connect.CallInfo, req connect.AnyRequest, expectMethod bool) error { - // method should be blank until after we make request - if !expectMethod { //nolint:nestif - if callInfo.HTTPMethod() != "" { - return fmt.Errorf("expected blank HTTP method in context but instead got %q", callInfo.HTTPMethod()) - } - if req.HTTPMethod() != "" { - return fmt.Errorf("expected blank HTTP method in request but instead got %q", req.HTTPMethod()) - } - } else { - // server interceptors see method from the start - // NB: In theory, the method could also be GET, not just POST. But for the - // configuration under test, it will always be POST. - if callInfo.HTTPMethod() != http.MethodPost { - return fmt.Errorf("expected HTTP method %s in context but instead got %q", http.MethodPost, callInfo.HTTPMethod()) - } - if req.HTTPMethod() != http.MethodPost { - return fmt.Errorf("expected HTTP method %s in request but instead got %q", http.MethodPost, req.HTTPMethod()) - } - } - if callInfo.Peer().Addr == "" { - return errors.New("no peer set on call info") - } - if req.Peer().Addr == "" { - return errors.New("no peer set on request") - } - if callInfo.Spec().Procedure != pingv1connect.PingServicePingProcedure { - return fmt.Errorf("expected spec procedure %s in context but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) - } - if req.Spec().Procedure != pingv1connect.PingServicePingProcedure { - return fmt.Errorf("expected spec procedure %s on request but got %s", pingv1connect.PingServicePingProcedure, callInfo.Spec().Procedure) - } - return nil -} - -func (h *callInfoChecker) getCallInfo(ctx context.Context) (connect.CallInfo, error) { - var callInfo connect.CallInfo - if h.client { - info, ok := connect.CallInfoFromOutgoingContext(ctx) - if !ok { - return nil, errors.New("no call info found in outgoing context") - } - callInfo = info - } else { - info, ok := connect.CallInfoFromIncomingContext(ctx) - if !ok { - return nil, errors.New("no call info found in incoming context") - } - callInfo = info - } - return callInfo, nil -} - -func (h *callInfoChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { +func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { h.count.Add(1) - - callInfo, err := h.getCallInfo(ctx) - if err != nil { - return nil, err - } - if h.client { - if err := h.validateCallInfo(callInfo, req, false); err != nil { - return nil, err + // should be blank until after we make request + if req.HTTPMethod() != "" { + return nil, fmt.Errorf("expected blank HTTP method but instead got %q", req.HTTPMethod()) } } else { - if err := h.validateCallInfo(callInfo, req, true); err != nil { - return nil, err + // server interceptors see method from the start + // NB: In theory, the method could also be GET, not just POST. But for the + // configuration under test, it will always be POST. + if req.HTTPMethod() != http.MethodPost { + return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) } } - resp, err := unaryFunc(ctx, req) - if err != nil { - return nil, err - } - - // Method should now be set on the outgoing context - callInfo, err = h.getCallInfo(ctx) - if err != nil { - return nil, err - } - - if err := h.validateCallInfo(callInfo, req, true); err != nil { - return nil, err + // NB: In theory, the method could also be GET, not just POST. But for the + // configuration under test, it will always be POST. + if req.HTTPMethod() != http.MethodPost { + return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) } - return resp, err } } -func (h *callInfoChecker) WrapStreamingClient(clientFunc connect.StreamingClientFunc) connect.StreamingClientFunc { +func (h *httpMethodChecker) WrapStreamingClient(clientFunc connect.StreamingClientFunc) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { // method not exposed to streaming interceptor, but that's okay because it's always POST for streams h.count.Add(1) @@ -449,7 +383,7 @@ func (h *callInfoChecker) WrapStreamingClient(clientFunc connect.StreamingClient } } -func (h *callInfoChecker) WrapStreamingHandler(handlerFunc connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (h *httpMethodChecker) WrapStreamingHandler(handlerFunc connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { // method not exposed to streaming interceptor, but that's okay because it's always POST for streams h.count.Add(1) From 45ee1ce1c8ce0c7a98df5ea3f18ec2433ba2da13 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 21 Jul 2025 10:11:25 -0400 Subject: [PATCH 53/57] Interceptor tests Signed-off-by: Steve Ayers --- client.go | 24 +++++- connect.go | 41 ++++++++++ connect_ext_test.go | 1 - context.go | 9 +-- interceptor.go | 39 ++++++++- interceptor_ext_test.go | 173 +++++++++++++++++++++++++++++++++++++++- 6 files changed, 272 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index 9e45deaa..2a7af459 100644 --- a/client.go +++ b/client.go @@ -104,6 +104,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien return response, conn.CloseResponse() }) if interceptor := config.Interceptor; interceptor != nil { + // interceptor here is the chain unaryFunc = interceptor.WrapUnary(unaryFunc) } client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) { @@ -119,7 +120,19 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien if callInfoOk { callInfo.peer = request.Peer() callInfo.spec = request.Spec() + // A client could have set request headers in the call info OR the request wrapper + // So if a callInfo exists in context, merge any headers from there into the request wrapper + // so that all headers are sent in the request + mergeHeaders(request.Header(), callInfo.requestHeader) + // Then, set the full list of merged headers into the call info so users can query the context + // for this information + // TODO - Does this necessarily need done? callInfo.requestHeader = request.Header() + + // Copy the call info into a sentinel value. This is so we can compare + // the sentinel value against the call info in context. If they're different, + // we can stop the request. This protects against changing the context in interceptors. + ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo) } response, err := unaryFunc(ctx, request) @@ -178,15 +191,22 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if c.err != nil { return nil, c.err } + callInfo, callInfoOk := getClientCallInfoFromContext(ctx) + // Set values in the context if there's a call info present + if callInfoOk { + // Copy the call info into a sentinel value. This is so we can compare + // the sentinel value against the call info in context. If they're different, + // we can stop the request. This protects against changing the context in interceptors. + ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo) + } conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method }) request.peer = conn.Peer() request.spec = conn.Spec() - callInfo, ok := getClientCallInfoFromContext(ctx) // Set values in the context if there's a call info present - if ok { + if callInfoOk { callInfo.peer = conn.Peer() callInfo.spec = conn.Spec() callInfo.responseSource = conn diff --git a/connect.go b/connect.go index caaf838b..963c7998 100644 --- a/connect.go +++ b/connect.go @@ -373,6 +373,47 @@ type hasHTTPMethod interface { getHTTPMethod() string } +type errStreamingClientConn struct { + StreamingClientConn + err error +} + +func (c *errStreamingClientConn) Receive(msg any) error { + return c.err +} + +func (c *errStreamingClientConn) Spec() Spec { + return Spec{} +} + +func (c *errStreamingClientConn) Peer() Peer { + return Peer{} +} + +func (c *errStreamingClientConn) Send(msg any) error { + return c.err +} + +func (c *errStreamingClientConn) CloseRequest() error { + return c.err +} + +func (c *errStreamingClientConn) CloseResponse() error { + return c.err +} + +func (c *errStreamingClientConn) RequestHeader() http.Header { + return make(http.Header) +} + +func (c *errStreamingClientConn) ResponseHeader() http.Header { + return make(http.Header) +} + +func (c *errStreamingClientConn) ResponseTrailer() http.Header { + return make(http.Header) +} + // receiveUnaryResponse unmarshals a message from a StreamingClientConn, then // envelopes the message and attaches headers and trailers. It attempts to // consume the response stream and isn't appropriate when receiving multiple diff --git a/connect_ext_test.go b/connect_ext_test.go index 9492b119..eddec15a 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -78,7 +78,6 @@ func TestCallInfo(t *testing.T) { ctx, callInfo := connect.NewClientContext(context.Background()) callInfo.RequestHeader().Set(clientHeader, headerValue) expect := &pingv1.PingResponse{Number: num} - response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) assert.Equal(t, response, expect) assert.Nil(t, err) diff --git a/context.go b/context.go index b34a6255..004be9ea 100644 --- a/context.go +++ b/context.go @@ -79,12 +79,6 @@ func NewClientContext(ctx context.Context) (context.Context, CallInfo) { return context.WithValue(ctx, clientCallInfoContextKey{}, info), info } -// CallInfoFromClientContext returns the CallInfo for the given client context, if there is one. -func CallInfoFromClientContext(ctx context.Context) (CallInfo, bool) { - value, ok := ctx.Value(clientCallInfoContextKey{}).(CallInfo) - return value, ok -} - // CallInfoFromHandlerContext returns the CallInfo for the given handler (i.e. incoming) context, if there is one. func CallInfoFromHandlerContext(ctx context.Context) (CallInfo, bool) { value, ok := ctx.Value(handlerCallInfoContextKey{}).(CallInfo) @@ -216,6 +210,7 @@ func (c *clientCallInfo) HTTPMethod() string { func (c *clientCallInfo) internalOnly() {} type clientCallInfoContextKey struct{} +type sentinelContextKey struct{} type handlerCallInfoContextKey struct{} // responseSource indicates a type that manage response headers and trailers. @@ -251,7 +246,7 @@ func newHandlerContext(ctx context.Context, info CallInfo) context.Context { // requestFromClientContext creates a new Request using the given context and message. func requestFromClientContext[T any](ctx context.Context, message *T) *Request[T] { request := NewRequest(message) - callInfo, ok := CallInfoFromClientContext(ctx) + callInfo, ok := getClientCallInfoFromContext(ctx) if ok { request.setHeader(callInfo.RequestHeader()) } diff --git a/interceptor.go b/interceptor.go index f0c3620a..f5aac3e1 100644 --- a/interceptor.go +++ b/interceptor.go @@ -16,6 +16,13 @@ package connect import ( "context" + "errors" +) + +var ( + // errNewClientContextProhibited signals that a new client context was created + // in an interceptor, which is prohibited. + errNewClientContextProhibited = errors.New("creating a new context in an interceptor is prohibited") ) // UnaryFunc is the generic signature of a unary RPC. Interceptors may wrap @@ -36,9 +43,8 @@ type StreamingHandlerFunc func(context.Context, StreamingHandlerConn) error // An Interceptor adds logic to a generated handler or client, like the // decorators or middleware you may have seen in other libraries. Interceptors -// may replace the context, mutate requests and responses, handle errors, -// retry, recover from panics, emit logs and metrics, or do nearly anything -// else. +// may mutate requests and responses, handle errors, retry, recover from panics, +// emit logs and metrics, or do nearly anything else. // // The returned functions must be safe to call concurrently. type Interceptor interface { @@ -85,6 +91,7 @@ func newChain(interceptors []Interceptor) *chain { func (c *chain) WrapUnary(next UnaryFunc) UnaryFunc { for _, interceptor := range c.interceptors { + next = unaryThunk(next) next = interceptor.WrapUnary(next) } return next @@ -92,6 +99,7 @@ func (c *chain) WrapUnary(next UnaryFunc) UnaryFunc { func (c *chain) WrapStreamingClient(next StreamingClientFunc) StreamingClientFunc { for _, interceptor := range c.interceptors { + next = streamingClientThunk(next) next = interceptor.WrapStreamingClient(next) } return next @@ -103,3 +111,28 @@ func (c *chain) WrapStreamingHandler(next StreamingHandlerFunc) StreamingHandler } return next } + +func unaryThunk(next UnaryFunc) UnaryFunc { + return func(ctx context.Context, req AnyRequest) (AnyResponse, error) { + if !checkSentinel(ctx) { + return nil, errNewClientContextProhibited + } + return next(ctx, req) + } +} + +func streamingClientThunk(next StreamingClientFunc) StreamingClientFunc { + return func(ctx context.Context, spec Spec) StreamingClientConn { + if !checkSentinel(ctx) { + return &errStreamingClientConn{err: errNewClientContextProhibited} + } + return next(ctx, spec) + } +} + +func checkSentinel(ctx context.Context) bool { + callInfo, _ := ctx.Value(clientCallInfoContextKey{}).(*clientCallInfo) + sentinel, _ := ctx.Value(sentinelContextKey{}).(*clientCallInfo) + // Only verify if there's a sentinel call info to compare it to + return callInfo == sentinel +} diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index b5892fde..bc022c7c 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -28,6 +28,139 @@ import ( "connectrpc.com/connect/internal/memhttp/memhttptest" ) +func TestNewClientContextFails(t *testing.T) { + // Verifies that calling NewClientContext in an interceptor fails when sending the new context downstream + t.Parallel() + t.Run("unary", func(t *testing.T) { + t.Parallel() + t.Run("first_interceptor", func(t *testing.T) { + t.Parallel() + createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + return connect.WithInterceptors( + &contextInterceptor{client: true, count: client1, createNewContext: true}, + &contextInterceptor{client: true, count: client2}, + ) + } + var client1, client2 atomic.Int32 + mux := http.NewServeMux() + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + createInterceptors(&client1, &client2), + ) + _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + + // Since we are creating a new client context, an error will be returned from the invocation + assert.NotNil(t, err) + assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") + // And because we're creating it in the first interceptor, only the first interceptor fires + assert.Equal(t, int32(1), client1.Load()) + assert.Equal(t, int32(0), client2.Load()) + }) + t.Run("subsequent_interceptor", func(t *testing.T) { + t.Parallel() + createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + return connect.WithInterceptors( + &contextInterceptor{client: true, count: client1}, + &contextInterceptor{client: true, count: client2, createNewContext: true}, + ) + } + var client1, client2 atomic.Int32 + mux := http.NewServeMux() + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + createInterceptors(&client1, &client2), + ) + _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + + // Since we are creating a new client context, an error will be returned from the invocation + assert.NotNil(t, err) + assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") + // And because we're creating it in the second interceptor, they both fire + assert.Equal(t, int32(1), client1.Load()) + assert.Equal(t, int32(1), client2.Load()) + }) + }) + t.Run("server_streaming", func(t *testing.T) { + t.Parallel() + t.Run("first_interceptor", func(t *testing.T) { + t.Parallel() + createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + return connect.WithInterceptors( + &contextInterceptor{client: true, count: client1, createNewContext: true}, + &contextInterceptor{client: true, count: client2}, + ) + } + var client1, client2 atomic.Int32 + mux := http.NewServeMux() + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + createInterceptors(&client1, &client2), + ) + responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) + + // Since we are creating a new client context, an error will be returned from the invocation + assert.Nil(t, responses) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") + // And because we're creating it in the first interceptor, only the first interceptor fires + assert.Equal(t, int32(1), client1.Load()) + assert.Equal(t, int32(0), client2.Load()) + }) + t.Run("subsequent_interceptor", func(t *testing.T) { + t.Parallel() + createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + return connect.WithInterceptors( + &contextInterceptor{client: true, count: client1}, + &contextInterceptor{client: true, count: client2, createNewContext: true}, + ) + } + var client1, client2 atomic.Int32 + mux := http.NewServeMux() + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + createInterceptors(&client1, &client2), + ) + responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) + + // Since we are creating a new client context, an error will be returned from the invocation + assert.Nil(t, responses) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") + // And because we're creating it in the second interceptor, all interceptors fire + assert.Equal(t, int32(1), client1.Load()) + assert.Equal(t, int32(1), client2.Load()) + }) + }) +} + func TestOnionOrderingEndToEnd(t *testing.T) { t.Parallel() // Helper function: returns a function that asserts that there's some value @@ -349,7 +482,7 @@ type httpMethodChecker struct { count atomic.Int32 } -func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { +func (h *httpMethodChecker) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { h.count.Add(1) if h.client { @@ -365,7 +498,7 @@ func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.Unary return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) } } - resp, err := unaryFunc(ctx, req) + resp, err := next(ctx, req) // NB: In theory, the method could also be GET, not just POST. But for the // configuration under test, it will always be POST. if req.HTTPMethod() != http.MethodPost { @@ -390,3 +523,39 @@ func (h *httpMethodChecker) WrapStreamingHandler(handlerFunc connect.StreamingHa return handlerFunc(ctx, conn) } } + +type contextInterceptor struct { + client bool + count *atomic.Int32 + // Whether the interceptor should attempt to create a new context (which will cause next() to return an error) + createNewContext bool +} + +func (h *contextInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + h.count.Add(1) + if h.createNewContext { + // This will cause next to return an error + ctx, _ = connect.NewClientContext(ctx) + } + return next(ctx, req) + } +} + +func (h *contextInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + h.count.Add(1) + if h.createNewContext { + // This will cause next to return an error + ctx, _ = connect.NewClientContext(ctx) + } + return next(ctx, spec) + } +} + +func (h *contextInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + h.count.Add(1) + return next(ctx, conn) + } +} From 101179907b70a0f481ddd9be317c5c4a1a0729fd Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 21 Jul 2025 11:10:02 -0400 Subject: [PATCH 54/57] Side quest tests Signed-off-by: Steve Ayers Fix names Signed-off-by: Steve Ayers --- interceptor_ext_test.go | 186 ++++++++++++++++++++++++++++++---------- 1 file changed, 139 insertions(+), 47 deletions(-) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index bc022c7c..9668aff8 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -25,9 +25,44 @@ import ( "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" ) +func TestSideQuestInInterceptor(t *testing.T) { + t.Parallel() + t.Run("unary", func(t *testing.T) { + t.Parallel() + t.Run("sidequest_succeeds", func(t *testing.T) { + t.Parallel() + createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32, server *memhttp.Server) connect.Option { + return connect.WithInterceptors( + newSideQuestInterceptor(t, clientCounter1, server), + newSideQuestInterceptor(t, clientCounter2, server), + ) + } + var clientCounter1, clientCounter2 atomic.Int32 + mux := http.NewServeMux() + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + createInterceptors(&clientCounter1, &clientCounter2, server), + ) + _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + }) +} + func TestNewClientContextFails(t *testing.T) { // Verifies that calling NewClientContext in an interceptor fails when sending the new context downstream t.Parallel() @@ -35,13 +70,13 @@ func TestNewClientContextFails(t *testing.T) { t.Parallel() t.Run("first_interceptor", func(t *testing.T) { t.Parallel() - createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { return connect.WithInterceptors( - &contextInterceptor{client: true, count: client1, createNewContext: true}, - &contextInterceptor{client: true, count: client2}, + &contextInterceptor{client: true, count: clientCounter1, createNewContext: true}, + &contextInterceptor{client: true, count: clientCounter2}, ) } - var client1, client2 atomic.Int32 + var clientCounter1, clientCounter2 atomic.Int32 mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( @@ -52,7 +87,7 @@ func TestNewClientContextFails(t *testing.T) { client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), - createInterceptors(&client1, &client2), + createInterceptors(&clientCounter1, &clientCounter2), ) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) @@ -60,18 +95,18 @@ func TestNewClientContextFails(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") // And because we're creating it in the first interceptor, only the first interceptor fires - assert.Equal(t, int32(1), client1.Load()) - assert.Equal(t, int32(0), client2.Load()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) }) t.Run("subsequent_interceptor", func(t *testing.T) { t.Parallel() - createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { return connect.WithInterceptors( - &contextInterceptor{client: true, count: client1}, - &contextInterceptor{client: true, count: client2, createNewContext: true}, + &contextInterceptor{client: true, count: clientCounter1}, + &contextInterceptor{client: true, count: clientCounter2, createNewContext: true}, ) } - var client1, client2 atomic.Int32 + var clientCounter1, clientCounter2 atomic.Int32 mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( @@ -82,7 +117,7 @@ func TestNewClientContextFails(t *testing.T) { client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), - createInterceptors(&client1, &client2), + createInterceptors(&clientCounter1, &clientCounter2), ) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) @@ -90,21 +125,21 @@ func TestNewClientContextFails(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") // And because we're creating it in the second interceptor, they both fire - assert.Equal(t, int32(1), client1.Load()) - assert.Equal(t, int32(1), client2.Load()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) }) }) t.Run("server_streaming", func(t *testing.T) { t.Parallel() t.Run("first_interceptor", func(t *testing.T) { t.Parallel() - createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { return connect.WithInterceptors( - &contextInterceptor{client: true, count: client1, createNewContext: true}, - &contextInterceptor{client: true, count: client2}, + &contextInterceptor{client: true, count: clientCounter1, createNewContext: true}, + &contextInterceptor{client: true, count: clientCounter2}, ) } - var client1, client2 atomic.Int32 + var clientCounter1, clientCounter2 atomic.Int32 mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( @@ -115,7 +150,7 @@ func TestNewClientContextFails(t *testing.T) { client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), - createInterceptors(&client1, &client2), + createInterceptors(&clientCounter1, &clientCounter2), ) responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) @@ -124,18 +159,18 @@ func TestNewClientContextFails(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") // And because we're creating it in the first interceptor, only the first interceptor fires - assert.Equal(t, int32(1), client1.Load()) - assert.Equal(t, int32(0), client2.Load()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) }) t.Run("subsequent_interceptor", func(t *testing.T) { t.Parallel() - createInterceptors := func(client1 *atomic.Int32, client2 *atomic.Int32) connect.Option { + createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { return connect.WithInterceptors( - &contextInterceptor{client: true, count: client1}, - &contextInterceptor{client: true, count: client2, createNewContext: true}, + &contextInterceptor{client: true, count: clientCounter1}, + &contextInterceptor{client: true, count: clientCounter2, createNewContext: true}, ) } - var client1, client2 atomic.Int32 + var clientCounter1, clientCounter2 atomic.Int32 mux := http.NewServeMux() mux.Handle( pingv1connect.NewPingServiceHandler( @@ -146,7 +181,7 @@ func TestNewClientContextFails(t *testing.T) { client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), - createInterceptors(&client1, &client2), + createInterceptors(&clientCounter1, &clientCounter2), ) responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) @@ -155,8 +190,8 @@ func TestNewClientContextFails(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") // And because we're creating it in the second interceptor, all interceptors fire - assert.Equal(t, int32(1), client1.Load()) - assert.Equal(t, int32(1), client2.Load()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) }) }) } @@ -201,7 +236,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { } } - var client1, client2, client3, handler1, handler2, handler3 atomic.Int32 + var clientCounter1, clientCounter2, clientCounter3, handlerCounter1, handlerCounter2, handlerCounter3 atomic.Int32 // The client and handler interceptor onions are the meat of the test. The // order of interceptor execution must be the same for unary and streaming @@ -216,7 +251,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { // intended order clear. clientOnion := connect.WithInterceptors( newHeaderInterceptor( - &client1, + &clientCounter1, // 1 (start). request: should see protocol-related headers func(_ connect.Spec, h http.Header) { assert.NotZero(t, h.Get("Content-Type")) @@ -225,29 +260,29 @@ func TestOnionOrderingEndToEnd(t *testing.T) { assertAllPresent, ), newHeaderInterceptor( - &client2, + &clientCounter2, newInspector("", "one"), // 2. request: add header "one" newInspector("three", "four"), // 11. response: check "three", add "four" ), newHeaderInterceptor( - &client3, + &clientCounter3, newInspector("one", "two"), // 3. request: check "one", add "two" newInspector("two", "three"), // 10. response: check "two", add "three" ), ) handlerOnion := connect.WithInterceptors( newHeaderInterceptor( - &handler1, + &handlerCounter1, newInspector("two", "three"), // 4. request: check "two", add "three" newInspector("one", "two"), // 9. response: check "one", add "two" ), newHeaderInterceptor( - &handler2, + &handlerCounter2, newInspector("three", "four"), // 5. request: check "three", add "four" newInspector("", "one"), // 8. response: add "one" ), newHeaderInterceptor( - &handler3, + &handlerCounter3, assertAllPresent, // 6. request: check "one"-"four" nil, // 7. response: no-op ), @@ -271,12 +306,12 @@ func TestOnionOrderingEndToEnd(t *testing.T) { assert.Nil(t, err) // make sure the interceptors were actually invoked - assert.Equal(t, int32(1), client1.Load()) - assert.Equal(t, int32(1), client2.Load()) - assert.Equal(t, int32(1), client3.Load()) - assert.Equal(t, int32(1), handler1.Load()) - assert.Equal(t, int32(1), handler2.Load()) - assert.Equal(t, int32(1), handler3.Load()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + assert.Equal(t, int32(1), clientCounter3.Load()) + assert.Equal(t, int32(1), handlerCounter1.Load()) + assert.Equal(t, int32(1), handlerCounter2.Load()) + assert.Equal(t, int32(1), handlerCounter3.Load()) responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) assert.Nil(t, err) @@ -288,12 +323,12 @@ func TestOnionOrderingEndToEnd(t *testing.T) { assert.Nil(t, responses.Close()) // make sure the interceptors were invoked again - assert.Equal(t, int32(2), client1.Load()) - assert.Equal(t, int32(2), client2.Load()) - assert.Equal(t, int32(2), client3.Load()) - assert.Equal(t, int32(2), handler1.Load()) - assert.Equal(t, int32(2), handler2.Load()) - assert.Equal(t, int32(2), handler3.Load()) + assert.Equal(t, int32(2), clientCounter1.Load()) + assert.Equal(t, int32(2), clientCounter2.Load()) + assert.Equal(t, int32(2), clientCounter3.Load()) + assert.Equal(t, int32(2), handlerCounter1.Load()) + assert.Equal(t, int32(2), handlerCounter2.Load()) + assert.Equal(t, int32(2), handlerCounter3.Load()) } func TestEmptyUnaryInterceptorFunc(t *testing.T) { @@ -559,3 +594,60 @@ func (h *contextInterceptor) WrapStreamingHandler(next connect.StreamingHandlerF return next(ctx, conn) } } + +type sideQuestInterceptor struct { + count *atomic.Int32 + client pingv1connect.PingServiceClient + t *testing.T +} + +func newSideQuestInterceptor( //nolint:thelper + t *testing.T, + counter *atomic.Int32, + server *memhttp.Server, +) *sideQuestInterceptor { + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + ) + return &sideQuestInterceptor{t: t, client: client, count: counter} +} + +func (h *sideQuestInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + h.count.Add(1) + num := int64(42) + // Create a new client context for the side quest Ping. This should succeed because we aren't + // sending this on through the interceptor chain and reusing this context + newCtx, _ := connect.NewClientContext(ctx) + resp, err := h.client.Ping(newCtx, connect.NewRequest(&pingv1.PingRequest{Number: num})) + assert.Nil(h.t, err) + assert.Equal(h.t, resp.Msg.Number, num) + + return next(ctx, req) + } +} + +func (h *sideQuestInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + h.count.Add(1) + // Create a new context for the side quest CountUp. This should succeed because we aren't + // sending this on through the interceptor chain and reusing this context + newCtx, _ := connect.NewClientContext(ctx) + responses, err := h.client.CountUp(newCtx, connect.NewRequest(&pingv1.CountUpRequest{Number: 3})) + assert.Nil(h.t, err) + var sum int64 + for responses.Receive() { + sum += responses.Msg().GetNumber() + } + assert.Equal(h.t, sum, 6) + assert.Nil(h.t, responses.Close()) + return next(ctx, spec) + } +} + +func (h *sideQuestInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + return next(ctx, conn) + } +} From 9143ba050c7f390e6c33d34540854167b5bc4eaa Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Mon, 21 Jul 2025 17:20:03 -0400 Subject: [PATCH 55/57] Extensive testing for simple and generic APIs using callinfo Signed-off-by: Steve Ayers --- client.go | 10 +- connect.go | 6 +- connect_ext_test.go | 453 +++++++++++++++++++++++++++++++++----------- context.go | 10 - interceptor.go | 4 +- 5 files changed, 346 insertions(+), 137 deletions(-) diff --git a/client.go b/client.go index 2a7af459..e78f8cd3 100644 --- a/client.go +++ b/client.go @@ -104,7 +104,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien return response, conn.CloseResponse() }) if interceptor := config.Interceptor; interceptor != nil { - // interceptor here is the chain + // interceptor is the full chain of all interceptors provided unaryFunc = interceptor.WrapUnary(unaryFunc) } client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) { @@ -124,10 +124,6 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien // So if a callInfo exists in context, merge any headers from there into the request wrapper // so that all headers are sent in the request mergeHeaders(request.Header(), callInfo.requestHeader) - // Then, set the full list of merged headers into the call info so users can query the context - // for this information - // TODO - Does this necessarily need done? - callInfo.requestHeader = request.Header() // Copy the call info into a sentinel value. This is so we can compare // the sentinel value against the call info in context. If they're different, @@ -168,7 +164,7 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) // This option eliminates the [Request] and [Response] wrappers, and instead uses the // context.Context to propagate information such as headers. func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, request *Req) (*Res, error) { - response, err := c.CallUnary(ctx, requestFromClientContext(ctx, request)) + response, err := c.CallUnary(ctx, NewRequest(request)) if response != nil { return response.Msg, err } @@ -241,7 +237,7 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques // This option eliminates the [Request] wrapper, and instead uses the context.Context to // propagate information such as headers. func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { - return c.CallServerStream(ctx, requestFromClientContext(ctx, requestMsg)) + return c.CallServerStream(ctx, NewRequest(requestMsg)) } // CallBidiStream calls a bidirectional streaming procedure. diff --git a/connect.go b/connect.go index 963c7998..8ab20c5d 100644 --- a/connect.go +++ b/connect.go @@ -211,11 +211,6 @@ func (r *Request[_]) setRequestMethod(method string) { r.method = method } -// setHeader sets the request header to the given value. -func (r *Request[_]) setHeader(header http.Header) { - r.header = header -} - // AnyRequest is the common method set of every [Request], regardless of type // parameter. It's used in unary interceptors. // @@ -373,6 +368,7 @@ type hasHTTPMethod interface { getHTTPMethod() string } +// errStreamingClientConn is a sentinel error implementation of StreamingClientConn type errStreamingClientConn struct { StreamingClientConn err error diff --git a/connect_ext_test.go b/connect_ext_test.go index eddec15a..fd3d06d1 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -31,6 +31,7 @@ import ( "net/http" "net/http/httptest" "runtime" + "sort" "strings" "sync" "testing" @@ -63,20 +64,26 @@ const ( clientMiddlewareErrorHeader = "Connect-Trigger-HTTP-Error" ) +var ( + expectedHeaderValues = []string{"foo", "bar"} //nolint:gochecknoglobals +) + func TestCallInfo(t *testing.T) { t.Parallel() t.Run("simple_api", func(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connectsimple.NewPingServiceHandler( - pingServerSimple{checkMetadata: true}, + pingServerSimple{}, )) server := memhttptest.NewServer(t, mux) client := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) t.Run("unary", func(t *testing.T) { num := int64(42) ctx, callInfo := connect.NewClientContext(context.Background()) - callInfo.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + callInfo.RequestHeader().Add(clientHeader, el) + } expect := &pingv1.PingResponse{Number: num} response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) assert.Equal(t, response, expect) @@ -85,13 +92,23 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) - assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + // When using the simple API for unary calls, users can only access response headers and trailers + // from the call info in context. + assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + }) + t.Run("unary_no_callinfo", func(t *testing.T) { + num := int64(42) + expect := &pingv1.PingResponse{Number: num} + response, err := client.Ping(context.Background(), &pingv1.PingRequest{Number: num}) + assert.Equal(t, response, expect) + assert.Nil(t, err) }) t.Run("server_stream", func(t *testing.T) { ctx, callInfo := connect.NewClientContext(context.Background()) - callInfo.RequestHeader().Set(clientHeader, headerValue) - + for _, el := range expectedHeaderValues { + callInfo.RequestHeader().Add(clientHeader, el) + } val := 3 stream, err := client.CountUp(ctx, &pingv1.CountUpRequest{ Number: int64(val), @@ -115,8 +132,35 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) - assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + + // On server-streaming calls, users can access response headers and trailers + // either from the call info in context or from the stream itself. + // This verifies that the both the stream and the call info have the same information + assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + }) + t.Run("server_stream_no_callinfo", func(t *testing.T) { + val := 3 + stream, err := client.CountUp(context.Background(), &pingv1.CountUpRequest{ + Number: int64(val), + }) + assert.Nil(t, err) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) + assert.Nil(t, stream.Err()) + assert.Nil(t, stream.Close()) }) }) t.Run("generics_api", func(t *testing.T) { @@ -130,10 +174,14 @@ func TestCallInfo(t *testing.T) { t.Run("unary", func(t *testing.T) { num := int64(42) request := connect.NewRequest(&pingv1.PingRequest{Number: num}) - request.Header().Set(clientHeader, headerValue) - expect := &pingv1.PingResponse{Number: num} ctx, callInfo := connect.NewClientContext(context.Background()) + // With the generics API, a user can use the call info or request wrapper or both to set request headers. + // The resulting headers should be combined and sent in the request. + request.Header().Add(clientHeader, "foo") + callInfo.RequestHeader().Add(clientHeader, "bar") + expect := &pingv1.PingResponse{Number: num} + response, err := client.Ping(ctx, request) assert.Nil(t, err) assert.Equal(t, response.Msg, expect) @@ -145,19 +193,42 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, callInfo.Spec().Procedure, request.Spec().Procedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, request.Peer().Addr) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) - assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + + // When using the generics API for unary calls, users can access response headers and trailers + // either from the call info in context or the response wrapper. This verifies both have the same information. + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) }) - t.Run("server_stream", func(t *testing.T) { - ctx, callInfo := connect.NewClientContext(context.Background()) - callInfo.RequestHeader().Set(clientHeader, headerValue) + t.Run("unary_no_callinfo", func(t *testing.T) { + num := int64(42) + request := connect.NewRequest(&pingv1.PingRequest{Number: num}) + request.Header().Add(clientHeader, "foo") + request.Header().Add(clientHeader, "bar") + expect := &pingv1.PingResponse{Number: num} + response, err := client.Ping(context.Background(), request) + assert.Nil(t, err) + assert.Equal(t, response.Msg, expect) + assert.Equal(t, request.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, request.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, request.Spec().IsClient) + assert.Equal(t, request.Peer().Addr, httptest.DefaultRemoteAddr) + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) + }) + t.Run("server_stream", func(t *testing.T) { val := 3 req := connect.NewRequest(&pingv1.CountUpRequest{ Number: int64(val), }) + ctx, callInfo := connect.NewClientContext(context.Background()) + // With the generics API, A user can use the call info or request wrapper or both to set request headers. + // The resulting headers should be combined and sent in the request. + callInfo.RequestHeader().Set(clientHeader, "foo") + req.Header().Add(clientHeader, "bar") + stream, err := client.CountUp(ctx, req) assert.Nil(t, err) // Receive expected messages @@ -174,48 +245,88 @@ func TestCallInfo(t *testing.T) { assert.False(t, stream.Receive()) assert.Nil(t, stream.Err()) assert.Nil(t, stream.Close()) - // Assert values on request and stream + // Assert values on request assert.Equal(t, req.Spec().StreamType, connect.StreamTypeServer) assert.Equal(t, req.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) assert.True(t, req.Spec().IsClient) assert.Equal(t, req.Peer().Addr, httptest.DefaultRemoteAddr) - assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) // Assert the same values are in the call info assert.Equal(t, callInfo.Spec().StreamType, req.Spec().StreamType) assert.Equal(t, callInfo.Spec().Procedure, req.Spec().Procedure) assert.True(t, callInfo.Spec().IsClient) assert.Equal(t, callInfo.Peer().Addr, req.Peer().Addr) - assert.Equal(t, callInfo.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, callInfo.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + + // On server-streaming calls, users can access response headers and trailers + // either from the call info in context or from the stream itself. + // This verifies that the both the stream and the call info have the same information + assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + }) + t.Run("server_stream_no_callinfo", func(t *testing.T) { + val := 3 + req := connect.NewRequest(&pingv1.CountUpRequest{ + Number: int64(val), + }) + req.Header().Set(clientHeader, "foo") + req.Header().Add(clientHeader, "bar") + + stream, err := client.CountUp(context.Background(), req) + assert.Nil(t, err) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) + assert.Nil(t, stream.Err()) + assert.Nil(t, stream.Close()) + // Assert values on request + assert.Equal(t, req.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, req.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, req.Spec().IsClient) + assert.Equal(t, req.Peer().Addr, httptest.DefaultRemoteAddr) + assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) }) }) } -func TestServer(t *testing.T) { +func TestServer(t *testing.T) { //nolint:gocyclo t.Parallel() testPing := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("ping", func(t *testing.T) { num := int64(42) request := connect.NewRequest(&pingv1.PingRequest{Number: num}) - request.Header().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Add(clientHeader, el) + } expect := &pingv1.PingResponse{Number: num} response, err := client.Ping(context.Background(), request) assert.Nil(t, err) assert.Equal(t, response.Msg, expect) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("zero_ping", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) - request.Header().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Add(clientHeader, el) + } response, err := client.Ping(context.Background(), request) assert.Nil(t, err) var expect pingv1.PingResponse assert.Equal(t, response.Msg, &expect) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("large_ping", func(t *testing.T) { // Using a large payload splits the request and response over multiple @@ -226,12 +337,14 @@ func TestServer(t *testing.T) { } hellos := strings.Repeat("hello", 1024*1024) // ~5mb request := connect.NewRequest(&pingv1.PingRequest{Text: hellos}) - request.Header().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Add(clientHeader, el) + } response, err := client.Ping(context.Background(), request) assert.Nil(t, err) assert.Equal(t, response.Msg.GetText(), hellos) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("ping_error", func(t *testing.T) { _, err := client.Ping( @@ -244,7 +357,7 @@ func TestServer(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) defer cancel() request := connect.NewRequest(&pingv1.PingRequest{}) - request.Header().Set(clientHeader, headerValue) + request.Header().Set(clientHeader, "foo") _, err := client.Ping(ctx, request) assert.Equal(t, connect.CodeOf(err), connect.CodeDeadlineExceeded) }) @@ -256,7 +369,9 @@ func TestServer(t *testing.T) { expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 ) stream := client.Sum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } for i := range upTo { err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) assert.Nil(t, err, assert.Sprintf("send %d", i)) @@ -264,8 +379,8 @@ func TestServer(t *testing.T) { response, err := stream.CloseAndReceive() assert.Nil(t, err) assert.Equal(t, response.Msg.GetSum(), expect) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("sum_error", func(t *testing.T) { stream := client.Sum(context.Background()) @@ -278,11 +393,14 @@ func TestServer(t *testing.T) { }) t.Run("sum_close_and_receive_without_send", func(t *testing.T) { stream := client.Sum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } got, err := stream.CloseAndReceive() assert.Nil(t, err) assert.Equal(t, got.Msg, &pingv1.SumResponse{}) // receive header only stream - assert.Equal(t, got.Header().Values(handlerHeader), []string{headerValue}) + assert.True(t, compareValues(got.Header().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(got.Trailer().Values(handlerTrailer), expectedHeaderValues)) }) } testCountUp := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper @@ -294,7 +412,8 @@ func TestServer(t *testing.T) { expect = append(expect, int64(i+1)) } request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo}) - request.Header().Set(clientHeader, headerValue) + request.Header().Add(clientHeader, "foo") + request.Header().Add(clientHeader, "bar") stream, err := client.CountUp(context.Background(), request) assert.Nil(t, err) for stream.Receive() { @@ -332,7 +451,8 @@ func TestServer(t *testing.T) { t.Run("count_up_cancel_after_first_response", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) request := connect.NewRequest(&pingv1.CountUpRequest{Number: 5}) - request.Header().Set(clientHeader, headerValue) + request.Header().Add(clientHeader, "foo") + request.Header().Add(clientHeader, "bar") stream, err := client.CountUp(ctx, request) assert.Nil(t, err) assert.True(t, stream.Receive()) @@ -349,7 +469,9 @@ func TestServer(t *testing.T) { expect := []int64{3, 8, 9} var got []int64 stream := client.CumSum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } if !expectSuccess { // server doesn't support HTTP/2 failNoHTTP2(t, stream) return @@ -378,8 +500,8 @@ func TestServer(t *testing.T) { }() wg.Wait() assert.Equal(t, got, expect) - assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("cumsum_error", func(t *testing.T) { stream := client.CumSum(context.Background()) @@ -399,7 +521,9 @@ func TestServer(t *testing.T) { }) t.Run("cumsum_empty_stream", func(t *testing.T) { stream := client.CumSum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } if !expectSuccess { // server doesn't support HTTP/2 failNoHTTP2(t, stream) return @@ -416,7 +540,9 @@ func TestServer(t *testing.T) { t.Run("cumsum_cancel_after_first_response", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) stream := client.CumSum(ctx) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } if !expectSuccess { // server doesn't support HTTP/2 failNoHTTP2(t, stream) cancel() @@ -446,7 +572,9 @@ func TestServer(t *testing.T) { cancel() return } - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 8})) cancel() // On a subsequent send, ensure that we are still catching context @@ -476,7 +604,9 @@ func TestServer(t *testing.T) { request := connect.NewRequest(&pingv1.FailRequest{ Code: int32(connect.CodeResourceExhausted), }) - request.Header().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Add(clientHeader, el) + } response, err := client.Fail(context.Background(), request) assert.Nil(t, response) @@ -488,8 +618,8 @@ func TestServer(t *testing.T) { assert.Equal(t, connectErr.Code(), connect.CodeResourceExhausted) assert.Equal(t, connectErr.Error(), "resource_exhausted: "+errorMessage) assert.Zero(t, connectErr.Details()) - assert.Equal(t, connectErr.Meta().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, connectErr.Meta().Values(handlerTrailer), []string{trailerValue}) + assert.True(t, compareValues(connectErr.Meta().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(connectErr.Meta().Values(handlerTrailer), expectedHeaderValues)) }) t.Run("middleware_errors_unary", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) @@ -2176,6 +2306,9 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { ) assert.Nil(t, err) req.Header.Set("Content-Type", "application/grpc") + for _, el := range expectedHeaderValues { + req.Header.Add(clientHeader, el) + } res, err := server.Client().Do(req) assert.Nil(t, err) assert.Equal(t, res.StatusCode, http.StatusOK) @@ -2903,51 +3036,57 @@ func (p *pluggablePingServer) CumSum( type pingServer struct { pingv1connect.UnimplementedPingServiceHandler + // Whether to verify metadata sent to the server. Can be used to force an error returned from the server + // by intentionally sending no metadata. checkMetadata bool includeErrorDetails bool } func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + if err := validateRequestInfo(request); err != nil { + return nil, err + } + if err := compareContextAndRequest(ctx, request, request.Header()); err != nil { + return nil, err + } if p.checkMetadata { if err := expectMetadata(request.Header()); err != nil { return nil, err } } - if request.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if request.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } response := connect.NewResponse( &pingv1.PingResponse{ Number: request.Msg.GetNumber(), Text: request.Msg.GetText(), }, ) - response.Header().Set(handlerHeader, headerValue) - response.Trailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := request.Header().Values(clientHeader) + for _, el := range reqHeader { + response.Header().Add(handlerHeader, el) + response.Trailer().Add(handlerTrailer, el) + } + return response, nil } func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { - if p.checkMetadata { - if err := expectMetadata(request.Header()); err != nil { - return nil, err - } + if err := validateRequestInfo(request); err != nil { + return nil, err } - if request.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if request.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + if err := compareContextAndRequest(ctx, request, request.Header()); err != nil { + return nil, err } err := connect.NewError( connect.Code(request.Msg.GetCode()), errors.New(errorMessage), ) - err.Meta().Set(handlerHeader, headerValue) - err.Meta().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the error metadata headers and trailers + reqHeader := request.Header().Values(clientHeader) + for _, el := range reqHeader { + err.Meta().Add(handlerHeader, el) + err.Meta().Add(handlerTrailer, el) + } if p.includeErrorDetails { detail, derr := connect.NewErrorDetail(&pingv1.FailRequest{Code: request.Msg.GetCode()}) if derr != nil { @@ -2962,17 +3101,14 @@ func (p pingServer) Sum( ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest], ) (*connect.Response[pingv1.SumResponse], error) { + if err := validateRequestInfo(stream); err != nil { + return nil, err + } if p.checkMetadata { if err := expectMetadata(stream.RequestHeader()); err != nil { return nil, err } } - if stream.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if stream.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } var sum int64 for stream.Receive() { sum += stream.Msg().GetNumber() @@ -2981,8 +3117,12 @@ func (p pingServer) Sum( return nil, stream.Err() } response := connect.NewResponse(&pingv1.SumResponse{Sum: sum}) - response.Header().Set(handlerHeader, headerValue) - response.Trailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := stream.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + response.Header().Add(handlerHeader, el) + response.Trailer().Add(handlerTrailer, el) + } return response, nil } @@ -2991,25 +3131,29 @@ func (p pingServer) CountUp( request *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse], ) error { + if err := validateRequestInfo(stream.Conn()); err != nil { + return err + } + if err := compareContextAndRequest(ctx, request, request.Header()); err != nil { + return err + } if p.checkMetadata { if err := expectMetadata(request.Header()); err != nil { return err } } - if request.Peer().Addr == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if request.Peer().Protocol == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } if request.Msg.GetNumber() <= 0 { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( "number must be positive: got %v", request.Msg.GetNumber(), )) } - stream.ResponseHeader().Set(handlerHeader, headerValue) - stream.ResponseTrailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := request.Header().Values(clientHeader) + for _, el := range reqHeader { + stream.ResponseHeader().Add(handlerHeader, el) + stream.ResponseTrailer().Add(handlerTrailer, el) + } for i := range request.Msg.GetNumber() { if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { return err @@ -3028,14 +3172,11 @@ func (p pingServer) CumSum( return err } } - if stream.Peer().Addr == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if stream.Peer().Protocol == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + reqHeader := stream.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + stream.ResponseHeader().Add(handlerHeader, el) + stream.ResponseTrailer().Add(handlerTrailer, el) } - stream.ResponseHeader().Set(handlerHeader, headerValue) - stream.ResponseTrailer().Set(handlerTrailer, trailerValue) for { msg, err := stream.Receive() if errors.Is(err, io.EOF) { @@ -3062,23 +3203,24 @@ func (p pingServerSimple) Ping(ctx context.Context, request *pingv1.PingRequest) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } + if err := validateRequestInfo(callInfo); err != nil { + return nil, err + } if p.checkMetadata { if err := expectMetadata(callInfo.RequestHeader()); err != nil { return nil, err } } - if callInfo.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if callInfo.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } response := &pingv1.PingResponse{ Number: request.GetNumber(), Text: request.GetText(), } - callInfo.ResponseHeader().Set(handlerHeader, headerValue) - callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := callInfo.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + callInfo.ResponseHeader().Add(handlerHeader, el) + callInfo.ResponseTrailer().Add(handlerTrailer, el) + } return response, nil } @@ -3091,25 +3233,26 @@ func (p pingServerSimple) CountUp( if !ok { return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } + if err := validateRequestInfo(callInfo); err != nil { + return err + } if p.checkMetadata { if err := expectMetadata(callInfo.RequestHeader()); err != nil { return err } } - if callInfo.Peer().Addr == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if callInfo.Peer().Protocol == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } if request.GetNumber() <= 0 { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( "number must be positive: got %v", request.GetNumber(), )) } - callInfo.ResponseHeader().Set(handlerHeader, headerValue) - callInfo.ResponseTrailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := callInfo.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + callInfo.ResponseHeader().Add(handlerHeader, el) + callInfo.ResponseTrailer().Add(handlerTrailer, el) + } for i := range request.GetNumber() { if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { return err @@ -3254,15 +3397,101 @@ func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSu assert.Nil(tb, stream.CloseResponse()) } +type requestInfo interface { + Peer() connect.Peer + Spec() connect.Spec +} + +// Validates that the peer and spec information is set correctly in a request. +func validateRequestInfo(request requestInfo) error { + if request.Peer().Addr == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } + if request.Peer().Protocol == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } + if request.Spec().Procedure == "" { + return connect.NewError(connect.CodeInternal, errors.New("no procedure name")) + } + return nil +} + +// Compares the information in the call info in context with the given request information to verify they match. +func compareContextAndRequest(ctx context.Context, request requestInfo, requestHeaders http.Header) error { + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) + if !ok { + return connect.NewError(connect.CodeInternal, errors.New("no call info in handler context")) + } + if request.Peer().Addr != callInfo.Peer().Addr { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched peer address. found %s in request and %s in call info", request.Peer().Addr, callInfo.Peer().Addr)) + } + if request.Peer().Protocol != callInfo.Peer().Protocol { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched peer protocol. found %s in request and %s in call info", request.Peer().Addr, callInfo.Peer().Addr)) + } + if request.Spec().Procedure != callInfo.Spec().Procedure { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched procedure name. found %s in request and %s in call info", request.Spec().Procedure, request.Spec().Procedure)) + } + if valid := compareHeaders(callInfo.RequestHeader(), requestHeaders); !valid { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched request headers. found %+v in request and %+v in call info", callInfo.RequestHeader(), requestHeaders)) + } + return nil +} + +// Returns an error if the given http headers don't contain the expected header values. +// Typically, most methods in the pingServer implementations just read the request headers +// and copy those to the response headers and trailers and let the client verify that way. +// However, this method can be used in conjunction with the server's verifyMetadata setting +// to force an error to be returned if metadata isn't set. For example, see +// TestGRPCMissingTrailersError tests. func expectMetadata(meta http.Header) error { - if got := meta.Get(clientHeader); got != headerValue { + vals := meta.Values(clientHeader) + if ok := compareValues(vals, expectedHeaderValues); !ok { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( - "%s %q: got %q, expected %q", - "header", + "header %q: got %q, expected %q", clientHeader, - got, - headerValue, + vals, + expectedHeaderValues, )) } return nil } + +// Compares two http Header objects to verify they contain the exact same information. +func compareHeaders(hdr1 http.Header, hdr2 http.Header) bool { + if len(hdr1) != len(hdr2) { + return false + } + for key, hdr1Val := range hdr1 { + hdr2Val, ok := hdr2[key] + if !ok || len(hdr1Val) != len(hdr2Val) { + return false + } + + if equal := compareValues(hdr1Val, hdr2Val); !equal { + return false + } + } + return true +} + +// Compares two string slices of header values to verify they are the same, ignoring order. +func compareValues(hdr1 []string, hdr2 []string) bool { + if len(hdr1) != len(hdr2) { + return false + } + // Copy slices to avoid race conditions with other tests trying to read the headers + sorted1 := make([]string, len(hdr1)) + copy(sorted1, hdr1) + sorted2 := make([]string, len(hdr2)) + copy(sorted2, hdr2) + + sort.Strings(sorted1) + sort.Strings(sorted2) + + for idx, el := range sorted1 { + if el != sorted2[idx] { + return false + } + } + return true +} diff --git a/context.go b/context.go index 004be9ea..10aa8d0e 100644 --- a/context.go +++ b/context.go @@ -242,13 +242,3 @@ func getClientCallInfoFromContext(ctx context.Context) (*clientCallInfo, bool) { func newHandlerContext(ctx context.Context, info CallInfo) context.Context { return context.WithValue(ctx, handlerCallInfoContextKey{}, info) } - -// requestFromClientContext creates a new Request using the given context and message. -func requestFromClientContext[T any](ctx context.Context, message *T) *Request[T] { - request := NewRequest(message) - callInfo, ok := getClientCallInfoFromContext(ctx) - if ok { - request.setHeader(callInfo.RequestHeader()) - } - return request -} diff --git a/interceptor.go b/interceptor.go index f5aac3e1..75be7dac 100644 --- a/interceptor.go +++ b/interceptor.go @@ -131,8 +131,6 @@ func streamingClientThunk(next StreamingClientFunc) StreamingClientFunc { } func checkSentinel(ctx context.Context) bool { - callInfo, _ := ctx.Value(clientCallInfoContextKey{}).(*clientCallInfo) - sentinel, _ := ctx.Value(sentinelContextKey{}).(*clientCallInfo) // Only verify if there's a sentinel call info to compare it to - return callInfo == sentinel + return ctx.Value(clientCallInfoContextKey{}) == ctx.Value(sentinelContextKey{}) } From fa4e176fca38090b17c94d5936167fcddd82026d Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Wed, 23 Jul 2025 12:02:23 -0400 Subject: [PATCH 56/57] Feedback Signed-off-by: Steve Ayers --- client.go | 28 ++----------------- .../simple/gen/genconnect/simple.connect.go | 10 +++++-- cmd/protoc-gen-connect-go/main.go | 8 ++++-- connect.go | 3 +- context.go | 6 ++-- interceptor.go | 16 ++++++----- .../v1/collidev1connect/collide.connect.go | 6 +++- .../ping/v1/pingv1connect/ping.connect.go | 14 ++++++++-- 8 files changed, 45 insertions(+), 46 deletions(-) diff --git a/client.go b/client.go index e78f8cd3..8dbbc093 100644 --- a/client.go +++ b/client.go @@ -79,7 +79,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) conn.onRequestSend(func(r *http.Request) { request.setRequestMethod(r.Method) - callInfo, ok := getClientCallInfoFromContext(ctx) + callInfo, ok := clientCallInfoFromContext(ctx) if ok { callInfo.method = r.Method } @@ -116,7 +116,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien protocolClient.WriteRequestHeader(StreamTypeUnary, request.Header()) // Also set them in the context if there's a call info present - callInfo, callInfoOk := getClientCallInfoFromContext(ctx) + callInfo, callInfoOk := clientCallInfoFromContext(ctx) if callInfoOk { callInfo.peer = request.Peer() callInfo.spec = request.Spec() @@ -158,19 +158,6 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) return c.callUnary(ctx, request) } -// CallUnarySimple calls a request-response procedure using the function signature -// associated with the "simple" generation option. -// -// This option eliminates the [Request] and [Response] wrappers, and instead uses the -// context.Context to propagate information such as headers. -func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, request *Req) (*Res, error) { - response, err := c.CallUnary(ctx, NewRequest(request)) - if response != nil { - return response.Msg, err - } - return nil, err -} - // CallClientStream calls a client streaming procedure. func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamForClient[Req, Res] { if c.err != nil { @@ -187,7 +174,7 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if c.err != nil { return nil, c.err } - callInfo, callInfoOk := getClientCallInfoFromContext(ctx) + callInfo, callInfoOk := clientCallInfoFromContext(ctx) // Set values in the context if there's a call info present if callInfoOk { // Copy the call info into a sentinel value. This is so we can compare @@ -231,15 +218,6 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques }, nil } -// CallServerStreamSimple calls a server streaming procedure using the function signature -// associated with the "simple" generation option. -// -// This option eliminates the [Request] wrapper, and instead uses the context.Context to -// propagate information such as headers. -func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { - return c.CallServerStream(ctx, NewRequest(requestMsg)) -} - // CallBidiStream calls a bidirectional streaming procedure. func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForClient[Req, Res] { if c.err != nil { diff --git a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go index 68149cc7..4ee0b4dc 100644 --- a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go +++ b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go @@ -116,7 +116,11 @@ type testServiceClient struct { // Method calls connect.test.simple.TestService.Method. func (c *testServiceClient) Method(ctx context.Context, req *gen.Request) (*gen.Response, error) { - return c.method.CallUnarySimple(ctx, req) + response, err := c.method.CallUnary(ctx, connect.NewRequest(req)) + if response != nil { + return response.Msg, err + } + return nil, err } // MethodClientStream calls connect.test.simple.TestService.MethodClientStream. @@ -126,12 +130,12 @@ func (c *testServiceClient) MethodClientStream(ctx context.Context) *connect.Cli // MethodServerStream calls connect.test.simple.TestService.MethodServerStream. func (c *testServiceClient) MethodServerStream(ctx context.Context, req *gen.Request) (*connect.ServerStreamForClient[gen.Response], error) { - return c.methodServerStream.CallServerStreamSimple(ctx, req) + return c.methodServerStream.CallServerStream(ctx, connect.NewRequest(req)) } // MethodBidiStream calls connect.test.simple.TestService.MethodBidiStream. func (c *testServiceClient) MethodBidiStream(ctx context.Context, req *gen.Request) (*connect.ServerStreamForClient[gen.Response], error) { - return c.methodBidiStream.CallServerStreamSimple(ctx, req) + return c.methodBidiStream.CallServerStream(ctx, connect.NewRequest(req)) } // TestServiceHandler is an implementation of the connect.test.simple.TestService service. diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index b9018ce1..e65c696a 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -384,7 +384,7 @@ func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, na g.P("return c.", unexport(method.GoName), ".CallClientStream(ctx)") case !isStreamingClient && isStreamingServer: if simple { - g.P("return c.", unexport(method.GoName), ".CallServerStreamSimple(ctx, req)") + g.P("return c.", unexport(method.GoName), ".CallServerStream(ctx, ", connectPackage.Ident("NewRequest"), "(req))") } else { g.P("return c.", unexport(method.GoName), ".CallServerStream(ctx, req)") } @@ -392,7 +392,11 @@ func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, na g.P("return c.", unexport(method.GoName), ".CallBidiStream(ctx)") default: if simple { - g.P("return c.", unexport(method.GoName), ".CallUnarySimple(ctx, req)") + g.P("response, err := c.", unexport(method.GoName), ".CallUnary(ctx, ", connectPackage.Ident("NewRequest"), "(req))") + g.P("if response != nil {") + g.P("return response.Msg, err") + g.P("}") + g.P("return nil, err") } else { g.P("return c.", unexport(method.GoName), ".CallUnary(ctx, req)") } diff --git a/connect.go b/connect.go index 8ab20c5d..4f8349c1 100644 --- a/connect.go +++ b/connect.go @@ -368,9 +368,8 @@ type hasHTTPMethod interface { getHTTPMethod() string } -// errStreamingClientConn is a sentinel error implementation of StreamingClientConn +// errStreamingClientConn is a sentinel error implementation of StreamingClientConn. type errStreamingClientConn struct { - StreamingClientConn err error } diff --git a/context.go b/context.go index 10aa8d0e..137c8c90 100644 --- a/context.go +++ b/context.go @@ -232,13 +232,13 @@ func (w *responseWrapper[Res]) ResponseTrailer() http.Header { return w.response.Trailer() } -// Gets a client (i.e. outgoing) call info from context. -func getClientCallInfoFromContext(ctx context.Context) (*clientCallInfo, bool) { +// clientCallInfoFromContext gets the call info from a client/outgoing context. +func clientCallInfoFromContext(ctx context.Context) (*clientCallInfo, bool) { info, ok := ctx.Value(clientCallInfoContextKey{}).(*clientCallInfo) return info, ok } -// newHandlerContext creates a new handler (i.e. incoming) context. +// newHandlerContext creates a new handler/incoming context. func newHandlerContext(ctx context.Context, info CallInfo) context.Context { return context.WithValue(ctx, handlerCallInfoContextKey{}, info) } diff --git a/interceptor.go b/interceptor.go index 75be7dac..d3bc1374 100644 --- a/interceptor.go +++ b/interceptor.go @@ -114,8 +114,8 @@ func (c *chain) WrapStreamingHandler(next StreamingHandlerFunc) StreamingHandler func unaryThunk(next UnaryFunc) UnaryFunc { return func(ctx context.Context, req AnyRequest) (AnyResponse, error) { - if !checkSentinel(ctx) { - return nil, errNewClientContextProhibited + if err := checkSentinel(ctx); err != nil { + return nil, err } return next(ctx, req) } @@ -123,14 +123,16 @@ func unaryThunk(next UnaryFunc) UnaryFunc { func streamingClientThunk(next StreamingClientFunc) StreamingClientFunc { return func(ctx context.Context, spec Spec) StreamingClientConn { - if !checkSentinel(ctx) { - return &errStreamingClientConn{err: errNewClientContextProhibited} + if err := checkSentinel(ctx); err != nil { + return &errStreamingClientConn{err: err} } return next(ctx, spec) } } -func checkSentinel(ctx context.Context) bool { - // Only verify if there's a sentinel call info to compare it to - return ctx.Value(clientCallInfoContextKey{}) == ctx.Value(sentinelContextKey{}) +func checkSentinel(ctx context.Context) error { + if ctx.Value(clientCallInfoContextKey{}) != ctx.Value(sentinelContextKey{}) { + return errNewClientContextProhibited + } + return nil } diff --git a/internal/gen/simple/connect/collide/v1/collidev1connect/collide.connect.go b/internal/gen/simple/connect/collide/v1/collidev1connect/collide.connect.go index ba3a7f0c..5081e3ad 100644 --- a/internal/gen/simple/connect/collide/v1/collidev1connect/collide.connect.go +++ b/internal/gen/simple/connect/collide/v1/collidev1connect/collide.connect.go @@ -83,7 +83,11 @@ type collideServiceClient struct { // Import calls connect.collide.v1.CollideService.Import. func (c *collideServiceClient) Import(ctx context.Context, req *v1.ImportRequest) (*v1.ImportResponse, error) { - return c._import.CallUnarySimple(ctx, req) + response, err := c._import.CallUnary(ctx, connect.NewRequest(req)) + if response != nil { + return response.Msg, err + } + return nil, err } // CollideServiceHandler is an implementation of the connect.collide.v1.CollideService service. diff --git a/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go index 9fe1d615..2ac64313 100644 --- a/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go @@ -134,12 +134,20 @@ type pingServiceClient struct { // Ping calls connect.ping.v1.PingService.Ping. func (c *pingServiceClient) Ping(ctx context.Context, req *v1.PingRequest) (*v1.PingResponse, error) { - return c.ping.CallUnarySimple(ctx, req) + response, err := c.ping.CallUnary(ctx, connect.NewRequest(req)) + if response != nil { + return response.Msg, err + } + return nil, err } // Fail calls connect.ping.v1.PingService.Fail. func (c *pingServiceClient) Fail(ctx context.Context, req *v1.FailRequest) (*v1.FailResponse, error) { - return c.fail.CallUnarySimple(ctx, req) + response, err := c.fail.CallUnary(ctx, connect.NewRequest(req)) + if response != nil { + return response.Msg, err + } + return nil, err } // Sum calls connect.ping.v1.PingService.Sum. @@ -149,7 +157,7 @@ func (c *pingServiceClient) Sum(ctx context.Context) *connect.ClientStreamForCli // CountUp calls connect.ping.v1.PingService.CountUp. func (c *pingServiceClient) CountUp(ctx context.Context, req *v1.CountUpRequest) (*connect.ServerStreamForClient[v1.CountUpResponse], error) { - return c.countUp.CallServerStreamSimple(ctx, req) + return c.countUp.CallServerStream(ctx, connect.NewRequest(req)) } // CumSum calls connect.ping.v1.PingService.CumSum. From f131a462c25a585b4d383f7546c1ed23cace8bb5 Mon Sep 17 00:00:00 2001 From: Steve Ayers Date: Thu, 24 Jul 2025 14:10:00 -0400 Subject: [PATCH 57/57] Add full host of tests for all RPC types and simple vs. generics API. Also fix client stream API --- bench_test.go | 2 +- client.go | 54 +- client_ext_test.go | 2 +- client_stream.go | 62 ++ .../simple/gen/genconnect/simple.connect.go | 4 +- cmd/protoc-gen-connect-go/main.go | 2 +- connect_ext_test.go | 959 +++++++++++++----- handler.go | 6 + interceptor_ext_test.go | 584 ++++++++--- .../ping/v1/pingv1connect/ping.connect.go | 4 +- 10 files changed, 1245 insertions(+), 434 deletions(-) diff --git a/bench_test.go b/bench_test.go index d5ed9416..e9b3e4b4 100644 --- a/bench_test.go +++ b/bench_test.go @@ -124,7 +124,7 @@ func BenchmarkConnect(b *testing.B) { response, err := stream.CloseAndReceive() if err != nil { b.Error(err) - } else if got := response.Msg.GetSum(); got != expect { + } else if got := response.GetSum(); got != expect { b.Errorf("expected %d, got %d", expect, got) } } diff --git a/client.go b/client.go index 6a25144a..f9e32daf 100644 --- a/client.go +++ b/client.go @@ -170,10 +170,14 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo } // CallClientStream calls a client streaming procedure in simple mode. -func (c *Client[Req, Res]) CallClientStreamSimple(ctx context.Context) (*ClientStreamForClient[Req, Res], error) { - stream := c.CallClientStream(ctx) - if stream.err != nil { - return nil, stream.err +func (c *Client[Req, Res]) CallClientStreamSimple(ctx context.Context) (*ClientStreamForClientSimple[Req, Res], error) { + if c.err != nil { + return &ClientStreamForClientSimple[Req, Res]{err: c.err}, c.err + } + + stream := &ClientStreamForClientSimple[Req, Res]{ + conn: c.newConn(ctx, StreamTypeClient, nil), + initializer: c.config.Initializer, } if err := stream.Send(nil); err != nil { return nil, err @@ -186,31 +190,12 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if c.err != nil { return nil, c.err } - callInfo, callInfoOk := clientCallInfoFromContext(ctx) - // Set values in the context if there's a call info present - if callInfoOk { - // Copy the call info into a sentinel value. This is so we can compare - // the sentinel value against the call info in context. If they're different, - // we can stop the request. This protects against changing the context in interceptors. - ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo) - } conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method }) request.peer = conn.Peer() request.spec = conn.Spec() - // Set values in the context if there's a call info present - if callInfoOk { - callInfo.peer = conn.Peer() - callInfo.spec = conn.Spec() - callInfo.responseSource = conn - - // Merge any callInfo request headers first, then do the request. - // so that context headers show first in the list of headers - mergeHeaders(conn.RequestHeader(), callInfo.RequestHeader()) - } - mergeHeaders(conn.RequestHeader(), request.header) // Send always returns an io.EOF unless the error is from the client-side. @@ -254,6 +239,14 @@ func (c *Client[Req, Res]) CallBidiStreamSimple(ctx context.Context) (*BidiStrea } func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn { + callInfo, callInfoOk := clientCallInfoFromContext(ctx) + // Set values in the context if there's a call info present + if callInfoOk { + // Copy the call info into a sentinel value. This is so we can compare + // the sentinel value against the call info in context. If they're different, + // we can stop the request. This protects against changing the context in interceptors. + ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo) + } newConn := func(ctx context.Context, spec Spec) StreamingClientConn { header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing c.protocolClient.WriteRequestHeader(streamType, header) @@ -264,7 +257,20 @@ func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, o if interceptor := c.config.Interceptor; interceptor != nil { newConn = interceptor.WrapStreamingClient(newConn) } - return newConn(ctx, c.config.newSpec(streamType)) + conn := newConn(ctx, c.config.newSpec(streamType)) + + // Set values in the context if there's a call info present + if callInfoOk { + callInfo.peer = conn.Peer() + callInfo.spec = conn.Spec() + callInfo.responseSource = conn + + // Merge any callInfo request headers first, then do the request. + // so that context headers show first in the list of headers + mergeHeaders(conn.RequestHeader(), callInfo.RequestHeader()) + } + + return conn } type clientConfig struct { diff --git a/client_ext_test.go b/client_ext_test.go index f2825588..639751be 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -431,7 +431,7 @@ func TestDynamicClient(t *testing.T) { if !assert.Nil(t, err) { return } - got := rsp.Msg.Get(methodDesc.Output().Fields().ByName("sum")).Int() + got := rsp.Get(methodDesc.Output().Fields().ByName("sum")).Int() assert.Equal(t, got, 42*2) }) t.Run("serverStream", func(t *testing.T) { diff --git a/client_stream.go b/client_stream.go index bacbca8a..9038d273 100644 --- a/client_stream.go +++ b/client_stream.go @@ -20,6 +20,68 @@ import ( "net/http" ) +// ClientStreamForClientsimple is the client's view of a client streaming RPC. +// for the simple API. +// +// It's returned from [Client].CallClientStreamSimple, but doesn't currently have an +// exported constructor function. +type ClientStreamForClientSimple[Req, Res any] struct { + conn StreamingClientConn + initializer maybeInitializer + // Error from client construction. If non-nil, return for all calls. + err error +} + +// Spec returns the specification for the RPC. +func (c *ClientStreamForClientSimple[_, _]) Spec() Spec { + return c.conn.Spec() +} + +// Peer describes the server for the RPC. +func (c *ClientStreamForClientSimple[_, _]) Peer() Peer { + return c.conn.Peer() +} + +// Send a message to the server. The first call to Send also sends the request +// headers. +// +// If the server returns an error, Send returns an error that wraps [io.EOF]. +// Clients should check for case using the standard library's [errors.Is] and +// unmarshal the error using CloseAndReceive. +func (c *ClientStreamForClientSimple[Req, Res]) Send(request *Req) error { + if c.err != nil { + return c.err + } + if request == nil { + return c.conn.Send(nil) + } + return c.conn.Send(request) +} + +// CloseAndReceive closes the send side of the stream and waits for the +// response. +func (c *ClientStreamForClientSimple[Req, Res]) CloseAndReceive() (*Res, error) { + if c.err != nil { + return nil, c.err + } + if err := c.conn.CloseRequest(); err != nil { + _ = c.conn.CloseResponse() + return nil, err + } + response, err := receiveUnaryResponse[Res](c.conn, c.initializer) + if err != nil { + _ = c.conn.CloseResponse() + return nil, err + } + return response.Msg, c.conn.CloseResponse() +} + +// Conn exposes the underlying StreamingClientConn. This may be useful if +// you'd prefer to wrap the connection in a different high-level API. +func (c *ClientStreamForClientSimple[Req, Res]) Conn() (StreamingClientConn, error) { + return c.conn, c.err +} + // ClientStreamForClient is the client's view of a client streaming RPC. // // It's returned from [Client].CallClientStream, but doesn't currently have an diff --git a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go index 6e1f426c..bbbaa05c 100644 --- a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go +++ b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go @@ -63,7 +63,7 @@ const ( // TestServiceClient is a client for the connect.test.simple.TestService service. type TestServiceClient interface { Method(context.Context, *gen.Request) (*gen.Response, error) - MethodClientStream(context.Context) (*connect.ClientStreamForClient[gen.Request, gen.Response], error) + MethodClientStream(context.Context) (*connect.ClientStreamForClientSimple[gen.Request, gen.Response], error) MethodServerStream(context.Context, *gen.Request) (*connect.ServerStreamForClient[gen.Response], error) MethodBidiStream(context.Context, *gen.Request) (*connect.ServerStreamForClient[gen.Response], error) } @@ -124,7 +124,7 @@ func (c *testServiceClient) Method(ctx context.Context, req *gen.Request) (*gen. } // MethodClientStream calls connect.test.simple.TestService.MethodClientStream. -func (c *testServiceClient) MethodClientStream(ctx context.Context) (*connect.ClientStreamForClient[gen.Request, gen.Response], error) { +func (c *testServiceClient) MethodClientStream(ctx context.Context) (*connect.ClientStreamForClientSimple[gen.Request, gen.Response], error) { return c.methodClientStream.CallClientStreamSimple(ctx) } diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index 3da65725..99b0cc90 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -435,7 +435,7 @@ func clientSignature(g *protogen.GeneratedFile, method *protogen.Method, named b // client streaming if simple { return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + - "(*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStreamForClient")) + + "(*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStreamForClientSimple")) + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + ", error)" } diff --git a/connect_ext_test.go b/connect_ext_test.go index fd3d06d1..20d9e764 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -56,8 +56,6 @@ const errorMessage = "oh no" // client doesn't set a header, and the server sets headers and trailers on the // response. const ( - headerValue = "some header value" - trailerValue = "some trailer value" clientHeader = "Connect-Client-Header" handlerHeader = "Connect-Handler-Header" handlerTrailer = "Connect-Handler-Trailer" @@ -79,23 +77,7 @@ func TestCallInfo(t *testing.T) { server := memhttptest.NewServer(t, mux) client := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) t.Run("unary", func(t *testing.T) { - num := int64(42) - ctx, callInfo := connect.NewClientContext(context.Background()) - for _, el := range expectedHeaderValues { - callInfo.RequestHeader().Add(clientHeader, el) - } - expect := &pingv1.PingResponse{Number: num} - response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) - assert.Equal(t, response, expect) - assert.Nil(t, err) - assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) - assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) - assert.True(t, callInfo.Spec().IsClient) - assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) - // When using the simple API for unary calls, users can only access response headers and trailers - // from the call info in context. - assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + testUnarySimple(t, client) }) t.Run("unary_no_callinfo", func(t *testing.T) { num := int64(42) @@ -104,13 +86,21 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, response, expect) assert.Nil(t, err) }) + t.Run("unary_generics_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + )) + server := memhttptest.NewServer(t, mux) + simpleClient := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) + testUnarySimple(t, simpleClient) + }) t.Run("server_stream", func(t *testing.T) { - ctx, callInfo := connect.NewClientContext(context.Background()) - for _, el := range expectedHeaderValues { - callInfo.RequestHeader().Add(clientHeader, el) - } + testServerStreamSimple(t, client) + }) + t.Run("server_stream_no_callinfo", func(t *testing.T) { val := 3 - stream, err := client.CountUp(ctx, &pingv1.CountUpRequest{ + stream, err := client.CountUp(context.Background(), &pingv1.CountUpRequest{ Number: int64(val), }) assert.Nil(t, err) @@ -128,78 +118,116 @@ func TestCallInfo(t *testing.T) { assert.False(t, stream.Receive()) assert.Nil(t, stream.Err()) assert.Nil(t, stream.Close()) - assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeServer) - assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) - assert.True(t, callInfo.Spec().IsClient) - assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) - - // On server-streaming calls, users can access response headers and trailers - // either from the call info in context or from the stream itself. - // This verifies that the both the stream and the call info have the same information - assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) - assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) }) - t.Run("server_stream_no_callinfo", func(t *testing.T) { - val := 3 - stream, err := client.CountUp(context.Background(), &pingv1.CountUpRequest{ - Number: int64(val), - }) + t.Run("server_stream_generics_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + )) + server := memhttptest.NewServer(t, mux) + simpleClient := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) + testServerStreamSimple(t, simpleClient) + }) + t.Run("client_stream", func(t *testing.T) { + testClientStreamSimple(t, client) + }) + t.Run("client_stream_generics_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + )) + server := memhttptest.NewServer(t, mux) + simpleClient := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) + testClientStreamSimple(t, simpleClient) + }) + t.Run("client_stream_no_callinfo", func(t *testing.T) { + const ( + upTo = 10 + expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 + ) + stream, err := client.Sum(context.Background()) assert.Nil(t, err) - // Receive expected messages - for idx := range val { - expected := int64(idx + 1) - assert.True(t, stream.Receive()) - assert.Nil(t, stream.Err()) - msg := stream.Msg() - assert.NotNil(t, msg) - assert.Equal(t, msg.GetNumber(), expected) + + // Send messages + for i := range upTo { + err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) + assert.Nil(t, err, assert.Sprintf("send %d", i)) } - // Stream should be done. Expect false on receive and close stream - assert.False(t, stream.Receive()) - assert.Nil(t, stream.Err()) - assert.Nil(t, stream.Close()) + response, err := stream.CloseAndReceive() + assert.Nil(t, err) + assert.Equal(t, response.GetSum(), expect) + }) + t.Run("bidi_stream", func(t *testing.T) { + testBidiStreamSimple(t, client, true) + }) + t.Run("bidi_stream_generics_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + )) + server := memhttptest.NewServer(t, mux) + simpleClient := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) + testBidiStreamSimple(t, simpleClient, true) + }) + t.Run("bidi_stream_no_callinfo", func(t *testing.T) { + send := []int64{3, 5, 1} + expect := []int64{3, 8, 9} + var got []int64 + stream, err := client.CumSum(context.Background()) + assert.Nil(t, err) + assert.NotNil(t, stream) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i, n := range send { + err := stream.Send(&pingv1.CumSumRequest{Number: n}) + assert.Nil(t, err, assert.Sprintf("send error #%d", i)) + } + assert.Nil(t, stream.CloseRequest()) + }() + go func() { + defer wg.Done() + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.Nil(t, err) + got = append(got, msg.GetSum()) + } + assert.Nil(t, stream.CloseResponse()) + }() + wg.Wait() + assert.Equal(t, got, expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) }) }) t.Run("generics_api", func(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler( - pingServer{checkMetadata: true}, + pingServer{}, )) server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) t.Run("unary", func(t *testing.T) { - num := int64(42) - request := connect.NewRequest(&pingv1.PingRequest{Number: num}) - - ctx, callInfo := connect.NewClientContext(context.Background()) - // With the generics API, a user can use the call info or request wrapper or both to set request headers. - // The resulting headers should be combined and sent in the request. - request.Header().Add(clientHeader, "foo") - callInfo.RequestHeader().Add(clientHeader, "bar") - expect := &pingv1.PingResponse{Number: num} - - response, err := client.Ping(ctx, request) - assert.Nil(t, err) - assert.Equal(t, response.Msg, expect) - assert.Equal(t, request.Spec().StreamType, connect.StreamTypeUnary) - assert.Equal(t, request.Spec().Procedure, pingv1connect.PingServicePingProcedure) - assert.True(t, request.Spec().IsClient) - assert.Equal(t, request.Peer().Addr, httptest.DefaultRemoteAddr) - assert.Equal(t, callInfo.Spec().StreamType, request.Spec().StreamType) - assert.Equal(t, callInfo.Spec().Procedure, request.Spec().Procedure) - assert.True(t, callInfo.Spec().IsClient) - assert.Equal(t, callInfo.Peer().Addr, request.Peer().Addr) - - // When using the generics API for unary calls, users can access response headers and trailers - // either from the call info in context or the response wrapper. This verifies both have the same information. - assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) - assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + testUnaryGenerics(t, client) + }) + t.Run("unary_simple_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + )) + server := memhttptest.NewServer(t, mux) + genericsClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + testUnaryGenerics(t, genericsClient) }) t.Run("unary_no_callinfo", func(t *testing.T) { num := int64(42) @@ -215,55 +243,22 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, request.Spec().Procedure, pingv1connect.PingServicePingProcedure) assert.True(t, request.Spec().IsClient) assert.Equal(t, request.Peer().Addr, httptest.DefaultRemoteAddr) - assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.PingResponse]{response: response} + assertResponseHeadersAndTrailers(t, wrapper) }) t.Run("server_stream", func(t *testing.T) { - val := 3 - req := connect.NewRequest(&pingv1.CountUpRequest{ - Number: int64(val), - }) - ctx, callInfo := connect.NewClientContext(context.Background()) - // With the generics API, A user can use the call info or request wrapper or both to set request headers. - // The resulting headers should be combined and sent in the request. - callInfo.RequestHeader().Set(clientHeader, "foo") - req.Header().Add(clientHeader, "bar") - - stream, err := client.CountUp(ctx, req) - assert.Nil(t, err) - // Receive expected messages - for idx := range val { - expected := int64(idx + 1) - assert.True(t, stream.Receive()) - assert.Nil(t, stream.Err()) - msg := stream.Msg() - assert.NotNil(t, msg) - assert.Equal(t, msg.GetNumber(), expected) - } - - // Stream should be done. Expect false on receive and close stream - assert.False(t, stream.Receive()) - assert.Nil(t, stream.Err()) - assert.Nil(t, stream.Close()) - // Assert values on request - assert.Equal(t, req.Spec().StreamType, connect.StreamTypeServer) - assert.Equal(t, req.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) - assert.True(t, req.Spec().IsClient) - assert.Equal(t, req.Peer().Addr, httptest.DefaultRemoteAddr) - - // Assert the same values are in the call info - assert.Equal(t, callInfo.Spec().StreamType, req.Spec().StreamType) - assert.Equal(t, callInfo.Spec().Procedure, req.Spec().Procedure) - assert.True(t, callInfo.Spec().IsClient) - assert.Equal(t, callInfo.Peer().Addr, req.Peer().Addr) - - // On server-streaming calls, users can access response headers and trailers - // either from the call info in context or from the stream itself. - // This verifies that the both the stream and the call info have the same information - assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) - assert.True(t, compareValues(callInfo.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(callInfo.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + testServerStreamGenerics(t, client) + }) + t.Run("server_stream_simple_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + )) + server := memhttptest.NewServer(t, mux) + genericsClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + testServerStreamGenerics(t, genericsClient) }) t.Run("server_stream_no_callinfo", func(t *testing.T) { val := 3 @@ -294,27 +289,99 @@ func TestCallInfo(t *testing.T) { assert.Equal(t, req.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) assert.True(t, req.Spec().IsClient) assert.Equal(t, req.Peer().Addr, httptest.DefaultRemoteAddr) - assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + assertResponseHeadersAndTrailers(t, stream) + }) + t.Run("client_stream", func(t *testing.T) { + testClientStreamGenerics(t, client) + }) + t.Run("client_stream_simple_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + )) + server := memhttptest.NewServer(t, mux) + genericsClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + testClientStreamGenerics(t, genericsClient) + }) + t.Run("client_stream_no_callinfo", func(t *testing.T) { + const ( + upTo = 10 + expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 + ) + stream := client.Sum(context.Background()) + stream.RequestHeader().Add(clientHeader, "foo") + stream.RequestHeader().Add(clientHeader, "bar") + + // Send messages + for i := range upTo { + err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) + assert.Nil(t, err, assert.Sprintf("send %d", i)) + } + + response, err := stream.CloseAndReceive() + assert.Nil(t, err) + assert.Equal(t, response.Msg.GetSum(), expect) + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.SumResponse]{response: response} + assertResponseHeadersAndTrailers(t, wrapper) + }) + t.Run("bidi_stream", func(t *testing.T) { + testBidiStreamGenerics(t, client, true) + }) + t.Run("bidi_stream_simple_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + )) + server := memhttptest.NewServer(t, mux) + genericsClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + testBidiStreamGenerics(t, genericsClient, true) + }) + t.Run("bidi_stream_no_callinfo", func(t *testing.T) { + send := []int64{3, 5, 1} + expect := []int64{3, 8, 9} + var got []int64 + stream := client.CumSum(context.Background()) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i, n := range send { + err := stream.Send(&pingv1.CumSumRequest{Number: n}) + assert.Nil(t, err, assert.Sprintf("send error #%d", i)) + } + assert.Nil(t, stream.CloseRequest()) + }() + go func() { + defer wg.Done() + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.Nil(t, err) + got = append(got, msg.GetSum()) + } + assert.Nil(t, stream.CloseResponse()) + }() + wg.Wait() + assert.Equal(t, got, expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) }) }) } -func TestServer(t *testing.T) { //nolint:gocyclo +func TestServer(t *testing.T) { t.Parallel() testPing := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("ping", func(t *testing.T) { - num := int64(42) - request := connect.NewRequest(&pingv1.PingRequest{Number: num}) - for _, el := range expectedHeaderValues { - request.Header().Add(clientHeader, el) - } - expect := &pingv1.PingResponse{Number: num} - response, err := client.Ping(context.Background(), request) - assert.Nil(t, err) - assert.Equal(t, response.Msg, expect) - assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) + testUnaryGenerics(t, client) }) t.Run("zero_ping", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) @@ -325,8 +392,10 @@ func TestServer(t *testing.T) { //nolint:gocyclo assert.Nil(t, err) var expect pingv1.PingResponse assert.Equal(t, response.Msg, &expect) - assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.PingResponse]{response: response} + assertResponseHeadersAndTrailers(t, wrapper) }) t.Run("large_ping", func(t *testing.T) { // Using a large payload splits the request and response over multiple @@ -343,8 +412,10 @@ func TestServer(t *testing.T) { //nolint:gocyclo response, err := client.Ping(context.Background(), request) assert.Nil(t, err) assert.Equal(t, response.Msg.GetText(), hellos) - assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.PingResponse]{response: response} + assertResponseHeadersAndTrailers(t, wrapper) }) t.Run("ping_error", func(t *testing.T) { _, err := client.Ping( @@ -364,23 +435,7 @@ func TestServer(t *testing.T) { //nolint:gocyclo } testSum := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("sum", func(t *testing.T) { - const ( - upTo = 10 - expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 - ) - stream := client.Sum(context.Background()) - for _, el := range expectedHeaderValues { - stream.RequestHeader().Add(clientHeader, el) - } - for i := range upTo { - err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) - assert.Nil(t, err, assert.Sprintf("send %d", i)) - } - response, err := stream.CloseAndReceive() - assert.Nil(t, err) - assert.Equal(t, response.Msg.GetSum(), expect) - assert.True(t, compareValues(response.Header().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(response.Trailer().Values(handlerTrailer), expectedHeaderValues)) + testClientStreamGenerics(t, client) }) t.Run("sum_error", func(t *testing.T) { stream := client.Sum(context.Background()) @@ -399,29 +454,15 @@ func TestServer(t *testing.T) { //nolint:gocyclo got, err := stream.CloseAndReceive() assert.Nil(t, err) assert.Equal(t, got.Msg, &pingv1.SumResponse{}) // receive header only stream - assert.True(t, compareValues(got.Header().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(got.Trailer().Values(handlerTrailer), expectedHeaderValues)) + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.SumResponse]{response: got} + assertResponseHeadersAndTrailers(t, wrapper) }) } testCountUp := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("count_up", func(t *testing.T) { - const upTo = 5 - got := make([]int64, 0, upTo) - expect := make([]int64, 0, upTo) - for i := range upTo { - expect = append(expect, int64(i+1)) - } - request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo}) - request.Header().Add(clientHeader, "foo") - request.Header().Add(clientHeader, "bar") - stream, err := client.CountUp(context.Background(), request) - assert.Nil(t, err) - for stream.Receive() { - got = append(got, stream.Msg().GetNumber()) - } - assert.Nil(t, stream.Err()) - assert.Nil(t, stream.Close()) - assert.Equal(t, got, expect) + testServerStreamGenerics(t, client) }) t.Run("count_up_error", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -465,43 +506,7 @@ func TestServer(t *testing.T) { //nolint:gocyclo } testCumSum := func(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper t.Run("cumsum", func(t *testing.T) { - send := []int64{3, 5, 1} - expect := []int64{3, 8, 9} - var got []int64 - stream := client.CumSum(context.Background()) - for _, el := range expectedHeaderValues { - stream.RequestHeader().Add(clientHeader, el) - } - if !expectSuccess { // server doesn't support HTTP/2 - failNoHTTP2(t, stream) - return - } - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - for i, n := range send { - err := stream.Send(&pingv1.CumSumRequest{Number: n}) - assert.Nil(t, err, assert.Sprintf("send error #%d", i)) - } - assert.Nil(t, stream.CloseRequest()) - }() - go func() { - defer wg.Done() - for { - msg, err := stream.Receive() - if errors.Is(err, io.EOF) { - break - } - assert.Nil(t, err) - got = append(got, msg.GetSum()) - } - assert.Nil(t, stream.CloseResponse()) - }() - wg.Wait() - assert.Equal(t, got, expect) - assert.True(t, compareValues(stream.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(stream.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) + testBidiStreamGenerics(t, client, expectSuccess) }) t.Run("cumsum_error", func(t *testing.T) { stream := client.CumSum(context.Background()) @@ -618,18 +623,24 @@ func TestServer(t *testing.T) { //nolint:gocyclo assert.Equal(t, connectErr.Code(), connect.CodeResourceExhausted) assert.Equal(t, connectErr.Error(), "resource_exhausted: "+errorMessage) assert.Zero(t, connectErr.Details()) - assert.True(t, compareValues(connectErr.Meta().Values(handlerHeader), expectedHeaderValues)) - assert.True(t, compareValues(connectErr.Meta().Values(handlerTrailer), expectedHeaderValues)) + // Wrap the connect error so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using a single function + wrapper := &errorWrapper{err: connectErr} + assertResponseHeadersAndTrailers(t, wrapper) }) t.Run("middleware_errors_unary", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) - request.Header().Set(clientMiddlewareErrorHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Set(clientMiddlewareErrorHeader, el) + } _, err := client.Ping(context.Background(), request) assertIsHTTPMiddlewareError(t, err) }) t.Run("middleware_errors_streaming", func(t *testing.T) { request := connect.NewRequest(&pingv1.CountUpRequest{Number: 10}) - request.Header().Set(clientMiddlewareErrorHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Set(clientMiddlewareErrorHeader, el) + } stream, err := client.CountUp(context.Background(), request) assert.Nil(t, err) assert.False(t, stream.Receive()) @@ -3104,6 +3115,9 @@ func (p pingServer) Sum( if err := validateRequestInfo(stream); err != nil { return nil, err } + if err := compareContextAndRequest(ctx, stream, stream.RequestHeader()); err != nil { + return nil, err + } if p.checkMetadata { if err := expectMetadata(stream.RequestHeader()); err != nil { return nil, err @@ -3166,29 +3180,7 @@ func (p pingServer) CumSum( ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], ) error { - var sum int64 - if p.checkMetadata { - if err := expectMetadata(stream.RequestHeader()); err != nil { - return err - } - } - reqHeader := stream.RequestHeader().Values(clientHeader) - for _, el := range reqHeader { - stream.ResponseHeader().Add(handlerHeader, el) - stream.ResponseTrailer().Add(handlerTrailer, el) - } - for { - msg, err := stream.Receive() - if errors.Is(err, io.EOF) { - return nil - } else if err != nil { - return err - } - sum += msg.GetNumber() - if err := stream.Send(&pingv1.CumSumResponse{Sum: sum}); err != nil { - return err - } - } + return handleCumSum(ctx, stream, p.checkMetadata) } type pingServerSimple struct { @@ -3266,23 +3258,24 @@ func (p pingServerSimple) Fail(ctx context.Context, request *pingv1.FailRequest) if !ok { return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) } + if err := validateRequestInfo(callInfo); err != nil { + return nil, err + } if p.checkMetadata { if err := expectMetadata(callInfo.RequestHeader()); err != nil { return nil, err } } - if callInfo.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if callInfo.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } err := connect.NewError( connect.Code(request.GetCode()), errors.New(errorMessage), ) - err.Meta().Set(handlerHeader, headerValue) - err.Meta().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the error metadata headers and trailers + reqHeader := callInfo.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + err.Meta().Add(handlerHeader, el) + err.Meta().Add(handlerTrailer, el) + } if p.includeErrorDetails { detail, derr := connect.NewErrorDetail(&pingv1.FailRequest{Code: request.GetCode()}) if derr != nil { @@ -3293,6 +3286,49 @@ func (p pingServerSimple) Fail(ctx context.Context, request *pingv1.FailRequest) return nil, err } +func (p pingServerSimple) Sum( + ctx context.Context, + stream *connect.ClientStream[pingv1.SumRequest], +) (*pingv1.SumResponse, error) { + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) + if !ok { + return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) + } + if err := validateRequestInfo(callInfo); err != nil { + return nil, err + } + if err := compareContextAndRequest(ctx, stream, stream.RequestHeader()); err != nil { + return nil, err + } + if p.checkMetadata { + if err := expectMetadata(callInfo.RequestHeader()); err != nil { + return nil, err + } + } + var sum int64 + for stream.Receive() { + sum += stream.Msg().GetNumber() + } + if stream.Err() != nil { + return nil, stream.Err() + } + response := &pingv1.SumResponse{Sum: sum} + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := stream.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + callInfo.ResponseHeader().Add(handlerHeader, el) + callInfo.ResponseTrailer().Add(handlerTrailer, el) + } + return response, nil +} + +func (p pingServerSimple) CumSum( + ctx context.Context, + stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], +) error { + return handleCumSum(ctx, stream, p.checkMetadata) +} + type deflateReader struct { r io.ReadCloser } @@ -3380,6 +3416,80 @@ func (failCompressor) Close() error { func (failCompressor) Reset(io.Writer) {} +type requestInfo interface { + Peer() connect.Peer + Spec() connect.Spec +} + +type responseInfo interface { + ResponseHeader() http.Header + ResponseTrailer() http.Header +} + +// responseWrapper wraps a Response object so that it can implement the responseInfo interface. +type responseWrapper[Res any] struct { + response *connect.Response[Res] +} + +func (w *responseWrapper[Res]) ResponseHeader() http.Header { + return w.response.Header() +} + +func (w *responseWrapper[Res]) ResponseTrailer() http.Header { + return w.response.Trailer() +} + +// errorWrapper wraps a Connect error so that it can implement the responseInfo interface. +type errorWrapper struct { + err *connect.Error +} + +func (w *errorWrapper) ResponseHeader() http.Header { + return w.err.Meta() +} + +func (w *errorWrapper) ResponseTrailer() http.Header { + return w.err.Meta() +} + +// handleCumSum handles the bidi endpoint CumSum for both pingServer and pingServerSimple. +// The API for bidi-streaming does not change for simple vs. generics API on the server. +func handleCumSum( + ctx context.Context, + stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], + checkMetadata bool, +) error { + if err := validateRequestInfo(stream); err != nil { + return err + } + if err := compareContextAndRequest(ctx, stream, stream.RequestHeader()); err != nil { + return err + } + if checkMetadata { + if err := expectMetadata(stream.RequestHeader()); err != nil { + return err + } + } + var sum int64 + reqHeader := stream.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + stream.ResponseHeader().Add(handlerHeader, el) + stream.ResponseTrailer().Add(handlerTrailer, el) + } + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } + sum += msg.GetNumber() + if err := stream.Send(&pingv1.CumSumResponse{Sum: sum}); err != nil { + return err + } + } +} + func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { tb.Helper() if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { @@ -3397,9 +3507,330 @@ func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSu assert.Nil(tb, stream.CloseResponse()) } -type requestInfo interface { - Peer() connect.Peer - Spec() connect.Spec +func testUnaryGenerics(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + num := int64(42) + request := connect.NewRequest(&pingv1.PingRequest{Number: num}) + + ctx, callInfo := connect.NewClientContext(context.Background()) + // With the generics API, a user can use the call info or request wrapper or both to set request headers. + // The resulting headers should be combined and sent in the request. + request.Header().Add(clientHeader, "foo") + callInfo.RequestHeader().Add(clientHeader, "bar") + expect := &pingv1.PingResponse{Number: num} + + response, err := client.Ping(ctx, request) + assert.Nil(t, err) + assert.Equal(t, response.Msg, expect) + // When using the generics API for unary calls, users can access request info such as spec and peer + // either from the call info in context or the request wrapper. This verifies both have the same information. + assert.Equal(t, request.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, request.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, request.Spec().IsClient) + assert.Equal(t, request.Peer().Addr, httptest.DefaultRemoteAddr) + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.PingResponse]{response: response} + + // When using the generics API for unary calls, users can access response headers and trailers + // either from the call info in context or the response wrapper. This verifies both have the same information. + assertResponseHeadersAndTrailers(t, callInfo) + assertResponseHeadersAndTrailers(t, wrapper) +} + +func testUnarySimple(t *testing.T, client pingv1connectsimple.PingServiceClient) { //nolint:thelper + num := int64(42) + ctx, callInfo := connect.NewClientContext(context.Background()) + for _, el := range expectedHeaderValues { + callInfo.RequestHeader().Add(clientHeader, el) + } + expect := &pingv1.PingResponse{Number: num} + response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) + assert.Equal(t, response, expect) + assert.Nil(t, err) + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + // When using the simple API for unary calls, users can only access response headers and trailers + // from the call info in context. + assertResponseHeadersAndTrailers(t, callInfo) +} + +func testServerStreamSimple(t *testing.T, client pingv1connectsimple.PingServiceClient) { //nolint:thelper + ctx, callInfo := connect.NewClientContext(context.Background()) + for _, el := range expectedHeaderValues { + callInfo.RequestHeader().Add(clientHeader, el) + } + val := 3 + stream, err := client.CountUp(ctx, &pingv1.CountUpRequest{ + Number: int64(val), + }) + assert.Nil(t, err) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) + assert.Nil(t, stream.Err()) + assert.Nil(t, stream.Close()) + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + // On server-streaming calls, users can access response headers and trailers + // either from the call info in context or from the stream itself. + // This verifies that the both the stream and the call info have the same information + assertResponseHeadersAndTrailers(t, callInfo) + assertResponseHeadersAndTrailers(t, stream) +} + +func testServerStreamGenerics(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + val := 3 + req := connect.NewRequest(&pingv1.CountUpRequest{ + Number: int64(val), + }) + ctx, callInfo := connect.NewClientContext(context.Background()) + // With the generics API, A user can use the call info or request wrapper or both to set request headers. + // The resulting headers should be combined and sent in the request. + callInfo.RequestHeader().Set(clientHeader, "foo") + req.Header().Add(clientHeader, "bar") + + stream, err := client.CountUp(ctx, req) + assert.Nil(t, err) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) + assert.Nil(t, stream.Err()) + assert.Nil(t, stream.Close()) + // Assert values on request + assert.Equal(t, req.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, req.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, req.Spec().IsClient) + assert.Equal(t, req.Peer().Addr, httptest.DefaultRemoteAddr) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + // On server-streaming calls, users can access response headers and trailers + // either from the call info in context or from the stream itself. + // This verifies that the both the stream and the call info have the same information + assertResponseHeadersAndTrailers(t, callInfo) + assertResponseHeadersAndTrailers(t, stream) +} + +func testClientStreamSimple(t *testing.T, client pingv1connectsimple.PingServiceClient) { //nolint:thelper + ctx, callInfo := connect.NewClientContext(context.Background()) + for _, el := range expectedHeaderValues { + callInfo.RequestHeader().Add(clientHeader, el) + } + + const ( + upTo = 10 + expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 + ) + stream, err := client.Sum(ctx) + assert.Nil(t, err) + + // Send messages + for i := range upTo { + err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) + assert.Nil(t, err, assert.Sprintf("send %d", i)) + } + + response, err := stream.CloseAndReceive() + assert.Nil(t, err) + assert.Equal(t, response.GetSum(), expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeClient) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceSumProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + assertResponseHeadersAndTrailers(t, callInfo) +} + +func testClientStreamGenerics(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + ctx, callInfo := connect.NewClientContext(context.Background()) + callInfo.RequestHeader().Add(clientHeader, "foo") + const ( + upTo = 10 + expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 + ) + stream := client.Sum(ctx) + stream.RequestHeader().Add(clientHeader, "bar") + + // Send messages + for i := range upTo { + err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) + assert.Nil(t, err, assert.Sprintf("send %d", i)) + } + + response, err := stream.CloseAndReceive() + assert.Nil(t, err) + assert.Equal(t, response.Msg.GetSum(), expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeClient) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceSumProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.SumResponse]{response: response} + assertResponseHeadersAndTrailers(t, wrapper) + assertResponseHeadersAndTrailers(t, callInfo) +} + +func testBidiStreamSimple(t *testing.T, client pingv1connectsimple.PingServiceClient, expectSuccess bool) { //nolint:thelper + send := []int64{3, 5, 1} + expect := []int64{3, 8, 9} + var got []int64 + ctx, callInfo := connect.NewClientContext(context.Background()) + callInfo.RequestHeader().Add(clientHeader, "foo") + callInfo.RequestHeader().Add(clientHeader, "bar") + + stream, err := client.CumSum(ctx) + assert.Nil(t, err) + assert.NotNil(t, stream) + + if !expectSuccess { // server doesn't support HTTP/2 + failNoHTTP2(t, stream) + return + } + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i, n := range send { + err := stream.Send(&pingv1.CumSumRequest{Number: n}) + assert.Nil(t, err, assert.Sprintf("send error #%d", i)) + } + assert.Nil(t, stream.CloseRequest()) + }() + go func() { + defer wg.Done() + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.Nil(t, err) + got = append(got, msg.GetSum()) + } + assert.Nil(t, stream.CloseResponse()) + }() + wg.Wait() + assert.Equal(t, got, expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + assertResponseHeadersAndTrailers(t, callInfo) + assertResponseHeadersAndTrailers(t, stream) +} + +func testBidiStreamGenerics(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper + send := []int64{3, 5, 1} + expect := []int64{3, 8, 9} + var got []int64 + ctx, callInfo := connect.NewClientContext(context.Background()) + // With the generics API, A user can use the call info or request wrapper or both to set request headers. + // The resulting headers should be combined and sent in the request. + callInfo.RequestHeader().Add(clientHeader, "foo") + stream := client.CumSum(ctx) + stream.RequestHeader().Add(clientHeader, "bar") + + if !expectSuccess { // server doesn't support HTTP/2 + failNoHTTP2(t, stream) + return + } + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i, n := range send { + err := stream.Send(&pingv1.CumSumRequest{Number: n}) + assert.Nil(t, err, assert.Sprintf("send error #%d", i)) + } + assert.Nil(t, stream.CloseRequest()) + }() + go func() { + defer wg.Done() + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.Nil(t, err) + got = append(got, msg.GetSum()) + } + assert.Nil(t, stream.CloseResponse()) + }() + wg.Wait() + assert.Equal(t, got, expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + assertResponseHeadersAndTrailers(t, callInfo) + assertResponseHeadersAndTrailers(t, stream) } // Validates that the peer and spec information is set correctly in a request. @@ -3437,7 +3868,7 @@ func compareContextAndRequest(ctx context.Context, request requestInfo, requestH return nil } -// Returns an error if the given http headers don't contain the expected header values. +// expectMetadata returns an error if the given http headers don't contain the expected header values. // Typically, most methods in the pingServer implementations just read the request headers // and copy those to the response headers and trailers and let the client verify that way. // However, this method can be used in conjunction with the server's verifyMetadata setting @@ -3456,7 +3887,13 @@ func expectMetadata(meta http.Header) error { return nil } -// Compares two http Header objects to verify they contain the exact same information. +// assertResponseHeadersAndTrailers verifies that the given response info contains the expected headers and trailers. +func assertResponseHeadersAndTrailers(t *testing.T, resp responseInfo) { //nolint:thelper + assert.True(t, compareValues(resp.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(resp.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) +} + +// compareHeaders compares two http Header objects to verify they contain the exact same information. func compareHeaders(hdr1 http.Header, hdr2 http.Header) bool { if len(hdr1) != len(hdr2) { return false @@ -3474,7 +3911,7 @@ func compareHeaders(hdr1 http.Header, hdr2 http.Header) bool { return true } -// Compares two string slices of header values to verify they are the same, ignoring order. +// compareValues compares two string slices of header values to verify they are the same, ignoring order. func compareValues(hdr1 []string, hdr2 []string) bool { if len(hdr1) != len(hdr2) { return false diff --git a/handler.go b/handler.go index ba25eb64..424225b5 100644 --- a/handler.go +++ b/handler.go @@ -146,6 +146,9 @@ func NewClientStreamHandler[Req, Res any]( conn: conn, initializer: config.Initializer, } + ctx = newHandlerContext(ctx, &streamCallInfo{ + conn: conn, + }) res, err := implementation(ctx, stream) if err != nil { return err @@ -236,6 +239,9 @@ func NewBidiStreamHandler[Req, Res any]( return newStreamHandler( config, func(ctx context.Context, conn StreamingHandlerConn) error { + ctx = newHandlerContext(ctx, &streamCallInfo{ + conn: conn, + }) return implementation( ctx, &BidiStream[Req, Res]{ diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 9668aff8..1e6003c2 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -16,7 +16,9 @@ package connect_test import ( "context" + "errors" "fmt" + "io" "net/http" "sync/atomic" "testing" @@ -25,173 +27,471 @@ import ( "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + pingv1connectsimple "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" ) -func TestSideQuestInInterceptor(t *testing.T) { - t.Parallel() - t.Run("unary", func(t *testing.T) { - t.Parallel() - t.Run("sidequest_succeeds", func(t *testing.T) { - t.Parallel() - createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32, server *memhttp.Server) connect.Option { - return connect.WithInterceptors( - newSideQuestInterceptor(t, clientCounter1, server), - newSideQuestInterceptor(t, clientCounter2, server), - ) - } - var clientCounter1, clientCounter2 atomic.Int32 - mux := http.NewServeMux() - mux.Handle( - pingv1connect.NewPingServiceHandler( - pingServer{}, - ), - ) - server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient( - server.Client(), - server.URL(), - createInterceptors(&clientCounter1, &clientCounter2, server), - ) - _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) - - assert.Nil(t, err) - assert.Equal(t, int32(1), clientCounter1.Load()) - assert.Equal(t, int32(1), clientCounter2.Load()) - }) - }) -} +const expectedContextErrorMessage = "creating a new context in an interceptor is prohibited" -func TestNewClientContextFails(t *testing.T) { - // Verifies that calling NewClientContext in an interceptor fails when sending the new context downstream +func TestNewClientContextInInterceptor(t *testing.T) { t.Parallel() - t.Run("unary", func(t *testing.T) { + t.Run("simple_api", func(t *testing.T) { t.Parallel() + mux := http.NewServeMux() + mux.Handle( + pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + ), + ) + server := memhttptest.NewServer(t, mux) t.Run("first_interceptor", func(t *testing.T) { - t.Parallel() - createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { - return connect.WithInterceptors( - &contextInterceptor{client: true, count: clientCounter1, createNewContext: true}, - &contextInterceptor{client: true, count: clientCounter2}, + // Because we're creating a new context in the first interceptor, only the first interceptor should fire + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connectsimple.PingServiceClient { + opts := connect.WithInterceptors( + &contextInterceptor{client: true, count: counter1, createNewContext: true}, + &contextInterceptor{client: true, count: counter2}, + ) + return pingv1connectsimple.NewPingServiceClient( + server.Client(), + server.URL(), + opts, ) } - var clientCounter1, clientCounter2 atomic.Int32 - mux := http.NewServeMux() - mux.Handle( - pingv1connect.NewPingServiceHandler( - pingServer{}, - ), - ) - server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient( - server.Client(), - server.URL(), - createInterceptors(&clientCounter1, &clientCounter2), - ) - _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) - - // Since we are creating a new client context, an error will be returned from the invocation - assert.NotNil(t, err) - assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") - // And because we're creating it in the first interceptor, only the first interceptor fires - assert.Equal(t, int32(1), clientCounter1.Load()) - assert.Equal(t, int32(0), clientCounter2.Load()) + t.Run("unary", func(t *testing.T) { + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + resp, err := client.Ping(context.Background(), &pingv1.PingRequest{Number: 10}) + + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream, err := client.CountUp(context.Background(), &pingv1.CountUpRequest{Number: 10}) + + assert.Nil(t, stream) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + // With client-streaming and the simple API, the initial call fails. This differs from + // the generics API which requires a call to stream.Send first to receive an error. + stream, err := client.Sum(context.Background()) + assert.NotNil(t, err) + assert.Nil(t, stream) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + // With bidi-streaming and the simple API, the initial call fails. This differs from + // the generics API which requires a call to stream.Send first to receive an error. + stream, err := client.CumSum(context.Background()) + assert.NotNil(t, err) + assert.Nil(t, stream) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) }) t.Run("subsequent_interceptor", func(t *testing.T) { + // Because we're creating a new context in the last interceptor, all interceptors should fire + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connectsimple.PingServiceClient { + opts := connect.WithInterceptors( + &contextInterceptor{client: true, count: counter1}, + &contextInterceptor{client: true, count: counter2, createNewContext: true}, + ) + return pingv1connectsimple.NewPingServiceClient( + server.Client(), + server.URL(), + opts, + ) + } + t.Run("unary", func(t *testing.T) { + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + resp, err := client.Ping(context.Background(), &pingv1.PingRequest{Number: 10}) + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CountUp(context.Background(), &pingv1.CountUpRequest{Number: 10}) + assert.Nil(t, stream) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + // With client-streaming and the simple API, the initial call fails. This differs from + // the generics API which requires a call to stream.Send first to receive an error. + stream, err := client.Sum(context.Background()) + assert.NotNil(t, err) + assert.Nil(t, stream) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + // With bidi-streaming and the simple API, the initial call fails. This differs from + // the generics API which requires a call to stream.Send first to receive an error. + stream, err := client.CumSum(context.Background()) + assert.NotNil(t, err) + assert.Nil(t, stream) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + }) + t.Run("sidequest_succeeds", func(t *testing.T) { + // These tests create a new context but it is used to issue a separate/new request and not reused in the + // interceptor chain. So, all interceptors should fire and no errors should be returned. t.Parallel() - createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { - return connect.WithInterceptors( - &contextInterceptor{client: true, count: clientCounter1}, - &contextInterceptor{client: true, count: clientCounter2, createNewContext: true}, + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connectsimple.PingServiceClient { + opts := connect.WithInterceptors( + newSideQuestInterceptor(t, counter1, server), + newSideQuestInterceptor(t, counter2, server), + ) + return pingv1connectsimple.NewPingServiceClient( + server.Client(), + server.URL(), + opts, ) } - var clientCounter1, clientCounter2 atomic.Int32 - mux := http.NewServeMux() - mux.Handle( - pingv1connect.NewPingServiceHandler( - pingServer{}, - ), - ) - server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient( - server.Client(), - server.URL(), - createInterceptors(&clientCounter1, &clientCounter2), - ) - _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) - - // Since we are creating a new client context, an error will be returned from the invocation - assert.NotNil(t, err) - assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") - // And because we're creating it in the second interceptor, they both fire - assert.Equal(t, int32(1), clientCounter1.Load()) - assert.Equal(t, int32(1), clientCounter2.Load()) + t.Run("unary", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + resp, err := client.Ping(context.Background(), &pingv1.PingRequest{Number: 10}) + assert.NotNil(t, resp) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CountUp(context.Background(), &pingv1.CountUpRequest{Number: 10}) + assert.NotNil(t, stream) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.Sum(context.Background()) + assert.NotNil(t, stream) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CumSum(context.Background()) + assert.Nil(t, err) + assert.NotNil(t, stream) + + assert.Nil(t, stream.CloseRequest()) + assert.Nil(t, stream.CloseResponse()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) }) }) - t.Run("server_streaming", func(t *testing.T) { + t.Run("generics_api", func(t *testing.T) { t.Parallel() + mux := http.NewServeMux() + mux.Handle( + pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + ), + ) + server := memhttptest.NewServer(t, mux) t.Run("first_interceptor", func(t *testing.T) { - t.Parallel() - createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { - return connect.WithInterceptors( - &contextInterceptor{client: true, count: clientCounter1, createNewContext: true}, - &contextInterceptor{client: true, count: clientCounter2}, + // Because we're creating a new context in the first interceptor, only the first interceptor should fire + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connect.PingServiceClient { + opts := connect.WithInterceptors( + &contextInterceptor{client: true, count: counter1, createNewContext: true}, + &contextInterceptor{client: true, count: counter2}, + ) + return pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + opts, ) } - var clientCounter1, clientCounter2 atomic.Int32 - mux := http.NewServeMux() - mux.Handle( - pingv1connect.NewPingServiceHandler( - pingServer{}, - ), - ) - server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient( - server.Client(), - server.URL(), - createInterceptors(&clientCounter1, &clientCounter2), - ) - responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) - - // Since we are creating a new client context, an error will be returned from the invocation - assert.Nil(t, responses) - assert.NotNil(t, err) - assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") - // And because we're creating it in the first interceptor, only the first interceptor fires - assert.Equal(t, int32(1), clientCounter1.Load()) - assert.Equal(t, int32(0), clientCounter2.Load()) + t.Run("unary", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + resp, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) + assert.Nil(t, stream) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.Sum(context.Background()) + assert.NotNil(t, stream) + + // With client-streaming and the generics API, a call to stream.Send is required to receive an error. + err := stream.Send(&pingv1.SumRequest{Number: int64(1)}) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + // We should receive the same error when we try to close the stream + resp, err := stream.CloseAndReceive() + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + //nolint:dupl // the test logic for bidi w/r/t generic and simple api looks the same, but it's subtly different + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.CumSum(context.Background()) + assert.NotNil(t, stream) + + // With bidi-streaming and the generics API, a call to stream.Send is required to receive an error. + err := stream.Send(&pingv1.CumSumRequest{Number: 1}) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + // We should receive the same error when we try to close the send and receive parts of the stream + err = stream.CloseRequest() + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + err = stream.CloseResponse() + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) }) + t.Run("subsequent_interceptor", func(t *testing.T) { - t.Parallel() - createInterceptors := func(clientCounter1 *atomic.Int32, clientCounter2 *atomic.Int32) connect.Option { - return connect.WithInterceptors( - &contextInterceptor{client: true, count: clientCounter1}, - &contextInterceptor{client: true, count: clientCounter2, createNewContext: true}, + // Because we're creating a new context in the last interceptor, all interceptors should fire + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connect.PingServiceClient { + opts := connect.WithInterceptors( + &contextInterceptor{client: true, count: counter1}, + &contextInterceptor{client: true, count: counter2, createNewContext: true}, + ) + return pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + opts, ) } - var clientCounter1, clientCounter2 atomic.Int32 - mux := http.NewServeMux() - mux.Handle( - pingv1connect.NewPingServiceHandler( - pingServer{}, - ), - ) - server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient( - server.Client(), - server.URL(), - createInterceptors(&clientCounter1, &clientCounter2), - ) - responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) - - // Since we are creating a new client context, an error will be returned from the invocation - assert.Nil(t, responses) - assert.NotNil(t, err) - assert.Equal(t, err.Error(), "creating a new context in an interceptor is prohibited") - // And because we're creating it in the second interceptor, all interceptors fire - assert.Equal(t, int32(1), clientCounter1.Load()) - assert.Equal(t, int32(1), clientCounter2.Load()) + t.Run("unary", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + resp, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) + assert.Nil(t, stream) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.Sum(context.Background()) + assert.NotNil(t, stream) + + // With client-streaming and the generics API, a call to stream.Send is required to receive an error. + err := stream.Send(&pingv1.SumRequest{Number: int64(1)}) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + // We should receive the same error when we try to close the stream + resp, err := stream.CloseAndReceive() + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + //nolint:dupl // the test logic for bidi w/r/t generic and simple api looks the same, but it's subtly different + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.CumSum(context.Background()) + assert.NotNil(t, stream) + + // With bidi-streaming and the generics API, a call to stream.Send is required to receive an error. + err := stream.Send(&pingv1.CumSumRequest{Number: 1}) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + // We should receive the same error when we try to close the send and receive parts of the stream + err = stream.CloseRequest() + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + err = stream.CloseResponse() + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + }) + t.Run("sidequest_succeeds", func(t *testing.T) { + // These tests create a new context but it is used to issue a separate/new request and not reused in the + // interceptor chain. So, all interceptors should fire and no errors should be returned. + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connect.PingServiceClient { + opts := connect.WithInterceptors( + newSideQuestInterceptor(t, counter1, server), + newSideQuestInterceptor(t, counter2, server), + ) + return pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + opts, + ) + } + t.Run("unary", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + resp, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + assert.NotNil(t, resp) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) + assert.NotNil(t, stream) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.Sum(context.Background()) + assert.NotNil(t, stream) + + err := stream.Send(&pingv1.SumRequest{Number: int64(1)}) + assert.Nil(t, err) + resp, err := stream.CloseAndReceive() + assert.NotNil(t, resp) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.CumSum(context.Background()) + assert.NotNil(t, stream) + + // The initial send should succeed + err := stream.Send(&pingv1.CumSumRequest{Number: 1}) + assert.Nil(t, err) + + // We should be able to successfully close the send part of the stream + assert.Nil(t, stream.CloseRequest()) + + // All receives should succeed + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.NotNil(t, msg) + assert.Nil(t, err) + } + // We should be able to successfully close the receive part of the stream + assert.Nil(t, stream.CloseResponse()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) }) }) } diff --git a/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go index 3cdf6f80..5b0cfaa3 100644 --- a/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go @@ -71,7 +71,7 @@ type PingServiceClient interface { // Fail always fails. Fail(context.Context, *v1.FailRequest) (*v1.FailResponse, error) // Sum calculates the sum of the numbers sent on the stream. - Sum(context.Context) (*connect.ClientStreamForClient[v1.SumRequest, v1.SumResponse], error) + Sum(context.Context) (*connect.ClientStreamForClientSimple[v1.SumRequest, v1.SumResponse], error) // CountUp returns a stream of the numbers up to the given request. CountUp(context.Context, *v1.CountUpRequest) (*connect.ServerStreamForClient[v1.CountUpResponse], error) // CumSum determines the cumulative sum of all the numbers sent on the stream. @@ -151,7 +151,7 @@ func (c *pingServiceClient) Fail(ctx context.Context, req *v1.FailRequest) (*v1. } // Sum calls connect.ping.v1.PingService.Sum. -func (c *pingServiceClient) Sum(ctx context.Context) (*connect.ClientStreamForClient[v1.SumRequest, v1.SumResponse], error) { +func (c *pingServiceClient) Sum(ctx context.Context) (*connect.ClientStreamForClientSimple[v1.SumRequest, v1.SumResponse], error) { return c.sum.CallClientStreamSimple(ctx) }