Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions proxy/conntrack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package proxy

import (
"context"
"io"
"strings"
"sync"
"sync/atomic"
)

type trackedConn struct {
cancel context.CancelFunc
closer io.Closer // the actual network connection
}

// ConnTracker tracks active connections per user email and supports
// killing all connections for a given user on removal.
// It both cancels the context AND closes the underlying connection
// to ensure immediate termination even for long-lived streams.
type ConnTracker struct {
mu sync.Mutex
conns map[string]map[uint64]*trackedConn
seq atomic.Uint64
}

func NewConnTracker() *ConnTracker {
return &ConnTracker{
conns: make(map[string]map[uint64]*trackedConn),
}
}

// Track registers a connection under the given email.
// conn is the underlying network connection that will be forcibly closed
// when KillAll is called. It may be nil if only context cancellation is needed.
// Returns a wrapped context and a cleanup function that MUST be deferred.
func (t *ConnTracker) Track(ctx context.Context, email string, conn io.Closer) (context.Context, context.CancelFunc, func()) {
ctx, cancel := context.WithCancel(ctx)
key := strings.ToLower(email)
id := t.seq.Add(1)

t.mu.Lock()
if t.conns[key] == nil {
t.conns[key] = make(map[uint64]*trackedConn)
}
t.conns[key][id] = &trackedConn{cancel: cancel, closer: conn}
t.mu.Unlock()

cleanup := func() {
cancel()
t.mu.Lock()
delete(t.conns[key], id)
if len(t.conns[key]) == 0 {
delete(t.conns, key)
}
t.mu.Unlock()
}
return ctx, cancel, cleanup
}

// KillAll cancels all active connections for the given email
// and forcibly closes the underlying network connections.
func (t *ConnTracker) KillAll(email string) {
key := strings.ToLower(email)
t.mu.Lock()
entries := t.conns[key]
delete(t.conns, key)
t.mu.Unlock()

for _, tc := range entries {
tc.cancel()
if tc.closer != nil {
tc.closer.Close()
}
}
}

// Count returns the number of active connections for the given email.
func (t *ConnTracker) Count(email string) int {
key := strings.ToLower(email)
t.mu.Lock()
defer t.mu.Unlock()
return len(t.conns[key])
}
24 changes: 23 additions & 1 deletion proxy/shadowsocks/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/transport/internet/udp"
Expand All @@ -26,6 +27,7 @@ type Server struct {
validator *Validator
policyManager policy.Manager
cone bool
connTracker *proxy.ConnTracker
}

// NewServer create a new Shadowsocks server.
Expand All @@ -48,6 +50,7 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
validator: validator,
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
cone: ctx.Value("cone").(bool),
connTracker: proxy.NewConnTracker(),
}

return s, nil
Expand All @@ -60,7 +63,11 @@ func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {

// RemoveUser implements proxy.UserManager.RemoveUser().
func (s *Server) RemoveUser(ctx context.Context, e string) error {
return s.validator.Del(e)
err := s.validator.Del(e)
if err == nil {
s.connTracker.KillAll(e)
}
return err
}

// GetUser implements proxy.UserManager.GetUser().
Expand Down Expand Up @@ -132,6 +139,12 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis

inbound := session.InboundFromContext(ctx)
var dest *net.Destination
var connCleanup func()
defer func() {
if connCleanup != nil {
connCleanup()
}
}()
reader := buf.NewPacketReader(conn)
for {
mpayload, err := reader.ReadMultiBuffer()
Expand All @@ -152,6 +165,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
request, data, err = DecodeUDPPacket(s.validator, payload)
if err == nil {
inbound.User = request.User
if request.User.Email != "" && connCleanup == nil {
ctx, _, connCleanup = s.connTracker.Track(ctx, request.User.Email, conn)
}
}
}

Expand Down Expand Up @@ -222,6 +238,12 @@ func (s *Server) handleConnection(ctx context.Context, conn stat.Connection, dis
}
inbound.User = request.User

if request.User.Email != "" {
var cleanup func()
ctx, _, cleanup = s.connTracker.Track(ctx, request.User.Email, conn)
defer cleanup()
}

dest := request.Destination()
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: conn.RemoteAddr(),
Expand Down
46 changes: 39 additions & 7 deletions proxy/shadowsocks_2022/inbound_multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
Expand All @@ -36,9 +37,10 @@ func init() {

type MultiUserInbound struct {
sync.Mutex
networks []net.Network
users []*protocol.MemoryUser
service *shadowaead_2022.MultiService[int]
networks []net.Network
users []*protocol.MemoryUser
service *shadowaead_2022.MultiService[int]
connTracker *proxy.ConnTracker
}

func NewMultiServer(ctx context.Context, config *MultiUserServerConfig) (*MultiUserInbound, error) {
Expand All @@ -63,8 +65,9 @@ func NewMultiServer(ctx context.Context, config *MultiUserServerConfig) (*MultiU
}

inbound := &MultiUserInbound{
networks: networks,
users: memUsers,
networks: networks,
users: memUsers,
connTracker: proxy.NewConnTracker(),
}
if config.Key == "" {
return nil, errors.New("missing key")
Expand Down Expand Up @@ -141,11 +144,14 @@ func (i *MultiUserInbound) RemoveUser(ctx context.Context, email string) error {
i.users = i.users[:ulen-1]

// sync to multi service
// Considering implements shadowsocks2022 in xray-core may have better performance.
i.service.UpdateUsersWithPasswords(
err := i.service.UpdateUsersWithPasswords(
C.MapIndexed(i.users, func(index int, it *protocol.MemoryUser) int { return index }),
C.Map(i.users, func(it *protocol.MemoryUser) string { return it.Account.(*MemoryAccount).Key }),
)
if err != nil {
return errors.New("failed to update users in service").Base(err)
}
i.connTracker.KillAll(email)

return nil
}
Expand Down Expand Up @@ -226,9 +232,22 @@ func (i *MultiUserInbound) Process(ctx context.Context, network net.Network, con

func (i *MultiUserInbound) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
inbound := session.InboundFromContext(ctx)
i.Lock()
userInt, _ := A.UserFromContext[int](ctx)
if userInt >= len(i.users) {
i.Unlock()
return errors.New("user index out of range, user may have been removed")
}
user := i.users[userInt]
i.Unlock()
inbound.User = user

if user.Email != "" {
var cleanup func()
ctx, _, cleanup = i.connTracker.Track(ctx, user.Email, conn)
defer cleanup()
}

ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: metadata.Source,
To: metadata.Destination,
Expand All @@ -251,9 +270,22 @@ func (i *MultiUserInbound) NewConnection(ctx context.Context, conn net.Conn, met

func (i *MultiUserInbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
inbound := session.InboundFromContext(ctx)
i.Lock()
userInt, _ := A.UserFromContext[int](ctx)
if userInt >= len(i.users) {
i.Unlock()
return errors.New("user index out of range, user may have been removed")
}
user := i.users[userInt]
i.Unlock()
inbound.User = user

if user.Email != "" {
var cleanup func()
ctx, _, cleanup = i.connTracker.Track(ctx, user.Email, conn)
defer cleanup()
}

ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: metadata.Source,
To: metadata.Destination,
Expand Down
16 changes: 15 additions & 1 deletion proxy/trojan/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport/internet/reality"
"github.com/xtls/xray-core/transport/internet/stat"
Expand All @@ -39,6 +40,7 @@ type Server struct {
validator *Validator
fallbacks map[string]map[string]map[string]*Fallback // or nil
cone bool
connTracker *proxy.ConnTracker
}

// NewServer creates a new trojan inbound handler.
Expand All @@ -60,6 +62,7 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
validator: validator,
cone: ctx.Value("cone").(bool),
connTracker: proxy.NewConnTracker(),
}

if config.Fallbacks != nil {
Expand Down Expand Up @@ -122,7 +125,11 @@ func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {

// RemoveUser implements proxy.UserManager.RemoveUser().
func (s *Server) RemoveUser(ctx context.Context, e string) error {
return s.validator.Del(e)
err := s.validator.Del(e)
if err == nil {
s.connTracker.KillAll(e)
}
return err
}

// GetUser implements proxy.UserManager.GetUser().
Expand Down Expand Up @@ -226,6 +233,13 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
inbound.Name = "trojan"
inbound.CanSpliceCopy = 3
inbound.User = user

if user.Email != "" {
var cleanup func()
ctx, _, cleanup = s.connTracker.Track(ctx, user.Email, conn)
defer cleanup()
}

sessionPolicy = s.policyManager.ForLevel(user.Level)

if destination.Network == net.Network_UDP { // handle udp request
Expand Down
14 changes: 13 additions & 1 deletion proxy/vless/inbound/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type Handler struct {
ctx context.Context
fallbacks map[string]map[string]map[string]*Fallback // or nil
// regexps map[string]*regexp.Regexp // or nil
connTracker *proxy.ConnTracker
}

// New creates a new VLess inbound handler.
Expand All @@ -99,6 +100,7 @@ func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Val
observer: v.GetFeature(extension.ObservatoryType()),
defaultDispatcher: v.GetFeature(routing.DispatcherType()).(routing.Dispatcher),
ctx: ctx,
connTracker: proxy.NewConnTracker(),
}

if config.Decryption != "" && config.Decryption != "none" {
Expand Down Expand Up @@ -244,7 +246,11 @@ func (h *Handler) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
// RemoveUser implements proxy.UserManager.RemoveUser().
func (h *Handler) RemoveUser(ctx context.Context, e string) error {
h.RemoveReverse(h.validator.GetByEmail(e))
return h.validator.Del(e)
err := h.validator.Del(e)
if err == nil {
h.connTracker.KillAll(e)
}
return err
}

// GetUser implements proxy.UserManager.GetUser().
Expand Down Expand Up @@ -537,6 +543,12 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
inbound.User = request.User
inbound.VlessRoute = net.PortFromBytes(userSentID[6:8])

if request.User.Email != "" {
var cleanup func()
ctx, _, cleanup = h.connTracker.Track(ctx, request.User.Email, connection)
defer cleanup()
}

account := request.User.Account.(*vless.MemoryAccount)

if account.Reverse != nil && request.Command != protocol.RequestCommandRvs {
Expand Down
10 changes: 10 additions & 0 deletions proxy/vmess/inbound/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
feature_inbound "github.com/xtls/xray-core/features/inbound"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/proxy/vmess"
"github.com/xtls/xray-core/proxy/vmess/encoding"
"github.com/xtls/xray-core/transport/internet/stat"
Expand Down Expand Up @@ -107,6 +108,7 @@ type Handler struct {
clients *vmess.TimedUserValidator
usersByEmail *userByEmail
sessionHistory *encoding.SessionHistory
connTracker *proxy.ConnTracker
}

// New creates a new VMess inbound handler.
Expand All @@ -118,6 +120,7 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
clients: vmess.NewTimedUserValidator(),
usersByEmail: newUserByEmail(config.GetDefaultValue()),
sessionHistory: encoding.NewSessionHistory(),
connTracker: proxy.NewConnTracker(),
}

for _, user := range config.User {
Expand Down Expand Up @@ -181,6 +184,7 @@ func (h *Handler) RemoveUser(ctx context.Context, email string) error {
return errors.New("User ", email, " not found.")
}
h.clients.Remove(email)
h.connTracker.KillAll(email)
return nil
}

Expand Down Expand Up @@ -272,6 +276,12 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
inbound.CanSpliceCopy = 3
inbound.User = request.User

if request.User.Email != "" {
var cleanup func()
ctx, _, cleanup = h.connTracker.Track(ctx, request.User.Email, connection)
defer cleanup()
}

sessionPolicy = h.policyManager.ForLevel(request.User.Level)

ctx, cancel := context.WithCancel(ctx)
Expand Down