Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions adapter/outbound/snell.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,24 @@ func (s *Snell) writeHeaderContext(ctx context.Context, c net.Conn, metadata *C.
// DialContext implements C.ProxyAdapter
func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
if s.reuse {
c, err := s.pool.Get()
if err != nil {
return nil, err
}

if err = s.writeHeaderContext(ctx, c, metadata); err != nil {
_ = c.Close()
return nil, err
// A pooled conn may be stale (server-side half-closed it after the
// previous session but before our maxAge fires). Detect by retrying
// the CONNECT write once — the second pool.Get either yields another
// idle conn or falls through to the factory and dials fresh.
for attempts := 0; attempts < 2; attempts++ {
c, gerr := s.pool.GetContext(ctx)
if gerr != nil {
return nil, gerr
}
if err = s.writeHeaderContext(ctx, c, metadata); err != nil {
_ = c.Close()
continue
}
return NewConn(c, s), nil
}
return NewConn(c, s), err
// Both pool attempts yielded stale conns. Fall through to a fresh
// dial below; it bypasses the pool but still writes a reuse-capable
// header so the next call can pool the conn after one session.
}

c, err := s.dialer.DialContext(ctx, "tcp", s.addr)
Expand Down
101 changes: 78 additions & 23 deletions transport/snell/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package snell

import (
"context"
"errors"
"io"
"net"
"sync"
Expand All @@ -11,39 +12,58 @@ import (
"github.com/metacubex/mihomo/transport/shadowsocks/shadowaead"
)

// defaultMaxUsesPerConn caps how many CONNECT sessions a single reuse-mode
// TCP connection may serve before we close it instead of returning to the
// pool. Surge's snell-server (v5.0.1 observed) closes a reuse-mode TCP
// after the second session, so anything beyond the cap is a soon-to-be
// dead socket from the client's perspective.
const defaultMaxUsesPerConn = 2

// drainReuseTimeout bounds how long PoolConn.Close waits for the server's
// zero-chunk half-close before declaring the conn unreusable. The server
// emits its zero chunk shortly after seeing ours, so one RTT plus slack is
// plenty; if the upstream is still streaming data we'd rather drop the
// conn than block the caller.
const drainReuseTimeout = 500 * time.Millisecond

type Pool struct {
pool *pool.Pool[*Snell]
pool *pool.Pool[*pooledEntry]
maxUsesPerConn int
}

// pooledEntry is the internal pool element. It tracks how many CONNECT
// sessions the wrapped TCP has served so far so the pool can evict it
// before the server-side use cap kicks in.
type pooledEntry struct {
conn *Snell
uses int
}

func (p *Pool) Get() (net.Conn, error) {
return p.GetContext(context.Background())
}

func (p *Pool) GetContext(ctx context.Context) (net.Conn, error) {
elm, err := p.pool.GetContext(ctx)
entry, err := p.pool.GetContext(ctx)
if err != nil {
return nil, err
}

return &PoolConn{Snell: elm, pool: p}, nil
entry.uses++
return &PoolConn{Snell: entry.conn, pool: p, uses: entry.uses}, nil
}

func (p *Pool) Put(conn *Snell) {
if err := HalfClose(conn); err != nil {
func (p *Pool) put(conn *Snell, uses int) {
if p.maxUsesPerConn > 0 && uses >= p.maxUsesPerConn {
_ = conn.Close()
return
}

p.put(conn)
}

func (p *Pool) put(conn *Snell) {
p.pool.Put(conn)
p.pool.Put(&pooledEntry{conn: conn, uses: uses})
}

type PoolConn struct {
*Snell
pool *Pool
uses int
closeWriteOnce sync.Once
closeWriteErr error
closeOnce sync.Once
Expand Down Expand Up @@ -77,26 +97,61 @@ func (pc *PoolConn) Close() error {
return
}

// mihomo use SetReadDeadline to break bidirectional copy between client and server.
// reset it before reuse connection to avoid io timeout error.
// If the relay terminated because the local side closed first (typical
// for short HTTP/1 responses), the server may still have data queued
// in the TCP receive buffer — most importantly its own zero-chunk
// half-close. Reusing this conn without draining would surface that
// stale data on the next session's first read. Drain until we see the
// server's zero chunk or hit drainReuseTimeout; only then put back.
if !pc.drainPendingForReuse() {
_ = pc.Snell.Close()
return
}

// mihomo uses SetReadDeadline to break bidirectional copy between
// client and server; reset it before reuse to avoid io timeout error.
_ = pc.Snell.Conn.SetReadDeadline(time.Time{})
pc.Snell.reply = false
pc.pool.put(pc.Snell)
pc.pool.put(pc.Snell, pc.uses)
})
return pc.closeErr
}

// drainPendingForReuse reads remaining frames from the snell stream until
// the server's zero-chunk half-close is observed, with a short deadline.
// Returns true if the conn is in a clean state suitable for pool reuse.
func (pc *PoolConn) drainPendingForReuse() bool {
if err := pc.Snell.Conn.SetReadDeadline(time.Now().Add(drainReuseTimeout)); err != nil {
return false
}
scratch := make([]byte, 4096)
for {
_, err := pc.Snell.Read(scratch)
if errors.Is(err, shadowaead.ErrZeroChunk) {
return true
}
if err != nil {
return false
}
}
}

func NewPool(factory func(context.Context) (*Snell, error)) *Pool {
p := pool.New[*Snell](
func(ctx context.Context) (*Snell, error) {
return factory(ctx)
p := &Pool{maxUsesPerConn: defaultMaxUsesPerConn}
p.pool = pool.New[*pooledEntry](
func(ctx context.Context) (*pooledEntry, error) {
c, err := factory(ctx)
if err != nil {
return nil, err
}
return &pooledEntry{conn: c}, nil
},
pool.WithAge[*Snell](15000),
pool.WithSize[*Snell](10),
pool.WithEvict[*Snell](func(item *Snell) {
_ = item.Close()
pool.WithAge[*pooledEntry](15000),
pool.WithSize[*pooledEntry](10),
pool.WithEvict[*pooledEntry](func(item *pooledEntry) {
_ = item.conn.Close()
}),
)

return &Pool{pool: p}
return p
}
160 changes: 151 additions & 9 deletions transport/snell/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import (
"github.com/metacubex/mihomo/transport/shadowsocks/shadowaead"
)

func TestPoolConnCloseIsIdempotent(t *testing.T) {
rawConn := &recordingConn{}
pooledConn := &Snell{Conn: rawConn}
func TestPoolConnCloseDrainsAndReturnsToPool(t *testing.T) {
rawConn := &drainConn{}
pooledConn := &Snell{Conn: rawConn, reply: true}
pool := NewPool(func(context.Context) (*Snell, error) {
return nil, errors.New("factory should not be called")
})
conn := &PoolConn{Snell: pooledConn, pool: pool}
conn := &PoolConn{Snell: pooledConn, pool: pool, uses: 1}

if err := conn.Close(); err != nil {
t.Fatal(err)
Expand All @@ -29,24 +29,120 @@ func TestPoolConnCloseIsIdempotent(t *testing.T) {
if rawConn.writes != 1 {
t.Fatalf("close should send one half-close record, got %d", rawConn.writes)
}
if !rawConn.readDeadlineCleared {
t.Fatal("close should clear read deadline before returning to pool")
}
if rawConn.closed {
t.Fatal("close should not close the underlying conn when drain succeeds")
}

got, err := pool.pool.Get()
if err != nil {
t.Fatal(err)
}
if got != pooledConn {
if got.conn != pooledConn {
t.Fatal("pooled connection mismatch")
}
if got.uses != 1 {
t.Fatalf("uses count not preserved: got %d, want 1", got.uses)
}
}

func TestPoolConnCloseDiscardsWhenDrainFails(t *testing.T) {
rawConn := &recordingConn{} // Read returns io.EOF, not ErrZeroChunk
pooledConn := &Snell{Conn: rawConn, reply: true}
factoryConn := &Snell{Conn: &recordingConn{}}
pool := NewPool(func(context.Context) (*Snell, error) {
return factoryConn, nil
})
conn := &PoolConn{Snell: pooledConn, pool: pool, uses: 1}

if err := conn.Close(); err != nil {
t.Fatal(err)
}
if !rawConn.closed {
t.Fatal("close should close the underlying conn when drain fails")
}

got, err := pool.pool.Get()
if err != nil {
t.Fatal(err)
}
if got.conn != factoryConn {
t.Fatal("drained-failed conn should not be returned to the pool")
}
}

func TestPoolConnCloseDiscardsAtMaxUsesPerConn(t *testing.T) {
rawConn := &drainConn{}
pooledConn := &Snell{Conn: rawConn, reply: true}
factoryConn := &Snell{Conn: &recordingConn{}}
pool := NewPool(func(context.Context) (*Snell, error) {
return factoryConn, nil
})
// uses == defaultMaxUsesPerConn — this Close must not return to the pool.
conn := &PoolConn{Snell: pooledConn, pool: pool, uses: defaultMaxUsesPerConn}

if err := conn.Close(); err != nil {
t.Fatal(err)
}
if !rawConn.closed {
t.Fatal("close should close the underlying conn at the per-conn use cap")
}

got, err := pool.pool.Get()
if err != nil {
t.Fatal(err)
}
if got.conn != factoryConn {
t.Fatal("capped conn should not be returned to the pool")
}
}

func TestPoolGetContextIncrementsUses(t *testing.T) {
factoryCalls := 0
factoryConn := &Snell{Conn: &drainConn{}, reply: true}
pool := NewPool(func(context.Context) (*Snell, error) {
factoryCalls++
return factoryConn, nil
})

first, err := pool.Get()
if err != nil {
t.Fatal(err)
}
pcFirst := first.(*PoolConn)
if pcFirst.uses != 1 {
t.Fatalf("first checkout uses: got %d, want 1", pcFirst.uses)
}
if err := pcFirst.Close(); err != nil {
t.Fatal(err)
}

second, err := pool.Get()
if err != nil {
t.Fatal(err)
}
pcSecond := second.(*PoolConn)
if pcSecond.uses != 2 {
t.Fatalf("second checkout uses: got %d, want 2", pcSecond.uses)
}
if pcSecond.Snell != factoryConn {
t.Fatal("second checkout should reuse the same TCP conn")
}
if factoryCalls != 1 {
t.Fatalf("factory should be called once across both checkouts, got %d", factoryCalls)
}
}

func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) {
rawConn := &recordingConn{}
rawConn := &drainConn{}
pooledConn := &Snell{Conn: rawConn, reply: true}
factoryConn := &Snell{Conn: &recordingConn{}}
pool := NewPool(func(context.Context) (*Snell, error) {
return factoryConn, nil
})
conn := &PoolConn{Snell: pooledConn, pool: pool}
conn := &PoolConn{Snell: pooledConn, pool: pool, uses: 1}

if err := conn.CloseWrite(); err != nil {
t.Fatal(err)
Expand All @@ -56,7 +152,7 @@ func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if got != factoryConn {
if got.conn != factoryConn {
t.Fatal("CloseWrite should not put the active connection back into the pool")
}
if rawConn.writes != 1 {
Expand All @@ -73,7 +169,7 @@ func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if got != pooledConn {
if got.conn != pooledConn {
t.Fatal("Close should return the connection to the pool after CloseWrite")
}
if rawConn.writes != 1 {
Expand Down Expand Up @@ -145,6 +241,52 @@ func (a recordingAddr) String() string {
return string(a)
}

// drainConn lets Close drain to completion: Read immediately surfaces the
// server's zero-chunk half-close, and SetReadDeadline records whether the
// caller cleared the drain deadline on the way out.
type drainConn struct {
writes int
closed bool
readDeadlineCleared bool
}

func (c *drainConn) Read([]byte) (int, error) {
return 0, shadowaead.ErrZeroChunk
}

func (c *drainConn) Write(b []byte) (int, error) {
c.writes++
return len(b), nil
}

func (c *drainConn) Close() error {
c.closed = true
return nil
}

func (*drainConn) LocalAddr() net.Addr {
return recordingAddr("local")
}

func (*drainConn) RemoteAddr() net.Addr {
return recordingAddr("remote")
}

func (*drainConn) SetDeadline(time.Time) error {
return nil
}

func (c *drainConn) SetReadDeadline(t time.Time) error {
if t.IsZero() {
c.readDeadlineCleared = true
}
return nil
}

func (*drainConn) SetWriteDeadline(time.Time) error {
return nil
}

type zeroChunkConn struct{}

func (zeroChunkConn) Read([]byte) (int, error) {
Expand Down