Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions adapter/outbound/snell.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
_ = c.Close()
return nil, err
}
if pc, ok := c.(*snell.PoolConn); ok {
pc.MarkReusable()
}
return NewConn(c, s), err
}

Expand Down
31 changes: 30 additions & 1 deletion transport/snell/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"time"

"github.com/metacubex/mihomo/component/pool"
Expand All @@ -15,6 +16,12 @@ type Pool struct {
pool *pool.Pool[*Snell]
}

const (
poolConnNew int32 = iota
poolConnReusable
poolConnClosedBeforeReuse
)

func (p *Pool) Get() (net.Conn, error) {
return p.GetContext(context.Background())
}
Expand Down Expand Up @@ -48,6 +55,8 @@ type PoolConn struct {
closeWriteErr error
closeOnce sync.Once
closeErr error
requestStarted atomic.Bool
reusableState atomic.Int32
}

func (pc *PoolConn) Read(b []byte) (int, error) {
Expand All @@ -59,18 +68,38 @@ func (pc *PoolConn) Read(b []byte) (int, error) {
}

func (pc *PoolConn) Write(b []byte) (int, error) {
return pc.Snell.Write(b)
n, err := pc.Snell.Write(b)
if err == nil && n == len(b) && len(b) > 0 {
pc.requestStarted.Store(true)
}
return n, err
}

// MarkReusable allows Close to return this request to the pool after DialContext succeeds.
func (pc *PoolConn) MarkReusable() {
if pc.requestStarted.Load() {
pc.reusableState.CompareAndSwap(poolConnNew, poolConnReusable)
}
}

func (pc *PoolConn) CloseWrite() error {
pc.closeWriteOnce.Do(func() {
if pc.reusableState.Load() != poolConnReusable {
pc.reusableState.CompareAndSwap(poolConnNew, poolConnClosedBeforeReuse)
pc.closeWriteErr = pc.Snell.Close()
return
}
pc.closeWriteErr = writeZeroChunk(pc.Snell)
})
return pc.closeWriteErr
}

func (pc *PoolConn) Close() error {
pc.closeOnce.Do(func() {
if pc.reusableState.Load() != poolConnReusable {
pc.closeErr = pc.CloseWrite()
return
}
if err := pc.CloseWrite(); err != nil {
pc.closeErr = err
_ = pc.Snell.Close()
Expand Down
118 changes: 112 additions & 6 deletions transport/snell/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@ func TestPoolConnCloseIsIdempotent(t *testing.T) {
})
conn := &PoolConn{Snell: pooledConn, pool: pool}

if _, err := conn.Write([]byte{Version, CommandConnectV2, 0}); err != nil {
t.Fatal(err)
}
conn.MarkReusable()
if err := conn.Close(); err != nil {
t.Fatal(err)
}
if err := conn.Close(); err != nil {
t.Fatal(err)
}

if rawConn.writes != 1 {
t.Fatalf("close should send one half-close record, got %d", rawConn.writes)
if rawConn.writes != 2 {
t.Fatalf("close should send the request and one half-close record, got %d writes", rawConn.writes)
}

got, err := pool.pool.Get()
Expand All @@ -39,6 +43,104 @@ func TestPoolConnCloseIsIdempotent(t *testing.T) {
}
}

func TestPoolConnCloseBeforeRequestClosesRawConnection(t *testing.T) {
rawConn := &recordingConn{}
pooledConn := &Snell{Conn: rawConn}
factoryConn := &Snell{Conn: &recordingConn{}}
pool := NewPool(func(context.Context) (*Snell, error) {
return factoryConn, nil
})
conn := &PoolConn{Snell: pooledConn, pool: pool}

if err := conn.Close(); err != nil {
t.Fatal(err)
}
if err := conn.Close(); err != nil {
t.Fatal(err)
}

if rawConn.writes != 0 {
t.Fatalf("close before request should not send half-close record, got %d writes", rawConn.writes)
}
if !rawConn.closed {
t.Fatal("close before request should close the raw connection")
}

got, err := pool.pool.Get()
if err != nil {
t.Fatal(err)
}
if got != factoryConn {
t.Fatal("unstarted connection should not be returned to the pool")
}
}

func TestPoolConnCloseAfterRequestBeforeReusableClosesRawConnection(t *testing.T) {
rawConn := &recordingConn{}
pooledConn := &Snell{Conn: rawConn}
factoryConn := &Snell{Conn: &recordingConn{}}
pool := NewPool(func(context.Context) (*Snell, error) {
return factoryConn, nil
})
conn := &PoolConn{Snell: pooledConn, pool: pool}

if _, err := conn.Write([]byte{Version, CommandConnectV2, 0}); err != nil {
t.Fatal(err)
}
if err := conn.Close(); err != nil {
t.Fatal(err)
}

if rawConn.writes != 1 {
t.Fatalf("close before reusable should only send the request, got %d writes", rawConn.writes)
}
if !rawConn.closed {
t.Fatal("close before reusable should close the raw connection")
}

got, err := pool.pool.Get()
if err != nil {
t.Fatal(err)
}
if got != factoryConn {
t.Fatal("connection closed before reusable should not be returned to the pool")
}
}

func TestPoolConnCloseWriteBeforeRequestClosesRawConnection(t *testing.T) {
rawConn := &recordingConn{}
pooledConn := &Snell{Conn: rawConn}
factoryConn := &Snell{Conn: &recordingConn{}}
pool := NewPool(func(context.Context) (*Snell, error) {
return factoryConn, nil
})
conn := &PoolConn{Snell: pooledConn, pool: pool}

if err := conn.CloseWrite(); err != nil {
t.Fatal(err)
}
if rawConn.writes != 0 {
t.Fatalf("CloseWrite before request should not send half-close record, got %d writes", rawConn.writes)
}
if !rawConn.closed {
t.Fatal("CloseWrite before request should close the raw connection")
}

// Simulate a concurrent write reporting success after the unstarted close path won.
conn.requestStarted.Store(true)
conn.MarkReusable()
if err := conn.Close(); err != nil {
t.Fatal(err)
}
got, err := pool.pool.Get()
if err != nil {
t.Fatal(err)
}
if got != factoryConn {
t.Fatal("unstarted connection should not be returned to the pool")
}
}

func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) {
rawConn := &recordingConn{}
pooledConn := &Snell{Conn: rawConn, reply: true}
Expand All @@ -48,6 +150,10 @@ func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) {
})
conn := &PoolConn{Snell: pooledConn, pool: pool}

if _, err := conn.Write([]byte{Version, CommandConnectV2, 0}); err != nil {
t.Fatal(err)
}
conn.MarkReusable()
if err := conn.CloseWrite(); err != nil {
t.Fatal(err)
}
Expand All @@ -59,8 +165,8 @@ func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) {
if got != factoryConn {
t.Fatal("CloseWrite should not put the active connection back into the pool")
}
if rawConn.writes != 1 {
t.Fatalf("CloseWrite should send one half-close record, got %d", rawConn.writes)
if rawConn.writes != 2 {
t.Fatalf("CloseWrite should send the request and one half-close record, got %d writes", rawConn.writes)
}
if !pooledConn.reply {
t.Fatal("CloseWrite should not reset reply while the read side may still be active")
Expand All @@ -76,8 +182,8 @@ func TestPoolConnCloseWriteDoesNotReturnConnectionToPool(t *testing.T) {
if got != pooledConn {
t.Fatal("Close should return the connection to the pool after CloseWrite")
}
if rawConn.writes != 1 {
t.Fatalf("Close after CloseWrite should not send another half-close record, got %d", rawConn.writes)
if rawConn.writes != 2 {
t.Fatalf("Close after CloseWrite should not send another half-close record, got %d writes", rawConn.writes)
}
if pooledConn.reply {
t.Fatal("Close should reset reply before returning the connection to the pool")
Expand Down
Loading