diff --git a/app/observatory/burst/healthping.go b/app/observatory/burst/healthping.go index cb1c3402d7cc..d57f51e8802d 100644 --- a/app/observatory/burst/healthping.go +++ b/app/observatory/burst/healthping.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "time" "github.com/xtls/xray-core/common/dice" @@ -24,11 +25,12 @@ type HealthPingSettings struct { // HealthPing is the health checker for balancers type HealthPing struct { - ctx context.Context - dispatcher routing.Dispatcher - access sync.Mutex - ticker *time.Ticker - tickerClose chan struct{} + ctx context.Context + cancelCtx context.CancelFunc + cancelPending atomic.Pointer[context.CancelFunc] + dispatcher routing.Dispatcher + access sync.Mutex + ticker *time.Ticker Settings *HealthPingSettings Results map[string]*HealthPingRTTS @@ -62,10 +64,10 @@ func NewHealthPing(ctx context.Context, dispatcher routing.Dispatcher, config *H settings.Destination = "https://connectivitycheck.gstatic.com/generate_204" } if settings.Interval == 0 { - settings.Interval = time.Duration(1) * time.Minute - } else if settings.Interval < 10 { + settings.Interval = 1 * time.Minute + } else if settings.Interval < 10*time.Second { errors.LogWarning(ctx, "health check interval is too small, 10s is applied") - settings.Interval = time.Duration(10) * time.Second + settings.Interval = 10 * time.Second } if settings.SamplingCount <= 0 { settings.SamplingCount = 10 @@ -73,10 +75,12 @@ func NewHealthPing(ctx context.Context, dispatcher routing.Dispatcher, config *H if settings.Timeout <= 0 { // results are saved after all health pings finish, // a larger timeout could possibly makes checks run longer - settings.Timeout = time.Duration(5) * time.Second + settings.Timeout = 5 * time.Second } + ctx, cancel := context.WithCancel(ctx) return &HealthPing{ ctx: ctx, + cancelCtx: cancel, dispatcher: dispatcher, Settings: settings, Results: nil, @@ -90,17 +94,7 @@ func (h *HealthPing) StartScheduler(selector func() ([]string, error)) { } interval := h.Settings.Interval * time.Duration(h.Settings.SamplingCount) ticker := time.NewTicker(interval) - tickerClose := make(chan struct{}) h.ticker = ticker - h.tickerClose = tickerClose - go func() { - tags, err := selector() - if err != nil { - errors.LogWarning(h.ctx, "error select outbounds for initial health check: ", err) - return - } - h.Check(tags) - }() go func() { for { @@ -110,13 +104,20 @@ func (h *HealthPing) StartScheduler(selector func() ([]string, error)) { errors.LogWarning(h.ctx, "error select outbounds for scheduled health check: ", err) return } - h.doCheck(tags, interval, h.Settings.SamplingCount) + subCtx, cancel := context.WithCancel(h.ctx) + old := h.cancelPending.Swap(&cancel) + if old != nil { + errors.LogDebug(h.ctx, "scheduled health check not finished before next round, canceling previous one") + (*old)() + } + h.doCheck(subCtx, tags, interval, h.Settings.SamplingCount) + h.cancelPending.CompareAndSwap(&cancel, nil) h.Cleanup(tags) }() select { case <-ticker.C: continue - case <-tickerClose: + case <-h.ctx.Done(): return } } @@ -130,8 +131,7 @@ func (h *HealthPing) StopScheduler() { } h.ticker.Stop() h.ticker = nil - close(h.tickerClose) - h.tickerClose = nil + h.cancelCtx() } // Check implements the HealthChecker @@ -140,7 +140,7 @@ func (h *HealthPing) Check(tags []string) error { return nil } errors.LogInfo(h.ctx, "perform one-time health check for tags ", tags) - h.doCheck(tags, 0, 1) + h.doCheck(h.ctx, tags, 0, 1) return nil } @@ -151,13 +151,14 @@ type rtt struct { // doCheck performs the 'rounds' amount checks in given 'duration'. You should make // sure all tags are valid for current balancer -func (h *HealthPing) doCheck(tags []string, duration time.Duration, rounds int) { +// cancel ctx will stop all pending checks +func (h *HealthPing) doCheck(ctx context.Context, tags []string, duration time.Duration, rounds int) { count := len(tags) * rounds if count == 0 { return } ch := make(chan *rtt, count) - + timers := make([]*time.Timer, 0, count) for _, tag := range tags { handler := tag client := newPingClient( @@ -172,7 +173,7 @@ func (h *HealthPing) doCheck(tags []string, duration time.Duration, rounds int) if duration > 0 { delay = time.Duration(dice.RollInt63n(int64(duration))) } - time.AfterFunc(delay, func() { + timers = append(timers, time.AfterFunc(delay, func() { errors.LogDebug(h.ctx, "checking ", handler) delay, err := client.MeasureDelay(h.Settings.HttpMethod) if err == nil { @@ -200,14 +201,21 @@ func (h *HealthPing) doCheck(tags []string, duration time.Duration, rounds int) handler: handler, value: rttFailed, } - }) + })) } } for i := 0; i < count; i++ { - rtt := <-ch - if rtt.value > 0 { - // should not put results when network is down - h.PutResult(rtt.handler, rtt.value) + select { + case rtt := <-ch: + if rtt.value > 0 { + // should not put results when network is down + h.PutResult(rtt.handler, rtt.value) + } + case <-ctx.Done(): + for _, timer := range timers { + timer.Stop() + } + return } } } diff --git a/app/observatory/burst/healthping_result.go b/app/observatory/burst/healthping_result.go index f48d37b60684..4759c5e32242 100644 --- a/app/observatory/burst/healthping_result.go +++ b/app/observatory/burst/healthping_result.go @@ -59,7 +59,7 @@ func (h *HealthPingRTTS) Put(d time.Duration) { if h.rtts == nil { h.rtts = make([]*pingRTT, h.cap) for i := 0; i < h.cap; i++ { - h.rtts[i] = &pingRTT{} + h.rtts[i] = &pingRTT{value: rttUntested} } h.idx = -1 } @@ -88,7 +88,7 @@ func (h *HealthPingRTTS) getStatistics() *HealthPingStats { validRTTs := make([]time.Duration, 0) for _, rtt := range h.rtts { switch { - case rtt.value == 0 || time.Since(rtt.time) > h.validity: + case rtt.value == rttUntested || time.Since(rtt.time) > h.validity: continue case rtt.value == rttFailed: stats.Fail++