Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions app/dispatcher/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,43 @@ import (

var errSniffingTimeout = errors.New("timeout on sniffing")

type rateLimitedWriter struct {
buf.Writer
limiter *protocol.RateLimiter
}

func (w *rateLimitedWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if w.limiter != nil && !mb.IsEmpty() {
w.limiter.Wait(int(mb.Len()))
}
return w.Writer.WriteMultiBuffer(mb)
}

type rateLimitedReader struct {
reader buf.Reader
limiter *protocol.RateLimiter
}

func (r *rateLimitedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
mb, err := r.reader.ReadMultiBuffer()
if r.limiter != nil && !mb.IsEmpty() {
r.limiter.Wait(int(mb.Len()))
}
return mb, err
}

func (r *rateLimitedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
timeoutReader, ok := r.reader.(buf.TimeoutReader)
if !ok {
return nil, buf.ErrNotTimeoutReader
}
mb, err := timeoutReader.ReadMultiBufferTimeout(timeout)
if r.limiter != nil && !mb.IsEmpty() {
r.limiter.Wait(int(mb.Len()))
}
return mb, err
}

type cachedReader struct {
sync.Mutex
reader buf.TimeoutReader // *pipe.Reader or *buf.TimeoutWrapperReader
Expand Down Expand Up @@ -159,6 +196,19 @@ func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *tran
}

if user != nil && len(user.Email) > 0 {
if limiter := user.GetUplinkLimiter(); limiter != nil {
inboundLink.Writer = &rateLimitedWriter{
Writer: inboundLink.Writer,
limiter: limiter,
}
}
if limiter := user.GetDownlinkLimiter(); limiter != nil {
outboundLink.Writer = &rateLimitedWriter{
Writer: outboundLink.Writer,
limiter: limiter,
}
}

p := d.policy.ForLevel(user.Level)
if p.Stats.UserUplink {
name := "user>>>" + user.Email + ">>>traffic>>>uplink"
Expand Down Expand Up @@ -194,14 +244,21 @@ func WrapLink(ctx context.Context, policyManager policy.Manager, statsManager st
user = sessionInbound.User
}

link.Reader = &buf.TimeoutWrapperReader{Reader: link.Reader}
timeoutReader := &buf.TimeoutWrapperReader{Reader: link.Reader}
link.Reader = timeoutReader

if user != nil && len(user.Email) > 0 {
p := policyManager.ForLevel(user.Level)
if p.Stats.UserUplink {
name := "user>>>" + user.Email + ">>>traffic>>>uplink"
if c, _ := stats.GetOrRegisterCounter(statsManager, name); c != nil {
link.Reader.(*buf.TimeoutWrapperReader).Counter = c
timeoutReader.Counter = c
}
}
if limiter := user.GetUplinkLimiter(); limiter != nil {
link.Reader = &rateLimitedReader{
reader: link.Reader,
limiter: limiter,
}
}
if p.Stats.UserDownlink {
Expand All @@ -213,6 +270,12 @@ func WrapLink(ctx context.Context, policyManager policy.Manager, statsManager st
}
}
}
if limiter := user.GetDownlinkLimiter(); limiter != nil {
link.Writer = &rateLimitedWriter{
Writer: link.Writer,
limiter: limiter,
}
}
if p.Stats.UserOnline {
trackOnlineIP(ctx, statsManager, user.Email, sessionInbound.Source.Address.String())
}
Expand Down
136 changes: 136 additions & 0 deletions common/protocol/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package protocol

import (
"strings"
"sync"
"time"
)

// RateLimiter is a shared token bucket for per-user byte/sec limits.
type RateLimiter struct {
mu sync.Mutex
rate float64
capacity float64
available float64
last time.Time
}

var globalLimiterRegistry = newLimiterRegistry()

type limiterRegistry struct {
mu sync.Mutex
byUser map[string]*RateLimiter
}

func newLimiterRegistry() *limiterRegistry {
return &limiterRegistry{
byUser: make(map[string]*RateLimiter),
}
}

func normalizeLimiterKey(email, direction string) string {
email = strings.TrimSpace(strings.ToLower(email))
if email == "" {
return ""
}
return direction + ":" + email
}

func (r *limiterRegistry) Get(email, direction string, rate uint64) *RateLimiter {
key := normalizeLimiterKey(email, direction)
if key == "" {
return NewRateLimiter(rate)
}

r.mu.Lock()
defer r.mu.Unlock()

limiter, found := r.byUser[key]
if !found {
limiter = NewRateLimiter(rate)
if limiter == nil {
limiter = &RateLimiter{}
}
r.byUser[key] = limiter
}
limiter.SetRate(rate)
return limiter
}

func NewRateLimiter(rate uint64) *RateLimiter {
limiter := &RateLimiter{}
limiter.SetRate(rate)
return limiter
}

func (l *RateLimiter) SetRate(rate uint64) {
if l == nil {
return
}

l.mu.Lock()
defer l.mu.Unlock()

now := time.Now()
if !l.last.IsZero() && l.rate > 0 {
elapsed := now.Sub(l.last).Seconds()
if elapsed > 0 {
l.available += elapsed * l.rate
}
}

l.rate = float64(rate)
if rate == 0 {
l.capacity = 0
l.available = 0
l.last = now
return
}

capacity := float64(rate)
if capacity < 64*1024 {
capacity = 64 * 1024
}
l.capacity = capacity
if l.available > capacity || l.last.IsZero() {
l.available = capacity
}
l.last = now
}

func (l *RateLimiter) Wait(size int) {
if l == nil || size <= 0 {
return
}

need := float64(size)
for {
l.mu.Lock()
if l.rate <= 0 {
l.mu.Unlock()
return
}
now := time.Now()
elapsed := now.Sub(l.last).Seconds()
if elapsed > 0 {
l.available += elapsed * l.rate
if l.available > l.capacity {
l.available = l.capacity
}
l.last = now
}
if l.available >= need {
l.available -= need
l.mu.Unlock()
return
}
missing := need - l.available
wait := time.Duration(missing / l.rate * float64(time.Second))
l.mu.Unlock()

if wait <= 0 {
wait = time.Millisecond
}
time.Sleep(wait)
}
}
72 changes: 72 additions & 0 deletions common/protocol/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package protocol

import "testing"

func TestLimiterRegistrySharesLimiterByEmailAndDirection(t *testing.T) {
email := "shared-limiter@example.com"

first := (&MemoryUser{
Email: email,
UplinkSpeedLimit: 1024,
}).GetUplinkLimiter()
if first == nil {
t.Fatal("expected first uplink limiter")
}
if got := uint64(first.rate); got != 1024 {
t.Fatalf("unexpected initial uplink rate: %d", got)
}

second := (&MemoryUser{
Email: email,
UplinkSpeedLimit: 4096,
}).GetUplinkLimiter()
if second == nil {
t.Fatal("expected second uplink limiter")
}
if first != second {
t.Fatal("expected limiter to be shared across user refreshes")
}
if got := uint64(first.rate); got != 4096 {
t.Fatalf("expected shared limiter rate to update, got %d", got)
}
}

func TestLimiterRegistrySeparatesDirections(t *testing.T) {
email := "directional-limiter@example.com"

uplink := (&MemoryUser{
Email: email,
UplinkSpeedLimit: 1024,
}).GetUplinkLimiter()
downlink := (&MemoryUser{
Email: email,
DownlinkSpeedLimit: 2048,
}).GetDownlinkLimiter()

if uplink == nil || downlink == nil {
t.Fatal("expected both limiters")
}
if uplink == downlink {
t.Fatal("expected uplink and downlink limiters to be separate")
}
if got := uint64(uplink.rate); got != 1024 {
t.Fatalf("unexpected uplink rate: %d", got)
}
if got := uint64(downlink.rate); got != 2048 {
t.Fatalf("unexpected downlink rate: %d", got)
}
}

func TestRateLimiterSetRateZeroDisablesLimiter(t *testing.T) {
limiter := NewRateLimiter(1024)
if limiter == nil {
t.Fatal("expected limiter")
}

limiter.SetRate(0)

if limiter.rate != 0 {
t.Fatalf("expected disabled rate, got %f", limiter.rate)
}
limiter.Wait(4096)
}
55 changes: 47 additions & 8 deletions common/protocol/user.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package protocol

import (
"sync"

"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/serial"
)
Expand Down Expand Up @@ -28,21 +30,28 @@ func (u *User) ToMemoryUser() (*MemoryUser, error) {
if err != nil {
return nil, err
}
return &MemoryUser{
Account: account,
Email: u.Email,
Level: u.Level,
}, nil
mu := &MemoryUser{
Account: account,
Email: u.Email,
Level: u.Level,
UplinkSpeedLimit: u.GetUplinkSpeedLimit(),
DownlinkSpeedLimit: u.GetDownlinkSpeedLimit(),
}
mu.uplinkLimiter = globalLimiterRegistry.Get(mu.Email, "uplink", mu.UplinkSpeedLimit)
mu.downlinkLimiter = globalLimiterRegistry.Get(mu.Email, "downlink", mu.DownlinkSpeedLimit)
return mu, nil
}

func ToProtoUser(mu *MemoryUser) *User {
if mu == nil {
return nil
}
return &User{
Account: serial.ToTypedMessage(mu.Account.ToProto()),
Email: mu.Email,
Level: mu.Level,
Account: serial.ToTypedMessage(mu.Account.ToProto()),
Email: mu.Email,
Level: mu.Level,
UplinkSpeedLimit: mu.UplinkSpeedLimit,
DownlinkSpeedLimit: mu.DownlinkSpeedLimit,
}
}

Expand All @@ -52,4 +61,34 @@ type MemoryUser struct {
Account Account
Email string
Level uint32
UplinkSpeedLimit uint64
DownlinkSpeedLimit uint64

limiterMu sync.Mutex
uplinkLimiter *RateLimiter
downlinkLimiter *RateLimiter
}

func (u *MemoryUser) GetUplinkLimiter() *RateLimiter {
if u == nil {
return nil
}
u.limiterMu.Lock()
defer u.limiterMu.Unlock()
if u.uplinkLimiter == nil {
u.uplinkLimiter = globalLimiterRegistry.Get(u.Email, "uplink", u.UplinkSpeedLimit)
}
return u.uplinkLimiter
}

func (u *MemoryUser) GetDownlinkLimiter() *RateLimiter {
if u == nil {
return nil
}
u.limiterMu.Lock()
defer u.limiterMu.Unlock()
if u.downlinkLimiter == nil {
u.downlinkLimiter = globalLimiterRegistry.Get(u.Email, "downlink", u.DownlinkSpeedLimit)
}
return u.downlinkLimiter
}
Loading