From da56e92dbb6644e5e258594169b4f17abb2d70a9 Mon Sep 17 00:00:00 2001 From: Saeed Date: Fri, 1 May 2026 11:05:14 +0330 Subject: [PATCH] add speed limiter for download and upload speed in megabytes --- app/dispatcher/default.go | 67 ++++++++++++++- common/protocol/ratelimit.go | 136 ++++++++++++++++++++++++++++++ common/protocol/ratelimit_test.go | 72 ++++++++++++++++ common/protocol/user.go | 55 ++++++++++-- common/protocol/user.pb.go | 34 ++++++-- common/protocol/user.proto | 2 + go.mod | 2 +- 7 files changed, 349 insertions(+), 19 deletions(-) create mode 100644 common/protocol/ratelimit.go create mode 100644 common/protocol/ratelimit_test.go diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index f6cfd76ebf6f..41445da17921 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -26,6 +26,43 @@ import ( var errSniffingTimeout = errors.New("timeout on sniffing") +type rateLimitedWriter struct { + buf.Writer + limiter *protocol.RateLimiter +} + +func (w *rateLimitedWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { + if w.limiter != nil && !mb.IsEmpty() { + w.limiter.Wait(int(mb.Len())) + } + return w.Writer.WriteMultiBuffer(mb) +} + +type rateLimitedReader struct { + reader buf.Reader + limiter *protocol.RateLimiter +} + +func (r *rateLimitedReader) ReadMultiBuffer() (buf.MultiBuffer, error) { + mb, err := r.reader.ReadMultiBuffer() + if r.limiter != nil && !mb.IsEmpty() { + r.limiter.Wait(int(mb.Len())) + } + return mb, err +} + +func (r *rateLimitedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) { + timeoutReader, ok := r.reader.(buf.TimeoutReader) + if !ok { + return nil, buf.ErrNotTimeoutReader + } + mb, err := timeoutReader.ReadMultiBufferTimeout(timeout) + if r.limiter != nil && !mb.IsEmpty() { + r.limiter.Wait(int(mb.Len())) + } + return mb, err +} + type cachedReader struct { sync.Mutex reader buf.TimeoutReader // *pipe.Reader or *buf.TimeoutWrapperReader @@ -159,6 +196,19 @@ func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *tran } if user != nil && len(user.Email) > 0 { + if limiter := user.GetUplinkLimiter(); limiter != nil { + inboundLink.Writer = &rateLimitedWriter{ + Writer: inboundLink.Writer, + limiter: limiter, + } + } + if limiter := user.GetDownlinkLimiter(); limiter != nil { + outboundLink.Writer = &rateLimitedWriter{ + Writer: outboundLink.Writer, + limiter: limiter, + } + } + p := d.policy.ForLevel(user.Level) if p.Stats.UserUplink { name := "user>>>" + user.Email + ">>>traffic>>>uplink" @@ -194,14 +244,21 @@ func WrapLink(ctx context.Context, policyManager policy.Manager, statsManager st user = sessionInbound.User } - link.Reader = &buf.TimeoutWrapperReader{Reader: link.Reader} + timeoutReader := &buf.TimeoutWrapperReader{Reader: link.Reader} + link.Reader = timeoutReader if user != nil && len(user.Email) > 0 { p := policyManager.ForLevel(user.Level) if p.Stats.UserUplink { name := "user>>>" + user.Email + ">>>traffic>>>uplink" if c, _ := stats.GetOrRegisterCounter(statsManager, name); c != nil { - link.Reader.(*buf.TimeoutWrapperReader).Counter = c + timeoutReader.Counter = c + } + } + if limiter := user.GetUplinkLimiter(); limiter != nil { + link.Reader = &rateLimitedReader{ + reader: link.Reader, + limiter: limiter, } } if p.Stats.UserDownlink { @@ -213,6 +270,12 @@ func WrapLink(ctx context.Context, policyManager policy.Manager, statsManager st } } } + if limiter := user.GetDownlinkLimiter(); limiter != nil { + link.Writer = &rateLimitedWriter{ + Writer: link.Writer, + limiter: limiter, + } + } if p.Stats.UserOnline { trackOnlineIP(ctx, statsManager, user.Email, sessionInbound.Source.Address.String()) } diff --git a/common/protocol/ratelimit.go b/common/protocol/ratelimit.go new file mode 100644 index 000000000000..ef8b0bdca63d --- /dev/null +++ b/common/protocol/ratelimit.go @@ -0,0 +1,136 @@ +package protocol + +import ( + "strings" + "sync" + "time" +) + +// RateLimiter is a shared token bucket for per-user byte/sec limits. +type RateLimiter struct { + mu sync.Mutex + rate float64 + capacity float64 + available float64 + last time.Time +} + +var globalLimiterRegistry = newLimiterRegistry() + +type limiterRegistry struct { + mu sync.Mutex + byUser map[string]*RateLimiter +} + +func newLimiterRegistry() *limiterRegistry { + return &limiterRegistry{ + byUser: make(map[string]*RateLimiter), + } +} + +func normalizeLimiterKey(email, direction string) string { + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" { + return "" + } + return direction + ":" + email +} + +func (r *limiterRegistry) Get(email, direction string, rate uint64) *RateLimiter { + key := normalizeLimiterKey(email, direction) + if key == "" { + return NewRateLimiter(rate) + } + + r.mu.Lock() + defer r.mu.Unlock() + + limiter, found := r.byUser[key] + if !found { + limiter = NewRateLimiter(rate) + if limiter == nil { + limiter = &RateLimiter{} + } + r.byUser[key] = limiter + } + limiter.SetRate(rate) + return limiter +} + +func NewRateLimiter(rate uint64) *RateLimiter { + limiter := &RateLimiter{} + limiter.SetRate(rate) + return limiter +} + +func (l *RateLimiter) SetRate(rate uint64) { + if l == nil { + return + } + + l.mu.Lock() + defer l.mu.Unlock() + + now := time.Now() + if !l.last.IsZero() && l.rate > 0 { + elapsed := now.Sub(l.last).Seconds() + if elapsed > 0 { + l.available += elapsed * l.rate + } + } + + l.rate = float64(rate) + if rate == 0 { + l.capacity = 0 + l.available = 0 + l.last = now + return + } + + capacity := float64(rate) + if capacity < 64*1024 { + capacity = 64 * 1024 + } + l.capacity = capacity + if l.available > capacity || l.last.IsZero() { + l.available = capacity + } + l.last = now +} + +func (l *RateLimiter) Wait(size int) { + if l == nil || size <= 0 { + return + } + + need := float64(size) + for { + l.mu.Lock() + if l.rate <= 0 { + l.mu.Unlock() + return + } + now := time.Now() + elapsed := now.Sub(l.last).Seconds() + if elapsed > 0 { + l.available += elapsed * l.rate + if l.available > l.capacity { + l.available = l.capacity + } + l.last = now + } + if l.available >= need { + l.available -= need + l.mu.Unlock() + return + } + missing := need - l.available + wait := time.Duration(missing / l.rate * float64(time.Second)) + l.mu.Unlock() + + if wait <= 0 { + wait = time.Millisecond + } + time.Sleep(wait) + } +} diff --git a/common/protocol/ratelimit_test.go b/common/protocol/ratelimit_test.go new file mode 100644 index 000000000000..a601fdd29eb4 --- /dev/null +++ b/common/protocol/ratelimit_test.go @@ -0,0 +1,72 @@ +package protocol + +import "testing" + +func TestLimiterRegistrySharesLimiterByEmailAndDirection(t *testing.T) { + email := "shared-limiter@example.com" + + first := (&MemoryUser{ + Email: email, + UplinkSpeedLimit: 1024, + }).GetUplinkLimiter() + if first == nil { + t.Fatal("expected first uplink limiter") + } + if got := uint64(first.rate); got != 1024 { + t.Fatalf("unexpected initial uplink rate: %d", got) + } + + second := (&MemoryUser{ + Email: email, + UplinkSpeedLimit: 4096, + }).GetUplinkLimiter() + if second == nil { + t.Fatal("expected second uplink limiter") + } + if first != second { + t.Fatal("expected limiter to be shared across user refreshes") + } + if got := uint64(first.rate); got != 4096 { + t.Fatalf("expected shared limiter rate to update, got %d", got) + } +} + +func TestLimiterRegistrySeparatesDirections(t *testing.T) { + email := "directional-limiter@example.com" + + uplink := (&MemoryUser{ + Email: email, + UplinkSpeedLimit: 1024, + }).GetUplinkLimiter() + downlink := (&MemoryUser{ + Email: email, + DownlinkSpeedLimit: 2048, + }).GetDownlinkLimiter() + + if uplink == nil || downlink == nil { + t.Fatal("expected both limiters") + } + if uplink == downlink { + t.Fatal("expected uplink and downlink limiters to be separate") + } + if got := uint64(uplink.rate); got != 1024 { + t.Fatalf("unexpected uplink rate: %d", got) + } + if got := uint64(downlink.rate); got != 2048 { + t.Fatalf("unexpected downlink rate: %d", got) + } +} + +func TestRateLimiterSetRateZeroDisablesLimiter(t *testing.T) { + limiter := NewRateLimiter(1024) + if limiter == nil { + t.Fatal("expected limiter") + } + + limiter.SetRate(0) + + if limiter.rate != 0 { + t.Fatalf("expected disabled rate, got %f", limiter.rate) + } + limiter.Wait(4096) +} diff --git a/common/protocol/user.go b/common/protocol/user.go index 75e8e65415ba..5d9862941c28 100644 --- a/common/protocol/user.go +++ b/common/protocol/user.go @@ -1,6 +1,8 @@ package protocol import ( + "sync" + "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/serial" ) @@ -28,11 +30,16 @@ func (u *User) ToMemoryUser() (*MemoryUser, error) { if err != nil { return nil, err } - return &MemoryUser{ - Account: account, - Email: u.Email, - Level: u.Level, - }, nil + mu := &MemoryUser{ + Account: account, + Email: u.Email, + Level: u.Level, + UplinkSpeedLimit: u.GetUplinkSpeedLimit(), + DownlinkSpeedLimit: u.GetDownlinkSpeedLimit(), + } + mu.uplinkLimiter = globalLimiterRegistry.Get(mu.Email, "uplink", mu.UplinkSpeedLimit) + mu.downlinkLimiter = globalLimiterRegistry.Get(mu.Email, "downlink", mu.DownlinkSpeedLimit) + return mu, nil } func ToProtoUser(mu *MemoryUser) *User { @@ -40,9 +47,11 @@ func ToProtoUser(mu *MemoryUser) *User { return nil } return &User{ - Account: serial.ToTypedMessage(mu.Account.ToProto()), - Email: mu.Email, - Level: mu.Level, + Account: serial.ToTypedMessage(mu.Account.ToProto()), + Email: mu.Email, + Level: mu.Level, + UplinkSpeedLimit: mu.UplinkSpeedLimit, + DownlinkSpeedLimit: mu.DownlinkSpeedLimit, } } @@ -52,4 +61,34 @@ type MemoryUser struct { Account Account Email string Level uint32 + UplinkSpeedLimit uint64 + DownlinkSpeedLimit uint64 + + limiterMu sync.Mutex + uplinkLimiter *RateLimiter + downlinkLimiter *RateLimiter +} + +func (u *MemoryUser) GetUplinkLimiter() *RateLimiter { + if u == nil { + return nil + } + u.limiterMu.Lock() + defer u.limiterMu.Unlock() + if u.uplinkLimiter == nil { + u.uplinkLimiter = globalLimiterRegistry.Get(u.Email, "uplink", u.UplinkSpeedLimit) + } + return u.uplinkLimiter +} + +func (u *MemoryUser) GetDownlinkLimiter() *RateLimiter { + if u == nil { + return nil + } + u.limiterMu.Lock() + defer u.limiterMu.Unlock() + if u.downlinkLimiter == nil { + u.downlinkLimiter = globalLimiterRegistry.Get(u.Email, "downlink", u.DownlinkSpeedLimit) + } + return u.downlinkLimiter } diff --git a/common/protocol/user.pb.go b/common/protocol/user.pb.go index bdefecb9c7e0..b6220a891bea 100644 --- a/common/protocol/user.pb.go +++ b/common/protocol/user.pb.go @@ -29,9 +29,11 @@ type User struct { Email string `protobuf:"bytes,2,opt,name=email,proto3" json:"email,omitempty"` // Protocol specific account information. Must be the account proto in one of // the proxies. - Account *serial.TypedMessage `protobuf:"bytes,3,opt,name=account,proto3" json:"account,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Account *serial.TypedMessage `protobuf:"bytes,3,opt,name=account,proto3" json:"account,omitempty"` + UplinkSpeedLimit uint64 `protobuf:"varint,4,opt,name=uplink_speed_limit,json=uplinkSpeedLimit,proto3" json:"uplink_speed_limit,omitempty"` + DownlinkSpeedLimit uint64 `protobuf:"varint,5,opt,name=downlink_speed_limit,json=downlinkSpeedLimit,proto3" json:"downlink_speed_limit,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *User) Reset() { @@ -85,15 +87,31 @@ func (x *User) GetAccount() *serial.TypedMessage { return nil } +func (x *User) GetUplinkSpeedLimit() uint64 { + if x != nil { + return x.UplinkSpeedLimit + } + return 0 +} + +func (x *User) GetDownlinkSpeedLimit() uint64 { + if x != nil { + return x.DownlinkSpeedLimit + } + return 0 +} + var File_common_protocol_user_proto protoreflect.FileDescriptor const file_common_protocol_user_proto_rawDesc = "" + "\n" + - "\x1acommon/protocol/user.proto\x12\x14xray.common.protocol\x1a!common/serial/typed_message.proto\"n\n" + - "\x04User\x12\x14\n" + - "\x05level\x18\x01 \x01(\rR\x05level\x12\x14\n" + - "\x05email\x18\x02 \x01(\tR\x05email\x12:\n" + - "\aaccount\x18\x03 \x01(\v2 .xray.common.serial.TypedMessageR\aaccountB^\n" + + "\x1acommon/protocol/user.proto\x12\x14xray.common.protocol\x1a!common/serial/typed_message.proto\"\x91\x01\n" + + "\x04User\x12\r\n" + + "\x05level\x18\x01 \x01(\r\x12\r\n" + + "\x05email\x18\x02 \x01(\t\x121\n" + + "\aaccount\x18\x03 \x01(\v2 .xray.common.serial.TypedMessage\x12\x1a\n" + + "\x12uplink_speed_limit\x18\x04 \x01(\x04\x12\x1c\n" + + "\x14downlink_speed_limit\x18\x05 \x01(\x04B^\n" + "\x18com.xray.common.protocolP\x01Z)github.com/xtls/xray-core/common/protocol\xaa\x02\x14Xray.Common.Protocolb\x06proto3" var ( diff --git a/common/protocol/user.proto b/common/protocol/user.proto index 14cf995b9226..8bad6989d6b6 100644 --- a/common/protocol/user.proto +++ b/common/protocol/user.proto @@ -16,4 +16,6 @@ message User { // Protocol specific account information. Must be the account proto in one of // the proxies. xray.common.serial.TypedMessage account = 3; + uint64 uplink_speed_limit = 4; + uint64 downlink_speed_limit = 5; } diff --git a/go.mod b/go.mod index 5fe853cbe686..941f25bce654 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/xtls/xray-core -go 1.26 +go 1.26.2 require ( github.com/apernet/quic-go v0.59.1-0.20260330051153-c402ee641eb6