Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,7 @@ func (a *Agent) handleRoleConflict(msg *stun.Message, local, remote Candidate, r
}

// handleInbound processes STUN traffic from a remote candidate.
func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) {
func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { //nolint:cyclop
if msg == nil || local == nil {
return
}
Expand All @@ -1415,6 +1415,10 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add
if remoteCandidate, ok = a.handleInboundRequest(remoteCandidate, local, remote, msg); !ok {
return
}
case stun.ClassErrorResponse:
a.handleInboundErrorResponse(remoteCandidate, local, remote, msg)

return
default:
}

Expand All @@ -1427,7 +1431,8 @@ func canHandleInbound(msg *stun.Message) bool {
return msg.Type.Method == stun.MethodBinding &&
(msg.Type.Class == stun.ClassSuccessResponse ||
msg.Type.Class == stun.ClassRequest ||
msg.Type.Class == stun.ClassIndication)
msg.Type.Class == stun.ClassIndication ||
msg.Type.Class == stun.ClassErrorResponse)
}

func (a *Agent) handleInboundResponse(
Expand Down Expand Up @@ -1513,6 +1518,65 @@ func (a *Agent) handleInboundRequest(
return remoteCandidate, true
}

func (a *Agent) handleInboundErrorResponse(
remoteCandidate, local Candidate, remote net.Addr, msg *stun.Message,
) bool {
a.log.Tracef("Inbound STUN (Error) from %s to %s", remote, local)

// Verify message integrity
if err := stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil {
a.log.Warnf("Discard error response with broken integrity from (%s), %v", remote, err)

return false
}

// Extract error code from the message
var errCode stun.ErrorCodeAttribute
if err := errCode.GetFrom(msg); err != nil {
a.log.Warnf("Failed to get error code from error response: %v", err)

return false
}

// Handle 487 Role Conflict error as per RFC 8445 section 7.2.5.1
if errCode.Code == stun.CodeRoleConflict {
a.log.Warnf("Received role conflict error (487) from %s, switching role", remote)

// Find the corresponding pending binding request
found, bindingReq, _ := a.handleInboundBindingSuccess(msg.TransactionID)
if !found {
a.log.Debugf("Received role conflict error for unknown transaction ID, ignoring")

return false
}

// Switch our role and regenerate tiebreaker
oldRole := a.role()
a.isControlling.Store(!a.isControlling.Load())
a.tieBreaker = globalMathRandomGenerator.Uint64()
a.setSelector()

a.log.Debugf("Switched ICE role %s → %s after receiving 487 error", oldRole, a.role())

// Re-enqueue the candidate pair in the triggered-check queue per RFC 8445 §7.2.5.1.
if remoteCandidate == nil {
a.log.Warnf("Cannot re-enqueue candidate pair, remote candidate not found for %s", bindingReq.destination)
} else if pair := a.findPair(local, remoteCandidate); pair != nil {
pair.state = CandidatePairStateWaiting
pair.bindingRequestCount = 0
} else {
a.log.Warnf("Cannot re-enqueue candidate pair for %s, not found in checklist", bindingReq.destination)
}

return true
}

// Log other error codes but don't handle them
a.log.Debugf("Received STUN error response %d (%s) from %s", errCode.Code, errCode.Reason, remote)

return false
}

// validateNonSTUNTraffic processes non STUN traffic from a remote candidate,
// and returns true if it is an actual remote candidate.
func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) {
Expand Down
89 changes: 89 additions & 0 deletions agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2268,6 +2268,95 @@ func TestRoleConflict(t *testing.T) {
})
}

func TestRoleConflictErrorResponse(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()

cfg := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
InterfaceFilter: problematicNetworkInterfaces,
}

aAgent, err := NewAgent(cfg)
require.NoError(t, err)

bAgent, err := NewAgent(cfg)
require.NoError(t, err)

// Set up both agents and exchange candidates
gatherAndExchangeCandidates(t, aAgent, bAgent)

// Start connectivity checks - aAgent will be controlling (Dial)
ufragB, pwdB, err := bAgent.GetLocalUserCredentials()
require.NoError(t, err)

connDone := make(chan struct{})
go func() {
defer close(connDone)
_, dialErr := aAgent.Dial(context.TODO(), ufragB, pwdB)
if dialErr != nil {
t.Logf("Dial error (expected): %v", dialErr)
}
}()

// Wait a bit for binding requests to be sent
time.Sleep(200 * time.Millisecond)

// Store the original role
originalRole := aAgent.isControlling.Load()
require.True(t, originalRole, "aAgent should be controlling after Dial")

// Access agent internals to get pending request and create error response
var txID [stun.TransactionIDSize]byte
var destAddr net.Addr
var localCand Candidate

err = aAgent.loop.Run(context.Background(), func(_ context.Context) {
if len(aAgent.pendingBindingRequests) > 0 {
txID = aAgent.pendingBindingRequests[0].transactionID
destAddr = aAgent.pendingBindingRequests[0].destination
}
})
require.NoError(t, err)
require.NotEqual(t, [stun.TransactionIDSize]byte{}, txID, "Should have a transaction ID")

localCands, err := aAgent.GetLocalCandidates()
require.NoError(t, err)
require.NotEmpty(t, localCands)
localCand = localCands[0]

// Create and send 487 error response
errorMsg, err := stun.Build(
stun.NewTransactionIDSetter(txID),
stun.BindingError,
stun.ErrorCodeAttribute{
Code: stun.CodeRoleConflict,
Reason: []byte("Role Conflict"),
},
stun.NewShortTermIntegrity(aAgent.remotePwd),
stun.Fingerprint,
)
require.NoError(t, err)

// Inject the error response through the agent's event loop to avoid data races
err = aAgent.loop.Run(context.Background(), func(_ context.Context) {
// nolint: contextcheck
aAgent.handleInbound(errorMsg, localCand, destAddr)
})
require.NoError(t, err)

// Verify role switched
newRole := aAgent.isControlling.Load()
require.False(t, newRole, "Agent should have switched to controlled after receiving 487 error")

// Clean up
require.NoError(t, aAgent.Close())
require.NoError(t, bAgent.Close())

<-connDone
}

func TestDefaultCandidateTypes(t *testing.T) {
expected := []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive, CandidateTypeRelay}

Expand Down
Loading