From 6ba75e10fc8734ef6ea37e0a3e1f30c3c10af3b2 Mon Sep 17 00:00:00 2001 From: sirzooro Date: Sun, 1 Mar 2026 13:35:01 +0100 Subject: [PATCH] Add RFC 7675 Consent Freshness checks Implemented RFC 7675 Consent Freshness checks using existing KeepAlives. Handle inbound 403 Forbidden STUN Binding Error response to immediately revoke consent and transition connection state to Failed. Added application hook to return Binding Error responses for inbound Binding Requests to allow sending 403 Forbidden STUN Binding Error to revoke consent. Hook is generic and allows sending any error with custom attributes. --- agent.go | 136 ++++++++++++++++++++++++++++++-- agent_config.go | 33 ++++++++ agent_config_test.go | 13 ++++ agent_options.go | 37 +++++++++ agent_options_test.go | 62 +++++++++++++++ agent_test.go | 175 ++++++++++++++++++++++++++++++++---------- gather.go | 16 ++-- gather_test.go | 6 +- selection.go | 70 ++++++++++++++--- selection_test.go | 102 ++++++++++++++++++++++-- 10 files changed, 576 insertions(+), 74 deletions(-) diff --git a/agent.go b/agent.go index 68104376..67691f10 100644 --- a/agent.go +++ b/agent.go @@ -175,6 +175,17 @@ type Agent struct { lastRenominationTime time.Time turnClientFactory func(*turn.ClientConfig) (turnClient, error) + + // How long consent remains valid without an authenticated non-error STUN Binding response. + consentFreshnessTimeout time.Duration + + // Timestamp of the last consent refresh for the selected candidate pair. + lastConsentAt time.Time + + // Callback that allows user to send an error response for inbound STUN Binding Requests before success is emitted. + // Returning nil continues normal success handling. + userBindingReqErrorRespHandler func( + m *stun.Message, local, remote Candidate, pair *CandidatePair) *BindingRequestErrorResponse } // NewAgent creates a new Agent. @@ -378,6 +389,8 @@ func createAgentBase(config *AgentConfig) (*Agent, error) { automaticRenomination: false, renominationInterval: 3 * time.Second, // Default matching libwebrtc turnClientFactory: defaultTurnClient, + userBindingReqErrorRespHandler: config.BindingRequestErrorResponseHandler, + consentFreshnessTimeout: defaultConsentFreshnessTimeout, } config.initWithDefaults(agent) @@ -750,6 +763,7 @@ func (a *Agent) setSelectedPair(pair *CandidatePair) { if pair == nil { var nilPair *CandidatePair a.selectedPair.Store(nilPair) + a.lastConsentAt = time.Time{} a.log.Tracef("Unset selected candidate pair") return @@ -757,6 +771,7 @@ func (a *Agent) setSelectedPair(pair *CandidatePair) { pair.nominated = true a.selectedPair.Store(pair) + a.lastConsentAt = time.Now() a.log.Tracef("Set selected candidate pair: %s", pair) // Signal connected: notify any Connect() calls waiting on onConnected @@ -769,6 +784,15 @@ func (a *Agent) setSelectedPair(pair *CandidatePair) { a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(pair) } +// consentExpired checks if the consent freshness has expired for the selected candidate pair. +func (a *Agent) consentExpired(now time.Time) bool { + if a.consentFreshnessTimeout <= 0 || a.lastConsentAt.IsZero() { + return false + } + + return now.Sub(a.lastConsentAt) > a.consentFreshnessTimeout +} + func (a *Agent) pingAllCandidates() { a.log.Trace("Pinging all candidates") @@ -885,6 +909,14 @@ func (a *Agent) validateSelectedPair() bool { return false } + now := time.Now() + if a.consentExpired(now) { + a.log.Warnf("Consent expired for selected pair after %v without valid response", a.consentFreshnessTimeout) + a.updateConnectionState(ConnectionStateFailed) + + return false + } + disconnectedTime := time.Since(selectedPair.Remote.LastReceived()) // Only allow transitions to failed if a.failedTimeout is non-zero @@ -1412,6 +1444,28 @@ func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) { } } +func (a *Agent) sendBindingError( + m *stun.Message, local, remote Candidate, bindingErrorResponse BindingRequestErrorResponse, +) { + setters := make([]stun.Setter, 0, len(bindingErrorResponse.ExtraAttributes)+5) + setters = append(setters, m, stun.BindingError) + setters = append(setters, bindingErrorResponse.ErrorCodeAttribute) + setters = append(setters, bindingErrorResponse.ExtraAttributes...) + setters = append(setters, + stun.NewShortTermIntegrity(a.localPwd), + stun.Fingerprint, + ) + + if out, err := stun.Build(setters...); err != nil { + a.log.Warnf("Failed to send binding error response from: %s to: %s error: %s", local, remote, err) + } else { + if pair := a.findPair(local, remote); pair != nil { + pair.UpdateResponseSent() + } + a.sendSTUN(out, local, remote) + } +} + // Removes pending binding requests that are over maxBindingRequestTimeout old // // Let HTO be the transaction timeout, which SHOULD be 2*RTT if @@ -1433,12 +1487,19 @@ func (a *Agent) invalidatePendingBindingRequests(filterTime time.Time) { } } -// Assert that the passed TransactionID is in our pendingBindingRequests and returns the destination -// If the bindingRequest was valid remove it from our pending cache. -func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bool, *bindingRequest, time.Duration) { +// consumePendingBindingRequest validates that the passed TransactionID and remote address match a pending binding +// request. If a match is found, the binding request is removed from the pending cache and returned along with how +// long ago it was sent. If no match is found, nil is returned. +func (a *Agent) consumePendingBindingRequest( + id [stun.TransactionIDSize]byte, remoteAddr net.Addr, +) (bool, *bindingRequest, time.Duration) { a.invalidatePendingBindingRequests(time.Now()) for i := range a.pendingBindingRequests { if a.pendingBindingRequests[i].transactionID == id { + if !addrEqual(a.pendingBindingRequests[i].destination, remoteAddr) { + return false, nil, 0 + } + validBindingRequest := a.pendingBindingRequests[i] a.pendingBindingRequests = append(a.pendingBindingRequests[:i], a.pendingBindingRequests[i+1:]...) @@ -1481,7 +1542,7 @@ func (a *Agent) handleRoleConflict(msg *stun.Message, local, remote Candidate, r } // handleInbound processes STUN traffic from a remote candidate. -func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { +func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { // nolint:cyclop if msg == nil || local == nil { return } @@ -1504,6 +1565,10 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add if remoteCandidate, ok = a.handleInboundRequest(remoteCandidate, local, remote, msg); !ok { return } + case stun.ClassErrorResponse: + if !a.handleInboundErrorResponse(remoteCandidate, local, remote, msg) { + return + } default: } @@ -1516,7 +1581,8 @@ func canHandleInbound(msg *stun.Message) bool { return msg.Type.Method == stun.MethodBinding && (msg.Type.Class == stun.ClassSuccessResponse || msg.Type.Class == stun.ClassRequest || - msg.Type.Class == stun.ClassIndication) + msg.Type.Class == stun.ClassIndication || + msg.Type.Class == stun.ClassErrorResponse) } func (a *Agent) handleInboundResponse( @@ -1534,9 +1600,17 @@ func (a *Agent) handleInboundResponse( return false } - a.getSelector().HandleSuccessResponse(msg, local, remoteCandidate, remote) + handled := a.getSelector().handleSuccessResponse(msg, local, remoteCandidate, remote) - return true + if handled { + if selectedPair := a.getSelectedPair(); selectedPair != nil && + selectedPair.Local.Equal(local) && selectedPair.Remote.Equal(remoteCandidate) { + // Note consent freshness + a.lastConsentAt = time.Now() + } + } + + return handled } func (a *Agent) handleInboundRequest( @@ -1602,6 +1676,54 @@ func (a *Agent) handleInboundRequest( return remoteCandidate, true } +func (a *Agent) handleInboundErrorResponse( + remoteCandidate, _ Candidate, remoteAddr net.Addr, msg *stun.Message, +) bool { // nolint:unparam + if err := stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil { + a.log.Warnf("Discard error response with broken integrity from (%s), %v", remoteAddr, err) + + return false + } + + if remoteCandidate == nil { + a.log.Warnf("Discard error response from (%s), no such remote", remoteAddr) + + return false + } + + ok, pendingRequest, _ := a.consumePendingBindingRequest(msg.TransactionID, remoteAddr) + if !ok { + a.log.Warnf("Discard error response from (%s), unknown TransactionID 0x%x or address mismatch", + remoteAddr, msg.TransactionID) + + return false + } + _ = pendingRequest + + errorCode := stun.ErrorCodeAttribute{} + if err := errorCode.GetFrom(msg); err != nil { + a.log.Debugf("Discard error response from (%s), no valid ERROR-CODE attribute: %v", remoteAddr, err) + + return false + } + + // Return true after successfully validating and accounting for an error response that doesn't immediately fail + // the agent, and false for cases where the error response should be discarded or indicates a fatal condition that + // should fail the agent. + switch errorCode.Code { + case stun.CodeForbidden: + a.log.Warnf("Received authenticated STUN 403 (Forbidden); revoking consent for %s", remoteAddr) + a.updateConnectionState(ConnectionStateFailed) + + return false + + default: + a.log.Debugf("Received authenticated STUN error response %d from %s", errorCode.Code, remoteAddr) + + return false + } +} + // validateNonSTUNTraffic processes non STUN traffic from a remote candidate, // and returns true if it is an actual remote candidate. func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) { diff --git a/agent_config.go b/agent_config.go index 0539207c..326580b1 100644 --- a/agent_config.go +++ b/agent_config.go @@ -56,6 +56,10 @@ const ( // maxBindingRequestTimeout is the wait time before binding requests can be deleted. maxBindingRequestTimeout = 4000 * time.Millisecond + + // defaultConsentFreshnessTimeout is the maximum time consent can remain valid + // without an authenticated, non-error STUN Binding response. + defaultConsentFreshnessTimeout = 30 * time.Second ) func defaultCandidateTypes() []CandidateType { @@ -226,6 +230,29 @@ type AgentConfig struct { // switched to that irrespective of relative priority between current selected pair // and priority of the pair being switched to. EnableUseCandidateCheckPriority bool + + // ConsentFreshnessTimeout determines how long consent remains valid without an authenticated, + // non-error STUN Binding response. + // When this is nil, it defaults to 30 seconds. A timeout of 0 disables consent freshness expiry. + ConsentFreshnessTimeout *time.Duration + + // BindingRequestErrorResponseHandler allows applications to send an error response for individual + // inbound STUN Binding Requests before a success response is emitted. + // It can be used to implement consent revocation by returning a Binding Error 403 (Forbidden) + // response when the agent receives a binding request for an existing candidate pair. + // Returning nil continues normal handling and sends a success response. + // Returning a non-nil BindingRequestErrorResponse causes the agent to send an authenticated + // STUN Binding Error response with the provided code/reason and optional extra attributes. + // Note: pair is nil when the binding request will create a new pair. + BindingRequestErrorResponseHandler func(m *stun.Message, local, remote Candidate, + pair *CandidatePair) *BindingRequestErrorResponse +} + +// BindingRequestErrorResponse defines the STUN Binding Error response emitted +// for an inbound STUN Binding Request. +type BindingRequestErrorResponse struct { + ErrorCodeAttribute stun.ErrorCodeAttribute + ExtraAttributes []stun.Setter } // initWithDefaults populates an agent and falls back to defaults if fields are unset. @@ -290,6 +317,12 @@ func (config *AgentConfig) initWithDefaults(agent *Agent) { //nolint:cyclop agent.keepaliveInterval = *config.KeepaliveInterval } + if config.ConsentFreshnessTimeout == nil { + agent.consentFreshnessTimeout = defaultConsentFreshnessTimeout + } else { + agent.consentFreshnessTimeout = *config.ConsentFreshnessTimeout + } + if config.CheckInterval == nil { agent.checkInterval = defaultCheckInterval } else { diff --git a/agent_config_test.go b/agent_config_test.go index 76450bb4..e950f13f 100644 --- a/agent_config_test.go +++ b/agent_config_test.go @@ -23,6 +23,7 @@ func TestAgentConfig_initWithDefaults(t *testing.T) { func(t *testing.T, result *Agent) { t.Helper() assert.Equal(t, result.relayAcceptanceMinWait, defaultRelayAcceptanceMinWait) + assert.Equal(t, result.consentFreshnessTimeout, defaultConsentFreshnessTimeout) }, }, { @@ -57,6 +58,18 @@ func TestAgentConfig_initWithDefaults(t *testing.T) { assert.Equal(t, result.relayAcceptanceMinWait, relayAcceptanceMinWait) }, }, + { + "consent freshness timeout can be disabled", + func() *AgentConfig { + zero := time.Duration(0) + + return &AgentConfig{ConsentFreshnessTimeout: &zero} + }(), + func(t *testing.T, result *Agent) { + t.Helper() + assert.Equal(t, time.Duration(0), result.consentFreshnessTimeout) + }, + }, } for _, test := range tests { diff --git a/agent_options.go b/agent_options.go index ac26c739..32adee78 100644 --- a/agent_options.go +++ b/agent_options.go @@ -960,3 +960,40 @@ func WithLoggerFactory(loggerFactory logging.LoggerFactory) AgentOption { return nil } } + +// WithConsentFreshnessTimeout sets how long consent remains valid without an authenticated, non-error +// STUN Binding response. +// A timeout of 0 disables consent freshness expiry. +func WithConsentFreshnessTimeout(timeout time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.consentFreshnessTimeout = timeout + + return nil + } +} + +// WithBindingRequestErrorResponseHandler sets a handler to return an optional STUN Binding Error response +// for inbound STUN Binding Requests. +// It can be used to implement consent revocation by returning a Binding Error 403 (Forbidden) response +// when the agent receives a binding request for an existing candidate pair. +// Returning nil continues normal success handling. +// Returning a non-nil response sends the configured ERROR-CODE and any additional attributes included +// in ExtraAttributes. +// Note: pair is nil when the binding request will create a new pair. +func WithBindingRequestErrorResponseHandler( + handler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) *BindingRequestErrorResponse, +) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.userBindingReqErrorRespHandler = handler + + return nil + } +} diff --git a/agent_options_test.go b/agent_options_test.go index 26b5e3d6..51d0f6a6 100644 --- a/agent_options_test.go +++ b/agent_options_test.go @@ -160,6 +160,7 @@ func TestWithTimeoutOptions(t *testing.T) { WithFailedTimeout(20*time.Second), WithKeepaliveInterval(3*time.Second), WithCheckInterval(150*time.Millisecond), + WithConsentFreshnessTimeout(40*time.Second), ) require.NoError(t, err) defer agent.Close() //nolint:errcheck @@ -168,6 +169,7 @@ func TestWithTimeoutOptions(t *testing.T) { assert.Equal(t, 20*time.Second, agent.failedTimeout) assert.Equal(t, 3*time.Second, agent.keepaliveInterval) assert.Equal(t, 150*time.Millisecond, agent.checkInterval) + assert.Equal(t, 40*time.Second, agent.consentFreshnessTimeout) } func TestWithAcceptanceWaitOptions(t *testing.T) { @@ -1681,3 +1683,63 @@ func (n *stubNet) CreateDialer(dialer *net.Dialer) transport.Dialer { func (n *stubNet) CreateListenConfig(listenerConfig *net.ListenConfig) transport.ListenConfig { return nil } + +func TestWithBindingRequestErrorResponseHandler(t *testing.T) { + t.Run("sets binding error response handler", func(t *testing.T) { + handlerCalled := false + handler := func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) *BindingRequestErrorResponse { + handlerCalled = true + + // Consent revocation should be sent by returning a Binding Error 403 (Forbidden). + return &BindingRequestErrorResponse{ + ErrorCodeAttribute: stun.ErrorCodeAttribute{Code: stun.CodeForbidden, Reason: []byte("Forbidden")}, + } + } + + agent, err := NewAgentWithOptions(WithBindingRequestErrorResponseHandler(handler)) + assert.NoError(t, err) + defer agent.Close() //nolint:errcheck + + assert.NotNil(t, agent.userBindingReqErrorRespHandler) + + if agent.userBindingReqErrorRespHandler != nil { + agent.userBindingReqErrorRespHandler(nil, nil, nil, nil) + assert.True(t, handlerCalled) + } + }) + + t.Run("default is nil", func(t *testing.T) { + agent, err := NewAgentWithOptions() + assert.NoError(t, err) + defer agent.Close() //nolint:errcheck + + assert.Nil(t, agent.userBindingReqErrorRespHandler) + }) + + t.Run("works with config", func(t *testing.T) { + handlerCalled := false + handler := func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) *BindingRequestErrorResponse { + handlerCalled = true + + return &BindingRequestErrorResponse{ + ErrorCodeAttribute: stun.ErrorCodeAttribute{Code: stun.CodeForbidden, Reason: []byte("Forbidden")}, + } + } + + config := &AgentConfig{ + NetworkTypes: []NetworkType{NetworkTypeUDP4}, + BindingRequestErrorResponseHandler: handler, + } + + agent, err := NewAgent(config) + assert.NoError(t, err) + defer agent.Close() //nolint:errcheck + + assert.NotNil(t, agent.userBindingReqErrorRespHandler) + + if agent.userBindingReqErrorRespHandler != nil { + agent.userBindingReqErrorRespHandler(nil, nil, nil, nil) + assert.True(t, handlerCalled) + } + }) +} diff --git a/agent_test.go b/agent_test.go index 3adcf6b9..3a9b09e9 100644 --- a/agent_test.go +++ b/agent_test.go @@ -48,6 +48,12 @@ func (r *recordingSelector) HandleSuccessResponse(*stun.Message, Candidate, Cand r.handledSuccess = true } +func (r *recordingSelector) handleSuccessResponse(*stun.Message, Candidate, Candidate, net.Addr) bool { + r.handledSuccess = true + + return true +} + func (r *recordingSelector) HandleBindingRequest(*stun.Message, Candidate, Candidate) { r.handledBindingRequest = true } @@ -634,6 +640,57 @@ func TestHandleInboundAdditionalCases(t *testing.T) { require.Equal(t, CandidatePairStateSucceeded, pair.state) require.Empty(t, agent.pendingBindingRequests) }) + + t.Run("Authenticated 403 error response revokes consent", func(t *testing.T) { + agent := newTestAgent(t) + defer func() { + require.NoError(t, agent.Close()) + }() + + local := newHostLocal(t) + remoteConfig := &CandidateHostConfig{ + Network: "udp", + Address: "192.0.2.3", + Port: 6666, + Component: 1, + } + remoteCandidate, err := NewCandidateHost(remoteConfig) + require.NoError(t, err) + remoteAddr := &net.UDPAddr{IP: net.ParseIP(remoteCandidate.Address()), Port: remoteCandidate.Port()} + transactionID := stun.NewTransactionID() + remotePwd := "remotekey403" + + require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) { + agent.selector = &controllingSelector{agent: agent, log: agent.log} + agent.selector.Start() + agent.localCandidates[local.NetworkType()] = append(agent.localCandidates[local.NetworkType()], local) + agent.addRemoteCandidate(remoteCandidate) //nolint:contextcheck + pair := agent.findPair(local, remoteCandidate) + require.NotNil(t, pair) + agent.setSelectedPair(pair) + agent.pendingBindingRequests = []bindingRequest{ // request awaiting response + { + timestamp: time.Now(), + transactionID: transactionID, + destination: remoteAddr, + }, + } + agent.remotePwd = remotePwd + })) + + msg, err := stun.Build(stun.BindingError, stun.NewTransactionIDSetter(transactionID), + stun.ErrorCodeAttribute{Code: stun.CodeForbidden, Reason: []byte("Forbidden")}, + stun.NewShortTermIntegrity(remotePwd), + stun.Fingerprint, + ) + require.NoError(t, err) + + agent.handleInbound(msg, local, remoteAddr) + + require.Equal(t, ConnectionStateFailed, agent.connectionState) + require.Nil(t, agent.getSelectedPair()) + require.Empty(t, agent.pendingBindingRequests) + }) } func TestInvalidAgentStarts(t *testing.T) { @@ -1751,9 +1808,10 @@ func TestLiteLifecycle(t *testing.T) { func TestValidateSelectedPairTransitions(t *testing.T) { agent := &Agent{ - disconnectedTimeout: time.Second, - failedTimeout: time.Second, - connectionState: ConnectionStateConnected, + disconnectedTimeout: time.Second, + failedTimeout: time.Second, + consentFreshnessTimeout: defaultConsentFreshnessTimeout, + connectionState: ConnectionStateConnected, connectionStateNotifier: &handlerNotifier{ connectionStateFunc: func(ConnectionState) {}, done: make(chan struct{}), @@ -1788,6 +1846,43 @@ func TestValidateSelectedPairTransitions(t *testing.T) { require.Equal(t, ConnectionStateFailed, agent.connectionState) } +func TestValidateSelectedPairConsentExpired(t *testing.T) { + agent := &Agent{ + disconnectedTimeout: time.Second, + failedTimeout: time.Second, + consentFreshnessTimeout: 30 * time.Second, + connectionState: ConnectionStateConnected, + connectionStateNotifier: &handlerNotifier{ + connectionStateFunc: func(ConnectionState) {}, + done: make(chan struct{}), + }, + log: logging.NewDefaultLoggerFactory().NewLogger("test"), + } + + local, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "1.1.1.1", + Port: 1000, + Component: ComponentRTP, + }) + require.NoError(t, err) + + remote, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "2.2.2.2", + Port: 2000, + Component: ComponentRTP, + }) + require.NoError(t, err) + remote.setLastReceived(time.Now()) + + agent.selectedPair.Store(newCandidatePair(local, remote, true)) + agent.lastConsentAt = time.Now().Add(-31 * time.Second) + + require.False(t, agent.validateSelectedPair()) + require.Equal(t, ConnectionStateFailed, agent.connectionState) +} + func TestNilCandidate(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) @@ -2550,42 +2645,44 @@ func TestAgentUpdateOptions(t *testing.T) { // All options except WithUrls should be rejected on a running agent. // When adding new options, add them here if they are not runtime-updatable. nonUpdatableOptions := map[string]AgentOption{ - "WithAddressRewriteRules": WithAddressRewriteRules(AddressRewriteRule{External: []string{"1.2.3.4"}}), - "WithICELite": WithICELite(true), - "WithPortRange": WithPortRange(5000, 6000), - "WithDisconnectedTimeout": WithDisconnectedTimeout(time.Second), - "WithFailedTimeout": WithFailedTimeout(time.Second), - "WithKeepaliveInterval": WithKeepaliveInterval(time.Second), - "WithHostAcceptanceMinWait": WithHostAcceptanceMinWait(time.Second), - "WithSrflxAcceptanceMinWait": WithSrflxAcceptanceMinWait(time.Second), - "WithPrflxAcceptanceMinWait": WithPrflxAcceptanceMinWait(time.Second), - "WithRelayAcceptanceMinWait": WithRelayAcceptanceMinWait(time.Second), - "WithSTUNGatherTimeout": WithSTUNGatherTimeout(time.Second), - "WithIPFilter": WithIPFilter(func(net.IP) bool { return true }), - "WithNet": WithNet(nil), - "WithMulticastDNSMode": WithMulticastDNSMode(MulticastDNSModeDisabled), - "WithMulticastDNSHostName": WithMulticastDNSHostName("test.local"), - "WithLocalCredentials": WithLocalCredentials("", ""), - "WithTCPMux": WithTCPMux(nil), - "WithUDPMux": WithUDPMux(nil), - "WithUDPMuxSrflx": WithUDPMuxSrflx(nil), - "WithProxyDialer": WithProxyDialer(nil), - "WithMaxBindingRequests": WithMaxBindingRequests(10), - "WithCheckInterval": WithCheckInterval(time.Second), - "WithRenomination": WithRenomination(DefaultNominationValueGenerator()), - "WithNominationAttribute": WithNominationAttribute(0x0030), - "WithIncludeLoopback": WithIncludeLoopback(), - "WithTCPPriorityOffset": WithTCPPriorityOffset(10), - "WithDisableActiveTCP": WithDisableActiveTCP(), - "WithBindingRequestHandler": WithBindingRequestHandler(nil), - "WithEnableUseCandidateCheckPriority": WithEnableUseCandidateCheckPriority(), - "WithContinualGatheringPolicy": WithContinualGatheringPolicy(GatherOnce), - "WithNetworkMonitorInterval": WithNetworkMonitorInterval(time.Second), - "WithNetworkTypes": WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), - "WithCandidateTypes": WithCandidateTypes([]CandidateType{CandidateTypeHost}), - "WithAutomaticRenomination": WithAutomaticRenomination(time.Second), - "WithInterfaceFilter": WithInterfaceFilter(func(string) bool { return true }), - "WithLoggerFactory": WithLoggerFactory(nil), + "WithAddressRewriteRules": WithAddressRewriteRules(AddressRewriteRule{External: []string{"1.2.3.4"}}), + "WithICELite": WithICELite(true), + "WithPortRange": WithPortRange(5000, 6000), + "WithDisconnectedTimeout": WithDisconnectedTimeout(time.Second), + "WithFailedTimeout": WithFailedTimeout(time.Second), + "WithKeepaliveInterval": WithKeepaliveInterval(time.Second), + "WithConsentFreshnessTimeout": WithConsentFreshnessTimeout(time.Second), + "WithHostAcceptanceMinWait": WithHostAcceptanceMinWait(time.Second), + "WithSrflxAcceptanceMinWait": WithSrflxAcceptanceMinWait(time.Second), + "WithPrflxAcceptanceMinWait": WithPrflxAcceptanceMinWait(time.Second), + "WithRelayAcceptanceMinWait": WithRelayAcceptanceMinWait(time.Second), + "WithSTUNGatherTimeout": WithSTUNGatherTimeout(time.Second), + "WithIPFilter": WithIPFilter(func(net.IP) bool { return true }), + "WithNet": WithNet(nil), + "WithMulticastDNSMode": WithMulticastDNSMode(MulticastDNSModeDisabled), + "WithMulticastDNSHostName": WithMulticastDNSHostName("test.local"), + "WithLocalCredentials": WithLocalCredentials("", ""), + "WithTCPMux": WithTCPMux(nil), + "WithUDPMux": WithUDPMux(nil), + "WithUDPMuxSrflx": WithUDPMuxSrflx(nil), + "WithProxyDialer": WithProxyDialer(nil), + "WithMaxBindingRequests": WithMaxBindingRequests(10), + "WithCheckInterval": WithCheckInterval(time.Second), + "WithRenomination": WithRenomination(DefaultNominationValueGenerator()), + "WithNominationAttribute": WithNominationAttribute(0x0030), + "WithIncludeLoopback": WithIncludeLoopback(), + "WithTCPPriorityOffset": WithTCPPriorityOffset(10), + "WithDisableActiveTCP": WithDisableActiveTCP(), + "WithBindingRequestHandler": WithBindingRequestHandler(nil), + "WithBindingRequestErrorResponseHandler": WithBindingRequestErrorResponseHandler(nil), + "WithEnableUseCandidateCheckPriority": WithEnableUseCandidateCheckPriority(), + "WithContinualGatheringPolicy": WithContinualGatheringPolicy(GatherOnce), + "WithNetworkMonitorInterval": WithNetworkMonitorInterval(time.Second), + "WithNetworkTypes": WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), + "WithCandidateTypes": WithCandidateTypes([]CandidateType{CandidateTypeHost}), + "WithAutomaticRenomination": WithAutomaticRenomination(time.Second), + "WithInterfaceFilter": WithInterfaceFilter(func(string) bool { return true }), + "WithLoggerFactory": WithLoggerFactory(nil), } for name, opt := range nonUpdatableOptions { diff --git a/gather.go b/gather.go index 72f73abd..94f926a0 100644 --- a/gather.go +++ b/gather.go @@ -920,19 +920,21 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { return } - conn, connectErr := dtls.Client(&fakenet.PacketConn{Conn: udpConn}, udpConn.RemoteAddr(), &dtls.Config{ - ServerName: url.Host, - InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec - LoggerFactory: a.loggerFactory, - }) + conn, connectErr := dtls.ClientWithOptions( + &fakenet.PacketConn{Conn: udpConn}, + udpConn.RemoteAddr(), + dtls.WithServerName(url.Host), + dtls.WithInsecureSkipVerify(a.insecureSkipVerify), //nolint:gosec + dtls.WithLoggerFactory(a.loggerFactory), + ) if connectErr != nil { - a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) + a.log.Warnf("Failed to create DTLS client: %s: %v", turnServerAddr, connectErr) return } if connectErr = conn.HandshakeContext(ctx); connectErr != nil { - a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) + a.log.Warnf("Failed to create DTLS client: %s: %v", turnServerAddr, connectErr) return } diff --git a/gather_test.go b/gather_test.go index 991de91d..3aff1224 100644 --- a/gather_test.go +++ b/gather_test.go @@ -452,12 +452,10 @@ func TestTURNConcurrency(t *testing.T) { require.NoError(t, genErr) serverPort := randomPort(t) - serverListener, err := dtls.Listen( + serverListener, err := dtls.ListenWithOptions( "udp", &net.UDPAddr{IP: net.ParseIP(localhostIPStr), Port: serverPort}, - &dtls.Config{ - Certificates: []tls.Certificate{certificate}, - }, + dtls.WithCertificates(certificate), ) require.NoError(t, err) diff --git a/selection.go b/selection.go index efed3353..f39e3ef1 100644 --- a/selection.go +++ b/selection.go @@ -15,8 +15,12 @@ type pairCandidateSelector interface { Start() ContactCandidates() PingCandidate(local, remote Candidate) + // Deprecated: use handleSuccessResponse instead, which returns a boolean indicating if the response was + // successfully handled HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) HandleBindingRequest(m *stun.Message, local, remote Candidate) + + handleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) bool } type controllingSelector struct { @@ -103,17 +107,32 @@ func (s *controllingSelector) nominatePair(pair *CandidatePair) { } func (s *controllingSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop - s.agent.sendBindingSuccess(message, local, remote) - pair := s.agent.findPair(local, remote) + if pair != nil { + pair.UpdateRequestReceived() + } + + if s.agent.userBindingReqErrorRespHandler != nil { + bindingErrorResponse := s.agent.userBindingReqErrorRespHandler(message, local, remote, pair) + if bindingErrorResponse != nil { + s.agent.sendBindingError(message, local, remote, *bindingErrorResponse) + + return + } + } + newPairAdded := false if pair == nil { pair = s.agent.addPair(local, remote) + newPairAdded = true pair.UpdateRequestReceived() + } + + s.agent.sendBindingSuccess(message, local, remote) + if newPairAdded { return } - pair.UpdateRequestReceived() if pair.state == CandidatePairStateSucceeded && s.nominatedPair == nil && s.agent.getSelectedPair() == nil { bestPair := s.agent.getBestAvailableCandidatePair() @@ -138,11 +157,17 @@ func (s *controllingSelector) HandleBindingRequest(message *stun.Message, local, } func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) { - ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) + s.handleSuccessResponse(m, local, remote, remoteAddr) +} + +func (s *controllingSelector) handleSuccessResponse( + m *stun.Message, local, remote Candidate, remoteAddr net.Addr, +) bool { + ok, pendingRequest, rtt := s.agent.consumePendingBindingRequest(m.TransactionID, remoteAddr) if !ok { s.log.Warnf("Discard success response from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) - return + return false } transactionAddr := pendingRequest.destination @@ -156,7 +181,7 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo remote, ) - return + return false } s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local) @@ -166,7 +191,7 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo // This shouldn't happen s.log.Error("Success response from invalid candidate pair") - return + return false } pair.state = CandidatePairStateSucceeded @@ -188,6 +213,8 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo } pair.UpdateRoundTripTime(rtt) + + return true } func (s *controllingSelector) PingCandidate(local, remote Candidate) { @@ -355,6 +382,10 @@ func (s *controlledSelector) PingCandidate(local, remote Candidate) { } func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) { + s.handleSuccessResponse(m, local, remote, remoteAddr) +} + +func (s *controlledSelector) handleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) bool { //nolint:godox // TODO according to the standard we should specifically answer a failed nomination: // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 @@ -363,11 +394,11 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot // request with an appropriate error code response (e.g., 400) // [RFC5389]. - ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) + ok, pendingRequest, rtt := s.agent.consumePendingBindingRequest(m.TransactionID, remoteAddr) if !ok { s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) - return + return false } transactionAddr := pendingRequest.destination @@ -381,7 +412,7 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot remote, ) - return + return false } s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local) @@ -391,7 +422,7 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot // This shouldn't happen s.log.Error("Success response from invalid candidate pair") - return + return false } pair.state = CandidatePairStateSucceeded @@ -407,14 +438,29 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot } pair.UpdateRoundTripTime(rtt) + + return true } func (s *controlledSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop pair := s.agent.findPair(local, remote) + if pair != nil { + pair.UpdateRequestReceived() + } + + if s.agent.userBindingReqErrorRespHandler != nil { + bindingErrorResponse := s.agent.userBindingReqErrorRespHandler(message, local, remote, pair) + if bindingErrorResponse != nil { + s.agent.sendBindingError(message, local, remote, *bindingErrorResponse) + + return + } + } + if pair == nil { pair = s.agent.addPair(local, remote) + pair.UpdateRequestReceived() } - pair.UpdateRequestReceived() if message.Contains(stun.AttrUseCandidate) || message.Contains(s.agent.nominationAttribute) { //nolint:nestif // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 diff --git a/selection_test.go b/selection_test.go index bfc294f2..2dce6458 100644 --- a/selection_test.go +++ b/selection_test.go @@ -29,8 +29,17 @@ const ( selectionTestPassword = "pwd" selectionTestRemoteUfrag = "remote" selectionTestLocalUfrag = "local" + selectionTestCustomAttr = stun.AttrType(0xC001) ) +type selectionTestCustomErrorAttributeSetter struct{} + +func (selectionTestCustomErrorAttributeSetter) AddTo(m *stun.Message) error { + m.Add(selectionTestCustomAttr, []byte("custom-attr")) + + return nil +} + func sendUntilDone(t *testing.T, writingConn, readingConn net.Conn, maxAttempts int) bool { t.Helper() @@ -168,6 +177,86 @@ func TestBindingRequestHandler(t *testing.T) { closePipe(t, controllingConn, controlledConn) } +func TestBindingRequestErrorResponseHandler(t *testing.T) { + testCases := []struct { + name string + buildSelector func(*Agent) pairCandidateSelector + }{ + { + name: "controlledSelector", + buildSelector: func(agent *Agent) pairCandidateSelector { + return &controlledSelector{agent: agent, log: agent.log} + }, + }, + { + name: "controllingSelector", + buildSelector: func(agent *Agent) pairCandidateSelector { + return &controllingSelector{agent: agent, log: agent.log} + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + agent, err := NewAgent(&AgentConfig{ + BindingRequestErrorResponseHandler: func( + _ *stun.Message, _, _ Candidate, _ *CandidatePair, + ) *BindingRequestErrorResponse { + return &BindingRequestErrorResponse{ + ErrorCodeAttribute: stun.ErrorCodeAttribute{ + Code: stun.CodeForbidden, + Reason: []byte("Forbidden"), + }, + ExtraAttributes: []stun.Setter{selectionTestCustomErrorAttributeSetter{}}, + } + }, + }) + require.NoError(t, err) + defer func() { require.NoError(t, agent.Close()) }() + + local, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "127.0.0.1", + Port: 12345, + Component: ComponentRTP, + }) + require.NoError(t, err) + + remote, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "127.0.0.1", + Port: 54321, + Component: ComponentRTP, + }) + require.NoError(t, err) + + mockConn := &mockPacketConnWithCapture{} + local.conn = mockConn + + selector := tc.buildSelector(agent) + selector.Start() + + msg, err := stun.Build(stun.BindingRequest, stun.TransactionID) + require.NoError(t, err) + + selector.HandleBindingRequest(msg, local, remote) + + require.Len(t, mockConn.sentPackets, 1) + + out := &stun.Message{Raw: mockConn.sentPackets[0]} + require.NoError(t, out.Decode()) + require.Equal(t, stun.MethodBinding, out.Type.Method) + require.Equal(t, stun.ClassErrorResponse, out.Type.Class) + + errorCode := stun.ErrorCodeAttribute{} + require.NoError(t, errorCode.GetFrom(out)) + require.Equal(t, stun.CodeForbidden, errorCode.Code) + require.True(t, out.Contains(selectionTestCustomAttr)) + require.Empty(t, agent.checklist) + }) + } +} + // copied from pion/webrtc's peerconnection_go_test.go. type testICELogger struct { lastErrorMessage string @@ -266,15 +355,18 @@ func bareAgentForPing() *Agent { connectionStateNotifier: &handlerNotifier{ done: make(chan struct{}), - connectionStateFunc: func(ConnectionState) {}}, //nolint formatting + connectionStateFunc: func(ConnectionState) {}, + }, //nolint formatting candidateNotifier: &handlerNotifier{ done: make(chan struct{}), - candidateFunc: func(Candidate) {}}, //nolint formatting + candidateFunc: func(Candidate) {}, + }, //nolint formatting selectedCandidatePairNotifier: &handlerNotifier{ done: make(chan struct{}), - candidatePairFunc: func(*CandidatePair) {}}, //nolint formatting + candidatePairFunc: func(*CandidatePair) {}, + }, //nolint formatting } } @@ -1304,7 +1396,7 @@ func TestControllingSideRenomination(t *testing.T) { require.NoError(t, err) // Handle the success response - this should switch to pair2 - selector.HandleSuccessResponse(successMsg, local2, remote, remote.addr()) + selector.handleSuccessResponse(successMsg, local2, remote, remote.addr()) // The controlling agent should have switched to pair2 selectedPair := agent.getSelectedPair() @@ -1384,7 +1476,7 @@ func TestControllingSideRenomination(t *testing.T) { // Handle the success response - this should NOT switch since it's standard nomination // and a pair is already selected - selector.HandleSuccessResponse(successMsg, local2, remote, remote.addr()) + selector.handleSuccessResponse(successMsg, local2, remote, remote.addr()) // The controlling agent should remain with pair1 selectedPair := agent.getSelectedPair()