diff --git a/bench_test.go b/bench_test.go index 333b5bbb..e9b3e4b4 100644 --- a/bench_test.go +++ b/bench_test.go @@ -112,7 +112,10 @@ func BenchmarkConnect(b *testing.B) { upTo = 1 expect = 1 ) - stream := client.Sum(ctx) + stream, err := client.Sum(ctx) + if err != nil { + b.Error(err) + } for number := int64(1); number <= upTo; number++ { if err := stream.Send(&pingv1.SumRequest{Number: number}); err != nil { b.Error(err) @@ -121,7 +124,7 @@ func BenchmarkConnect(b *testing.B) { response, err := stream.CloseAndReceive() if err != nil { b.Error(err) - } else if got := response.Msg.GetSum(); got != expect { + } else if got := response.GetSum(); got != expect { b.Errorf("expected %d, got %d", expect, got) } } @@ -159,7 +162,10 @@ func BenchmarkConnect(b *testing.B) { const ( upTo = 1 ) - stream := client.CumSum(ctx) + stream, err := client.CumSum(ctx) + if err != nil { + b.Error(err) + } number := int64(1) for ; number <= upTo; number++ { if err := stream.Send(&pingv1.CumSumRequest{Number: number}); err != nil { diff --git a/client.go b/client.go index ffb9336f..f9e32daf 100644 --- a/client.go +++ b/client.go @@ -79,6 +79,10 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) conn.onRequestSend(func(r *http.Request) { request.setRequestMethod(r.Method) + callInfo, ok := clientCallInfoFromContext(ctx) + if ok { + callInfo.method = r.Method + } }) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the @@ -100,6 +104,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien return response, conn.CloseResponse() }) if interceptor := config.Interceptor; interceptor != nil { + // interceptor is the full chain of all interceptors provided unaryFunc = interceptor.WrapUnary(unaryFunc) } client.callUnary = func(ctx context.Context, request *Request[Req]) (*Response[Res], error) { @@ -109,6 +114,23 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien request.spec = unarySpec request.peer = client.protocolClient.Peer() protocolClient.WriteRequestHeader(StreamTypeUnary, request.Header()) + + // Also set them in the context if there's a call info present + callInfo, callInfoOk := clientCallInfoFromContext(ctx) + if callInfoOk { + callInfo.peer = request.Peer() + callInfo.spec = request.Spec() + // A client could have set request headers in the call info OR the request wrapper + // So if a callInfo exists in context, merge any headers from there into the request wrapper + // so that all headers are sent in the request + mergeHeaders(request.Header(), callInfo.requestHeader) + + // Copy the call info into a sentinel value. This is so we can compare + // the sentinel value against the call info in context. If they're different, + // we can stop the request. This protects against changing the context in interceptors. + ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo) + } + response, err := unaryFunc(ctx, request) if err != nil { return nil, err @@ -117,6 +139,12 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien if !ok { return nil, errorf(CodeInternal, "unexpected client response type %T", response) } + if callInfoOk { + // Wrap the response and set it into the context callinfo + callInfo.responseSource = &responseWrapper[Res]{ + response: typed, + } + } return typed, nil } return client @@ -130,19 +158,6 @@ func (c *Client[Req, Res]) CallUnary(ctx context.Context, request *Request[Req]) return c.callUnary(ctx, request) } -// CallUnarySimple calls a request-response procedure using the function signature -// associated with the "simple" generation option. -// -// This option eliminates the [Request] and [Response] wrappers, and instead uses the -// context.Context to propagate information such as headers. -func (c *Client[Req, Res]) CallUnarySimple(ctx context.Context, requestMsg *Req) (*Res, error) { - response, err := c.CallUnary(ctx, requestFromContext(ctx, requestMsg)) - if response != nil { - return response.Msg, err - } - return nil, err -} - // CallClientStream calls a client streaming procedure. func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamForClient[Req, Res] { if c.err != nil { @@ -154,6 +169,22 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo } } +// CallClientStream calls a client streaming procedure in simple mode. +func (c *Client[Req, Res]) CallClientStreamSimple(ctx context.Context) (*ClientStreamForClientSimple[Req, Res], error) { + if c.err != nil { + return &ClientStreamForClientSimple[Req, Res]{err: c.err}, c.err + } + + stream := &ClientStreamForClientSimple[Req, Res]{ + conn: c.newConn(ctx, StreamTypeClient, nil), + initializer: c.config.Initializer, + } + if err := stream.Send(nil); err != nil { + return nil, err + } + return stream, nil +} + // CallServerStream calls a server streaming procedure. func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Request[Req]) (*ServerStreamForClient[Res], error) { if c.err != nil { @@ -162,9 +193,11 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { request.method = r.Method }) - request.spec = conn.Spec() request.peer = conn.Peer() + request.spec = conn.Spec() + mergeHeaders(conn.RequestHeader(), request.header) + // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. @@ -182,15 +215,6 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques }, nil } -// CallServerStreamSimple calls a server streaming procedure using the function signature -// associated with the "simple" generation option. -// -// This option eliminates the [Request] wrapper, and instead uses the context.Context to -// propagate information such as headers. -func (c *Client[Req, Res]) CallServerStreamSimple(ctx context.Context, requestMsg *Req) (*ServerStreamForClient[Res], error) { - return c.CallServerStream(ctx, requestFromContext(ctx, requestMsg)) -} - // CallBidiStream calls a bidirectional streaming procedure. func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForClient[Req, Res] { if c.err != nil { @@ -202,7 +226,27 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli } } +// CallBidiStreamSimple calls a bidirectional streaming procedure in simple mode. +func (c *Client[Req, Res]) CallBidiStreamSimple(ctx context.Context) (*BidiStreamForClient[Req, Res], error) { + stream := c.CallBidiStream(ctx) + if stream.err != nil { + return nil, stream.err + } + if err := stream.Send(nil); err != nil { + return nil, err + } + return stream, nil +} + func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn { + callInfo, callInfoOk := clientCallInfoFromContext(ctx) + // Set values in the context if there's a call info present + if callInfoOk { + // Copy the call info into a sentinel value. This is so we can compare + // the sentinel value against the call info in context. If they're different, + // we can stop the request. This protects against changing the context in interceptors. + ctx = context.WithValue(ctx, sentinelContextKey{}, callInfo) + } newConn := func(ctx context.Context, spec Spec) StreamingClientConn { header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing c.protocolClient.WriteRequestHeader(streamType, header) @@ -213,7 +257,20 @@ func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, o if interceptor := c.config.Interceptor; interceptor != nil { newConn = interceptor.WrapStreamingClient(newConn) } - return newConn(ctx, c.config.newSpec(streamType)) + conn := newConn(ctx, c.config.newSpec(streamType)) + + // Set values in the context if there's a call info present + if callInfoOk { + callInfo.peer = conn.Peer() + callInfo.spec = conn.Spec() + callInfo.responseSource = conn + + // Merge any callInfo request headers first, then do the request. + // so that context headers show first in the list of headers + mergeHeaders(conn.RequestHeader(), callInfo.RequestHeader()) + } + + return conn } type clientConfig struct { diff --git a/client_ext_test.go b/client_ext_test.go index 4e5b8351..639751be 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -18,10 +18,12 @@ import ( "bytes" "context" "crypto/rand" + "crypto/tls" "errors" "fmt" "io" "log" + "net" "net/http" "net/http/httptest" "runtime" @@ -36,6 +38,7 @@ import ( pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp/memhttptest" + "golang.org/x/net/http2" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" @@ -389,6 +392,48 @@ func TestDynamicClient(t *testing.T) { got := rsp.Msg.Get(methodDesc.Output().Fields().ByName("sum")).Int() assert.Equal(t, got, 42*2) }) + t.Run("clientStreamSimple", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Sum") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + connected := make(chan struct{}) + transport := &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + close(connected) + return server.Transport().DialTLSContext(ctx, network, addr, cfg) + }, + AllowHTTP: true, + } + client := connect.NewClient[dynamicpb.Message, dynamicpb.Message]( + &http.Client{Transport: transport}, + server.URL()+"/connect.ping.v1.PingService/Sum", + connect.WithSchema(methodDesc), + connect.WithResponseInitializer(initializer), + ) + stream, err := client.CallClientStreamSimple(ctx) + assert.Nil(t, err) + select { + case <-connected: + break + case <-time.After(time.Second): + t.Error("CallClientStreamSimple did not eagerly send headers") + } + msg := dynamicpb.NewMessage(methodDesc.Input()) + msg.Set( + methodDesc.Input().Fields().ByName("number"), + protoreflect.ValueOfInt64(42), + ) + assert.Nil(t, stream.Send(msg)) + assert.Nil(t, stream.Send(msg)) + rsp, err := stream.CloseAndReceive() + if !assert.Nil(t, err) { + return + } + got := rsp.Get(methodDesc.Output().Fields().ByName("sum")).Int() + assert.Equal(t, got, 42*2) + }) t.Run("serverStream", func(t *testing.T) { t.Parallel() desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CountUp") @@ -445,6 +490,48 @@ func TestDynamicClient(t *testing.T) { got := out.Get(methodDesc.Output().Fields().ByName("number")).Int() assert.Equal(t, got, 42) }) + t.Run("bidiSimple", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CumSum") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + connected := make(chan struct{}) + transport := &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + close(connected) + return server.Transport().DialTLSContext(ctx, network, addr, cfg) + }, + AllowHTTP: true, + } + client := connect.NewClient[dynamicpb.Message, dynamicpb.Message]( + &http.Client{Transport: transport}, + server.URL()+"/connect.ping.v1.PingService/CumSum", + connect.WithSchema(methodDesc), + connect.WithResponseInitializer(initializer), + ) + stream, err := client.CallBidiStreamSimple(ctx) + assert.Nil(t, err) + select { + case <-connected: + break + case <-time.After(time.Second): + t.Error("CallBidiStreamSimple did not eagerly send headers") + } + msg := dynamicpb.NewMessage(methodDesc.Input()) + msg.Set( + methodDesc.Input().Fields().ByName("number"), + protoreflect.ValueOfInt64(42), + ) + assert.Nil(t, stream.Send(msg)) + assert.Nil(t, stream.CloseRequest()) + out, err := stream.Receive() + if assert.Nil(t, err) { + return + } + got := out.Get(methodDesc.Output().Fields().ByName("number")).Int() + assert.Equal(t, got, 42) + }) t.Run("option", func(t *testing.T) { t.Parallel() desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping") diff --git a/client_stream.go b/client_stream.go index bacbca8a..9038d273 100644 --- a/client_stream.go +++ b/client_stream.go @@ -20,6 +20,68 @@ import ( "net/http" ) +// ClientStreamForClientsimple is the client's view of a client streaming RPC. +// for the simple API. +// +// It's returned from [Client].CallClientStreamSimple, but doesn't currently have an +// exported constructor function. +type ClientStreamForClientSimple[Req, Res any] struct { + conn StreamingClientConn + initializer maybeInitializer + // Error from client construction. If non-nil, return for all calls. + err error +} + +// Spec returns the specification for the RPC. +func (c *ClientStreamForClientSimple[_, _]) Spec() Spec { + return c.conn.Spec() +} + +// Peer describes the server for the RPC. +func (c *ClientStreamForClientSimple[_, _]) Peer() Peer { + return c.conn.Peer() +} + +// Send a message to the server. The first call to Send also sends the request +// headers. +// +// If the server returns an error, Send returns an error that wraps [io.EOF]. +// Clients should check for case using the standard library's [errors.Is] and +// unmarshal the error using CloseAndReceive. +func (c *ClientStreamForClientSimple[Req, Res]) Send(request *Req) error { + if c.err != nil { + return c.err + } + if request == nil { + return c.conn.Send(nil) + } + return c.conn.Send(request) +} + +// CloseAndReceive closes the send side of the stream and waits for the +// response. +func (c *ClientStreamForClientSimple[Req, Res]) CloseAndReceive() (*Res, error) { + if c.err != nil { + return nil, c.err + } + if err := c.conn.CloseRequest(); err != nil { + _ = c.conn.CloseResponse() + return nil, err + } + response, err := receiveUnaryResponse[Res](c.conn, c.initializer) + if err != nil { + _ = c.conn.CloseResponse() + return nil, err + } + return response.Msg, c.conn.CloseResponse() +} + +// Conn exposes the underlying StreamingClientConn. This may be useful if +// you'd prefer to wrap the connection in a different high-level API. +func (c *ClientStreamForClientSimple[Req, Res]) Conn() (StreamingClientConn, error) { + return c.conn, c.err +} + // ClientStreamForClient is the client's view of a client streaming RPC. // // It's returned from [Client].CallClientStream, but doesn't currently have an diff --git a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go index 68149cc7..bbbaa05c 100644 --- a/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go +++ b/cmd/protoc-gen-connect-go/internal/testdata/simple/gen/genconnect/simple.connect.go @@ -63,7 +63,7 @@ const ( // TestServiceClient is a client for the connect.test.simple.TestService service. type TestServiceClient interface { Method(context.Context, *gen.Request) (*gen.Response, error) - MethodClientStream(context.Context) *connect.ClientStreamForClient[gen.Request, gen.Response] + MethodClientStream(context.Context) (*connect.ClientStreamForClientSimple[gen.Request, gen.Response], error) MethodServerStream(context.Context, *gen.Request) (*connect.ServerStreamForClient[gen.Response], error) MethodBidiStream(context.Context, *gen.Request) (*connect.ServerStreamForClient[gen.Response], error) } @@ -116,28 +116,32 @@ type testServiceClient struct { // Method calls connect.test.simple.TestService.Method. func (c *testServiceClient) Method(ctx context.Context, req *gen.Request) (*gen.Response, error) { - return c.method.CallUnarySimple(ctx, req) + response, err := c.method.CallUnary(ctx, connect.NewRequest(req)) + if response != nil { + return response.Msg, err + } + return nil, err } // MethodClientStream calls connect.test.simple.TestService.MethodClientStream. -func (c *testServiceClient) MethodClientStream(ctx context.Context) *connect.ClientStreamForClient[gen.Request, gen.Response] { - return c.methodClientStream.CallClientStream(ctx) +func (c *testServiceClient) MethodClientStream(ctx context.Context) (*connect.ClientStreamForClientSimple[gen.Request, gen.Response], error) { + return c.methodClientStream.CallClientStreamSimple(ctx) } // MethodServerStream calls connect.test.simple.TestService.MethodServerStream. func (c *testServiceClient) MethodServerStream(ctx context.Context, req *gen.Request) (*connect.ServerStreamForClient[gen.Response], error) { - return c.methodServerStream.CallServerStreamSimple(ctx, req) + return c.methodServerStream.CallServerStream(ctx, connect.NewRequest(req)) } // MethodBidiStream calls connect.test.simple.TestService.MethodBidiStream. func (c *testServiceClient) MethodBidiStream(ctx context.Context, req *gen.Request) (*connect.ServerStreamForClient[gen.Response], error) { - return c.methodBidiStream.CallServerStreamSimple(ctx, req) + return c.methodBidiStream.CallServerStream(ctx, connect.NewRequest(req)) } // TestServiceHandler is an implementation of the connect.test.simple.TestService service. type TestServiceHandler interface { Method(context.Context, *gen.Request) (*gen.Response, error) - MethodClientStream(context.Context, *connect.ClientStream[gen.Request]) (*connect.Response[gen.Response], error) + MethodClientStream(context.Context, *connect.ClientStream[gen.Request]) (*gen.Response, error) MethodServerStream(context.Context, *gen.Request, *connect.ServerStream[gen.Response]) error MethodBidiStream(context.Context, *gen.Request, *connect.ServerStream[gen.Response]) error } @@ -155,7 +159,7 @@ func NewTestServiceHandler(svc TestServiceHandler, opts ...connect.HandlerOption connect.WithSchema(testServiceMethods.ByName("Method")), connect.WithHandlerOptions(opts...), ) - testServiceMethodClientStreamHandler := connect.NewClientStreamHandler( + testServiceMethodClientStreamHandler := connect.NewClientStreamHandlerSimple( TestServiceMethodClientStreamProcedure, svc.MethodClientStream, connect.WithSchema(testServiceMethods.ByName("MethodClientStream")), @@ -196,7 +200,7 @@ func (UnimplementedTestServiceHandler) Method(context.Context, *gen.Request) (*g return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.test.simple.TestService.Method is not implemented")) } -func (UnimplementedTestServiceHandler) MethodClientStream(context.Context, *connect.ClientStream[gen.Request]) (*connect.Response[gen.Response], error) { +func (UnimplementedTestServiceHandler) MethodClientStream(context.Context, *connect.ClientStream[gen.Request]) (*gen.Response, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.test.simple.TestService.MethodClientStream is not implemented")) } diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index b9018ce1..99b0cc90 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -381,18 +381,30 @@ func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, na switch { case isStreamingClient && !isStreamingServer: - g.P("return c.", unexport(method.GoName), ".CallClientStream(ctx)") + if simple { + g.P("return c.", unexport(method.GoName), ".CallClientStreamSimple(ctx)") + } else { + g.P("return c.", unexport(method.GoName), ".CallClientStream(ctx)") + } case !isStreamingClient && isStreamingServer: if simple { - g.P("return c.", unexport(method.GoName), ".CallServerStreamSimple(ctx, req)") + g.P("return c.", unexport(method.GoName), ".CallServerStream(ctx, ", connectPackage.Ident("NewRequest"), "(req))") } else { g.P("return c.", unexport(method.GoName), ".CallServerStream(ctx, req)") } case isStreamingClient && isStreamingServer: - g.P("return c.", unexport(method.GoName), ".CallBidiStream(ctx)") + if simple { + g.P("return c.", unexport(method.GoName), ".CallBidiStreamSimple(ctx)") + } else { + g.P("return c.", unexport(method.GoName), ".CallBidiStream(ctx)") + } default: if simple { - g.P("return c.", unexport(method.GoName), ".CallUnarySimple(ctx, req)") + g.P("response, err := c.", unexport(method.GoName), ".CallUnary(ctx, ", connectPackage.Ident("NewRequest"), "(req))") + g.P("if response != nil {") + g.P("return response.Msg, err") + g.P("}") + g.P("return nil, err") } else { g.P("return c.", unexport(method.GoName), ".CallUnary(ctx, req)") } @@ -409,12 +421,24 @@ func clientSignature(g *protogen.GeneratedFile, method *protogen.Method, named b } if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() { // bidi streaming + if simple { + return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + + "(*" + g.QualifiedGoIdent(connectPackage.Ident("BidiStreamForClient")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + + ", error)" + } return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + "*" + g.QualifiedGoIdent(connectPackage.Ident("BidiStreamForClient")) + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" } if method.Desc.IsStreamingClient() { // client streaming + if simple { + return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + + "(*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStreamForClientSimple")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + + ", error)" + } return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStreamForClient")) + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" @@ -480,7 +504,11 @@ func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, s idempotency := methodIdempotency(method) switch { case isStreamingClient && !isStreamingServer: - g.P(procedureHandlerName(method), ` := `, connectPackage.Ident("NewClientStreamHandler"), "(") + if simple { + g.P(procedureHandlerName(method), ` := `, connectPackage.Ident("NewClientStreamHandlerSimple"), "(") + } else { + g.P(procedureHandlerName(method), ` := `, connectPackage.Ident("NewClientStreamHandler"), "(") + } case !isStreamingClient && isStreamingServer: if simple { g.P(procedureHandlerName(method), ` := `, connectPackage.Ident("NewServerStreamHandlerSimple"), "(") @@ -564,6 +592,12 @@ func serverSignatureParams(g *protogen.GeneratedFile, method *protogen.Method, n } if method.Desc.IsStreamingClient() { // client streaming + if simple { + return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " + + streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStream")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + + ") (*" + g.QualifiedGoIdent(method.Output.GoIdent) + " ,error)" + } return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " + streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStream")) + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + diff --git a/connect.go b/connect.go index 596bc7fd..4f8349c1 100644 --- a/connect.go +++ b/connect.go @@ -211,11 +211,6 @@ func (r *Request[_]) setRequestMethod(method string) { r.method = method } -// setHeader sets the request header to the given value. -func (r *Request[_]) setHeader(header http.Header) { - r.header = header -} - // AnyRequest is the common method set of every [Request], regardless of type // parameter. It's used in unary interceptors. // @@ -287,16 +282,6 @@ func (r *Response[_]) Trailer() http.Header { return r.trailer } -// setHeader sets the response header. -func (r *Response[_]) setHeader(header http.Header) { - r.header = header -} - -// setTrailer sets the response trailer. -func (r *Response[_]) setTrailer(trailer http.Header) { - r.trailer = trailer -} - // internalOnly implements AnyResponse. func (r *Response[_]) internalOnly() {} @@ -383,6 +368,47 @@ type hasHTTPMethod interface { getHTTPMethod() string } +// errStreamingClientConn is a sentinel error implementation of StreamingClientConn. +type errStreamingClientConn struct { + err error +} + +func (c *errStreamingClientConn) Receive(msg any) error { + return c.err +} + +func (c *errStreamingClientConn) Spec() Spec { + return Spec{} +} + +func (c *errStreamingClientConn) Peer() Peer { + return Peer{} +} + +func (c *errStreamingClientConn) Send(msg any) error { + return c.err +} + +func (c *errStreamingClientConn) CloseRequest() error { + return c.err +} + +func (c *errStreamingClientConn) CloseResponse() error { + return c.err +} + +func (c *errStreamingClientConn) RequestHeader() http.Header { + return make(http.Header) +} + +func (c *errStreamingClientConn) ResponseHeader() http.Header { + return make(http.Header) +} + +func (c *errStreamingClientConn) ResponseTrailer() http.Header { + return make(http.Header) +} + // receiveUnaryResponse unmarshals a message from a StreamingClientConn, then // envelopes the message and attaches headers and trailers. It attempts to // consume the response stream and isn't appropriate when receiving multiple diff --git a/connect_ext_test.go b/connect_ext_test.go index 22e392e3..20d9e764 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -29,7 +29,9 @@ import ( rand "math/rand/v2" "net" "net/http" + "net/http/httptest" "runtime" + "sort" "strings" "sync" "testing" @@ -40,6 +42,7 @@ import ( pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/generics/connect/import/v1/importv1connect" "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + pingv1connectsimple "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/proto" @@ -53,37 +56,346 @@ const errorMessage = "oh no" // client doesn't set a header, and the server sets headers and trailers on the // response. const ( - headerValue = "some header value" - trailerValue = "some trailer value" clientHeader = "Connect-Client-Header" handlerHeader = "Connect-Handler-Header" handlerTrailer = "Connect-Handler-Trailer" clientMiddlewareErrorHeader = "Connect-Trigger-HTTP-Error" ) -func TestServer(t *testing.T) { +var ( + expectedHeaderValues = []string{"foo", "bar"} //nolint:gochecknoglobals +) + +func TestCallInfo(t *testing.T) { t.Parallel() - testPing := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper - t.Run("ping", func(t *testing.T) { + t.Run("simple_api", func(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + )) + server := memhttptest.NewServer(t, mux) + client := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) + t.Run("unary", func(t *testing.T) { + testUnarySimple(t, client) + }) + t.Run("unary_no_callinfo", func(t *testing.T) { + num := int64(42) + expect := &pingv1.PingResponse{Number: num} + response, err := client.Ping(context.Background(), &pingv1.PingRequest{Number: num}) + assert.Equal(t, response, expect) + assert.Nil(t, err) + }) + t.Run("unary_generics_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + )) + server := memhttptest.NewServer(t, mux) + simpleClient := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) + testUnarySimple(t, simpleClient) + }) + t.Run("server_stream", func(t *testing.T) { + testServerStreamSimple(t, client) + }) + t.Run("server_stream_no_callinfo", func(t *testing.T) { + val := 3 + stream, err := client.CountUp(context.Background(), &pingv1.CountUpRequest{ + Number: int64(val), + }) + assert.Nil(t, err) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) + assert.Nil(t, stream.Err()) + assert.Nil(t, stream.Close()) + }) + t.Run("server_stream_generics_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + )) + server := memhttptest.NewServer(t, mux) + simpleClient := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) + testServerStreamSimple(t, simpleClient) + }) + t.Run("client_stream", func(t *testing.T) { + testClientStreamSimple(t, client) + }) + t.Run("client_stream_generics_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + )) + server := memhttptest.NewServer(t, mux) + simpleClient := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) + testClientStreamSimple(t, simpleClient) + }) + t.Run("client_stream_no_callinfo", func(t *testing.T) { + const ( + upTo = 10 + expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 + ) + stream, err := client.Sum(context.Background()) + assert.Nil(t, err) + + // Send messages + for i := range upTo { + err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) + assert.Nil(t, err, assert.Sprintf("send %d", i)) + } + + response, err := stream.CloseAndReceive() + assert.Nil(t, err) + assert.Equal(t, response.GetSum(), expect) + }) + t.Run("bidi_stream", func(t *testing.T) { + testBidiStreamSimple(t, client, true) + }) + t.Run("bidi_stream_generics_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + )) + server := memhttptest.NewServer(t, mux) + simpleClient := pingv1connectsimple.NewPingServiceClient(server.Client(), server.URL()) + testBidiStreamSimple(t, simpleClient, true) + }) + t.Run("bidi_stream_no_callinfo", func(t *testing.T) { + send := []int64{3, 5, 1} + expect := []int64{3, 8, 9} + var got []int64 + stream, err := client.CumSum(context.Background()) + assert.Nil(t, err) + assert.NotNil(t, stream) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i, n := range send { + err := stream.Send(&pingv1.CumSumRequest{Number: n}) + assert.Nil(t, err, assert.Sprintf("send error #%d", i)) + } + assert.Nil(t, stream.CloseRequest()) + }() + go func() { + defer wg.Done() + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.Nil(t, err) + got = append(got, msg.GetSum()) + } + assert.Nil(t, stream.CloseResponse()) + }() + wg.Wait() + assert.Equal(t, got, expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) + }) + }) + t.Run("generics_api", func(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + )) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + t.Run("unary", func(t *testing.T) { + testUnaryGenerics(t, client) + }) + t.Run("unary_simple_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + )) + server := memhttptest.NewServer(t, mux) + genericsClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + testUnaryGenerics(t, genericsClient) + }) + t.Run("unary_no_callinfo", func(t *testing.T) { num := int64(42) request := connect.NewRequest(&pingv1.PingRequest{Number: num}) - request.Header().Set(clientHeader, headerValue) + request.Header().Add(clientHeader, "foo") + request.Header().Add(clientHeader, "bar") expect := &pingv1.PingResponse{Number: num} + response, err := client.Ping(context.Background(), request) assert.Nil(t, err) assert.Equal(t, response.Msg, expect) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + assert.Equal(t, request.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, request.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, request.Spec().IsClient) + assert.Equal(t, request.Peer().Addr, httptest.DefaultRemoteAddr) + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.PingResponse]{response: response} + assertResponseHeadersAndTrailers(t, wrapper) + }) + t.Run("server_stream", func(t *testing.T) { + testServerStreamGenerics(t, client) + }) + t.Run("server_stream_simple_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + )) + server := memhttptest.NewServer(t, mux) + genericsClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + testServerStreamGenerics(t, genericsClient) + }) + t.Run("server_stream_no_callinfo", func(t *testing.T) { + val := 3 + req := connect.NewRequest(&pingv1.CountUpRequest{ + Number: int64(val), + }) + req.Header().Set(clientHeader, "foo") + req.Header().Add(clientHeader, "bar") + + stream, err := client.CountUp(context.Background(), req) + assert.Nil(t, err) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) + assert.Nil(t, stream.Err()) + assert.Nil(t, stream.Close()) + // Assert values on request + assert.Equal(t, req.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, req.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, req.Spec().IsClient) + assert.Equal(t, req.Peer().Addr, httptest.DefaultRemoteAddr) + assertResponseHeadersAndTrailers(t, stream) + }) + t.Run("client_stream", func(t *testing.T) { + testClientStreamGenerics(t, client) + }) + t.Run("client_stream_simple_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + )) + server := memhttptest.NewServer(t, mux) + genericsClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + testClientStreamGenerics(t, genericsClient) + }) + t.Run("client_stream_no_callinfo", func(t *testing.T) { + const ( + upTo = 10 + expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 + ) + stream := client.Sum(context.Background()) + stream.RequestHeader().Add(clientHeader, "foo") + stream.RequestHeader().Add(clientHeader, "bar") + + // Send messages + for i := range upTo { + err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) + assert.Nil(t, err, assert.Sprintf("send %d", i)) + } + + response, err := stream.CloseAndReceive() + assert.Nil(t, err) + assert.Equal(t, response.Msg.GetSum(), expect) + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.SumResponse]{response: response} + assertResponseHeadersAndTrailers(t, wrapper) + }) + t.Run("bidi_stream", func(t *testing.T) { + testBidiStreamGenerics(t, client, true) + }) + t.Run("bidi_stream_simple_server", func(t *testing.T) { + mux := http.NewServeMux() + mux.Handle(pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + )) + server := memhttptest.NewServer(t, mux) + genericsClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + testBidiStreamGenerics(t, genericsClient, true) + }) + t.Run("bidi_stream_no_callinfo", func(t *testing.T) { + send := []int64{3, 5, 1} + expect := []int64{3, 8, 9} + var got []int64 + stream := client.CumSum(context.Background()) + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i, n := range send { + err := stream.Send(&pingv1.CumSumRequest{Number: n}) + assert.Nil(t, err, assert.Sprintf("send error #%d", i)) + } + assert.Nil(t, stream.CloseRequest()) + }() + go func() { + defer wg.Done() + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.Nil(t, err) + got = append(got, msg.GetSum()) + } + assert.Nil(t, stream.CloseResponse()) + }() + wg.Wait() + assert.Equal(t, got, expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) + }) + }) +} + +func TestServer(t *testing.T) { + t.Parallel() + testPing := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + t.Run("ping", func(t *testing.T) { + testUnaryGenerics(t, client) }) t.Run("zero_ping", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) - request.Header().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Add(clientHeader, el) + } response, err := client.Ping(context.Background(), request) assert.Nil(t, err) var expect pingv1.PingResponse assert.Equal(t, response.Msg, &expect) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.PingResponse]{response: response} + assertResponseHeadersAndTrailers(t, wrapper) }) t.Run("large_ping", func(t *testing.T) { // Using a large payload splits the request and response over multiple @@ -94,12 +406,16 @@ func TestServer(t *testing.T) { } hellos := strings.Repeat("hello", 1024*1024) // ~5mb request := connect.NewRequest(&pingv1.PingRequest{Text: hellos}) - request.Header().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Add(clientHeader, el) + } response, err := client.Ping(context.Background(), request) assert.Nil(t, err) assert.Equal(t, response.Msg.GetText(), hellos) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.PingResponse]{response: response} + assertResponseHeadersAndTrailers(t, wrapper) }) t.Run("ping_error", func(t *testing.T) { _, err := client.Ping( @@ -112,28 +428,14 @@ func TestServer(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) defer cancel() request := connect.NewRequest(&pingv1.PingRequest{}) - request.Header().Set(clientHeader, headerValue) + request.Header().Set(clientHeader, "foo") _, err := client.Ping(ctx, request) assert.Equal(t, connect.CodeOf(err), connect.CodeDeadlineExceeded) }) } testSum := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("sum", func(t *testing.T) { - const ( - upTo = 10 - expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 - ) - stream := client.Sum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) - for i := range upTo { - err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) - assert.Nil(t, err, assert.Sprintf("send %d", i)) - } - response, err := stream.CloseAndReceive() - assert.Nil(t, err) - assert.Equal(t, response.Msg.GetSum(), expect) - assert.Equal(t, response.Header().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, response.Trailer().Values(handlerTrailer), []string{trailerValue}) + testClientStreamGenerics(t, client) }) t.Run("sum_error", func(t *testing.T) { stream := client.Sum(context.Background()) @@ -146,31 +448,21 @@ func TestServer(t *testing.T) { }) t.Run("sum_close_and_receive_without_send", func(t *testing.T) { stream := client.Sum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } got, err := stream.CloseAndReceive() assert.Nil(t, err) assert.Equal(t, got.Msg, &pingv1.SumResponse{}) // receive header only stream - assert.Equal(t, got.Header().Values(handlerHeader), []string{headerValue}) + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.SumResponse]{response: got} + assertResponseHeadersAndTrailers(t, wrapper) }) } testCountUp := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper t.Run("count_up", func(t *testing.T) { - const upTo = 5 - got := make([]int64, 0, upTo) - expect := make([]int64, 0, upTo) - for i := range upTo { - expect = append(expect, int64(i+1)) - } - request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo}) - request.Header().Set(clientHeader, headerValue) - stream, err := client.CountUp(context.Background(), request) - assert.Nil(t, err) - for stream.Receive() { - got = append(got, stream.Msg().GetNumber()) - } - assert.Nil(t, stream.Err()) - assert.Nil(t, stream.Close()) - assert.Equal(t, got, expect) + testServerStreamGenerics(t, client) }) t.Run("count_up_error", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -200,7 +492,8 @@ func TestServer(t *testing.T) { t.Run("count_up_cancel_after_first_response", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) request := connect.NewRequest(&pingv1.CountUpRequest{Number: 5}) - request.Header().Set(clientHeader, headerValue) + request.Header().Add(clientHeader, "foo") + request.Header().Add(clientHeader, "bar") stream, err := client.CountUp(ctx, request) assert.Nil(t, err) assert.True(t, stream.Receive()) @@ -213,41 +506,7 @@ func TestServer(t *testing.T) { } testCumSum := func(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper t.Run("cumsum", func(t *testing.T) { - send := []int64{3, 5, 1} - expect := []int64{3, 8, 9} - var got []int64 - stream := client.CumSum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) - if !expectSuccess { // server doesn't support HTTP/2 - failNoHTTP2(t, stream) - return - } - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - for i, n := range send { - err := stream.Send(&pingv1.CumSumRequest{Number: n}) - assert.Nil(t, err, assert.Sprintf("send error #%d", i)) - } - assert.Nil(t, stream.CloseRequest()) - }() - go func() { - defer wg.Done() - for { - msg, err := stream.Receive() - if errors.Is(err, io.EOF) { - break - } - assert.Nil(t, err) - got = append(got, msg.GetSum()) - } - assert.Nil(t, stream.CloseResponse()) - }() - wg.Wait() - assert.Equal(t, got, expect) - assert.Equal(t, stream.ResponseHeader().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, stream.ResponseTrailer().Values(handlerTrailer), []string{trailerValue}) + testBidiStreamGenerics(t, client, expectSuccess) }) t.Run("cumsum_error", func(t *testing.T) { stream := client.CumSum(context.Background()) @@ -267,7 +526,9 @@ func TestServer(t *testing.T) { }) t.Run("cumsum_empty_stream", func(t *testing.T) { stream := client.CumSum(context.Background()) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } if !expectSuccess { // server doesn't support HTTP/2 failNoHTTP2(t, stream) return @@ -284,7 +545,9 @@ func TestServer(t *testing.T) { t.Run("cumsum_cancel_after_first_response", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) stream := client.CumSum(ctx) - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } if !expectSuccess { // server doesn't support HTTP/2 failNoHTTP2(t, stream) cancel() @@ -314,7 +577,9 @@ func TestServer(t *testing.T) { cancel() return } - stream.RequestHeader().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + stream.RequestHeader().Add(clientHeader, el) + } assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 8})) cancel() // On a subsequent send, ensure that we are still catching context @@ -344,7 +609,9 @@ func TestServer(t *testing.T) { request := connect.NewRequest(&pingv1.FailRequest{ Code: int32(connect.CodeResourceExhausted), }) - request.Header().Set(clientHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Add(clientHeader, el) + } response, err := client.Fail(context.Background(), request) assert.Nil(t, response) @@ -356,18 +623,24 @@ func TestServer(t *testing.T) { assert.Equal(t, connectErr.Code(), connect.CodeResourceExhausted) assert.Equal(t, connectErr.Error(), "resource_exhausted: "+errorMessage) assert.Zero(t, connectErr.Details()) - assert.Equal(t, connectErr.Meta().Values(handlerHeader), []string{headerValue}) - assert.Equal(t, connectErr.Meta().Values(handlerTrailer), []string{trailerValue}) + // Wrap the connect error so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using a single function + wrapper := &errorWrapper{err: connectErr} + assertResponseHeadersAndTrailers(t, wrapper) }) t.Run("middleware_errors_unary", func(t *testing.T) { request := connect.NewRequest(&pingv1.PingRequest{}) - request.Header().Set(clientMiddlewareErrorHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Set(clientMiddlewareErrorHeader, el) + } _, err := client.Ping(context.Background(), request) assertIsHTTPMiddlewareError(t, err) }) t.Run("middleware_errors_streaming", func(t *testing.T) { request := connect.NewRequest(&pingv1.CountUpRequest{Number: 10}) - request.Header().Set(clientMiddlewareErrorHeader, headerValue) + for _, el := range expectedHeaderValues { + request.Header().Set(clientMiddlewareErrorHeader, el) + } stream, err := client.CountUp(context.Background(), request) assert.Nil(t, err) assert.False(t, stream.Receive()) @@ -2044,6 +2317,9 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { ) assert.Nil(t, err) req.Header.Set("Content-Type", "application/grpc") + for _, el := range expectedHeaderValues { + req.Header.Add(clientHeader, el) + } res, err := server.Client().Do(req) assert.Nil(t, err) assert.Equal(t, res.StatusCode, http.StatusOK) @@ -2768,59 +3044,26 @@ func (p *pluggablePingServer) CumSum( return p.cumSum(ctx, stream) } -func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { - tb.Helper() - if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { - assert.ErrorIs(tb, err, io.EOF) - assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown) +type pingServer struct { + pingv1connect.UnimplementedPingServiceHandler + + // Whether to verify metadata sent to the server. Can be used to force an error returned from the server + // by intentionally sending no metadata. + checkMetadata bool + includeErrorDetails bool +} + +func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + if err := validateRequestInfo(request); err != nil { + return nil, err } - assert.Nil(tb, stream.CloseRequest()) - _, err := stream.Receive() - assert.NotNil(tb, err) // should be 505 - assert.True( - tb, - strings.Contains(err.Error(), "HTTP status 505"), - assert.Sprintf("expected 505, got %v", err), - ) - assert.Nil(tb, stream.CloseResponse()) -} - -func expectClientHeader(check bool, req connect.AnyRequest) error { - if !check { - return nil - } - return expectMetadata(req.Header(), "header", clientHeader, headerValue) -} - -func expectMetadata(meta http.Header, metaType, key, value string) error { - if got := meta.Get(key); got != value { - return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( - "%s %q: got %q, expected %q", - metaType, - key, - got, - value, - )) - } - return nil -} - -type pingServer struct { - pingv1connect.UnimplementedPingServiceHandler - - checkMetadata bool - includeErrorDetails bool -} - -func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { - if err := expectClientHeader(p.checkMetadata, request); err != nil { + if err := compareContextAndRequest(ctx, request, request.Header()); err != nil { return nil, err } - if request.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if request.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + if p.checkMetadata { + if err := expectMetadata(request.Header()); err != nil { + return nil, err + } } response := connect.NewResponse( &pingv1.PingResponse{ @@ -2828,27 +3071,33 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi Text: request.Msg.GetText(), }, ) - response.Header().Set(handlerHeader, headerValue) - response.Trailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := request.Header().Values(clientHeader) + for _, el := range reqHeader { + response.Header().Add(handlerHeader, el) + response.Trailer().Add(handlerTrailer, el) + } + return response, nil } func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { - if err := expectClientHeader(p.checkMetadata, request); err != nil { + if err := validateRequestInfo(request); err != nil { return nil, err } - if request.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if request.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + if err := compareContextAndRequest(ctx, request, request.Header()); err != nil { + return nil, err } err := connect.NewError( connect.Code(request.Msg.GetCode()), errors.New(errorMessage), ) - err.Meta().Set(handlerHeader, headerValue) - err.Meta().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the error metadata headers and trailers + reqHeader := request.Header().Values(clientHeader) + for _, el := range reqHeader { + err.Meta().Add(handlerHeader, el) + err.Meta().Add(handlerTrailer, el) + } if p.includeErrorDetails { detail, derr := connect.NewErrorDetail(&pingv1.FailRequest{Code: request.Msg.GetCode()}) if derr != nil { @@ -2863,17 +3112,17 @@ func (p pingServer) Sum( ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest], ) (*connect.Response[pingv1.SumResponse], error) { + if err := validateRequestInfo(stream); err != nil { + return nil, err + } + if err := compareContextAndRequest(ctx, stream, stream.RequestHeader()); err != nil { + return nil, err + } if p.checkMetadata { - if err := expectMetadata(stream.RequestHeader(), "header", clientHeader, headerValue); err != nil { + if err := expectMetadata(stream.RequestHeader()); err != nil { return nil, err } } - if stream.Peer().Addr == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer address")) - } - if stream.Peer().Protocol == "" { - return nil, connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) - } var sum int64 for stream.Receive() { sum += stream.Msg().GetNumber() @@ -2882,8 +3131,12 @@ func (p pingServer) Sum( return nil, stream.Err() } response := connect.NewResponse(&pingv1.SumResponse{Sum: sum}) - response.Header().Set(handlerHeader, headerValue) - response.Trailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := stream.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + response.Header().Add(handlerHeader, el) + response.Trailer().Add(handlerTrailer, el) + } return response, nil } @@ -2892,14 +3145,16 @@ func (p pingServer) CountUp( request *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse], ) error { - if err := expectClientHeader(p.checkMetadata, request); err != nil { + if err := validateRequestInfo(stream.Conn()); err != nil { return err } - if request.Peer().Addr == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + if err := compareContextAndRequest(ctx, request, request.Header()); err != nil { + return err } - if request.Peer().Protocol == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + if p.checkMetadata { + if err := expectMetadata(request.Header()); err != nil { + return err + } } if request.Msg.GetNumber() <= 0 { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( @@ -2907,8 +3162,12 @@ func (p pingServer) CountUp( request.Msg.GetNumber(), )) } - stream.ResponseHeader().Set(handlerHeader, headerValue) - stream.ResponseTrailer().Set(handlerTrailer, trailerValue) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := request.Header().Values(clientHeader) + for _, el := range reqHeader { + stream.ResponseHeader().Add(handlerHeader, el) + stream.ResponseTrailer().Add(handlerTrailer, el) + } for i := range request.Msg.GetNumber() { if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { return err @@ -2921,32 +3180,153 @@ func (p pingServer) CumSum( ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], ) error { - var sum int64 + return handleCumSum(ctx, stream, p.checkMetadata) +} + +type pingServerSimple struct { + pingv1connectsimple.UnimplementedPingServiceHandler + + checkMetadata bool + includeErrorDetails bool +} + +func (p pingServerSimple) Ping(ctx context.Context, request *pingv1.PingRequest) (*pingv1.PingResponse, error) { + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) + if !ok { + return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) + } + if err := validateRequestInfo(callInfo); err != nil { + return nil, err + } if p.checkMetadata { - if err := expectMetadata(stream.RequestHeader(), "header", clientHeader, headerValue); err != nil { - return err + if err := expectMetadata(callInfo.RequestHeader()); err != nil { + return nil, err } } - if stream.Peer().Addr == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + response := &pingv1.PingResponse{ + Number: request.GetNumber(), + Text: request.GetText(), } - if stream.Peer().Protocol == "" { - return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := callInfo.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + callInfo.ResponseHeader().Add(handlerHeader, el) + callInfo.ResponseTrailer().Add(handlerTrailer, el) } - stream.ResponseHeader().Set(handlerHeader, headerValue) - stream.ResponseTrailer().Set(handlerTrailer, trailerValue) - for { - msg, err := stream.Receive() - if errors.Is(err, io.EOF) { - return nil - } else if err != nil { + return response, nil +} + +func (p pingServerSimple) CountUp( + ctx context.Context, + request *pingv1.CountUpRequest, + stream *connect.ServerStream[pingv1.CountUpResponse], +) error { + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) + if !ok { + return connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) + } + if err := validateRequestInfo(callInfo); err != nil { + return err + } + if p.checkMetadata { + if err := expectMetadata(callInfo.RequestHeader()); err != nil { return err } - sum += msg.GetNumber() - if err := stream.Send(&pingv1.CumSumResponse{Sum: sum}); err != nil { + } + if request.GetNumber() <= 0 { + return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( + "number must be positive: got %v", + request.GetNumber(), + )) + } + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := callInfo.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + callInfo.ResponseHeader().Add(handlerHeader, el) + callInfo.ResponseTrailer().Add(handlerTrailer, el) + } + for i := range request.GetNumber() { + if err := stream.Send(&pingv1.CountUpResponse{Number: i + 1}); err != nil { return err } } + return nil +} + +func (p pingServerSimple) Fail(ctx context.Context, request *pingv1.FailRequest) (*pingv1.FailResponse, error) { + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) + if !ok { + return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) + } + if err := validateRequestInfo(callInfo); err != nil { + return nil, err + } + if p.checkMetadata { + if err := expectMetadata(callInfo.RequestHeader()); err != nil { + return nil, err + } + } + err := connect.NewError( + connect.Code(request.GetCode()), + errors.New(errorMessage), + ) + // Copy the values sent in the client request header to the error metadata headers and trailers + reqHeader := callInfo.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + err.Meta().Add(handlerHeader, el) + err.Meta().Add(handlerTrailer, el) + } + if p.includeErrorDetails { + detail, derr := connect.NewErrorDetail(&pingv1.FailRequest{Code: request.GetCode()}) + if derr != nil { + return nil, derr + } + err.AddDetail(detail) + } + return nil, err +} + +func (p pingServerSimple) Sum( + ctx context.Context, + stream *connect.ClientStream[pingv1.SumRequest], +) (*pingv1.SumResponse, error) { + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) + if !ok { + return nil, connect.NewError(connect.CodeInternal, errors.New("no call info found in context")) + } + if err := validateRequestInfo(callInfo); err != nil { + return nil, err + } + if err := compareContextAndRequest(ctx, stream, stream.RequestHeader()); err != nil { + return nil, err + } + if p.checkMetadata { + if err := expectMetadata(callInfo.RequestHeader()); err != nil { + return nil, err + } + } + var sum int64 + for stream.Receive() { + sum += stream.Msg().GetNumber() + } + if stream.Err() != nil { + return nil, stream.Err() + } + response := &pingv1.SumResponse{Sum: sum} + // Copy the values sent in the client request header to the response headers and trailers + reqHeader := stream.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + callInfo.ResponseHeader().Add(handlerHeader, el) + callInfo.ResponseTrailer().Add(handlerTrailer, el) + } + return response, nil +} + +func (p pingServerSimple) CumSum( + ctx context.Context, + stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], +) error { + return handleCumSum(ctx, stream, p.checkMetadata) } type deflateReader struct { @@ -3035,3 +3415,520 @@ func (failCompressor) Close() error { } func (failCompressor) Reset(io.Writer) {} + +type requestInfo interface { + Peer() connect.Peer + Spec() connect.Spec +} + +type responseInfo interface { + ResponseHeader() http.Header + ResponseTrailer() http.Header +} + +// responseWrapper wraps a Response object so that it can implement the responseInfo interface. +type responseWrapper[Res any] struct { + response *connect.Response[Res] +} + +func (w *responseWrapper[Res]) ResponseHeader() http.Header { + return w.response.Header() +} + +func (w *responseWrapper[Res]) ResponseTrailer() http.Header { + return w.response.Trailer() +} + +// errorWrapper wraps a Connect error so that it can implement the responseInfo interface. +type errorWrapper struct { + err *connect.Error +} + +func (w *errorWrapper) ResponseHeader() http.Header { + return w.err.Meta() +} + +func (w *errorWrapper) ResponseTrailer() http.Header { + return w.err.Meta() +} + +// handleCumSum handles the bidi endpoint CumSum for both pingServer and pingServerSimple. +// The API for bidi-streaming does not change for simple vs. generics API on the server. +func handleCumSum( + ctx context.Context, + stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse], + checkMetadata bool, +) error { + if err := validateRequestInfo(stream); err != nil { + return err + } + if err := compareContextAndRequest(ctx, stream, stream.RequestHeader()); err != nil { + return err + } + if checkMetadata { + if err := expectMetadata(stream.RequestHeader()); err != nil { + return err + } + } + var sum int64 + reqHeader := stream.RequestHeader().Values(clientHeader) + for _, el := range reqHeader { + stream.ResponseHeader().Add(handlerHeader, el) + stream.ResponseTrailer().Add(handlerTrailer, el) + } + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } + sum += msg.GetNumber() + if err := stream.Send(&pingv1.CumSumResponse{Sum: sum}); err != nil { + return err + } + } +} + +func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { + tb.Helper() + if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { + assert.ErrorIs(tb, err, io.EOF) + assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown) + } + assert.Nil(tb, stream.CloseRequest()) + _, err := stream.Receive() + assert.NotNil(tb, err) // should be 505 + assert.True( + tb, + strings.Contains(err.Error(), "HTTP status 505"), + assert.Sprintf("expected 505, got %v", err), + ) + assert.Nil(tb, stream.CloseResponse()) +} + +func testUnaryGenerics(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + num := int64(42) + request := connect.NewRequest(&pingv1.PingRequest{Number: num}) + + ctx, callInfo := connect.NewClientContext(context.Background()) + // With the generics API, a user can use the call info or request wrapper or both to set request headers. + // The resulting headers should be combined and sent in the request. + request.Header().Add(clientHeader, "foo") + callInfo.RequestHeader().Add(clientHeader, "bar") + expect := &pingv1.PingResponse{Number: num} + + response, err := client.Ping(ctx, request) + assert.Nil(t, err) + assert.Equal(t, response.Msg, expect) + // When using the generics API for unary calls, users can access request info such as spec and peer + // either from the call info in context or the request wrapper. This verifies both have the same information. + assert.Equal(t, request.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, request.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, request.Spec().IsClient) + assert.Equal(t, request.Peer().Addr, httptest.DefaultRemoteAddr) + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.PingResponse]{response: response} + + // When using the generics API for unary calls, users can access response headers and trailers + // either from the call info in context or the response wrapper. This verifies both have the same information. + assertResponseHeadersAndTrailers(t, callInfo) + assertResponseHeadersAndTrailers(t, wrapper) +} + +func testUnarySimple(t *testing.T, client pingv1connectsimple.PingServiceClient) { //nolint:thelper + num := int64(42) + ctx, callInfo := connect.NewClientContext(context.Background()) + for _, el := range expectedHeaderValues { + callInfo.RequestHeader().Add(clientHeader, el) + } + expect := &pingv1.PingResponse{Number: num} + response, err := client.Ping(ctx, &pingv1.PingRequest{Number: num}) + assert.Equal(t, response, expect) + assert.Nil(t, err) + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeUnary) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServicePingProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + // When using the simple API for unary calls, users can only access response headers and trailers + // from the call info in context. + assertResponseHeadersAndTrailers(t, callInfo) +} + +func testServerStreamSimple(t *testing.T, client pingv1connectsimple.PingServiceClient) { //nolint:thelper + ctx, callInfo := connect.NewClientContext(context.Background()) + for _, el := range expectedHeaderValues { + callInfo.RequestHeader().Add(clientHeader, el) + } + val := 3 + stream, err := client.CountUp(ctx, &pingv1.CountUpRequest{ + Number: int64(val), + }) + assert.Nil(t, err) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) + assert.Nil(t, stream.Err()) + assert.Nil(t, stream.Close()) + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + // On server-streaming calls, users can access response headers and trailers + // either from the call info in context or from the stream itself. + // This verifies that the both the stream and the call info have the same information + assertResponseHeadersAndTrailers(t, callInfo) + assertResponseHeadersAndTrailers(t, stream) +} + +func testServerStreamGenerics(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + val := 3 + req := connect.NewRequest(&pingv1.CountUpRequest{ + Number: int64(val), + }) + ctx, callInfo := connect.NewClientContext(context.Background()) + // With the generics API, A user can use the call info or request wrapper or both to set request headers. + // The resulting headers should be combined and sent in the request. + callInfo.RequestHeader().Set(clientHeader, "foo") + req.Header().Add(clientHeader, "bar") + + stream, err := client.CountUp(ctx, req) + assert.Nil(t, err) + // Receive expected messages + for idx := range val { + expected := int64(idx + 1) + assert.True(t, stream.Receive()) + assert.Nil(t, stream.Err()) + msg := stream.Msg() + assert.NotNil(t, msg) + assert.Equal(t, msg.GetNumber(), expected) + } + + // Stream should be done. Expect false on receive and close stream + assert.False(t, stream.Receive()) + assert.Nil(t, stream.Err()) + assert.Nil(t, stream.Close()) + // Assert values on request + assert.Equal(t, req.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, req.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, req.Spec().IsClient) + assert.Equal(t, req.Peer().Addr, httptest.DefaultRemoteAddr) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeServer) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + // On server-streaming calls, users can access response headers and trailers + // either from the call info in context or from the stream itself. + // This verifies that the both the stream and the call info have the same information + assertResponseHeadersAndTrailers(t, callInfo) + assertResponseHeadersAndTrailers(t, stream) +} + +func testClientStreamSimple(t *testing.T, client pingv1connectsimple.PingServiceClient) { //nolint:thelper + ctx, callInfo := connect.NewClientContext(context.Background()) + for _, el := range expectedHeaderValues { + callInfo.RequestHeader().Add(clientHeader, el) + } + + const ( + upTo = 10 + expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 + ) + stream, err := client.Sum(ctx) + assert.Nil(t, err) + + // Send messages + for i := range upTo { + err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) + assert.Nil(t, err, assert.Sprintf("send %d", i)) + } + + response, err := stream.CloseAndReceive() + assert.Nil(t, err) + assert.Equal(t, response.GetSum(), expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeClient) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceSumProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + assertResponseHeadersAndTrailers(t, callInfo) +} + +func testClientStreamGenerics(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper + ctx, callInfo := connect.NewClientContext(context.Background()) + callInfo.RequestHeader().Add(clientHeader, "foo") + const ( + upTo = 10 + expect = 55 // 1+10 + 2+9 + ... + 5+6 = 55 + ) + stream := client.Sum(ctx) + stream.RequestHeader().Add(clientHeader, "bar") + + // Send messages + for i := range upTo { + err := stream.Send(&pingv1.SumRequest{Number: int64(i + 1)}) + assert.Nil(t, err, assert.Sprintf("send %d", i)) + } + + response, err := stream.CloseAndReceive() + assert.Nil(t, err) + assert.Equal(t, response.Msg.GetSum(), expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeClient) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceSumProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + // Wrap the response object so that it can implement the responseInfo interface and we can verify its response + // headers and trailers using the same function callInfo does + wrapper := &responseWrapper[pingv1.SumResponse]{response: response} + assertResponseHeadersAndTrailers(t, wrapper) + assertResponseHeadersAndTrailers(t, callInfo) +} + +func testBidiStreamSimple(t *testing.T, client pingv1connectsimple.PingServiceClient, expectSuccess bool) { //nolint:thelper + send := []int64{3, 5, 1} + expect := []int64{3, 8, 9} + var got []int64 + ctx, callInfo := connect.NewClientContext(context.Background()) + callInfo.RequestHeader().Add(clientHeader, "foo") + callInfo.RequestHeader().Add(clientHeader, "bar") + + stream, err := client.CumSum(ctx) + assert.Nil(t, err) + assert.NotNil(t, stream) + + if !expectSuccess { // server doesn't support HTTP/2 + failNoHTTP2(t, stream) + return + } + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i, n := range send { + err := stream.Send(&pingv1.CumSumRequest{Number: n}) + assert.Nil(t, err, assert.Sprintf("send error #%d", i)) + } + assert.Nil(t, stream.CloseRequest()) + }() + go func() { + defer wg.Done() + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.Nil(t, err) + got = append(got, msg.GetSum()) + } + assert.Nil(t, stream.CloseResponse()) + }() + wg.Wait() + assert.Equal(t, got, expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + assertResponseHeadersAndTrailers(t, callInfo) + assertResponseHeadersAndTrailers(t, stream) +} + +func testBidiStreamGenerics(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper + send := []int64{3, 5, 1} + expect := []int64{3, 8, 9} + var got []int64 + ctx, callInfo := connect.NewClientContext(context.Background()) + // With the generics API, A user can use the call info or request wrapper or both to set request headers. + // The resulting headers should be combined and sent in the request. + callInfo.RequestHeader().Add(clientHeader, "foo") + stream := client.CumSum(ctx) + stream.RequestHeader().Add(clientHeader, "bar") + + if !expectSuccess { // server doesn't support HTTP/2 + failNoHTTP2(t, stream) + return + } + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i, n := range send { + err := stream.Send(&pingv1.CumSumRequest{Number: n}) + assert.Nil(t, err, assert.Sprintf("send error #%d", i)) + } + assert.Nil(t, stream.CloseRequest()) + }() + go func() { + defer wg.Done() + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.Nil(t, err) + got = append(got, msg.GetSum()) + } + assert.Nil(t, stream.CloseResponse()) + }() + wg.Wait() + assert.Equal(t, got, expect) + + // Assert values on stream + assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, stream.Spec().IsClient) + assert.Equal(t, stream.Peer().Addr, httptest.DefaultRemoteAddr) + + // Assert the same values are in the call info + assert.Equal(t, callInfo.Spec().StreamType, connect.StreamTypeBidi) + assert.Equal(t, callInfo.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) + assert.True(t, callInfo.Spec().IsClient) + assert.Equal(t, callInfo.Peer().Addr, httptest.DefaultRemoteAddr) + + assertResponseHeadersAndTrailers(t, callInfo) + assertResponseHeadersAndTrailers(t, stream) +} + +// Validates that the peer and spec information is set correctly in a request. +func validateRequestInfo(request requestInfo) error { + if request.Peer().Addr == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer address")) + } + if request.Peer().Protocol == "" { + return connect.NewError(connect.CodeInternal, errors.New("no peer protocol")) + } + if request.Spec().Procedure == "" { + return connect.NewError(connect.CodeInternal, errors.New("no procedure name")) + } + return nil +} + +// Compares the information in the call info in context with the given request information to verify they match. +func compareContextAndRequest(ctx context.Context, request requestInfo, requestHeaders http.Header) error { + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) + if !ok { + return connect.NewError(connect.CodeInternal, errors.New("no call info in handler context")) + } + if request.Peer().Addr != callInfo.Peer().Addr { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched peer address. found %s in request and %s in call info", request.Peer().Addr, callInfo.Peer().Addr)) + } + if request.Peer().Protocol != callInfo.Peer().Protocol { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched peer protocol. found %s in request and %s in call info", request.Peer().Addr, callInfo.Peer().Addr)) + } + if request.Spec().Procedure != callInfo.Spec().Procedure { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched procedure name. found %s in request and %s in call info", request.Spec().Procedure, request.Spec().Procedure)) + } + if valid := compareHeaders(callInfo.RequestHeader(), requestHeaders); !valid { + return connect.NewError(connect.CodeInternal, fmt.Errorf("mismatched request headers. found %+v in request and %+v in call info", callInfo.RequestHeader(), requestHeaders)) + } + return nil +} + +// expectMetadata returns an error if the given http headers don't contain the expected header values. +// Typically, most methods in the pingServer implementations just read the request headers +// and copy those to the response headers and trailers and let the client verify that way. +// However, this method can be used in conjunction with the server's verifyMetadata setting +// to force an error to be returned if metadata isn't set. For example, see +// TestGRPCMissingTrailersError tests. +func expectMetadata(meta http.Header) error { + vals := meta.Values(clientHeader) + if ok := compareValues(vals, expectedHeaderValues); !ok { + return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf( + "header %q: got %q, expected %q", + clientHeader, + vals, + expectedHeaderValues, + )) + } + return nil +} + +// assertResponseHeadersAndTrailers verifies that the given response info contains the expected headers and trailers. +func assertResponseHeadersAndTrailers(t *testing.T, resp responseInfo) { //nolint:thelper + assert.True(t, compareValues(resp.ResponseHeader().Values(handlerHeader), expectedHeaderValues)) + assert.True(t, compareValues(resp.ResponseTrailer().Values(handlerTrailer), expectedHeaderValues)) +} + +// compareHeaders compares two http Header objects to verify they contain the exact same information. +func compareHeaders(hdr1 http.Header, hdr2 http.Header) bool { + if len(hdr1) != len(hdr2) { + return false + } + for key, hdr1Val := range hdr1 { + hdr2Val, ok := hdr2[key] + if !ok || len(hdr1Val) != len(hdr2Val) { + return false + } + + if equal := compareValues(hdr1Val, hdr2Val); !equal { + return false + } + } + return true +} + +// compareValues compares two string slices of header values to verify they are the same, ignoring order. +func compareValues(hdr1 []string, hdr2 []string) bool { + if len(hdr1) != len(hdr2) { + return false + } + // Copy slices to avoid race conditions with other tests trying to read the headers + sorted1 := make([]string, len(hdr1)) + copy(sorted1, hdr1) + sorted2 := make([]string, len(hdr2)) + copy(sorted2, hdr2) + + sort.Strings(sorted1) + sort.Strings(sorted2) + + for idx, el := range sorted1 { + if el != sorted2[idx] { + return false + } + } + return true +} diff --git a/context.go b/context.go index 7a7bdbd5..137c8c90 100644 --- a/context.go +++ b/context.go @@ -19,68 +19,226 @@ import ( "net/http" ) -type requestIncomingHeaderContextKey struct{} -type requestOutgoingHeaderContextKey struct{} -type responseHeaderAddressContextKey struct{} -type responseTrailerAddressContextKey struct{} - -// HeaderFromIncomingContext gets the header from a request sent to a handler. -func HeaderFromIncomingContext(ctx context.Context) (http.Header, bool) { - value, ok := ctx.Value(requestIncomingHeaderContextKey{}).(http.Header) - return value, ok +// CallInfo represents information relevant to an RPC call. +// Values returned by these methods are not thread-safe. Users should expect +// data races if they create an outgoing client CallInfo in context and then pass that +// CallInfo to another goroutine and try to call methods on it concurrent with the RPC. +type CallInfo interface { + // Spec returns a description of this call. + Spec() Spec + // Peer describes the other party for this call. + Peer() Peer + // RequestHeader returns the HTTP headers for this request. Headers beginning with + // "Connect-" and "Grpc-" are reserved for use by the Connect and gRPC + // protocols: applications may read them but shouldn't write them. + RequestHeader() http.Header + // ResponseHeader returns the HTTP headers for this response. Headers beginning with + // "Connect-" and "Grpc-" are reserved for use by the Connect and gRPC + // protocols: applications may read them but shouldn't write them. + // On the client side, this method returns nil before + // the call is actually made. After the call is made, for streaming operations, + // this method will block for the server to actually return response headers. + ResponseHeader() http.Header + // ResponseTrailer returns the trailers for this response. Depending on the underlying + // RPC protocol, trailers may be sent as HTTP trailers or a protocol-specific + // block of in-body metadata. + // + // Trailers beginning with "Connect-" and "Grpc-" are reserved for use by the + // Connect and gRPC protocols: applications may read them but shouldn't write + // them. + // + // On the client side, this method returns nil before + // the call is actually made. After the call is made, for streaming operations, + // this method will block for the server to actually return response trailers. + ResponseTrailer() http.Header + // HTTPMethod returns the HTTP method for this request. This is nearly always + // POST, but side-effect-free unary RPCs could be made via a GET. + // + // On a newly created request, via NewRequest, this will return the empty + // string until the actual request is actually sent and the HTTP method + // determined. This means that client interceptor functions will see the + // empty string until *after* they delegate to the handler they wrapped. It + // is even possible for this to return the empty string after such delegation, + // if the request was never actually sent to the server (and thus no + // determination ever made about the HTTP method). + HTTPMethod() string + + internalOnly() } -// HeaderFromOutgoingContext gets the header from a request sent by a client. -func HeaderFromOutgoingContext(ctx context.Context) (http.Header, bool) { - value, ok := ctx.Value(requestOutgoingHeaderContextKey{}).(http.Header) +// Create a new client (i.e. outgoing) context for use from a client. When the +// returned context is passed to RPCs, the returned call info can be used to set +// request metadata before the RPC is invoked and to inspect response +// metadata after the RPC completes. +// +// The returned context may be re-used across RPCs as long as they are +// not concurrent. Results of all CallInfo methods other than +// RequestHeader() are undefined if the context is used with concurrent RPCs. +func NewClientContext(ctx context.Context) (context.Context, CallInfo) { + info := &clientCallInfo{} + return context.WithValue(ctx, clientCallInfoContextKey{}, info), info +} + +// CallInfoFromHandlerContext returns the CallInfo for the given handler (i.e. incoming) context, if there is one. +func CallInfoFromHandlerContext(ctx context.Context) (CallInfo, bool) { + value, ok := ctx.Value(handlerCallInfoContextKey{}).(CallInfo) return value, ok } -// WithIncomingHeader adds the header to the context from a request sent to a handler. -func WithIncomingHeader(ctx context.Context, header http.Header) context.Context { - return context.WithValue(ctx, requestIncomingHeaderContextKey{}, header) +// handlerCallInfo is a CallInfo implementation used for handlers. +type handlerCallInfo struct { + spec Spec + peer Peer + method string + requestHeader http.Header + responseHeader http.Header + responseTrailer http.Header +} + +func (c *handlerCallInfo) Spec() Spec { + return c.spec +} + +func (c *handlerCallInfo) Peer() Peer { + return c.peer +} + +func (c *handlerCallInfo) RequestHeader() http.Header { + if c.requestHeader == nil { + c.requestHeader = make(http.Header) + } + return c.requestHeader +} + +func (c *handlerCallInfo) ResponseHeader() http.Header { + if c.responseHeader == nil { + c.responseHeader = make(http.Header) + } + return c.responseHeader +} + +func (c *handlerCallInfo) ResponseTrailer() http.Header { + if c.responseTrailer == nil { + c.responseTrailer = make(http.Header) + } + return c.responseTrailer +} + +func (c *handlerCallInfo) HTTPMethod() string { + return c.method +} + +// internalOnly implements CallInfo. +func (c *handlerCallInfo) internalOnly() {} + +// streamCallInfo is a CallInfo implementation used for streaming RPC handlers. +type streamCallInfo struct { + conn StreamingHandlerConn +} + +func (c *streamCallInfo) Spec() Spec { + return c.conn.Spec() +} + +func (c *streamCallInfo) Peer() Peer { + return c.conn.Peer() +} + +func (c *streamCallInfo) RequestHeader() http.Header { + return c.conn.RequestHeader() } -// WithOutgoingHeader adds the header to the context from a request sent by a client. -func WithOutgoingHeader(ctx context.Context, header http.Header) context.Context { - return context.WithValue(ctx, requestOutgoingHeaderContextKey{}, header) +func (c *streamCallInfo) ResponseHeader() http.Header { + return c.conn.ResponseHeader() } -// WithStoreResponseHeader returns a new context to be given to a client when making a request -// that will result in the header pointer being set to the response header. -func WithStoreResponseHeader(ctx context.Context, header *http.Header) context.Context { - return context.WithValue(ctx, responseHeaderAddressContextKey{}, header) +func (c *streamCallInfo) ResponseTrailer() http.Header { + return c.conn.ResponseTrailer() } -// WithStoreResponseTrailer returns a new context to be given to a client when making a request -// that will result in the trailer pointer being set to the response trailer. -func WithStoreResponseTrailer(ctx context.Context, trailer *http.Header) context.Context { - return context.WithValue(ctx, responseTrailerAddressContextKey{}, trailer) +func (c *streamCallInfo) HTTPMethod() string { + // All stream calls are POSTs + return http.MethodPost } -// SetResponseHeader sets the response header within a simple handler implementation. -func SetResponseHeader(ctx context.Context, header http.Header) { - responseHeaderAddress, ok := ctx.Value(responseHeaderAddressContextKey{}).(*http.Header) - if !ok { - return +// internalOnly implements CallInfo. +func (c *streamCallInfo) internalOnly() {} + +// clientCallInfo is a CallInfo implementation used for clients. +type clientCallInfo struct { + responseSource + spec Spec + peer Peer + method string + requestHeader http.Header +} + +func (c *clientCallInfo) Spec() Spec { + return c.spec +} + +func (c *clientCallInfo) Peer() Peer { + return c.peer +} + +func (c *clientCallInfo) RequestHeader() http.Header { + if c.requestHeader == nil { + c.requestHeader = make(http.Header) } - *responseHeaderAddress = header + return c.requestHeader } -// SetResponseTrailer sets the response trailer within a simple handler implementation. -func SetResponseTrailer(ctx context.Context, trailer http.Header) { - responseTrailerAddress, ok := ctx.Value(responseTrailerAddressContextKey{}).(*http.Header) - if !ok { - return +func (c *clientCallInfo) ResponseHeader() http.Header { + if c.responseSource == nil { + return nil } - *responseTrailerAddress = trailer + return c.responseSource.ResponseHeader() } -func requestFromContext[T any](ctx context.Context, message *T) *Request[T] { - request := NewRequest[T](message) - header, ok := HeaderFromOutgoingContext(ctx) - if ok { - request.setHeader(header) +func (c *clientCallInfo) ResponseTrailer() http.Header { + if c.responseSource == nil { + return nil } - return request + return c.responseSource.ResponseTrailer() +} + +func (c *clientCallInfo) HTTPMethod() string { + return c.method +} + +// internalOnly implements CallInfo. +func (c *clientCallInfo) internalOnly() {} + +type clientCallInfoContextKey struct{} +type sentinelContextKey struct{} +type handlerCallInfoContextKey struct{} + +// responseSource indicates a type that manage response headers and trailers. +type responseSource interface { + ResponseHeader() http.Header + ResponseTrailer() http.Header +} + +// responseWrapper wraps a Response object so that it can implement the responseSource interface. +type responseWrapper[Res any] struct { + response *Response[Res] +} + +func (w *responseWrapper[Res]) ResponseHeader() http.Header { + return w.response.Header() +} + +func (w *responseWrapper[Res]) ResponseTrailer() http.Header { + return w.response.Trailer() +} + +// clientCallInfoFromContext gets the call info from a client/outgoing context. +func clientCallInfoFromContext(ctx context.Context) (*clientCallInfo, bool) { + info, ok := ctx.Value(clientCallInfoContextKey{}).(*clientCallInfo) + return info, ok +} + +// newHandlerContext creates a new handler/incoming context. +func newHandlerContext(ctx context.Context, info CallInfo) context.Context { + return context.WithValue(ctx, handlerCallInfoContextKey{}, info) } diff --git a/error_example_test.go b/error_example_test.go index d8155f75..1bcc68c3 100644 --- a/error_example_test.go +++ b/error_example_test.go @@ -22,7 +22,7 @@ import ( connect "connectrpc.com/connect" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" - "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" ) func ExampleError_Message() { @@ -47,14 +47,15 @@ func ExampleIsNotModifiedError() { // Enable client-side support for HTTP GETs. connect.WithHTTPGet(), ) - req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) - first, err := client.Ping(context.Background(), req) + req := &pingv1.PingRequest{Number: 42} + ctx, callInfo := connect.NewClientContext(context.Background()) + _, err := client.Ping(ctx, req) if err != nil { fmt.Println(err) return } // If the server set an Etag, we can use it to cache the response. - etag := first.Header().Get("Etag") + etag := callInfo.ResponseHeader().Get("Etag") if etag == "" { fmt.Println("no Etag in response headers") return @@ -62,7 +63,7 @@ func ExampleIsNotModifiedError() { fmt.Println("cached response with Etag", etag) // Now we'd like to make the same request again, but avoid re-fetching the // response if possible. - req.Header().Set("If-None-Match", etag) + callInfo.RequestHeader().Set("If-None-Match", etag) _, err = client.Ping(context.Background(), req) if connect.IsNotModifiedError(err) { fmt.Println("can reuse cached response") diff --git a/error_not_modified_example_test.go b/error_not_modified_example_test.go index fe520548..e87cd07a 100644 --- a/error_not_modified_example_test.go +++ b/error_not_modified_example_test.go @@ -16,12 +16,13 @@ package connect_test import ( "context" + "errors" "net/http" "strconv" connect "connectrpc.com/connect" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" - "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" ) // ExampleCachingServer is an example of how servers can take advantage the @@ -35,22 +36,27 @@ type ExampleCachingPingServer struct { // indicates this), so clients using the Connect protocol may call it with HTTP // GET requests. This implementation uses Etags to manage client-side caching. func (*ExampleCachingPingServer) Ping( - _ context.Context, - req *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { - resp := connect.NewResponse(&pingv1.PingResponse{ - Number: req.Msg.GetNumber(), - }) + ctx context.Context, + req *pingv1.PingRequest, +) (*pingv1.PingResponse, error) { + resp := &pingv1.PingResponse{ + Number: req.GetNumber(), + } + callInfo, ok := connect.CallInfoFromHandlerContext(ctx) + if !ok { + return nil, errors.New("no call info found in context") + } + // Our hashing logic is simple: we use the number in the PingResponse. - hash := strconv.FormatInt(resp.Msg.GetNumber(), 10) + hash := strconv.FormatInt(resp.GetNumber(), 10) // If the request was an HTTP GET, we'll need to check if the client already // has the response cached. - if req.HTTPMethod() == http.MethodGet && req.Header().Get("If-None-Match") == hash { + if callInfo.HTTPMethod() == http.MethodGet && callInfo.RequestHeader().Get("If-None-Match") == hash { return nil, connect.NewNotModifiedError(http.Header{ "Etag": []string{hash}, }) } - resp.Header().Set("Etag", hash) + callInfo.ResponseHeader().Set("Etag", hash) return resp, nil } diff --git a/example_init_test.go b/example_init_test.go index e7abee52..e79d95b8 100644 --- a/example_init_test.go +++ b/example_init_test.go @@ -17,7 +17,7 @@ package connect_test import ( "net/http" - "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp" ) @@ -32,6 +32,6 @@ func init() { // deadlock, see: // (https://github.com/golang/go/issues/48394) mux := http.NewServeMux() - mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + mux.Handle(pingv1connect.NewPingServiceHandler(pingServerSimple{})) examplePingServer = memhttp.NewServer(mux) } diff --git a/handler.go b/handler.go index 9fc9bb6e..424225b5 100644 --- a/handler.go +++ b/handler.go @@ -66,12 +66,36 @@ func NewUnaryHandler[Req, Res any]( if err != nil { return err } + // Add the request header to the context, and store the response header + // and trailer to propagate back to the caller. + info := &handlerCallInfo{ + peer: request.Peer(), + spec: request.Spec(), + method: request.HTTPMethod(), + requestHeader: request.Header(), + } + ctx = newHandlerContext(ctx, info) response, err := untyped(ctx, request) if err != nil { return err } - mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) - mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) + + // Add response headers/trailers from the context callinfo into the conn if they exist + if info.responseHeader != nil { + mergeNonProtocolHeaders(conn.ResponseHeader(), info.responseHeader) + } + if info.responseTrailer != nil { + mergeNonProtocolHeaders(conn.ResponseTrailer(), info.responseTrailer) + } + + // Add response headers/trailers from the response into the conn if they exist + if len(response.Header()) != 0 { + mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) + } + if len(response.Trailer()) != 0 { + mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) + } + return conn.Send(response.Any()) } @@ -98,28 +122,11 @@ func NewUnaryHandlerSimple[Req, Res any]( return NewUnaryHandler( procedure, func(ctx context.Context, request *Request[Req]) (*Response[Res], error) { - var responseHeader http.Header - var responseTrailer http.Header - // Add the request header to the context, and store the response header - // and trailer to propagate back to the caller. - ctx = WithIncomingHeader( - WithStoreResponseHeader( - WithStoreResponseTrailer( - ctx, - &responseTrailer, - ), - &responseHeader, - ), - request.Header(), - ) responseMsg, err := unary(ctx, request.Msg) - if responseMsg != nil { - response := NewResponse(responseMsg) - response.setHeader(responseHeader) - response.setTrailer(responseHeader) - return response, err + if err != nil { + return nil, err } - return nil, err + return NewResponse(responseMsg), nil }, options..., ) @@ -139,6 +146,9 @@ func NewClientStreamHandler[Req, Res any]( conn: conn, initializer: config.Initializer, } + ctx = newHandlerContext(ctx, &streamCallInfo{ + conn: conn, + }) res, err := implementation(ctx, stream) if err != nil { return err @@ -155,6 +165,29 @@ func NewClientStreamHandler[Req, Res any]( ) } +// NewClientStreamHandlerSimple constructs a [Handler] for a request-streaming procedure +// using the function signature associated with the "simple" generation option. +// +// This option eliminates the [Response] wrapper, and instead uses the context.Context +// to propagate information such as headers. +func NewClientStreamHandlerSimple[Req, Res any]( + procedure string, + implementation func(context.Context, *ClientStream[Req]) (*Res, error), + options ...HandlerOption, +) *Handler { + return NewClientStreamHandler( + procedure, + func(ctx context.Context, stream *ClientStream[Req]) (*Response[Res], error) { + responseMsg, err := implementation(ctx, stream) + if err != nil { + return nil, err + } + return NewResponse(responseMsg), nil + }, + options..., + ) +} + // NewServerStreamHandler constructs a [Handler] for a server streaming procedure. func NewServerStreamHandler[Req, Res any]( procedure string, @@ -169,6 +202,9 @@ func NewServerStreamHandler[Req, Res any]( if err != nil { return err } + ctx = newHandlerContext(ctx, &streamCallInfo{ + conn: conn, + }) return implementation(ctx, req, &ServerStream[Res]{conn: conn}) }, ) @@ -187,11 +223,6 @@ func NewServerStreamHandlerSimple[Req, Res any]( return NewServerStreamHandler( procedure, func(ctx context.Context, request *Request[Req], serverStream *ServerStream[Res]) error { - // Add the request header to the context. - ctx = WithIncomingHeader( - ctx, - request.Header(), - ) return implementation(ctx, request.Msg, serverStream) }, options..., @@ -208,6 +239,9 @@ func NewBidiStreamHandler[Req, Res any]( return newStreamHandler( config, func(ctx context.Context, conn StreamingHandlerConn) error { + ctx = newHandlerContext(ctx, &streamCallInfo{ + conn: conn, + }) return implementation( ctx, &BidiStream[Req, Res]{ diff --git a/handler_example_test.go b/handler_example_test.go index 9fe849dd..36843dff 100644 --- a/handler_example_test.go +++ b/handler_example_test.go @@ -43,7 +43,7 @@ func (*ExamplePingServer) Ping( } // Sum implements pingv1connect.PingServiceHandler. -func (p *ExamplePingServer) Sum(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { +func (p *ExamplePingServer) Sum(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*pingv1.SumResponse, error) { var sum int64 for stream.Receive() { sum += stream.Msg().GetNumber() @@ -51,7 +51,7 @@ func (p *ExamplePingServer) Sum(ctx context.Context, stream *connect.ClientStrea if stream.Err() != nil { return nil, stream.Err() } - return connect.NewResponse(&pingv1.SumResponse{Sum: sum}), nil + return &pingv1.SumResponse{Sum: sum}, nil } // CountUp implements pingv1connect.PingServiceHandler. diff --git a/handler_ext_test.go b/handler_ext_test.go index c5541222..0ac2deb9 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -380,6 +380,47 @@ func TestDynamicHandler(t *testing.T) { } assert.Equal(t, rsp.Msg.Sum, 42*2) }) + t.Run("clientStreamSimple", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Sum") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + dynamicSum := func(_ context.Context, stream *connect.ClientStream[dynamicpb.Message]) (*dynamicpb.Message, error) { + var sum int64 + for stream.Receive() { + got := stream.Msg().Get( + methodDesc.Input().Fields().ByName("number"), + ).Int() + sum += got + } + msg := dynamicpb.NewMessage(methodDesc.Output()) + msg.Set( + methodDesc.Output().Fields().ByName("sum"), + protoreflect.ValueOfInt64(sum), + ) + return msg, nil + } + mux := http.NewServeMux() + mux.Handle("/connect.ping.v1.PingService/Sum", + connect.NewClientStreamHandlerSimple( + "/connect.ping.v1.PingService/Sum", + dynamicSum, + connect.WithSchema(methodDesc), + connect.WithRequestInitializer(initializer), + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + stream := client.Sum(context.Background()) + assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 42})) + assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 42})) + rsp, err := stream.CloseAndReceive() + if !assert.Nil(t, err) { + return + } + assert.Equal(t, rsp.Msg.Sum, 42*2) + }) t.Run("serverStream", func(t *testing.T) { t.Parallel() desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CountUp") diff --git a/interceptor.go b/interceptor.go index f0c3620a..d3bc1374 100644 --- a/interceptor.go +++ b/interceptor.go @@ -16,6 +16,13 @@ package connect import ( "context" + "errors" +) + +var ( + // errNewClientContextProhibited signals that a new client context was created + // in an interceptor, which is prohibited. + errNewClientContextProhibited = errors.New("creating a new context in an interceptor is prohibited") ) // UnaryFunc is the generic signature of a unary RPC. Interceptors may wrap @@ -36,9 +43,8 @@ type StreamingHandlerFunc func(context.Context, StreamingHandlerConn) error // An Interceptor adds logic to a generated handler or client, like the // decorators or middleware you may have seen in other libraries. Interceptors -// may replace the context, mutate requests and responses, handle errors, -// retry, recover from panics, emit logs and metrics, or do nearly anything -// else. +// may mutate requests and responses, handle errors, retry, recover from panics, +// emit logs and metrics, or do nearly anything else. // // The returned functions must be safe to call concurrently. type Interceptor interface { @@ -85,6 +91,7 @@ func newChain(interceptors []Interceptor) *chain { func (c *chain) WrapUnary(next UnaryFunc) UnaryFunc { for _, interceptor := range c.interceptors { + next = unaryThunk(next) next = interceptor.WrapUnary(next) } return next @@ -92,6 +99,7 @@ func (c *chain) WrapUnary(next UnaryFunc) UnaryFunc { func (c *chain) WrapStreamingClient(next StreamingClientFunc) StreamingClientFunc { for _, interceptor := range c.interceptors { + next = streamingClientThunk(next) next = interceptor.WrapStreamingClient(next) } return next @@ -103,3 +111,28 @@ func (c *chain) WrapStreamingHandler(next StreamingHandlerFunc) StreamingHandler } return next } + +func unaryThunk(next UnaryFunc) UnaryFunc { + return func(ctx context.Context, req AnyRequest) (AnyResponse, error) { + if err := checkSentinel(ctx); err != nil { + return nil, err + } + return next(ctx, req) + } +} + +func streamingClientThunk(next StreamingClientFunc) StreamingClientFunc { + return func(ctx context.Context, spec Spec) StreamingClientConn { + if err := checkSentinel(ctx); err != nil { + return &errStreamingClientConn{err: err} + } + return next(ctx, spec) + } +} + +func checkSentinel(ctx context.Context) error { + if ctx.Value(clientCallInfoContextKey{}) != ctx.Value(sentinelContextKey{}) { + return errNewClientContextProhibited + } + return nil +} diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index b5892fde..1e6003c2 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -16,7 +16,9 @@ package connect_test import ( "context" + "errors" "fmt" + "io" "net/http" "sync/atomic" "testing" @@ -25,9 +27,475 @@ import ( "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/generics/connect/ping/v1/pingv1connect" + pingv1connectsimple "connectrpc.com/connect/internal/gen/simple/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" ) +const expectedContextErrorMessage = "creating a new context in an interceptor is prohibited" + +func TestNewClientContextInInterceptor(t *testing.T) { + t.Parallel() + t.Run("simple_api", func(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle( + pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + ), + ) + server := memhttptest.NewServer(t, mux) + t.Run("first_interceptor", func(t *testing.T) { + // Because we're creating a new context in the first interceptor, only the first interceptor should fire + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connectsimple.PingServiceClient { + opts := connect.WithInterceptors( + &contextInterceptor{client: true, count: counter1, createNewContext: true}, + &contextInterceptor{client: true, count: counter2}, + ) + return pingv1connectsimple.NewPingServiceClient( + server.Client(), + server.URL(), + opts, + ) + } + t.Run("unary", func(t *testing.T) { + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + resp, err := client.Ping(context.Background(), &pingv1.PingRequest{Number: 10}) + + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream, err := client.CountUp(context.Background(), &pingv1.CountUpRequest{Number: 10}) + + assert.Nil(t, stream) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + // With client-streaming and the simple API, the initial call fails. This differs from + // the generics API which requires a call to stream.Send first to receive an error. + stream, err := client.Sum(context.Background()) + assert.NotNil(t, err) + assert.Nil(t, stream) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + // With bidi-streaming and the simple API, the initial call fails. This differs from + // the generics API which requires a call to stream.Send first to receive an error. + stream, err := client.CumSum(context.Background()) + assert.NotNil(t, err) + assert.Nil(t, stream) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + }) + t.Run("subsequent_interceptor", func(t *testing.T) { + // Because we're creating a new context in the last interceptor, all interceptors should fire + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connectsimple.PingServiceClient { + opts := connect.WithInterceptors( + &contextInterceptor{client: true, count: counter1}, + &contextInterceptor{client: true, count: counter2, createNewContext: true}, + ) + return pingv1connectsimple.NewPingServiceClient( + server.Client(), + server.URL(), + opts, + ) + } + t.Run("unary", func(t *testing.T) { + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + resp, err := client.Ping(context.Background(), &pingv1.PingRequest{Number: 10}) + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CountUp(context.Background(), &pingv1.CountUpRequest{Number: 10}) + assert.Nil(t, stream) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + // With client-streaming and the simple API, the initial call fails. This differs from + // the generics API which requires a call to stream.Send first to receive an error. + stream, err := client.Sum(context.Background()) + assert.NotNil(t, err) + assert.Nil(t, stream) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + // With bidi-streaming and the simple API, the initial call fails. This differs from + // the generics API which requires a call to stream.Send first to receive an error. + stream, err := client.CumSum(context.Background()) + assert.NotNil(t, err) + assert.Nil(t, stream) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + }) + t.Run("sidequest_succeeds", func(t *testing.T) { + // These tests create a new context but it is used to issue a separate/new request and not reused in the + // interceptor chain. So, all interceptors should fire and no errors should be returned. + t.Parallel() + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connectsimple.PingServiceClient { + opts := connect.WithInterceptors( + newSideQuestInterceptor(t, counter1, server), + newSideQuestInterceptor(t, counter2, server), + ) + return pingv1connectsimple.NewPingServiceClient( + server.Client(), + server.URL(), + opts, + ) + } + t.Run("unary", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + resp, err := client.Ping(context.Background(), &pingv1.PingRequest{Number: 10}) + assert.NotNil(t, resp) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CountUp(context.Background(), &pingv1.CountUpRequest{Number: 10}) + assert.NotNil(t, stream) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.Sum(context.Background()) + assert.NotNil(t, stream) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CumSum(context.Background()) + assert.Nil(t, err) + assert.NotNil(t, stream) + + assert.Nil(t, stream.CloseRequest()) + assert.Nil(t, stream.CloseResponse()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + }) + }) + t.Run("generics_api", func(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle( + pingv1connectsimple.NewPingServiceHandler( + pingServerSimple{}, + ), + ) + server := memhttptest.NewServer(t, mux) + t.Run("first_interceptor", func(t *testing.T) { + // Because we're creating a new context in the first interceptor, only the first interceptor should fire + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connect.PingServiceClient { + opts := connect.WithInterceptors( + &contextInterceptor{client: true, count: counter1, createNewContext: true}, + &contextInterceptor{client: true, count: counter2}, + ) + return pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + opts, + ) + } + t.Run("unary", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + resp, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) + assert.Nil(t, stream) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.Sum(context.Background()) + assert.NotNil(t, stream) + + // With client-streaming and the generics API, a call to stream.Send is required to receive an error. + err := stream.Send(&pingv1.SumRequest{Number: int64(1)}) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + // We should receive the same error when we try to close the stream + resp, err := stream.CloseAndReceive() + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + //nolint:dupl // the test logic for bidi w/r/t generic and simple api looks the same, but it's subtly different + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.CumSum(context.Background()) + assert.NotNil(t, stream) + + // With bidi-streaming and the generics API, a call to stream.Send is required to receive an error. + err := stream.Send(&pingv1.CumSumRequest{Number: 1}) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + // We should receive the same error when we try to close the send and receive parts of the stream + err = stream.CloseRequest() + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + err = stream.CloseResponse() + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(0), clientCounter2.Load()) + }) + }) + + t.Run("subsequent_interceptor", func(t *testing.T) { + // Because we're creating a new context in the last interceptor, all interceptors should fire + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connect.PingServiceClient { + opts := connect.WithInterceptors( + &contextInterceptor{client: true, count: counter1}, + &contextInterceptor{client: true, count: counter2, createNewContext: true}, + ) + return pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + opts, + ) + } + t.Run("unary", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + resp, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) + assert.Nil(t, stream) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.Sum(context.Background()) + assert.NotNil(t, stream) + + // With client-streaming and the generics API, a call to stream.Send is required to receive an error. + err := stream.Send(&pingv1.SumRequest{Number: int64(1)}) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + // We should receive the same error when we try to close the stream + resp, err := stream.CloseAndReceive() + assert.Nil(t, resp) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + //nolint:dupl // the test logic for bidi w/r/t generic and simple api looks the same, but it's subtly different + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.CumSum(context.Background()) + assert.NotNil(t, stream) + + // With bidi-streaming and the generics API, a call to stream.Send is required to receive an error. + err := stream.Send(&pingv1.CumSumRequest{Number: 1}) + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + // We should receive the same error when we try to close the send and receive parts of the stream + err = stream.CloseRequest() + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + err = stream.CloseResponse() + assert.NotNil(t, err) + assert.Equal(t, err.Error(), expectedContextErrorMessage) + + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + }) + t.Run("sidequest_succeeds", func(t *testing.T) { + // These tests create a new context but it is used to issue a separate/new request and not reused in the + // interceptor chain. So, all interceptors should fire and no errors should be returned. + createClient := func(counter1 *atomic.Int32, counter2 *atomic.Int32) pingv1connect.PingServiceClient { + opts := connect.WithInterceptors( + newSideQuestInterceptor(t, counter1, server), + newSideQuestInterceptor(t, counter2, server), + ) + return pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + opts, + ) + } + t.Run("unary", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + resp, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) + assert.NotNil(t, resp) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("server_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) + assert.NotNil(t, stream) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("client_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.Sum(context.Background()) + assert.NotNil(t, stream) + + err := stream.Send(&pingv1.SumRequest{Number: int64(1)}) + assert.Nil(t, err) + resp, err := stream.CloseAndReceive() + assert.NotNil(t, resp) + assert.Nil(t, err) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + var clientCounter1, clientCounter2 atomic.Int32 + client := createClient(&clientCounter1, &clientCounter2) + stream := client.CumSum(context.Background()) + assert.NotNil(t, stream) + + // The initial send should succeed + err := stream.Send(&pingv1.CumSumRequest{Number: 1}) + assert.Nil(t, err) + + // We should be able to successfully close the send part of the stream + assert.Nil(t, stream.CloseRequest()) + + // All receives should succeed + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + break + } + assert.NotNil(t, msg) + assert.Nil(t, err) + } + // We should be able to successfully close the receive part of the stream + assert.Nil(t, stream.CloseResponse()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + }) + }) + }) +} + func TestOnionOrderingEndToEnd(t *testing.T) { t.Parallel() // Helper function: returns a function that asserts that there's some value @@ -68,7 +536,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { } } - var client1, client2, client3, handler1, handler2, handler3 atomic.Int32 + var clientCounter1, clientCounter2, clientCounter3, handlerCounter1, handlerCounter2, handlerCounter3 atomic.Int32 // The client and handler interceptor onions are the meat of the test. The // order of interceptor execution must be the same for unary and streaming @@ -83,7 +551,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { // intended order clear. clientOnion := connect.WithInterceptors( newHeaderInterceptor( - &client1, + &clientCounter1, // 1 (start). request: should see protocol-related headers func(_ connect.Spec, h http.Header) { assert.NotZero(t, h.Get("Content-Type")) @@ -92,29 +560,29 @@ func TestOnionOrderingEndToEnd(t *testing.T) { assertAllPresent, ), newHeaderInterceptor( - &client2, + &clientCounter2, newInspector("", "one"), // 2. request: add header "one" newInspector("three", "four"), // 11. response: check "three", add "four" ), newHeaderInterceptor( - &client3, + &clientCounter3, newInspector("one", "two"), // 3. request: check "one", add "two" newInspector("two", "three"), // 10. response: check "two", add "three" ), ) handlerOnion := connect.WithInterceptors( newHeaderInterceptor( - &handler1, + &handlerCounter1, newInspector("two", "three"), // 4. request: check "two", add "three" newInspector("one", "two"), // 9. response: check "one", add "two" ), newHeaderInterceptor( - &handler2, + &handlerCounter2, newInspector("three", "four"), // 5. request: check "three", add "four" newInspector("", "one"), // 8. response: add "one" ), newHeaderInterceptor( - &handler3, + &handlerCounter3, assertAllPresent, // 6. request: check "one"-"four" nil, // 7. response: no-op ), @@ -138,12 +606,12 @@ func TestOnionOrderingEndToEnd(t *testing.T) { assert.Nil(t, err) // make sure the interceptors were actually invoked - assert.Equal(t, int32(1), client1.Load()) - assert.Equal(t, int32(1), client2.Load()) - assert.Equal(t, int32(1), client3.Load()) - assert.Equal(t, int32(1), handler1.Load()) - assert.Equal(t, int32(1), handler2.Load()) - assert.Equal(t, int32(1), handler3.Load()) + assert.Equal(t, int32(1), clientCounter1.Load()) + assert.Equal(t, int32(1), clientCounter2.Load()) + assert.Equal(t, int32(1), clientCounter3.Load()) + assert.Equal(t, int32(1), handlerCounter1.Load()) + assert.Equal(t, int32(1), handlerCounter2.Load()) + assert.Equal(t, int32(1), handlerCounter3.Load()) responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) assert.Nil(t, err) @@ -155,12 +623,12 @@ func TestOnionOrderingEndToEnd(t *testing.T) { assert.Nil(t, responses.Close()) // make sure the interceptors were invoked again - assert.Equal(t, int32(2), client1.Load()) - assert.Equal(t, int32(2), client2.Load()) - assert.Equal(t, int32(2), client3.Load()) - assert.Equal(t, int32(2), handler1.Load()) - assert.Equal(t, int32(2), handler2.Load()) - assert.Equal(t, int32(2), handler3.Load()) + assert.Equal(t, int32(2), clientCounter1.Load()) + assert.Equal(t, int32(2), clientCounter2.Load()) + assert.Equal(t, int32(2), clientCounter3.Load()) + assert.Equal(t, int32(2), handlerCounter1.Load()) + assert.Equal(t, int32(2), handlerCounter2.Load()) + assert.Equal(t, int32(2), handlerCounter3.Load()) } func TestEmptyUnaryInterceptorFunc(t *testing.T) { @@ -349,7 +817,7 @@ type httpMethodChecker struct { count atomic.Int32 } -func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { +func (h *httpMethodChecker) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { h.count.Add(1) if h.client { @@ -365,7 +833,7 @@ func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.Unary return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) } } - resp, err := unaryFunc(ctx, req) + resp, err := next(ctx, req) // NB: In theory, the method could also be GET, not just POST. But for the // configuration under test, it will always be POST. if req.HTTPMethod() != http.MethodPost { @@ -390,3 +858,96 @@ func (h *httpMethodChecker) WrapStreamingHandler(handlerFunc connect.StreamingHa return handlerFunc(ctx, conn) } } + +type contextInterceptor struct { + client bool + count *atomic.Int32 + // Whether the interceptor should attempt to create a new context (which will cause next() to return an error) + createNewContext bool +} + +func (h *contextInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + h.count.Add(1) + if h.createNewContext { + // This will cause next to return an error + ctx, _ = connect.NewClientContext(ctx) + } + return next(ctx, req) + } +} + +func (h *contextInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + h.count.Add(1) + if h.createNewContext { + // This will cause next to return an error + ctx, _ = connect.NewClientContext(ctx) + } + return next(ctx, spec) + } +} + +func (h *contextInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + h.count.Add(1) + return next(ctx, conn) + } +} + +type sideQuestInterceptor struct { + count *atomic.Int32 + client pingv1connect.PingServiceClient + t *testing.T +} + +func newSideQuestInterceptor( //nolint:thelper + t *testing.T, + counter *atomic.Int32, + server *memhttp.Server, +) *sideQuestInterceptor { + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + ) + return &sideQuestInterceptor{t: t, client: client, count: counter} +} + +func (h *sideQuestInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + h.count.Add(1) + num := int64(42) + // Create a new client context for the side quest Ping. This should succeed because we aren't + // sending this on through the interceptor chain and reusing this context + newCtx, _ := connect.NewClientContext(ctx) + resp, err := h.client.Ping(newCtx, connect.NewRequest(&pingv1.PingRequest{Number: num})) + assert.Nil(h.t, err) + assert.Equal(h.t, resp.Msg.Number, num) + + return next(ctx, req) + } +} + +func (h *sideQuestInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + h.count.Add(1) + // Create a new context for the side quest CountUp. This should succeed because we aren't + // sending this on through the interceptor chain and reusing this context + newCtx, _ := connect.NewClientContext(ctx) + responses, err := h.client.CountUp(newCtx, connect.NewRequest(&pingv1.CountUpRequest{Number: 3})) + assert.Nil(h.t, err) + var sum int64 + for responses.Receive() { + sum += responses.Msg().GetNumber() + } + assert.Equal(h.t, sum, 6) + assert.Nil(h.t, responses.Close()) + return next(ctx, spec) + } +} + +func (h *sideQuestInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + return next(ctx, conn) + } +} diff --git a/internal/gen/simple/connect/collide/v1/collidev1connect/collide.connect.go b/internal/gen/simple/connect/collide/v1/collidev1connect/collide.connect.go index ba3a7f0c..5081e3ad 100644 --- a/internal/gen/simple/connect/collide/v1/collidev1connect/collide.connect.go +++ b/internal/gen/simple/connect/collide/v1/collidev1connect/collide.connect.go @@ -83,7 +83,11 @@ type collideServiceClient struct { // Import calls connect.collide.v1.CollideService.Import. func (c *collideServiceClient) Import(ctx context.Context, req *v1.ImportRequest) (*v1.ImportResponse, error) { - return c._import.CallUnarySimple(ctx, req) + response, err := c._import.CallUnary(ctx, connect.NewRequest(req)) + if response != nil { + return response.Msg, err + } + return nil, err } // CollideServiceHandler is an implementation of the connect.collide.v1.CollideService service. diff --git a/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go index 9fe1d615..5b0cfaa3 100644 --- a/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/simple/connect/ping/v1/pingv1connect/ping.connect.go @@ -71,11 +71,11 @@ type PingServiceClient interface { // Fail always fails. Fail(context.Context, *v1.FailRequest) (*v1.FailResponse, error) // Sum calculates the sum of the numbers sent on the stream. - Sum(context.Context) *connect.ClientStreamForClient[v1.SumRequest, v1.SumResponse] + Sum(context.Context) (*connect.ClientStreamForClientSimple[v1.SumRequest, v1.SumResponse], error) // CountUp returns a stream of the numbers up to the given request. CountUp(context.Context, *v1.CountUpRequest) (*connect.ServerStreamForClient[v1.CountUpResponse], error) // CumSum determines the cumulative sum of all the numbers sent on the stream. - CumSum(context.Context) *connect.BidiStreamForClient[v1.CumSumRequest, v1.CumSumResponse] + CumSum(context.Context) (*connect.BidiStreamForClient[v1.CumSumRequest, v1.CumSumResponse], error) } // NewPingServiceClient constructs a client for the connect.ping.v1.PingService service. By default, @@ -134,27 +134,35 @@ type pingServiceClient struct { // Ping calls connect.ping.v1.PingService.Ping. func (c *pingServiceClient) Ping(ctx context.Context, req *v1.PingRequest) (*v1.PingResponse, error) { - return c.ping.CallUnarySimple(ctx, req) + response, err := c.ping.CallUnary(ctx, connect.NewRequest(req)) + if response != nil { + return response.Msg, err + } + return nil, err } // Fail calls connect.ping.v1.PingService.Fail. func (c *pingServiceClient) Fail(ctx context.Context, req *v1.FailRequest) (*v1.FailResponse, error) { - return c.fail.CallUnarySimple(ctx, req) + response, err := c.fail.CallUnary(ctx, connect.NewRequest(req)) + if response != nil { + return response.Msg, err + } + return nil, err } // Sum calls connect.ping.v1.PingService.Sum. -func (c *pingServiceClient) Sum(ctx context.Context) *connect.ClientStreamForClient[v1.SumRequest, v1.SumResponse] { - return c.sum.CallClientStream(ctx) +func (c *pingServiceClient) Sum(ctx context.Context) (*connect.ClientStreamForClientSimple[v1.SumRequest, v1.SumResponse], error) { + return c.sum.CallClientStreamSimple(ctx) } // CountUp calls connect.ping.v1.PingService.CountUp. func (c *pingServiceClient) CountUp(ctx context.Context, req *v1.CountUpRequest) (*connect.ServerStreamForClient[v1.CountUpResponse], error) { - return c.countUp.CallServerStreamSimple(ctx, req) + return c.countUp.CallServerStream(ctx, connect.NewRequest(req)) } // CumSum calls connect.ping.v1.PingService.CumSum. -func (c *pingServiceClient) CumSum(ctx context.Context) *connect.BidiStreamForClient[v1.CumSumRequest, v1.CumSumResponse] { - return c.cumSum.CallBidiStream(ctx) +func (c *pingServiceClient) CumSum(ctx context.Context) (*connect.BidiStreamForClient[v1.CumSumRequest, v1.CumSumResponse], error) { + return c.cumSum.CallBidiStreamSimple(ctx) } // PingServiceHandler is an implementation of the connect.ping.v1.PingService service. @@ -164,7 +172,7 @@ type PingServiceHandler interface { // Fail always fails. Fail(context.Context, *v1.FailRequest) (*v1.FailResponse, error) // Sum calculates the sum of the numbers sent on the stream. - Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*connect.Response[v1.SumResponse], error) + Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*v1.SumResponse, error) // CountUp returns a stream of the numbers up to the given request. CountUp(context.Context, *v1.CountUpRequest, *connect.ServerStream[v1.CountUpResponse]) error // CumSum determines the cumulative sum of all the numbers sent on the stream. @@ -191,7 +199,7 @@ func NewPingServiceHandler(svc PingServiceHandler, opts ...connect.HandlerOption connect.WithSchema(pingServiceMethods.ByName("Fail")), connect.WithHandlerOptions(opts...), ) - pingServiceSumHandler := connect.NewClientStreamHandler( + pingServiceSumHandler := connect.NewClientStreamHandlerSimple( PingServiceSumProcedure, svc.Sum, connect.WithSchema(pingServiceMethods.ByName("Sum")), @@ -238,7 +246,7 @@ func (UnimplementedPingServiceHandler) Fail(context.Context, *v1.FailRequest) (* return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Fail is not implemented")) } -func (UnimplementedPingServiceHandler) Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*connect.Response[v1.SumResponse], error) { +func (UnimplementedPingServiceHandler) Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*v1.SumResponse, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Sum is not implemented")) }