Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions app/proxyman/outbound/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,20 @@ func (h *Handler) DestIpAddress() net.IP {
return internet.DestIpAddress()
}

func (h *Handler) SocketSettings() *internet.SocketConfig {
if h.streamSettings == nil {
return nil
}
return h.streamSettings.SocketSettings
}

func (h *Handler) UsesProxySettings() bool {
if h.senderSettings != nil && h.senderSettings.ProxySettings.HasTag() {
return true
}
return h.streamSettings != nil && h.streamSettings.SocketSettings != nil && len(h.streamSettings.SocketSettings.DialerProxy) > 0
}

// Dial implements internet.Dialer.
func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connection, error) {
if h.senderSettings != nil {
Expand Down
210 changes: 131 additions & 79 deletions proxy/freedom/freedom.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ var defaultBlockAllRule *FinalRule
func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
h := new(Handler)
if handler, ok := session.FullHandlerFromContext(ctx).(handlerWithSocketSettings); ok {
if sockopt := handler.SocketSettings(); sockopt != nil {
h.socketStrategy = sockopt.DomainStrategy
}
}
if handler, ok := session.FullHandlerFromContext(ctx).(handlerWithProxySettings); ok {
h.usesProxySettings = handler.UsesProxySettings()
}
if err := core.RequireFeatures(ctx, func(pm policy.Manager) error {
return h.Init(config.(*Config), pm)
}); err != nil {
Expand Down Expand Up @@ -88,6 +96,14 @@ func init() {
}
}

type handlerWithSocketSettings interface {
SocketSettings() *internet.SocketConfig
}

type handlerWithProxySettings interface {
UsesProxySettings() bool
}

type FinalRule struct {
action RuleAction
network [8]bool
Expand All @@ -98,9 +114,11 @@ type FinalRule struct {

// Handler handles Freedom connections.
type Handler struct {
policyManager policy.Manager
config *Config
finalRules []*FinalRule
policyManager policy.Manager
config *Config
finalRules []*FinalRule
socketStrategy internet.DomainStrategy
usesProxySettings bool
}

func buildFinalRule(config *FinalRuleConfig) (*FinalRule, error) {
Expand Down Expand Up @@ -177,22 +195,6 @@ func getDefaultFinalRule(inbound *session.Inbound) *FinalRule {
return nil
}

func (h *Handler) shouldResolveDomainBeforeFinalRules(dialDest net.Destination, defaultRule *FinalRule) bool {
if !dialDest.Address.Family().IsDomain() {
return false
}
if len(h.finalRules) > 0 {
rule := h.finalRules[0]
if rule.action == RuleAction_Allow && rule.network[dialDest.Network] && len(rule.port) == 0 && rule.ip == nil {
return false
}
}
if defaultRule != nil || len(h.finalRules) > 0 {
return true
}
return false
}

func (h *Handler) matchFinalRule(network net.Network, address net.Address, port net.Port, defaultRule *FinalRule) *FinalRule {
for _, rule := range h.finalRules {
if rule.Apply(network, address, port) {
Expand Down Expand Up @@ -239,11 +241,32 @@ func (h *Handler) blockDelay(rule *FinalRule) time.Duration {
min = rule.blockDelay.Min
max = rule.blockDelay.Max
}
abs := max - min
span := max - min
if max < min {
abs = min - max
span = min - max
}
return time.Duration(min+uint64(dice.Roll(int(abs+1)))) * time.Second
return time.Duration(min+uint64(dice.Roll(int(span+1)))) * time.Second
}

func (h *Handler) blackhole(ctx context.Context, input buf.Reader, output buf.Writer, rule *FinalRule, dest *net.Destination) error {
delay := h.blockDelay(rule)
errors.LogInfo(ctx, "blocked target: ", *dest, ", blackholing connection for ", delay)
timer := time.AfterFunc(delay, func() {
common.Interrupt(input)
common.Interrupt(output)
errors.LogInfo(ctx, "closed blackholed connection to blocked target: ", *dest)
})
defer timer.Stop()
defer common.Close(output)
_ = buf.Copy(input, buf.Discard)
return nil
}

func (h *Handler) udpDomainStrategy() internet.DomainStrategy {
if h.config.DomainStrategy.HasStrategy() {
return h.config.DomainStrategy
}
return h.socketStrategy
}

func isValidAddress(addr *net.IPOrDomain) bool {
Expand Down Expand Up @@ -295,40 +318,73 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
var blockedRule *FinalRule
err := retry.ExponentialBackoff(5, 100).On(func() error {
dialDest := destination
if h.config.DomainStrategy.HasStrategy() && dialDest.Address.Family().IsDomain() {
strategy := h.config.DomainStrategy
if destination.Network == net.Network_UDP && origTargetAddr != nil && outGateway == nil {
strategy = strategy.GetDynamicStrategy(origTargetAddr.Family())
}
ips, err := internet.LookupForIP(dialDest.Address.Domain(), strategy, outGateway)
if err != nil {
errors.LogInfoInner(ctx, err, "failed to get IP address for domain ", dialDest.Address.Domain())
if h.config.DomainStrategy.ForceIP() {
return err
}
} else {
dialDest = net.Destination{
Network: dialDest.Network,
Address: net.IPAddress(ips[dice.Roll(len(ips))]),
Port: dialDest.Port,

if dialDest.Address.Family().IsDomain() {
if strategy := h.config.DomainStrategy; strategy.HasStrategy() {
if destination.Network == net.Network_UDP && origTargetAddr != nil && outGateway == nil {
strategy = strategy.GetDynamicStrategy(origTargetAddr.Family())
}
errors.LogInfo(ctx, "dialing to ", dialDest)
}
} else if h.shouldResolveDomainBeforeFinalRules(dialDest, defaultRule) { // asis + domain + hasrules
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, dialDest.Address.Domain())
if err != nil {
errors.LogInfoInner(ctx, err, "failed to get IP address for domain ", dialDest.Address.Domain())
} else if len(addrs) > 0 {
if addr := net.IPAddress(addrs[dice.Roll(len(addrs))].IP); addr != nil {
dialDest.Address = addr
ips, err := internet.LookupForIP(dialDest.Address.Domain(), strategy, outGateway)
if err != nil { // SRV/TXT
errors.LogInfoInner(ctx, err, "failed to get IP address for domain ", dialDest.Address.Domain())
if h.config.DomainStrategy.ForceIP() || defaultRule != nil || len(h.finalRules) > 0 {
return err // retry
}
} else { // to ip
dialDest = net.Destination{
Network: dialDest.Network,
Address: net.IPAddress(ips[dice.Roll(len(ips))]),
Port: dialDest.Port,
}
errors.LogInfo(ctx, "dialing to ", dialDest)
if rule := h.matchFinalRule(dialDest.Network, dialDest.Address, dialDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
blockedDest = &dialDest
blockedRule = rule
return nil
}
}
} else if defaultRule != nil || len(h.finalRules) > 0 { // freedom asis + hasrules
if strategy := h.socketStrategy; strategy.HasStrategy() {
ips, err := internet.LookupForIP(dialDest.Address.Domain(), strategy, outGateway)
if err != nil { // SRV/TXT
errors.LogInfoInner(ctx, err, "failed to get IP address for domain ", dialDest.Address.Domain())
if strategy.ForceIP() {
return err // retry
}
}
for _, ip := range ips {
if addr := net.IPAddress(ip); addr != nil {
if rule := h.matchFinalRule(dialDest.Network, addr, dialDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
blockedDest = &dialDest
blockedDest.Address = addr
blockedRule = rule
return nil
}
}
}
} else { // sockopt asis
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, dialDest.Address.Domain())
if err != nil { // SRV/TXT
errors.LogInfoInner(ctx, err, "failed to get IP address for domain ", dialDest.Address.Domain())
}
for _, addr := range addrs {
if ipAddr := net.IPAddress(addr.IP); ipAddr != nil {
if rule := h.matchFinalRule(dialDest.Network, ipAddr, dialDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
blockedDest = &dialDest
blockedDest.Address = ipAddr
blockedRule = rule
return nil
}
}
}
}
}
}
if rule := h.matchFinalRule(dialDest.Network, dialDest.Address, dialDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
blockedDest = &dialDest
blockedRule = rule
return nil
} else {
if rule := h.matchFinalRule(dialDest.Network, dialDest.Address, dialDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
blockedDest = &dialDest
blockedRule = rule
return nil
}
}

rawConn, err := dialer.Dial(ctx, dialDest)
Expand All @@ -343,25 +399,21 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return errors.New("failed to open connection to ", destination).Base(err)
}
if blockedDest != nil {
delay := h.blockDelay(blockedRule)
errors.LogInfo(ctx, "blocked target: ", *blockedDest, ", blackholing connection for ", delay)
timer := time.AfterFunc(delay, func() {
common.Interrupt(input)
common.Interrupt(output)
errors.LogInfo(ctx, "closed blackholed connection to blocked target: ", *blockedDest)
})
defer timer.Stop()
defer common.Close(output)
if err := buf.Copy(input, buf.Discard); err != nil {
return nil
return h.blackhole(ctx, input, output, blockedRule, blockedDest)
}
if defaultRule != nil || len(h.finalRules) > 0 {
if h.usesProxySettings {
errors.LogInfo(ctx, "skipping final rule check for proxied remote endpoint, original target: ", destination)
} else {
// SRV/TXT, lookup failed
remoteDest := net.DestinationFromAddr(conn.RemoteAddr())
if rule := h.matchFinalRule(remoteDest.Network, remoteDest.Address, remoteDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
conn.Close()
return h.blackhole(ctx, input, output, rule, &remoteDest)
}
}
return nil
}
// TODO: SRV/TXT
// if remoteDest := net.DestinationFromAddr(conn.RemoteAddr()); h.applyFinalRules(remoteDest.Network, remoteDest.Address, remoteDest.Port, defaultRule) == RuleAction_Block {
// conn.Close()
// return blackhole(remoteDest)
// }

if h.config.ProxyProtocol > 0 && h.config.ProxyProtocol <= 2 {
version := byte(h.config.ProxyProtocol)
srcAddr := inbound.Source.RawNetAddr()
Expand Down Expand Up @@ -406,7 +458,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
writer = buf.NewWriter(conn)
}
} else {
writer = NewPacketWriter(conn, h, defaultRule, UDPOverride, destination)
writer = NewPacketWriter(conn, h, defaultRule, UDPOverride, destination, outGateway)
if h.config.Noises != nil {
errors.LogDebug(ctx, "NOISE", h.config.Noises)
writer = &NoisePacketWriter{
Expand Down Expand Up @@ -535,7 +587,7 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
}

// DialDest means the dial target used in the dialer when creating conn
func NewPacketWriter(conn net.Conn, h *Handler, defaultRule *FinalRule, UDPOverride net.Destination, DialDest net.Destination) buf.Writer {
func NewPacketWriter(conn net.Conn, h *Handler, defaultRule *FinalRule, UDPOverride net.Destination, DialDest net.Destination, outGateway net.Address) buf.Writer {
iConn := conn
statConn, ok := iConn.(*stat.CounterConnection)
if ok {
Expand All @@ -559,7 +611,7 @@ func NewPacketWriter(conn net.Conn, h *Handler, defaultRule *FinalRule, UDPOverr
DefaultRule: defaultRule,
UDPOverride: UDPOverride,
ResolvedUDPAddr: resolvedUDPAddr,
LocalAddr: net.DestinationFromAddr(conn.LocalAddr()).Address,
OutGateway: outGateway,
}

}
Expand All @@ -578,7 +630,7 @@ type PacketWriter struct {
// Resulting in these packets being sent to many different IPs randomly
// So, cache and keep the resolve result
ResolvedUDPAddr *utils.TypedSyncMap[string, net.Address]
LocalAddr net.Address
OutGateway net.Address
}

func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
Expand All @@ -601,21 +653,21 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if ip, ok := w.ResolvedUDPAddr.Load(b.UDP.Address.Domain()); ok {
b.UDP.Address = ip
} else {
ShouldUseSystemResolver := true
if w.Handler.config.DomainStrategy.HasStrategy() {
ips, err := internet.LookupForIP(b.UDP.Address.Domain(), w.Handler.config.DomainStrategy, w.LocalAddr)
shouldUseSystemResolver := true
if resolveStrategy := w.Handler.udpDomainStrategy(); resolveStrategy.HasStrategy() {
ips, err := internet.LookupForIP(b.UDP.Address.Domain(), resolveStrategy, w.OutGateway)
if err != nil {
// drop packet if resolve failed when forceIP
if w.Handler.config.DomainStrategy.ForceIP() {
if resolveStrategy.ForceIP() {
b.Release()
continue
}
} else {
ip = net.IPAddress(ips[dice.Roll(len(ips))])
ShouldUseSystemResolver = false
shouldUseSystemResolver = false
}
}
if ShouldUseSystemResolver {
if shouldUseSystemResolver {
udpAddr, err := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
if err != nil {
b.Release()
Expand Down
4 changes: 1 addition & 3 deletions transport/internet/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ var (
globalTransportConfigCreatorCache = make(map[string]ConfigCreator)
)

var strategy = [][]byte{
var strategy = [11][3]byte{
// name strategy, prefer, fallback
{0, 0, 0}, // AsIs none, /, /
{1, 0, 0}, // UseIP use, both, none
Expand All @@ -27,8 +27,6 @@ var strategy = [][]byte{
{2, 6, 4}, // ForceIPv6v4 force, 6, 4
}

const unknownProtocol = "unknown"

func RegisterProtocolConfigCreator(name string, creator ConfigCreator) error {
if _, found := globalTransportConfigCreatorCache[name]; found {
return errors.New("protocol ", name, " is already registered").AtError()
Expand Down