Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import (
"sync"
"time"

"github.com/shadowsocks/go-shadowsocks2/socks"
"golang.getoutline.org/sdk/transport"
"golang.getoutline.org/sdk/transport/shadowsocks"
"github.com/shadowsocks/go-shadowsocks2/socks"

"golang.getoutline.org/tunnel-server/internal/slicepool"
onet "golang.getoutline.org/tunnel-server/net"
Expand Down Expand Up @@ -295,36 +295,43 @@ func PacketServe(clientConn net.PacketConn, assocHandle AssociationHandleFunc, m
pkt := &packet{payload: buffer[:n], done: lazySlice.Release}

// TODO(#19): Include server address in the NAT key as well.
assoc := nm.Get(clientAddr.String())
clientAddrKey := clientAddr.String()
assoc := nm.Get(clientAddrKey)
if assoc == nil {
assoc = &association{
pc: clientConn,
clientAddr: clientAddr,
readCh: make(chan *packet, 5),
doneCh: make(chan struct{}),
}
if err != nil {
slog.Error("Failed to handle association", slog.Any("err", err))
return
}

var existing bool
assoc, existing = nm.Add(clientAddr.String(), assoc)
assoc, existing = nm.Add(clientAddrKey, assoc)
if !existing {
metrics.AddNATEntry()
go func() {
assocHandle(ctx, assoc)
metrics.RemoveNATEntry()
_ = assoc.Close()
nm.DelIfMatches(clientAddrKey, assoc)
metrics.RemoveNATEntry()
}()
}
}
select {
case <-assoc.doneCh:
nm.Del(clientAddr.String())
case assoc.readCh <- pkt:
nm.DelIfMatches(clientAddrKey, assoc)
pkt.done()
default:
queued, closed := assoc.enqueue(pkt)
if queued {
return
}
if closed {
nm.DelIfMatches(clientAddrKey, assoc)
pkt.done()
return
}
slog.Debug("Dropping packet due to full read queue")
pkt.done()
// TODO: Add a metric to track number of dropped packets.
Comment thread
greptile-apps[bot] marked this conversation as resolved.
}
}()
Expand All @@ -350,6 +357,7 @@ type association struct {
readCh chan *packet
doneCh chan struct{}
closeOnce sync.Once
mu sync.Mutex
}

var _ net.Conn = (*association)(nil)
Expand All @@ -372,11 +380,41 @@ func (a *association) Write(b []byte) (n int, err error) {
return a.pc.WriteTo(b, a.clientAddr)
}

func (a *association) enqueue(pkt *packet) (queued bool, closed bool) {
a.mu.Lock()
defer a.mu.Unlock()

select {
case <-a.doneCh:
return false, true
default:
}

select {
case a.readCh <- pkt:
return true, false
default:
return false, false
}
}

func (a *association) Close() error {
a.closeOnce.Do(func() {
if a.doneCh != nil {
close(a.doneCh)
}
a.mu.Lock()
defer a.mu.Unlock()
for {
select {
case pkt := <-a.readCh:
if pkt != nil {
pkt.done()
}
default:
return
}
}
})
return nil
}
Expand Down Expand Up @@ -507,6 +545,18 @@ func (m *natmap) Del(clientAddr string) {
}
}

// DelIfMatches deletes the entry for clientAddr only if the stored
// association is the same object as expected. This prevents a finishing
// goroutine from evicting a newer association that reused the same key.
func (m *natmap) DelIfMatches(clientAddr string, expected *association) {
m.Lock()
defer m.Unlock()

if m.associations[clientAddr] == expected {
delete(m.associations, clientAddr)
}
}

// Add adds a UDP NAT entry to the natmap and returns it. If it already existed,
// in the natmap, the existing entry is returned instead.
func (m *natmap) Add(clientAddr string, assoc *association) (*association, bool) {
Expand Down
85 changes: 83 additions & 2 deletions service/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,23 @@ type udpReport struct {

// Stub metrics implementation for testing NAT behaviors.
type natTestMetrics struct {
natEntriesAdded int
natEntriesAdded int
natEntriesRemoved int
mu sync.Mutex
}

var _ NATMetrics = (*natTestMetrics)(nil)

func (m *natTestMetrics) AddNATEntry() {
m.mu.Lock()
defer m.mu.Unlock()
m.natEntriesAdded++
}
func (m *natTestMetrics) RemoveNATEntry() {}
func (m *natTestMetrics) RemoveNATEntry() {
m.mu.Lock()
defer m.mu.Unlock()
m.natEntriesRemoved++
}

type fakeUDPAssociationMetrics struct {
accessKey string
Expand Down Expand Up @@ -221,6 +229,7 @@ func TestAssociationCloseWhileReading(t *testing.T) {
pc: makePacketConn(),
clientAddr: &clientAddr,
readCh: make(chan *packet),
doneCh: make(chan struct{}),
}
go func() {
buf := make([]byte, 1024)
Expand All @@ -232,6 +241,78 @@ func TestAssociationCloseWhileReading(t *testing.T) {
assert.NoError(t, err, "Close should not panic or return an error")
}

func TestAssociationCloseReleasesQueuedPackets(t *testing.T) {
assoc := &association{
pc: makePacketConn(),
clientAddr: &clientAddr,
readCh: make(chan *packet, 2),
doneCh: make(chan struct{}),
}
released := 0
assoc.readCh <- &packet{payload: []byte{1}, done: func() { released++ }}
assoc.readCh <- &packet{payload: []byte{2}, done: func() { released++ }}

err := assoc.Close()

require.NoError(t, err)
assert.Equal(t, 2, released, "Close should release queued packets")
}

func TestAssociationEnqueueAfterClose(t *testing.T) {
assoc := &association{
pc: makePacketConn(),
clientAddr: &clientAddr,
readCh: make(chan *packet, 1),
doneCh: make(chan struct{}),
}
require.NoError(t, assoc.Close())

queued, closed := assoc.enqueue(&packet{payload: []byte{1}, done: func() {}})
require.False(t, queued)
require.True(t, closed)
}

func TestPacketServeRemovesClosedAssociationsFromNAT(t *testing.T) {
clientConn := makePacketConn()
metrics := &natTestMetrics{}
handled := make(chan []byte, 2)
done := make(chan struct{})

go func() {
PacketServe(clientConn, func(ctx context.Context, conn net.Conn) {
buf := make([]byte, 16)
n, err := conn.Read(buf)
if err != nil {
t.Errorf("Read failed: %v", err)
return
}
handled <- append([]byte(nil), buf[:n]...)
}, metrics)
close(done)
}()

clientConn.recv <- fakePacket{addr: &clientAddr, payload: []byte{1}}
require.Equal(t, []byte{1}, <-handled)

require.Eventually(t, func() bool {
metrics.mu.Lock()
defer metrics.mu.Unlock()
return metrics.natEntriesAdded == 1 && metrics.natEntriesRemoved == 1
}, time.Second, 10*time.Millisecond)

clientConn.recv <- fakePacket{addr: &clientAddr, payload: []byte{2}}
require.Equal(t, []byte{2}, <-handled)

require.Eventually(t, func() bool {
metrics.mu.Lock()
defer metrics.mu.Unlock()
return metrics.natEntriesAdded == 2 && metrics.natEntriesRemoved == 2
}, time.Second, 10*time.Millisecond)

require.NoError(t, clientConn.Close())
<-done
}

func TestAssociationHandler_Handle_IPFilter(t *testing.T) {
t.Run("RequirePublicIP blocks localhost", func(t *testing.T) {
handler, sendPayload, targetConn := startTestHandler()
Expand Down
Loading