diff --git a/service/udp.go b/service/udp.go index e9f88c48..f3d7bb31 100644 --- a/service/udp.go +++ b/service/udp.go @@ -26,9 +26,9 @@ import ( "sync" "time" + "github.com/shadowsocks/go-shadowsocks2/socks" "golang.getoutline.org/sdk/transport" "golang.getoutline.org/sdk/transport/shadowsocks" - "github.com/shadowsocks/go-shadowsocks2/socks" "golang.getoutline.org/tunnel-server/internal/slicepool" onet "golang.getoutline.org/tunnel-server/net" @@ -295,7 +295,8 @@ func PacketServe(clientConn net.PacketConn, assocHandle AssociationHandleFunc, m pkt := &packet{payload: buffer[:n], done: lazySlice.Release} // TODO(#19): Include server address in the NAT key as well. - assoc := nm.Get(clientAddr.String()) + clientAddrKey := clientAddr.String() + assoc := nm.Get(clientAddrKey) if assoc == nil { assoc = &association{ pc: clientConn, @@ -303,28 +304,34 @@ func PacketServe(clientConn net.PacketConn, assocHandle AssociationHandleFunc, m readCh: make(chan *packet, 5), doneCh: make(chan struct{}), } - if err != nil { - slog.Error("Failed to handle association", slog.Any("err", err)) - return - } - var existing bool - assoc, existing = nm.Add(clientAddr.String(), assoc) + assoc, existing = nm.Add(clientAddrKey, assoc) if !existing { metrics.AddNATEntry() go func() { assocHandle(ctx, assoc) - metrics.RemoveNATEntry() _ = assoc.Close() + nm.DelIfMatches(clientAddrKey, assoc) + metrics.RemoveNATEntry() }() } } select { case <-assoc.doneCh: - nm.Del(clientAddr.String()) - case assoc.readCh <- pkt: + nm.DelIfMatches(clientAddrKey, assoc) + pkt.done() default: + queued, closed := assoc.enqueue(pkt) + if queued { + return + } + if closed { + nm.DelIfMatches(clientAddrKey, assoc) + pkt.done() + return + } slog.Debug("Dropping packet due to full read queue") + pkt.done() // TODO: Add a metric to track number of dropped packets. } }() @@ -350,6 +357,7 @@ type association struct { readCh chan *packet doneCh chan struct{} closeOnce sync.Once + mu sync.Mutex } var _ net.Conn = (*association)(nil) @@ -372,11 +380,41 @@ func (a *association) Write(b []byte) (n int, err error) { return a.pc.WriteTo(b, a.clientAddr) } +func (a *association) enqueue(pkt *packet) (queued bool, closed bool) { + a.mu.Lock() + defer a.mu.Unlock() + + select { + case <-a.doneCh: + return false, true + default: + } + + select { + case a.readCh <- pkt: + return true, false + default: + return false, false + } +} + func (a *association) Close() error { a.closeOnce.Do(func() { if a.doneCh != nil { close(a.doneCh) } + a.mu.Lock() + defer a.mu.Unlock() + for { + select { + case pkt := <-a.readCh: + if pkt != nil { + pkt.done() + } + default: + return + } + } }) return nil } @@ -507,6 +545,18 @@ func (m *natmap) Del(clientAddr string) { } } +// DelIfMatches deletes the entry for clientAddr only if the stored +// association is the same object as expected. This prevents a finishing +// goroutine from evicting a newer association that reused the same key. +func (m *natmap) DelIfMatches(clientAddr string, expected *association) { + m.Lock() + defer m.Unlock() + + if m.associations[clientAddr] == expected { + delete(m.associations, clientAddr) + } +} + // Add adds a UDP NAT entry to the natmap and returns it. If it already existed, // in the natmap, the existing entry is returned instead. func (m *natmap) Add(clientAddr string, assoc *association) (*association, bool) { diff --git a/service/udp_test.go b/service/udp_test.go index 2d4eb3fe..99ac4253 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -150,15 +150,23 @@ type udpReport struct { // Stub metrics implementation for testing NAT behaviors. type natTestMetrics struct { - natEntriesAdded int + natEntriesAdded int + natEntriesRemoved int + mu sync.Mutex } var _ NATMetrics = (*natTestMetrics)(nil) func (m *natTestMetrics) AddNATEntry() { + m.mu.Lock() + defer m.mu.Unlock() m.natEntriesAdded++ } -func (m *natTestMetrics) RemoveNATEntry() {} +func (m *natTestMetrics) RemoveNATEntry() { + m.mu.Lock() + defer m.mu.Unlock() + m.natEntriesRemoved++ +} type fakeUDPAssociationMetrics struct { accessKey string @@ -221,6 +229,7 @@ func TestAssociationCloseWhileReading(t *testing.T) { pc: makePacketConn(), clientAddr: &clientAddr, readCh: make(chan *packet), + doneCh: make(chan struct{}), } go func() { buf := make([]byte, 1024) @@ -232,6 +241,78 @@ func TestAssociationCloseWhileReading(t *testing.T) { assert.NoError(t, err, "Close should not panic or return an error") } +func TestAssociationCloseReleasesQueuedPackets(t *testing.T) { + assoc := &association{ + pc: makePacketConn(), + clientAddr: &clientAddr, + readCh: make(chan *packet, 2), + doneCh: make(chan struct{}), + } + released := 0 + assoc.readCh <- &packet{payload: []byte{1}, done: func() { released++ }} + assoc.readCh <- &packet{payload: []byte{2}, done: func() { released++ }} + + err := assoc.Close() + + require.NoError(t, err) + assert.Equal(t, 2, released, "Close should release queued packets") +} + +func TestAssociationEnqueueAfterClose(t *testing.T) { + assoc := &association{ + pc: makePacketConn(), + clientAddr: &clientAddr, + readCh: make(chan *packet, 1), + doneCh: make(chan struct{}), + } + require.NoError(t, assoc.Close()) + + queued, closed := assoc.enqueue(&packet{payload: []byte{1}, done: func() {}}) + require.False(t, queued) + require.True(t, closed) +} + +func TestPacketServeRemovesClosedAssociationsFromNAT(t *testing.T) { + clientConn := makePacketConn() + metrics := &natTestMetrics{} + handled := make(chan []byte, 2) + done := make(chan struct{}) + + go func() { + PacketServe(clientConn, func(ctx context.Context, conn net.Conn) { + buf := make([]byte, 16) + n, err := conn.Read(buf) + if err != nil { + t.Errorf("Read failed: %v", err) + return + } + handled <- append([]byte(nil), buf[:n]...) + }, metrics) + close(done) + }() + + clientConn.recv <- fakePacket{addr: &clientAddr, payload: []byte{1}} + require.Equal(t, []byte{1}, <-handled) + + require.Eventually(t, func() bool { + metrics.mu.Lock() + defer metrics.mu.Unlock() + return metrics.natEntriesAdded == 1 && metrics.natEntriesRemoved == 1 + }, time.Second, 10*time.Millisecond) + + clientConn.recv <- fakePacket{addr: &clientAddr, payload: []byte{2}} + require.Equal(t, []byte{2}, <-handled) + + require.Eventually(t, func() bool { + metrics.mu.Lock() + defer metrics.mu.Unlock() + return metrics.natEntriesAdded == 2 && metrics.natEntriesRemoved == 2 + }, time.Second, 10*time.Millisecond) + + require.NoError(t, clientConn.Close()) + <-done +} + func TestAssociationHandler_Handle_IPFilter(t *testing.T) { t.Run("RequirePublicIP blocks localhost", func(t *testing.T) { handler, sendPayload, targetConn := startTestHandler()