Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 80 additions & 18 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"math"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -58,8 +59,10 @@ type Agent struct {
tieBreaker uint64
lite bool

connectionState ConnectionState
gatheringState GatheringState
connectionState ConnectionState
gatheringState GatheringState
gatherGeneration uint64
gatherEndSent bool

mDNSMode MulticastDNSMode
mDNSName string
Expand Down Expand Up @@ -1052,28 +1055,39 @@ func (a *Agent) addRemoteCandidate(cand Candidate) { //nolint:cyclop
a.requestConnectivityCheck()
}

func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn net.PacketConn) error {
func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn net.PacketConn, gen uint64) error {
if err := ctx.Err(); err != nil {
return err
}

cleanupCandidate := func(reason string) {
if err := cand.close(); err != nil {
a.log.Warnf("Failed to close %s candidate: %v", reason, err)
}
if err := candidateConn.Close(); err != nil {
a.log.Warnf("Failed to close %s candidate connection: %v", reason, err)
}
}

return a.loop.Run(ctx, func(context.Context) {
if a.gatherGeneration != gen {
a.log.Debugf("Ignoring candidate from different gather generation (a: %d c: %d)", a.gatherGeneration, gen)
cleanupCandidate("old")

return
}

set := a.localCandidates[cand.NetworkType()]
for _, candidate := range set {
if candidate.Equal(cand) {
a.log.Debugf("Ignore duplicate candidate: %s", cand)
if err := cand.close(); err != nil {
a.log.Warnf("Failed to close duplicate candidate: %v", err)
}
if err := candidateConn.Close(); err != nil {
a.log.Warnf("Failed to close duplicate candidate connection: %v", err)
}
cleanupCandidate("duplicate")

return
}
}

a.setCandidateExtensions(cand)
a.setCandidateExtensions(cand, gen)
cand.start(a, candidateConn, a.startedCh)

set = append(set, cand)
Expand All @@ -1093,14 +1107,22 @@ func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn
})
}

func (a *Agent) setCandidateExtensions(cand Candidate) {
func (a *Agent) setCandidateExtensions(cand Candidate, candidateGeneration uint64) {
err := cand.AddExtension(CandidateExtension{
Key: "ufrag",
Value: a.localUfrag,
})
if err != nil {
a.log.Errorf("Failed to add ufrag extension to candidate: %v", err)
}

err = cand.AddExtension(CandidateExtension{
Key: "generation",
Value: strconv.FormatUint(candidateGeneration, 10),
})
if err != nil {
a.log.Errorf("Failed to add generation extension to candidate: %v", err)
}
}

// GetRemoteCandidates returns the remote candidates.
Expand Down Expand Up @@ -1633,18 +1655,22 @@ func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop
}

var err error
var shouldNotifyGatherEnd bool
if runErr := a.loop.Run(a.loop, func(_ context.Context) {
if a.gatheringState == GatheringStateGathering {
a.gatherCandidateCancel()
}
if a.gatheringState != GatheringStateNew {
shouldNotifyGatherEnd = a.setGatheringStateLocked(GatheringStateComplete, a.gatherGeneration)
}
a.bumpGatheringGenerationLocked()

// Clear all agent needed to take back to fresh state
a.removeUfragFromMux()
a.localUfrag = ufrag
a.localPwd = pwd
a.remoteUfrag = ""
a.remotePwd = ""
a.gatheringState = GatheringStateNew
a.checklist = make([]*CandidatePair, 0)
a.pairsByID = make(map[uint64]*CandidatePair)
a.pendingBindingRequests = make([]bindingRequest, 0)
Expand All @@ -1661,27 +1687,63 @@ func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop
return runErr
}

if shouldNotifyGatherEnd {
a.candidateNotifier.EnqueueCandidate(nil)
}

return err
}

func (a *Agent) setGatheringState(newState GatheringState) error {
func (a *Agent) setGatheringState(newState GatheringState, generation uint64) error {
var shouldNotifyGatherEnd bool
done := make(chan struct{})
if err := a.loop.Run(a.loop, func(context.Context) {
if a.gatheringState != newState && newState == GatheringStateComplete {
a.candidateNotifier.EnqueueCandidate(nil)
}

a.gatheringState = newState
shouldNotifyGatherEnd = a.setGatheringStateLocked(newState, generation)
close(done)
}); err != nil {
return err
}

<-done

if shouldNotifyGatherEnd {
a.candidateNotifier.EnqueueCandidate(nil)
}

return nil
}

// setGatheringStateLocked updates the gathering state and returns true if
// the caller should send an end-of-candidates notification (nil candidate)
// via a.candidateNotifier.EnqueueCandidate(nil). The notification is deferred
// to the caller so that it happens outside the agent loop, avoiding nested
// lock acquisition between the loop and the notifier.
func (a *Agent) setGatheringStateLocked(newState GatheringState, generation uint64) bool {
if generation != a.gatherGeneration {
return false
}

if a.gatheringState == newState {
return false
}

a.gatheringState = newState

if newState == GatheringStateComplete && !a.gatherEndSent {
a.gatherEndSent = true

return true
}

return false
}

func (a *Agent) bumpGatheringGenerationLocked() {
a.gatherGeneration++
a.gatherEndSent = false
a.gatheringState = GatheringStateNew
}

func (a *Agent) needsToCheckPriorityOnNominated() bool {
return !a.lite || a.enableUseCandidateCheckPriority
}
Expand Down
89 changes: 80 additions & 9 deletions agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1385,9 +1385,7 @@ func TestAgentRestart(t *testing.T) {

t.Run("Restart Both Sides", func(t *testing.T) {
// Get all addresses of candidates concatenated
generateCandidateAddressStrings := func(candidates []Candidate, err error) (out string) {
require.NoError(t, err)

generateCandidateAddressStrings := func(candidates []Candidate) (out string) {
for _, c := range candidates {
out += c.Address() + ":"
out += strconv.Itoa(c.Port())
Expand All @@ -1396,14 +1394,31 @@ func TestAgentRestart(t *testing.T) {
return
}

candidateHasGeneration := func(generation uint64, candidate Candidate) {
genString := strconv.FormatUint(generation, 10)
ext, ok := candidate.GetExtension("generation")

require.True(t, ok)
require.Equal(t, genString, ext.Value)
}

// Store the original candidates, confirm that after we reconnect we have new pairs
connA, connB := pipe(t, &AgentConfig{
DisconnectedTimeout: &oneSecond,
FailedTimeout: &oneSecond,
})
defer closePipe(t, connA, connB)
connAFirstCandidates := generateCandidateAddressStrings(connA.agent.GetLocalCandidates())
connBFirstCandidates := generateCandidateAddressStrings(connB.agent.GetLocalCandidates())

aFirstGeneration := connA.agent.gatherGeneration
bFirstGeneration := connB.agent.gatherGeneration

connAFirstCandidates, err := connA.agent.GetLocalCandidates()
require.NoError(t, err)
connBFirstCandidates, err := connB.agent.GetLocalCandidates()
require.NoError(t, err)

candidateHasGeneration(aFirstGeneration, connAFirstCandidates[0])
candidateHasGeneration(bFirstGeneration, connBFirstCandidates[0])

aNotifier, aConnected := onConnected()
require.NoError(t, connA.agent.OnConnectionStateChange(aNotifier))
Expand All @@ -1415,6 +1430,10 @@ func TestAgentRestart(t *testing.T) {
require.NoError(t, connA.agent.Restart("", ""))
require.NoError(t, connB.agent.Restart("", ""))

// Generation should change after Restart call
require.NotEqual(t, aFirstGeneration, connA.agent.gatherGeneration)
require.NotEqual(t, bFirstGeneration, connB.agent.gatherGeneration)

// Exchange Candidates and Credentials
ufrag, pwd, err := connB.agent.GetLocalUserCredentials()
require.NoError(t, err)
Expand All @@ -1430,9 +1449,21 @@ func TestAgentRestart(t *testing.T) {
<-aConnected
<-bConnected

connASecondCandidates, err := connA.agent.GetLocalCandidates()
require.NoError(t, err)
connBSecondCandidates, err := connB.agent.GetLocalCandidates()
require.NoError(t, err)

candidateHasGeneration(connA.agent.gatherGeneration, connASecondCandidates[0])
candidateHasGeneration(connB.agent.gatherGeneration, connASecondCandidates[0])

// Assert that we have new candidates each time
require.NotEqual(t, connAFirstCandidates, generateCandidateAddressStrings(connA.agent.GetLocalCandidates()))
require.NotEqual(t, connBFirstCandidates, generateCandidateAddressStrings(connB.agent.GetLocalCandidates()))
aFirstCandidatesString := generateCandidateAddressStrings(connAFirstCandidates)
aSecondCandidatesString := generateCandidateAddressStrings(connASecondCandidates)
bFirstCandidatesString := generateCandidateAddressStrings(connBFirstCandidates)
bSecondCandidatesString := generateCandidateAddressStrings(connBSecondCandidates)
require.NotEqual(t, aFirstCandidatesString, aSecondCandidatesString)
require.NotEqual(t, bFirstCandidatesString, bSecondCandidatesString)
})
}

Expand Down Expand Up @@ -1511,7 +1542,7 @@ func TestGetLocalCandidates(t *testing.T) {

expectedCandidates = append(expectedCandidates, cand)

err = agent.addCandidate(context.Background(), cand, dummyConn)
err = agent.addCandidate(context.Background(), cand, dummyConn, agent.gatherGeneration)
require.NoError(t, err)
}

Expand Down Expand Up @@ -2151,7 +2182,7 @@ func TestSetCandidatesUfrag(t *testing.T) {
cand, errCand := NewCandidateHost(&cfg)
require.NoError(t, errCand)

err = agent.addCandidate(context.Background(), cand, dummyConn)
err = agent.addCandidate(context.Background(), cand, dummyConn, agent.gatherGeneration)
require.NoError(t, err)
}

Expand All @@ -2166,6 +2197,46 @@ func TestSetCandidatesUfrag(t *testing.T) {
}
}

func TestAddingCandidatesFromOtherGenerations(t *testing.T) {
var config AgentConfig

agent, err := NewAgent(&config)
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()

agent.gatherGeneration = 3

dummyConn := &net.UDPConn{}

for i := 0; i < 5; i++ {
cfg := CandidateHostConfig{
Network: "udp",
Address: "192.168.0.2",
Port: 1000 + i,
Component: 1,
}

cand, errCand := NewCandidateHost(&cfg)
require.NoError(t, errCand)

err = agent.addCandidate(context.Background(), cand, dummyConn, uint64(i)) // nolint:gosec
require.NoError(t, err)
}

actualCandidates, err := agent.GetLocalCandidates()
require.NoError(t, err)
require.Equal(t, 1, len(actualCandidates), "Only the candidate with a matching generation should be added")

ext, ok := actualCandidates[0].GetExtension("generation")
require.True(t, ok)

generation, err := strconv.ParseUint(ext.Value, 10, 64)
require.NoError(t, err)
require.Equal(t, agent.gatherGeneration, generation)
}

func TestAlwaysSentKeepAlive(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()

Expand Down
Loading
Loading