diff --git a/proxy/conntrack.go b/proxy/conntrack.go new file mode 100644 index 000000000000..f7bbc8d85b3d --- /dev/null +++ b/proxy/conntrack.go @@ -0,0 +1,83 @@ +package proxy + +import ( + "context" + "io" + "strings" + "sync" + "sync/atomic" +) + +type trackedConn struct { + cancel context.CancelFunc + closer io.Closer // the actual network connection +} + +// ConnTracker tracks active connections per user email and supports +// killing all connections for a given user on removal. +// It both cancels the context AND closes the underlying connection +// to ensure immediate termination even for long-lived streams. +type ConnTracker struct { + mu sync.Mutex + conns map[string]map[uint64]*trackedConn + seq atomic.Uint64 +} + +func NewConnTracker() *ConnTracker { + return &ConnTracker{ + conns: make(map[string]map[uint64]*trackedConn), + } +} + +// Track registers a connection under the given email. +// conn is the underlying network connection that will be forcibly closed +// when KillAll is called. It may be nil if only context cancellation is needed. +// Returns a wrapped context and a cleanup function that MUST be deferred. +func (t *ConnTracker) Track(ctx context.Context, email string, conn io.Closer) (context.Context, context.CancelFunc, func()) { + ctx, cancel := context.WithCancel(ctx) + key := strings.ToLower(email) + id := t.seq.Add(1) + + t.mu.Lock() + if t.conns[key] == nil { + t.conns[key] = make(map[uint64]*trackedConn) + } + t.conns[key][id] = &trackedConn{cancel: cancel, closer: conn} + t.mu.Unlock() + + cleanup := func() { + cancel() + t.mu.Lock() + delete(t.conns[key], id) + if len(t.conns[key]) == 0 { + delete(t.conns, key) + } + t.mu.Unlock() + } + return ctx, cancel, cleanup +} + +// KillAll cancels all active connections for the given email +// and forcibly closes the underlying network connections. +func (t *ConnTracker) KillAll(email string) { + key := strings.ToLower(email) + t.mu.Lock() + entries := t.conns[key] + delete(t.conns, key) + t.mu.Unlock() + + for _, tc := range entries { + tc.cancel() + if tc.closer != nil { + tc.closer.Close() + } + } +} + +// Count returns the number of active connections for the given email. +func (t *ConnTracker) Count(email string) int { + key := strings.ToLower(email) + t.mu.Lock() + defer t.mu.Unlock() + return len(t.conns[key]) +} diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 360ea38c8d53..fbafd1ff797c 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -16,6 +16,7 @@ import ( "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" + "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/udp" @@ -26,6 +27,7 @@ type Server struct { validator *Validator policyManager policy.Manager cone bool + connTracker *proxy.ConnTracker } // NewServer create a new Shadowsocks server. @@ -48,6 +50,7 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { validator: validator, policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), cone: ctx.Value("cone").(bool), + connTracker: proxy.NewConnTracker(), } return s, nil @@ -60,7 +63,11 @@ func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { // RemoveUser implements proxy.UserManager.RemoveUser(). func (s *Server) RemoveUser(ctx context.Context, e string) error { - return s.validator.Del(e) + err := s.validator.Del(e) + if err == nil { + s.connTracker.KillAll(e) + } + return err } // GetUser implements proxy.UserManager.GetUser(). @@ -132,6 +139,12 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis inbound := session.InboundFromContext(ctx) var dest *net.Destination + var connCleanup func() + defer func() { + if connCleanup != nil { + connCleanup() + } + }() reader := buf.NewPacketReader(conn) for { mpayload, err := reader.ReadMultiBuffer() @@ -152,6 +165,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis request, data, err = DecodeUDPPacket(s.validator, payload) if err == nil { inbound.User = request.User + if request.User.Email != "" && connCleanup == nil { + ctx, _, connCleanup = s.connTracker.Track(ctx, request.User.Email, conn) + } } } @@ -222,6 +238,12 @@ func (s *Server) handleConnection(ctx context.Context, conn stat.Connection, dis } inbound.User = request.User + if request.User.Email != "" { + var cleanup func() + ctx, _, cleanup = s.connTracker.Track(ctx, request.User.Email, conn) + defer cleanup() + } + dest := request.Destination() ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ From: conn.RemoteAddr(), diff --git a/proxy/shadowsocks_2022/inbound_multi.go b/proxy/shadowsocks_2022/inbound_multi.go index 4bfa086aa9ef..6cf5ffd2db51 100644 --- a/proxy/shadowsocks_2022/inbound_multi.go +++ b/proxy/shadowsocks_2022/inbound_multi.go @@ -18,6 +18,7 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" @@ -36,9 +37,10 @@ func init() { type MultiUserInbound struct { sync.Mutex - networks []net.Network - users []*protocol.MemoryUser - service *shadowaead_2022.MultiService[int] + networks []net.Network + users []*protocol.MemoryUser + service *shadowaead_2022.MultiService[int] + connTracker *proxy.ConnTracker } func NewMultiServer(ctx context.Context, config *MultiUserServerConfig) (*MultiUserInbound, error) { @@ -63,8 +65,9 @@ func NewMultiServer(ctx context.Context, config *MultiUserServerConfig) (*MultiU } inbound := &MultiUserInbound{ - networks: networks, - users: memUsers, + networks: networks, + users: memUsers, + connTracker: proxy.NewConnTracker(), } if config.Key == "" { return nil, errors.New("missing key") @@ -141,11 +144,14 @@ func (i *MultiUserInbound) RemoveUser(ctx context.Context, email string) error { i.users = i.users[:ulen-1] // sync to multi service - // Considering implements shadowsocks2022 in xray-core may have better performance. - i.service.UpdateUsersWithPasswords( + err := i.service.UpdateUsersWithPasswords( C.MapIndexed(i.users, func(index int, it *protocol.MemoryUser) int { return index }), C.Map(i.users, func(it *protocol.MemoryUser) string { return it.Account.(*MemoryAccount).Key }), ) + if err != nil { + return errors.New("failed to update users in service").Base(err) + } + i.connTracker.KillAll(email) return nil } @@ -226,9 +232,22 @@ func (i *MultiUserInbound) Process(ctx context.Context, network net.Network, con func (i *MultiUserInbound) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { inbound := session.InboundFromContext(ctx) + i.Lock() userInt, _ := A.UserFromContext[int](ctx) + if userInt >= len(i.users) { + i.Unlock() + return errors.New("user index out of range, user may have been removed") + } user := i.users[userInt] + i.Unlock() inbound.User = user + + if user.Email != "" { + var cleanup func() + ctx, _, cleanup = i.connTracker.Track(ctx, user.Email, conn) + defer cleanup() + } + ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ From: metadata.Source, To: metadata.Destination, @@ -251,9 +270,22 @@ func (i *MultiUserInbound) NewConnection(ctx context.Context, conn net.Conn, met func (i *MultiUserInbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error { inbound := session.InboundFromContext(ctx) + i.Lock() userInt, _ := A.UserFromContext[int](ctx) + if userInt >= len(i.users) { + i.Unlock() + return errors.New("user index out of range, user may have been removed") + } user := i.users[userInt] + i.Unlock() inbound.User = user + + if user.Email != "" { + var cleanup func() + ctx, _, cleanup = i.connTracker.Track(ctx, user.Email, conn) + defer cleanup() + } + ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ From: metadata.Source, To: metadata.Destination, diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index d66219c48d2c..8a8895163037 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -20,6 +20,7 @@ import ( "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" + "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/stat" @@ -39,6 +40,7 @@ type Server struct { validator *Validator fallbacks map[string]map[string]map[string]*Fallback // or nil cone bool + connTracker *proxy.ConnTracker } // NewServer creates a new trojan inbound handler. @@ -60,6 +62,7 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), validator: validator, cone: ctx.Value("cone").(bool), + connTracker: proxy.NewConnTracker(), } if config.Fallbacks != nil { @@ -122,7 +125,11 @@ func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { // RemoveUser implements proxy.UserManager.RemoveUser(). func (s *Server) RemoveUser(ctx context.Context, e string) error { - return s.validator.Del(e) + err := s.validator.Del(e) + if err == nil { + s.connTracker.KillAll(e) + } + return err } // GetUser implements proxy.UserManager.GetUser(). @@ -226,6 +233,13 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con inbound.Name = "trojan" inbound.CanSpliceCopy = 3 inbound.User = user + + if user.Email != "" { + var cleanup func() + ctx, _, cleanup = s.connTracker.Track(ctx, user.Email, conn) + defer cleanup() + } + sessionPolicy = s.policyManager.ForLevel(user.Level) if destination.Network == net.Network_UDP { // handle udp request diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 0301fb7889bb..d90dd0a02c7e 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -85,6 +85,7 @@ type Handler struct { ctx context.Context fallbacks map[string]map[string]map[string]*Fallback // or nil // regexps map[string]*regexp.Regexp // or nil + connTracker *proxy.ConnTracker } // New creates a new VLess inbound handler. @@ -99,6 +100,7 @@ func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Val observer: v.GetFeature(extension.ObservatoryType()), defaultDispatcher: v.GetFeature(routing.DispatcherType()).(routing.Dispatcher), ctx: ctx, + connTracker: proxy.NewConnTracker(), } if config.Decryption != "" && config.Decryption != "none" { @@ -244,7 +246,11 @@ func (h *Handler) AddUser(ctx context.Context, u *protocol.MemoryUser) error { // RemoveUser implements proxy.UserManager.RemoveUser(). func (h *Handler) RemoveUser(ctx context.Context, e string) error { h.RemoveReverse(h.validator.GetByEmail(e)) - return h.validator.Del(e) + err := h.validator.Del(e) + if err == nil { + h.connTracker.KillAll(e) + } + return err } // GetUser implements proxy.UserManager.GetUser(). @@ -537,6 +543,12 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s inbound.User = request.User inbound.VlessRoute = net.PortFromBytes(userSentID[6:8]) + if request.User.Email != "" { + var cleanup func() + ctx, _, cleanup = h.connTracker.Track(ctx, request.User.Email, connection) + defer cleanup() + } + account := request.User.Account.(*vless.MemoryAccount) if account.Reverse != nil && request.Command != protocol.RequestCommandRvs { diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 6a8591ad6929..40960693c09c 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -21,6 +21,7 @@ import ( feature_inbound "github.com/xtls/xray-core/features/inbound" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" + "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/proxy/vmess" "github.com/xtls/xray-core/proxy/vmess/encoding" "github.com/xtls/xray-core/transport/internet/stat" @@ -107,6 +108,7 @@ type Handler struct { clients *vmess.TimedUserValidator usersByEmail *userByEmail sessionHistory *encoding.SessionHistory + connTracker *proxy.ConnTracker } // New creates a new VMess inbound handler. @@ -118,6 +120,7 @@ func New(ctx context.Context, config *Config) (*Handler, error) { clients: vmess.NewTimedUserValidator(), usersByEmail: newUserByEmail(config.GetDefaultValue()), sessionHistory: encoding.NewSessionHistory(), + connTracker: proxy.NewConnTracker(), } for _, user := range config.User { @@ -181,6 +184,7 @@ func (h *Handler) RemoveUser(ctx context.Context, email string) error { return errors.New("User ", email, " not found.") } h.clients.Remove(email) + h.connTracker.KillAll(email) return nil } @@ -272,6 +276,12 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s inbound.CanSpliceCopy = 3 inbound.User = request.User + if request.User.Email != "" { + var cleanup func() + ctx, _, cleanup = h.connTracker.Track(ctx, request.User.Email, connection) + defer cleanup() + } + sessionPolicy = h.policyManager.ForLevel(request.User.Level) ctx, cancel := context.WithCancel(ctx)