diff --git a/adapter/outbound/snell.go b/adapter/outbound/snell.go index c740ab4965..418d6f0a3a 100644 --- a/adapter/outbound/snell.go +++ b/adapter/outbound/snell.go @@ -91,6 +91,9 @@ func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn _ = c.Close() return nil, err } + if pc, ok := c.(*snell.PoolConn); ok { + pc.MarkReusable() + } return NewConn(c, s), err } diff --git a/transport/snell/pool.go b/transport/snell/pool.go index 602f92799e..99de60bbab 100644 --- a/transport/snell/pool.go +++ b/transport/snell/pool.go @@ -5,6 +5,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/metacubex/mihomo/component/pool" @@ -15,6 +16,12 @@ type Pool struct { pool *pool.Pool[*Snell] } +const ( + poolConnNew int32 = iota + poolConnReusable + poolConnClosedBeforeReuse +) + func (p *Pool) Get() (net.Conn, error) { return p.GetContext(context.Background()) } @@ -48,6 +55,8 @@ type PoolConn struct { closeWriteErr error closeOnce sync.Once closeErr error + requestStarted atomic.Bool + reusableState atomic.Int32 } func (pc *PoolConn) Read(b []byte) (int, error) { @@ -59,11 +68,27 @@ func (pc *PoolConn) Read(b []byte) (int, error) { } func (pc *PoolConn) Write(b []byte) (int, error) { - return pc.Snell.Write(b) + n, err := pc.Snell.Write(b) + if err == nil && n == len(b) && len(b) > 0 { + pc.requestStarted.Store(true) + } + return n, err +} + +// MarkReusable allows Close to return this request to the pool after DialContext succeeds. +func (pc *PoolConn) MarkReusable() { + if pc.requestStarted.Load() { + pc.reusableState.CompareAndSwap(poolConnNew, poolConnReusable) + } } func (pc *PoolConn) CloseWrite() error { pc.closeWriteOnce.Do(func() { + if pc.reusableState.Load() != poolConnReusable { + pc.reusableState.CompareAndSwap(poolConnNew, poolConnClosedBeforeReuse) + pc.closeWriteErr = pc.Snell.Close() + return + } pc.closeWriteErr = writeZeroChunk(pc.Snell) }) return pc.closeWriteErr @@ -71,6 +96,10 @@ func (pc *PoolConn) CloseWrite() error { func (pc *PoolConn) Close() error { pc.closeOnce.Do(func() { + if pc.reusableState.Load() != poolConnReusable { + pc.closeErr = pc.CloseWrite() + return + } if err := pc.CloseWrite(); err != nil { pc.closeErr = err _ = pc.Snell.Close() diff --git a/transport/snell/pool_test.go b/transport/snell/pool_test.go index 1fa5983ac9..3e42851cf5 100644 --- a/transport/snell/pool_test.go +++ b/transport/snell/pool_test.go @@ -19,6 +19,10 @@ func TestPoolConnCloseIsIdempotent(t *testing.T) { }) conn := &PoolConn{Snell: pooledConn, pool: pool} + if _, err := conn.Write([]byte{Version, CommandConnectV2, 0}); err != nil { + t.Fatal(err) + } + conn.MarkReusable() if err := conn.Close(); err != nil { t.Fatal(err) } @@ -26,8 +30,8 @@ func TestPoolConnCloseIsIdempotent(t *testing.T) { t.Fatal(err) } - if rawConn.writes != 1 { - t.Fatalf("close should send one half-close record, got %d", rawConn.writes) + if rawConn.writes != 2 { + t.Fatalf("close should send the request and one half-close record, got %d writes", rawConn.writes) } got, err := pool.pool.Get() @@ -39,6 +43,104 @@ func TestPoolConnCloseIsIdempotent(t *testing.T) { } } +func TestPoolConnCloseBeforeRequestClosesRawConnection(t *testing.T) { + rawConn := &recordingConn{} + pooledConn := &Snell{Conn: rawConn} + factoryConn := &Snell{Conn: &recordingConn{}} + pool := NewPool(func(context.Context) (*Snell, error) { + return factoryConn, nil + }) + conn := &PoolConn{Snell: pooledConn, pool: pool} + + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if err := conn.Close(); err != nil { + t.Fatal(err) + } + + if rawConn.writes != 0 { + t.Fatalf("close before request should not send half-close record, got %d writes", rawConn.writes) + } + if !rawConn.closed { + t.Fatal("close before request should close the raw connection") + } + + got, err := pool.pool.Get() + if err != nil { + t.Fatal(err) + } + if got != factoryConn { + t.Fatal("unstarted connection should not be returned to the pool") + } +} + +func TestPoolConnCloseAfterRequestBeforeReusableClosesRawConnection(t *testing.T) { + rawConn := &recordingConn{} + pooledConn := &Snell{Conn: rawConn} + factoryConn := &Snell{Conn: &recordingConn{}} + pool := NewPool(func(context.Context) (*Snell, error) { + return factoryConn, nil + }) + conn := &PoolConn{Snell: pooledConn, pool: pool} + + if _, err := conn.Write([]byte{Version, CommandConnectV2, 0}); err != nil { + t.Fatal(err) + } + if err := conn.Close(); err != nil { + t.Fatal(err) + } + + if rawConn.writes != 1 { + t.Fatalf("close before reusable should only send the request, got %d writes", rawConn.writes) + } + if !rawConn.closed { + t.Fatal("close before reusable should close the raw connection") + } + + got, err := pool.pool.Get() + if err != nil { + t.Fatal(err) + } + if got != factoryConn { + t.Fatal("connection closed before reusable should not be returned to the pool") + } +} + +func TestPoolConnCloseWriteBeforeRequestClosesRawConnection(t *testing.T) { + rawConn := &recordingConn{} + pooledConn := &Snell{Conn: rawConn} + factoryConn := &Snell{Conn: &recordingConn{}} + pool := NewPool(func(context.Context) (*Snell, error) { + return factoryConn, nil + }) + conn := &PoolConn{Snell: pooledConn, pool: pool} + + if err := conn.CloseWrite(); err != nil { + t.Fatal(err) + } + if rawConn.writes != 0 { + t.Fatalf("CloseWrite before request should not send half-close record, got %d writes", rawConn.writes) + } + if !rawConn.closed { + t.Fatal("CloseWrite before request should close the raw connection") + } + + // Simulate a concurrent write reporting success after the unstarted close path won. + conn.requestStarted.Store(true) + conn.MarkReusable() + if err := conn.Close(); err != nil { + t.Fatal(err) + } + got, err := pool.pool.Get() + if err != nil { + t.Fatal(err) + } + if got != factoryConn { + t.Fatal("unstarted connection should not be returned to the pool") + } +} + func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) { rawConn := &recordingConn{} pooledConn := &Snell{Conn: rawConn, reply: true} @@ -48,6 +150,10 @@ func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) { }) conn := &PoolConn{Snell: pooledConn, pool: pool} + if _, err := conn.Write([]byte{Version, CommandConnectV2, 0}); err != nil { + t.Fatal(err) + } + conn.MarkReusable() if err := conn.CloseWrite(); err != nil { t.Fatal(err) } @@ -59,8 +165,8 @@ func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) { if got != factoryConn { t.Fatal("CloseWrite should not put the active connection back into the pool") } - if rawConn.writes != 1 { - t.Fatalf("CloseWrite should send one half-close record, got %d", rawConn.writes) + if rawConn.writes != 2 { + t.Fatalf("CloseWrite should send the request and one half-close record, got %d writes", rawConn.writes) } if !pooledConn.reply { t.Fatal("CloseWrite should not reset reply while the read side may still be active") @@ -76,8 +182,8 @@ func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) { if got != pooledConn { t.Fatal("Close should return the connection to the pool after CloseWrite") } - if rawConn.writes != 1 { - t.Fatalf("Close after CloseWrite should not send another half-close record, got %d", rawConn.writes) + if rawConn.writes != 2 { + t.Fatalf("Close after CloseWrite should not send another half-close record, got %d writes", rawConn.writes) } if pooledConn.reply { t.Fatal("Close should reset reply before returning the connection to the pool")