diff --git a/adapter/outbound/snell.go b/adapter/outbound/snell.go index c740ab4965..1bef77aa0a 100644 --- a/adapter/outbound/snell.go +++ b/adapter/outbound/snell.go @@ -82,16 +82,24 @@ func (s *Snell) writeHeaderContext(ctx context.Context, c net.Conn, metadata *C. // DialContext implements C.ProxyAdapter func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { if s.reuse { - c, err := s.pool.Get() - if err != nil { - return nil, err - } - - if err = s.writeHeaderContext(ctx, c, metadata); err != nil { - _ = c.Close() - return nil, err + // A pooled conn may be stale (server-side half-closed it after the + // previous session but before our maxAge fires). Detect by retrying + // the CONNECT write once — the second pool.Get either yields another + // idle conn or falls through to the factory and dials fresh. + for attempts := 0; attempts < 2; attempts++ { + c, gerr := s.pool.GetContext(ctx) + if gerr != nil { + return nil, gerr + } + if err = s.writeHeaderContext(ctx, c, metadata); err != nil { + _ = c.Close() + continue + } + return NewConn(c, s), nil } - return NewConn(c, s), err + // Both pool attempts yielded stale conns. Fall through to a fresh + // dial below; it bypasses the pool but still writes a reuse-capable + // header so the next call can pool the conn after one session. } c, err := s.dialer.DialContext(ctx, "tcp", s.addr) diff --git a/transport/snell/pool.go b/transport/snell/pool.go index 602f92799e..42f9c555c6 100644 --- a/transport/snell/pool.go +++ b/transport/snell/pool.go @@ -2,6 +2,7 @@ package snell import ( "context" + "errors" "io" "net" "sync" @@ -11,8 +12,31 @@ import ( "github.com/metacubex/mihomo/transport/shadowsocks/shadowaead" ) +// defaultMaxUsesPerConn caps how many CONNECT sessions a single reuse-mode +// TCP connection may serve before we close it instead of returning to the +// pool. Surge's snell-server (v5.0.1 observed) closes a reuse-mode TCP +// after the second session, so anything beyond the cap is a soon-to-be +// dead socket from the client's perspective. +const defaultMaxUsesPerConn = 2 + +// drainReuseTimeout bounds how long PoolConn.Close waits for the server's +// zero-chunk half-close before declaring the conn unreusable. The server +// emits its zero chunk shortly after seeing ours, so one RTT plus slack is +// plenty; if the upstream is still streaming data we'd rather drop the +// conn than block the caller. +const drainReuseTimeout = 500 * time.Millisecond + type Pool struct { - pool *pool.Pool[*Snell] + pool *pool.Pool[*pooledEntry] + maxUsesPerConn int +} + +// pooledEntry is the internal pool element. It tracks how many CONNECT +// sessions the wrapped TCP has served so far so the pool can evict it +// before the server-side use cap kicks in. +type pooledEntry struct { + conn *Snell + uses int } func (p *Pool) Get() (net.Conn, error) { @@ -20,30 +44,26 @@ func (p *Pool) Get() (net.Conn, error) { } func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) { - elm, err := p.pool.GetContext(ctx) + entry, err := p.pool.GetContext(ctx) if err != nil { return nil, err } - - return &PoolConn{Snell: elm, pool: p}, nil + entry.uses++ + return &PoolConn{Snell: entry.conn, pool: p, uses: entry.uses}, nil } -func (p *Pool) Put(conn *Snell) { - if err := HalfClose(conn); err != nil { +func (p *Pool) put(conn *Snell, uses int) { + if p.maxUsesPerConn > 0 && uses >= p.maxUsesPerConn { _ = conn.Close() return } - - p.put(conn) -} - -func (p *Pool) put(conn *Snell) { - p.pool.Put(conn) + p.pool.Put(&pooledEntry{conn: conn, uses: uses}) } type PoolConn struct { *Snell pool *Pool + uses int closeWriteOnce sync.Once closeWriteErr error closeOnce sync.Once @@ -77,26 +97,61 @@ func (pc *PoolConn) Close() error { return } - // mihomo use SetReadDeadline to break bidirectional copy between client and server. - // reset it before reuse connection to avoid io timeout error. + // If the relay terminated because the local side closed first (typical + // for short HTTP/1 responses), the server may still have data queued + // in the TCP receive buffer — most importantly its own zero-chunk + // half-close. Reusing this conn without draining would surface that + // stale data on the next session's first read. Drain until we see the + // server's zero chunk or hit drainReuseTimeout; only then put back. + if !pc.drainPendingForReuse() { + _ = pc.Snell.Close() + return + } + + // mihomo uses SetReadDeadline to break bidirectional copy between + // client and server; reset it before reuse to avoid io timeout error. _ = pc.Snell.Conn.SetReadDeadline(time.Time{}) pc.Snell.reply = false - pc.pool.put(pc.Snell) + pc.pool.put(pc.Snell, pc.uses) }) return pc.closeErr } +// drainPendingForReuse reads remaining frames from the snell stream until +// the server's zero-chunk half-close is observed, with a short deadline. +// Returns true if the conn is in a clean state suitable for pool reuse. +func (pc *PoolConn) drainPendingForReuse() bool { + if err := pc.Snell.Conn.SetReadDeadline(time.Now().Add(drainReuseTimeout)); err != nil { + return false + } + scratch := make([]byte, 4096) + for { + _, err := pc.Snell.Read(scratch) + if errors.Is(err, shadowaead.ErrZeroChunk) { + return true + } + if err != nil { + return false + } + } +} + func NewPool(factory func(context.Context) (*Snell, error)) *Pool { - p := pool.New[*Snell]( - func(ctx context.Context) (*Snell, error) { - return factory(ctx) + p := &Pool{maxUsesPerConn: defaultMaxUsesPerConn} + p.pool = pool.New[*pooledEntry]( + func(ctx context.Context) (*pooledEntry, error) { + c, err := factory(ctx) + if err != nil { + return nil, err + } + return &pooledEntry{conn: c}, nil }, - pool.WithAge[*Snell](15000), - pool.WithSize[*Snell](10), - pool.WithEvict[*Snell](func(item *Snell) { - _ = item.Close() + pool.WithAge[*pooledEntry](15000), + pool.WithSize[*pooledEntry](10), + pool.WithEvict[*pooledEntry](func(item *pooledEntry) { + _ = item.conn.Close() }), ) - return &Pool{pool: p} + return p } diff --git a/transport/snell/pool_test.go b/transport/snell/pool_test.go index 1fa5983ac9..564596333f 100644 --- a/transport/snell/pool_test.go +++ b/transport/snell/pool_test.go @@ -11,13 +11,13 @@ import ( "github.com/metacubex/mihomo/transport/shadowsocks/shadowaead" ) -func TestPoolConnCloseIsIdempotent(t *testing.T) { - rawConn := &recordingConn{} - pooledConn := &Snell{Conn: rawConn} +func TestPoolConnCloseDrainsAndReturnsToPool(t *testing.T) { + rawConn := &drainConn{} + pooledConn := &Snell{Conn: rawConn, reply: true} pool := NewPool(func(context.Context) (*Snell, error) { return nil, errors.New("factory should not be called") }) - conn := &PoolConn{Snell: pooledConn, pool: pool} + conn := &PoolConn{Snell: pooledConn, pool: pool, uses: 1} if err := conn.Close(); err != nil { t.Fatal(err) @@ -29,24 +29,120 @@ func TestPoolConnCloseIsIdempotent(t *testing.T) { if rawConn.writes != 1 { t.Fatalf("close should send one half-close record, got %d", rawConn.writes) } + if !rawConn.readDeadlineCleared { + t.Fatal("close should clear read deadline before returning to pool") + } + if rawConn.closed { + t.Fatal("close should not close the underlying conn when drain succeeds") + } got, err := pool.pool.Get() if err != nil { t.Fatal(err) } - if got != pooledConn { + if got.conn != pooledConn { t.Fatal("pooled connection mismatch") } + if got.uses != 1 { + t.Fatalf("uses count not preserved: got %d, want 1", got.uses) + } +} + +func TestPoolConnCloseDiscardsWhenDrainFails(t *testing.T) { + rawConn := &recordingConn{} // Read returns io.EOF, not ErrZeroChunk + pooledConn := &Snell{Conn: rawConn, reply: true} + factoryConn := &Snell{Conn: &recordingConn{}} + pool := NewPool(func(context.Context) (*Snell, error) { + return factoryConn, nil + }) + conn := &PoolConn{Snell: pooledConn, pool: pool, uses: 1} + + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if !rawConn.closed { + t.Fatal("close should close the underlying conn when drain fails") + } + + got, err := pool.pool.Get() + if err != nil { + t.Fatal(err) + } + if got.conn != factoryConn { + t.Fatal("drained-failed conn should not be returned to the pool") + } +} + +func TestPoolConnCloseDiscardsAtMaxUsesPerConn(t *testing.T) { + rawConn := &drainConn{} + pooledConn := &Snell{Conn: rawConn, reply: true} + factoryConn := &Snell{Conn: &recordingConn{}} + pool := NewPool(func(context.Context) (*Snell, error) { + return factoryConn, nil + }) + // uses == defaultMaxUsesPerConn — this Close must not return to the pool. + conn := &PoolConn{Snell: pooledConn, pool: pool, uses: defaultMaxUsesPerConn} + + if err := conn.Close(); err != nil { + t.Fatal(err) + } + if !rawConn.closed { + t.Fatal("close should close the underlying conn at the per-conn use cap") + } + + got, err := pool.pool.Get() + if err != nil { + t.Fatal(err) + } + if got.conn != factoryConn { + t.Fatal("capped conn should not be returned to the pool") + } +} + +func TestPoolGetContextIncrementsUses(t *testing.T) { + factoryCalls := 0 + factoryConn := &Snell{Conn: &drainConn{}, reply: true} + pool := NewPool(func(context.Context) (*Snell, error) { + factoryCalls++ + return factoryConn, nil + }) + + first, err := pool.Get() + if err != nil { + t.Fatal(err) + } + pcFirst := first.(*PoolConn) + if pcFirst.uses != 1 { + t.Fatalf("first checkout uses: got %d, want 1", pcFirst.uses) + } + if err := pcFirst.Close(); err != nil { + t.Fatal(err) + } + + second, err := pool.Get() + if err != nil { + t.Fatal(err) + } + pcSecond := second.(*PoolConn) + if pcSecond.uses != 2 { + t.Fatalf("second checkout uses: got %d, want 2", pcSecond.uses) + } + if pcSecond.Snell != factoryConn { + t.Fatal("second checkout should reuse the same TCP conn") + } + if factoryCalls != 1 { + t.Fatalf("factory should be called once across both checkouts, got %d", factoryCalls) + } } func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) { - rawConn := &recordingConn{} + rawConn := &drainConn{} pooledConn := &Snell{Conn: rawConn, reply: true} factoryConn := &Snell{Conn: &recordingConn{}} pool := NewPool(func(context.Context) (*Snell, error) { return factoryConn, nil }) - conn := &PoolConn{Snell: pooledConn, pool: pool} + conn := &PoolConn{Snell: pooledConn, pool: pool, uses: 1} if err := conn.CloseWrite(); err != nil { t.Fatal(err) @@ -56,7 +152,7 @@ func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) { if err != nil { t.Fatal(err) } - if got != factoryConn { + if got.conn != factoryConn { t.Fatal("CloseWrite should not put the active connection back into the pool") } if rawConn.writes != 1 { @@ -73,7 +169,7 @@ func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) { if err != nil { t.Fatal(err) } - if got != pooledConn { + if got.conn != pooledConn { t.Fatal("Close should return the connection to the pool after CloseWrite") } if rawConn.writes != 1 { @@ -145,6 +241,52 @@ func (a recordingAddr) String() string { return string(a) } +// drainConn lets Close drain to completion: Read immediately surfaces the +// server's zero-chunk half-close, and SetReadDeadline records whether the +// caller cleared the drain deadline on the way out. +type drainConn struct { + writes int + closed bool + readDeadlineCleared bool +} + +func (c *drainConn) Read([]byte) (int, error) { + return 0, shadowaead.ErrZeroChunk +} + +func (c *drainConn) Write(b []byte) (int, error) { + c.writes++ + return len(b), nil +} + +func (c *drainConn) Close() error { + c.closed = true + return nil +} + +func (*drainConn) LocalAddr() net.Addr { + return recordingAddr("local") +} + +func (*drainConn) RemoteAddr() net.Addr { + return recordingAddr("remote") +} + +func (*drainConn) SetDeadline(time.Time) error { + return nil +} + +func (c *drainConn) SetReadDeadline(t time.Time) error { + if t.IsZero() { + c.readDeadlineCleared = true + } + return nil +} + +func (*drainConn) SetWriteDeadline(time.Time) error { + return nil +} + type zeroChunkConn struct{} func (zeroChunkConn) Read([]byte) (int, error) {