diff --git a/agent.go b/agent.go index 68104376..67691f10 100644 --- a/agent.go +++ b/agent.go @@ -175,6 +175,17 @@ type Agent struct { lastRenominationTime time.Time turnClientFactory func(*turn.ClientConfig) (turnClient, error) + + // How long consent remains valid without an authenticated non-error STUN Binding response. + consentFreshnessTimeout time.Duration + + // Timestamp of the last consent refresh for the selected candidate pair. + lastConsentAt time.Time + + // Callback that allows user to send an error response for inbound STUN Binding Requests before success is emitted. + // Returning nil continues normal success handling. + userBindingReqErrorRespHandler func( + m *stun.Message, local, remote Candidate, pair *CandidatePair) *BindingRequestErrorResponse } // NewAgent creates a new Agent. @@ -378,6 +389,8 @@ func createAgentBase(config *AgentConfig) (*Agent, error) { automaticRenomination: false, renominationInterval: 3 * time.Second, // Default matching libwebrtc turnClientFactory: defaultTurnClient, + userBindingReqErrorRespHandler: config.BindingRequestErrorResponseHandler, + consentFreshnessTimeout: defaultConsentFreshnessTimeout, } config.initWithDefaults(agent) @@ -750,6 +763,7 @@ func (a *Agent) setSelectedPair(pair *CandidatePair) { if pair == nil { var nilPair *CandidatePair a.selectedPair.Store(nilPair) + a.lastConsentAt = time.Time{} a.log.Tracef("Unset selected candidate pair") return @@ -757,6 +771,7 @@ func (a *Agent) setSelectedPair(pair *CandidatePair) { pair.nominated = true a.selectedPair.Store(pair) + a.lastConsentAt = time.Now() a.log.Tracef("Set selected candidate pair: %s", pair) // Signal connected: notify any Connect() calls waiting on onConnected @@ -769,6 +784,15 @@ func (a *Agent) setSelectedPair(pair *CandidatePair) { a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(pair) } +// consentExpired checks if the consent freshness has expired for the selected candidate pair. +func (a *Agent) consentExpired(now time.Time) bool { + if a.consentFreshnessTimeout <= 0 || a.lastConsentAt.IsZero() { + return false + } + + return now.Sub(a.lastConsentAt) > a.consentFreshnessTimeout +} + func (a *Agent) pingAllCandidates() { a.log.Trace("Pinging all candidates") @@ -885,6 +909,14 @@ func (a *Agent) validateSelectedPair() bool { return false } + now := time.Now() + if a.consentExpired(now) { + a.log.Warnf("Consent expired for selected pair after %v without valid response", a.consentFreshnessTimeout) + a.updateConnectionState(ConnectionStateFailed) + + return false + } + disconnectedTime := time.Since(selectedPair.Remote.LastReceived()) // Only allow transitions to failed if a.failedTimeout is non-zero @@ -1412,6 +1444,28 @@ func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) { } } +func (a *Agent) sendBindingError( + m *stun.Message, local, remote Candidate, bindingErrorResponse BindingRequestErrorResponse, +) { + setters := make([]stun.Setter, 0, len(bindingErrorResponse.ExtraAttributes)+5) + setters = append(setters, m, stun.BindingError) + setters = append(setters, bindingErrorResponse.ErrorCodeAttribute) + setters = append(setters, bindingErrorResponse.ExtraAttributes...) + setters = append(setters, + stun.NewShortTermIntegrity(a.localPwd), + stun.Fingerprint, + ) + + if out, err := stun.Build(setters...); err != nil { + a.log.Warnf("Failed to send binding error response from: %s to: %s error: %s", local, remote, err) + } else { + if pair := a.findPair(local, remote); pair != nil { + pair.UpdateResponseSent() + } + a.sendSTUN(out, local, remote) + } +} + // Removes pending binding requests that are over maxBindingRequestTimeout old // // Let HTO be the transaction timeout, which SHOULD be 2*RTT if @@ -1433,12 +1487,19 @@ func (a *Agent) invalidatePendingBindingRequests(filterTime time.Time) { } } -// Assert that the passed TransactionID is in our pendingBindingRequests and returns the destination -// If the bindingRequest was valid remove it from our pending cache. -func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bool, *bindingRequest, time.Duration) { +// consumePendingBindingRequest validates that the passed TransactionID and remote address match a pending binding +// request. If a match is found, the binding request is removed from the pending cache and returned along with how +// long ago it was sent. If no match is found, nil is returned. +func (a *Agent) consumePendingBindingRequest( + id [stun.TransactionIDSize]byte, remoteAddr net.Addr, +) (bool, *bindingRequest, time.Duration) { a.invalidatePendingBindingRequests(time.Now()) for i := range a.pendingBindingRequests { if a.pendingBindingRequests[i].transactionID == id { + if !addrEqual(a.pendingBindingRequests[i].destination, remoteAddr) { + return false, nil, 0 + } + validBindingRequest := a.pendingBindingRequests[i] a.pendingBindingRequests = append(a.pendingBindingRequests[:i], a.pendingBindingRequests[i+1:]...) @@ -1481,7 +1542,7 @@ func (a *Agent) handleRoleConflict(msg *stun.Message, local, remote Candidate, r } // handleInbound processes STUN traffic from a remote candidate. -func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { +func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { // nolint:cyclop if msg == nil || local == nil { return } @@ -1504,6 +1565,10 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add if remoteCandidate, ok = a.handleInboundRequest(remoteCandidate, local, remote, msg); !ok { return } + case stun.ClassErrorResponse: + if !a.handleInboundErrorResponse(remoteCandidate, local, remote, msg) { + return + } default: } @@ -1516,7 +1581,8 @@ func canHandleInbound(msg *stun.Message) bool { return msg.Type.Method == stun.MethodBinding && (msg.Type.Class == stun.ClassSuccessResponse || msg.Type.Class == stun.ClassRequest || - msg.Type.Class == stun.ClassIndication) + msg.Type.Class == stun.ClassIndication || + msg.Type.Class == stun.ClassErrorResponse) } func (a *Agent) handleInboundResponse( @@ -1534,9 +1600,17 @@ func (a *Agent) handleInboundResponse( return false } - a.getSelector().HandleSuccessResponse(msg, local, remoteCandidate, remote) + handled := a.getSelector().handleSuccessResponse(msg, local, remoteCandidate, remote) - return true + if handled { + if selectedPair := a.getSelectedPair(); selectedPair != nil && + selectedPair.Local.Equal(local) && selectedPair.Remote.Equal(remoteCandidate) { + // Note consent freshness + a.lastConsentAt = time.Now() + } + } + + return handled } func (a *Agent) handleInboundRequest( @@ -1602,6 +1676,54 @@ func (a *Agent) handleInboundRequest( return remoteCandidate, true } +func (a *Agent) handleInboundErrorResponse( + remoteCandidate, _ Candidate, remoteAddr net.Addr, msg *stun.Message, +) bool { // nolint:unparam + if err := stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil { + a.log.Warnf("Discard error response with broken integrity from (%s), %v", remoteAddr, err) + + return false + } + + if remoteCandidate == nil { + a.log.Warnf("Discard error response from (%s), no such remote", remoteAddr) + + return false + } + + ok, pendingRequest, _ := a.consumePendingBindingRequest(msg.TransactionID, remoteAddr) + if !ok { + a.log.Warnf("Discard error response from (%s), unknown TransactionID 0x%x or address mismatch", + remoteAddr, msg.TransactionID) + + return false + } + _ = pendingRequest + + errorCode := stun.ErrorCodeAttribute{} + if err := errorCode.GetFrom(msg); err != nil { + a.log.Debugf("Discard error response from (%s), no valid ERROR-CODE attribute: %v", remoteAddr, err) + + return false + } + + // Return true after successfully validating and accounting for an error response that doesn't immediately fail + // the agent, and false for cases where the error response should be discarded or indicates a fatal condition that + // should fail the agent. + switch errorCode.Code { + case stun.CodeForbidden: + a.log.Warnf("Received authenticated STUN 403 (Forbidden); revoking consent for %s", remoteAddr) + a.updateConnectionState(ConnectionStateFailed) + + return false + + default: + a.log.Debugf("Received authenticated STUN error response %d from %s", errorCode.Code, remoteAddr) + + return false + } +} + // validateNonSTUNTraffic processes non STUN traffic from a remote candidate, // and returns true if it is an actual remote candidate. func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) { diff --git a/agent_config.go b/agent_config.go index 0539207c..326580b1 100644 --- a/agent_config.go +++ b/agent_config.go @@ -56,6 +56,10 @@ const ( // maxBindingRequestTimeout is the wait time before binding requests can be deleted. maxBindingRequestTimeout = 4000 * time.Millisecond + + // defaultConsentFreshnessTimeout is the maximum time consent can remain valid + // without an authenticated, non-error STUN Binding response. + defaultConsentFreshnessTimeout = 30 * time.Second ) func defaultCandidateTypes() []CandidateType { @@ -226,6 +230,29 @@ type AgentConfig struct { // switched to that irrespective of relative priority between current selected pair // and priority of the pair being switched to. EnableUseCandidateCheckPriority bool + + // ConsentFreshnessTimeout determines how long consent remains valid without an authenticated, + // non-error STUN Binding response. + // When this is nil, it defaults to 30 seconds. A timeout of 0 disables consent freshness expiry. + ConsentFreshnessTimeout *time.Duration + + // BindingRequestErrorResponseHandler allows applications to send an error response for individual + // inbound STUN Binding Requests before a success response is emitted. + // It can be used to implement consent revocation by returning a Binding Error 403 (Forbidden) + // response when the agent receives a binding request for an existing candidate pair. + // Returning nil continues normal handling and sends a success response. + // Returning a non-nil BindingRequestErrorResponse causes the agent to send an authenticated + // STUN Binding Error response with the provided code/reason and optional extra attributes. + // Note: pair is nil when the binding request will create a new pair. + BindingRequestErrorResponseHandler func(m *stun.Message, local, remote Candidate, + pair *CandidatePair) *BindingRequestErrorResponse +} + +// BindingRequestErrorResponse defines the STUN Binding Error response emitted +// for an inbound STUN Binding Request. +type BindingRequestErrorResponse struct { + ErrorCodeAttribute stun.ErrorCodeAttribute + ExtraAttributes []stun.Setter } // initWithDefaults populates an agent and falls back to defaults if fields are unset. @@ -290,6 +317,12 @@ func (config *AgentConfig) initWithDefaults(agent *Agent) { //nolint:cyclop agent.keepaliveInterval = *config.KeepaliveInterval } + if config.ConsentFreshnessTimeout == nil { + agent.consentFreshnessTimeout = defaultConsentFreshnessTimeout + } else { + agent.consentFreshnessTimeout = *config.ConsentFreshnessTimeout + } + if config.CheckInterval == nil { agent.checkInterval = defaultCheckInterval } else { diff --git a/agent_config_test.go b/agent_config_test.go index 76450bb4..e950f13f 100644 --- a/agent_config_test.go +++ b/agent_config_test.go @@ -23,6 +23,7 @@ func TestAgentConfig_initWithDefaults(t *testing.T) { func(t *testing.T, result *Agent) { t.Helper() assert.Equal(t, result.relayAcceptanceMinWait, defaultRelayAcceptanceMinWait) + assert.Equal(t, result.consentFreshnessTimeout, defaultConsentFreshnessTimeout) }, }, { @@ -57,6 +58,18 @@ func TestAgentConfig_initWithDefaults(t *testing.T) { assert.Equal(t, result.relayAcceptanceMinWait, relayAcceptanceMinWait) }, }, + { + "consent freshness timeout can be disabled", + func() *AgentConfig { + zero := time.Duration(0) + + return &AgentConfig{ConsentFreshnessTimeout: &zero} + }(), + func(t *testing.T, result *Agent) { + t.Helper() + assert.Equal(t, time.Duration(0), result.consentFreshnessTimeout) + }, + }, } for _, test := range tests { diff --git a/agent_options.go b/agent_options.go index ac26c739..32adee78 100644 --- a/agent_options.go +++ b/agent_options.go @@ -960,3 +960,40 @@ func WithLoggerFactory(loggerFactory logging.LoggerFactory) AgentOption { return nil } } + +// WithConsentFreshnessTimeout sets how long consent remains valid without an authenticated, non-error +// STUN Binding response. +// A timeout of 0 disables consent freshness expiry. +func WithConsentFreshnessTimeout(timeout time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.consentFreshnessTimeout = timeout + + return nil + } +} + +// WithBindingRequestErrorResponseHandler sets a handler to return an optional STUN Binding Error response +// for inbound STUN Binding Requests. +// It can be used to implement consent revocation by returning a Binding Error 403 (Forbidden) response +// when the agent receives a binding request for an existing candidate pair. +// Returning nil continues normal success handling. +// Returning a non-nil response sends the configured ERROR-CODE and any additional attributes included +// in ExtraAttributes. +// Note: pair is nil when the binding request will create a new pair. +func WithBindingRequestErrorResponseHandler( + handler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) *BindingRequestErrorResponse, +) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.userBindingReqErrorRespHandler = handler + + return nil + } +} diff --git a/agent_options_test.go b/agent_options_test.go index 26b5e3d6..51d0f6a6 100644 --- a/agent_options_test.go +++ b/agent_options_test.go @@ -160,6 +160,7 @@ func TestWithTimeoutOptions(t *testing.T) { WithFailedTimeout(20*time.Second), WithKeepaliveInterval(3*time.Second), WithCheckInterval(150*time.Millisecond), + WithConsentFreshnessTimeout(40*time.Second), ) require.NoError(t, err) defer agent.Close() //nolint:errcheck @@ -168,6 +169,7 @@ func TestWithTimeoutOptions(t *testing.T) { assert.Equal(t, 20*time.Second, agent.failedTimeout) assert.Equal(t, 3*time.Second, agent.keepaliveInterval) assert.Equal(t, 150*time.Millisecond, agent.checkInterval) + assert.Equal(t, 40*time.Second, agent.consentFreshnessTimeout) } func TestWithAcceptanceWaitOptions(t *testing.T) { @@ -1681,3 +1683,63 @@ func (n *stubNet) CreateDialer(dialer *net.Dialer) transport.Dialer { func (n *stubNet) CreateListenConfig(listenerConfig *net.ListenConfig) transport.ListenConfig { return nil } + +func TestWithBindingRequestErrorResponseHandler(t *testing.T) { + t.Run("sets binding error response handler", func(t *testing.T) { + handlerCalled := false + handler := func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) *BindingRequestErrorResponse { + handlerCalled = true + + // Consent revocation should be sent by returning a Binding Error 403 (Forbidden). + return &BindingRequestErrorResponse{ + ErrorCodeAttribute: stun.ErrorCodeAttribute{Code: stun.CodeForbidden, Reason: []byte("Forbidden")}, + } + } + + agent, err := NewAgentWithOptions(WithBindingRequestErrorResponseHandler(handler)) + assert.NoError(t, err) + defer agent.Close() //nolint:errcheck + + assert.NotNil(t, agent.userBindingReqErrorRespHandler) + + if agent.userBindingReqErrorRespHandler != nil { + agent.userBindingReqErrorRespHandler(nil, nil, nil, nil) + assert.True(t, handlerCalled) + } + }) + + t.Run("default is nil", func(t *testing.T) { + agent, err := NewAgentWithOptions() + assert.NoError(t, err) + defer agent.Close() //nolint:errcheck + + assert.Nil(t, agent.userBindingReqErrorRespHandler) + }) + + t.Run("works with config", func(t *testing.T) { + handlerCalled := false + handler := func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) *BindingRequestErrorResponse { + handlerCalled = true + + return &BindingRequestErrorResponse{ + ErrorCodeAttribute: stun.ErrorCodeAttribute{Code: stun.CodeForbidden, Reason: []byte("Forbidden")}, + } + } + + config := &AgentConfig{ + NetworkTypes: []NetworkType{NetworkTypeUDP4}, + BindingRequestErrorResponseHandler: handler, + } + + agent, err := NewAgent(config) + assert.NoError(t, err) + defer agent.Close() //nolint:errcheck + + assert.NotNil(t, agent.userBindingReqErrorRespHandler) + + if agent.userBindingReqErrorRespHandler != nil { + agent.userBindingReqErrorRespHandler(nil, nil, nil, nil) + assert.True(t, handlerCalled) + } + }) +} diff --git a/agent_test.go b/agent_test.go index 3adcf6b9..3a9b09e9 100644 --- a/agent_test.go +++ b/agent_test.go @@ -48,6 +48,12 @@ func (r *recordingSelector) HandleSuccessResponse(*stun.Message, Candidate, Cand r.handledSuccess = true } +func (r *recordingSelector) handleSuccessResponse(*stun.Message, Candidate, Candidate, net.Addr) bool { + r.handledSuccess = true + + return true +} + func (r *recordingSelector) HandleBindingRequest(*stun.Message, Candidate, Candidate) { r.handledBindingRequest = true } @@ -634,6 +640,57 @@ func TestHandleInboundAdditionalCases(t *testing.T) { require.Equal(t, CandidatePairStateSucceeded, pair.state) require.Empty(t, agent.pendingBindingRequests) }) + + t.Run("Authenticated 403 error response revokes consent", func(t *testing.T) { + agent := newTestAgent(t) + defer func() { + require.NoError(t, agent.Close()) + }() + + local := newHostLocal(t) + remoteConfig := &CandidateHostConfig{ + Network: "udp", + Address: "192.0.2.3", + Port: 6666, + Component: 1, + } + remoteCandidate, err := NewCandidateHost(remoteConfig) + require.NoError(t, err) + remoteAddr := &net.UDPAddr{IP: net.ParseIP(remoteCandidate.Address()), Port: remoteCandidate.Port()} + transactionID := stun.NewTransactionID() + remotePwd := "remotekey403" + + require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) { + agent.selector = &controllingSelector{agent: agent, log: agent.log} + agent.selector.Start() + agent.localCandidates[local.NetworkType()] = append(agent.localCandidates[local.NetworkType()], local) + agent.addRemoteCandidate(remoteCandidate) //nolint:contextcheck + pair := agent.findPair(local, remoteCandidate) + require.NotNil(t, pair) + agent.setSelectedPair(pair) + agent.pendingBindingRequests = []bindingRequest{ // request awaiting response + { + timestamp: time.Now(), + transactionID: transactionID, + destination: remoteAddr, + }, + } + agent.remotePwd = remotePwd + })) + + msg, err := stun.Build(stun.BindingError, stun.NewTransactionIDSetter(transactionID), + stun.ErrorCodeAttribute{Code: stun.CodeForbidden, Reason: []byte("Forbidden")}, + stun.NewShortTermIntegrity(remotePwd), + stun.Fingerprint, + ) + require.NoError(t, err) + + agent.handleInbound(msg, local, remoteAddr) + + require.Equal(t, ConnectionStateFailed, agent.connectionState) + require.Nil(t, agent.getSelectedPair()) + require.Empty(t, agent.pendingBindingRequests) + }) } func TestInvalidAgentStarts(t *testing.T) { @@ -1751,9 +1808,10 @@ func TestLiteLifecycle(t *testing.T) { func TestValidateSelectedPairTransitions(t *testing.T) { agent := &Agent{ - disconnectedTimeout: time.Second, - failedTimeout: time.Second, - connectionState: ConnectionStateConnected, + disconnectedTimeout: time.Second, + failedTimeout: time.Second, + consentFreshnessTimeout: defaultConsentFreshnessTimeout, + connectionState: ConnectionStateConnected, connectionStateNotifier: &handlerNotifier{ connectionStateFunc: func(ConnectionState) {}, done: make(chan struct{}), @@ -1788,6 +1846,43 @@ func TestValidateSelectedPairTransitions(t *testing.T) { require.Equal(t, ConnectionStateFailed, agent.connectionState) } +func TestValidateSelectedPairConsentExpired(t *testing.T) { + agent := &Agent{ + disconnectedTimeout: time.Second, + failedTimeout: time.Second, + consentFreshnessTimeout: 30 * time.Second, + connectionState: ConnectionStateConnected, + connectionStateNotifier: &handlerNotifier{ + connectionStateFunc: func(ConnectionState) {}, + done: make(chan struct{}), + }, + log: logging.NewDefaultLoggerFactory().NewLogger("test"), + } + + local, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "1.1.1.1", + Port: 1000, + Component: ComponentRTP, + }) + require.NoError(t, err) + + remote, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "2.2.2.2", + Port: 2000, + Component: ComponentRTP, + }) + require.NoError(t, err) + remote.setLastReceived(time.Now()) + + agent.selectedPair.Store(newCandidatePair(local, remote, true)) + agent.lastConsentAt = time.Now().Add(-31 * time.Second) + + require.False(t, agent.validateSelectedPair()) + require.Equal(t, ConnectionStateFailed, agent.connectionState) +} + func TestNilCandidate(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) @@ -2550,42 +2645,44 @@ func TestAgentUpdateOptions(t *testing.T) { // All options except WithUrls should be rejected on a running agent. // When adding new options, add them here if they are not runtime-updatable. nonUpdatableOptions := map[string]AgentOption{ - "WithAddressRewriteRules": WithAddressRewriteRules(AddressRewriteRule{External: []string{"1.2.3.4"}}), - "WithICELite": WithICELite(true), - "WithPortRange": WithPortRange(5000, 6000), - "WithDisconnectedTimeout": WithDisconnectedTimeout(time.Second), - "WithFailedTimeout": WithFailedTimeout(time.Second), - "WithKeepaliveInterval": WithKeepaliveInterval(time.Second), - "WithHostAcceptanceMinWait": WithHostAcceptanceMinWait(time.Second), - "WithSrflxAcceptanceMinWait": WithSrflxAcceptanceMinWait(time.Second), - "WithPrflxAcceptanceMinWait": WithPrflxAcceptanceMinWait(time.Second), - "WithRelayAcceptanceMinWait": WithRelayAcceptanceMinWait(time.Second), - "WithSTUNGatherTimeout": WithSTUNGatherTimeout(time.Second), - "WithIPFilter": WithIPFilter(func(net.IP) bool { return true }), - "WithNet": WithNet(nil), - "WithMulticastDNSMode": WithMulticastDNSMode(MulticastDNSModeDisabled), - "WithMulticastDNSHostName": WithMulticastDNSHostName("test.local"), - "WithLocalCredentials": WithLocalCredentials("", ""), - "WithTCPMux": WithTCPMux(nil), - "WithUDPMux": WithUDPMux(nil), - "WithUDPMuxSrflx": WithUDPMuxSrflx(nil), - "WithProxyDialer": WithProxyDialer(nil), - "WithMaxBindingRequests": WithMaxBindingRequests(10), - "WithCheckInterval": WithCheckInterval(time.Second), - "WithRenomination": WithRenomination(DefaultNominationValueGenerator()), - "WithNominationAttribute": WithNominationAttribute(0x0030), - "WithIncludeLoopback": WithIncludeLoopback(), - "WithTCPPriorityOffset": WithTCPPriorityOffset(10), - "WithDisableActiveTCP": WithDisableActiveTCP(), - "WithBindingRequestHandler": WithBindingRequestHandler(nil), - "WithEnableUseCandidateCheckPriority": WithEnableUseCandidateCheckPriority(), - "WithContinualGatheringPolicy": WithContinualGatheringPolicy(GatherOnce), - "WithNetworkMonitorInterval": WithNetworkMonitorInterval(time.Second), - "WithNetworkTypes": WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), - "WithCandidateTypes": WithCandidateTypes([]CandidateType{CandidateTypeHost}), - "WithAutomaticRenomination": WithAutomaticRenomination(time.Second), - "WithInterfaceFilter": WithInterfaceFilter(func(string) bool { return true }), - "WithLoggerFactory": WithLoggerFactory(nil), + "WithAddressRewriteRules": WithAddressRewriteRules(AddressRewriteRule{External: []string{"1.2.3.4"}}), + "WithICELite": WithICELite(true), + "WithPortRange": WithPortRange(5000, 6000), + "WithDisconnectedTimeout": WithDisconnectedTimeout(time.Second), + "WithFailedTimeout": WithFailedTimeout(time.Second), + "WithKeepaliveInterval": WithKeepaliveInterval(time.Second), + "WithConsentFreshnessTimeout": WithConsentFreshnessTimeout(time.Second), + "WithHostAcceptanceMinWait": WithHostAcceptanceMinWait(time.Second), + "WithSrflxAcceptanceMinWait": WithSrflxAcceptanceMinWait(time.Second), + "WithPrflxAcceptanceMinWait": WithPrflxAcceptanceMinWait(time.Second), + "WithRelayAcceptanceMinWait": WithRelayAcceptanceMinWait(time.Second), + "WithSTUNGatherTimeout": WithSTUNGatherTimeout(time.Second), + "WithIPFilter": WithIPFilter(func(net.IP) bool { return true }), + "WithNet": WithNet(nil), + "WithMulticastDNSMode": WithMulticastDNSMode(MulticastDNSModeDisabled), + "WithMulticastDNSHostName": WithMulticastDNSHostName("test.local"), + "WithLocalCredentials": WithLocalCredentials("", ""), + "WithTCPMux": WithTCPMux(nil), + "WithUDPMux": WithUDPMux(nil), + "WithUDPMuxSrflx": WithUDPMuxSrflx(nil), + "WithProxyDialer": WithProxyDialer(nil), + "WithMaxBindingRequests": WithMaxBindingRequests(10), + "WithCheckInterval": WithCheckInterval(time.Second), + "WithRenomination": WithRenomination(DefaultNominationValueGenerator()), + "WithNominationAttribute": WithNominationAttribute(0x0030), + "WithIncludeLoopback": WithIncludeLoopback(), + "WithTCPPriorityOffset": WithTCPPriorityOffset(10), + "WithDisableActiveTCP": WithDisableActiveTCP(), + "WithBindingRequestHandler": WithBindingRequestHandler(nil), + "WithBindingRequestErrorResponseHandler": WithBindingRequestErrorResponseHandler(nil), + "WithEnableUseCandidateCheckPriority": WithEnableUseCandidateCheckPriority(), + "WithContinualGatheringPolicy": WithContinualGatheringPolicy(GatherOnce), + "WithNetworkMonitorInterval": WithNetworkMonitorInterval(time.Second), + "WithNetworkTypes": WithNetworkTypes([]NetworkType{NetworkTypeUDP4}), + "WithCandidateTypes": WithCandidateTypes([]CandidateType{CandidateTypeHost}), + "WithAutomaticRenomination": WithAutomaticRenomination(time.Second), + "WithInterfaceFilter": WithInterfaceFilter(func(string) bool { return true }), + "WithLoggerFactory": WithLoggerFactory(nil), } for name, opt := range nonUpdatableOptions { diff --git a/gather.go b/gather.go index 72f73abd..94f926a0 100644 --- a/gather.go +++ b/gather.go @@ -920,19 +920,21 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { return } - conn, connectErr := dtls.Client(&fakenet.PacketConn{Conn: udpConn}, udpConn.RemoteAddr(), &dtls.Config{ - ServerName: url.Host, - InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec - LoggerFactory: a.loggerFactory, - }) + conn, connectErr := dtls.ClientWithOptions( + &fakenet.PacketConn{Conn: udpConn}, + udpConn.RemoteAddr(), + dtls.WithServerName(url.Host), + dtls.WithInsecureSkipVerify(a.insecureSkipVerify), //nolint:gosec + dtls.WithLoggerFactory(a.loggerFactory), + ) if connectErr != nil { - a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) + a.log.Warnf("Failed to create DTLS client: %s: %v", turnServerAddr, connectErr) return } if connectErr = conn.HandshakeContext(ctx); connectErr != nil { - a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) + a.log.Warnf("Failed to create DTLS client: %s: %v", turnServerAddr, connectErr) return } diff --git a/gather_test.go b/gather_test.go index 991de91d..3aff1224 100644 --- a/gather_test.go +++ b/gather_test.go @@ -452,12 +452,10 @@ func TestTURNConcurrency(t *testing.T) { require.NoError(t, genErr) serverPort := randomPort(t) - serverListener, err := dtls.Listen( + serverListener, err := dtls.ListenWithOptions( "udp", &net.UDPAddr{IP: net.ParseIP(localhostIPStr), Port: serverPort}, - &dtls.Config{ - Certificates: []tls.Certificate{certificate}, - }, + dtls.WithCertificates(certificate), ) require.NoError(t, err) diff --git a/selection.go b/selection.go index efed3353..f39e3ef1 100644 --- a/selection.go +++ b/selection.go @@ -15,8 +15,12 @@ type pairCandidateSelector interface { Start() ContactCandidates() PingCandidate(local, remote Candidate) + // Deprecated: use handleSuccessResponse instead, which returns a boolean indicating if the response was + // successfully handled HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) HandleBindingRequest(m *stun.Message, local, remote Candidate) + + handleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) bool } type controllingSelector struct { @@ -103,17 +107,32 @@ func (s *controllingSelector) nominatePair(pair *CandidatePair) { } func (s *controllingSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop - s.agent.sendBindingSuccess(message, local, remote) - pair := s.agent.findPair(local, remote) + if pair != nil { + pair.UpdateRequestReceived() + } + + if s.agent.userBindingReqErrorRespHandler != nil { + bindingErrorResponse := s.agent.userBindingReqErrorRespHandler(message, local, remote, pair) + if bindingErrorResponse != nil { + s.agent.sendBindingError(message, local, remote, *bindingErrorResponse) + + return + } + } + newPairAdded := false if pair == nil { pair = s.agent.addPair(local, remote) + newPairAdded = true pair.UpdateRequestReceived() + } + + s.agent.sendBindingSuccess(message, local, remote) + if newPairAdded { return } - pair.UpdateRequestReceived() if pair.state == CandidatePairStateSucceeded && s.nominatedPair == nil && s.agent.getSelectedPair() == nil { bestPair := s.agent.getBestAvailableCandidatePair() @@ -138,11 +157,17 @@ func (s *controllingSelector) HandleBindingRequest(message *stun.Message, local, } func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) { - ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) + s.handleSuccessResponse(m, local, remote, remoteAddr) +} + +func (s *controllingSelector) handleSuccessResponse( + m *stun.Message, local, remote Candidate, remoteAddr net.Addr, +) bool { + ok, pendingRequest, rtt := s.agent.consumePendingBindingRequest(m.TransactionID, remoteAddr) if !ok { s.log.Warnf("Discard success response from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) - return + return false } transactionAddr := pendingRequest.destination @@ -156,7 +181,7 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo remote, ) - return + return false } s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local) @@ -166,7 +191,7 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo // This shouldn't happen s.log.Error("Success response from invalid candidate pair") - return + return false } pair.state = CandidatePairStateSucceeded @@ -188,6 +213,8 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo } pair.UpdateRoundTripTime(rtt) + + return true } func (s *controllingSelector) PingCandidate(local, remote Candidate) { @@ -355,6 +382,10 @@ func (s *controlledSelector) PingCandidate(local, remote Candidate) { } func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) { + s.handleSuccessResponse(m, local, remote, remoteAddr) +} + +func (s *controlledSelector) handleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) bool { //nolint:godox // TODO according to the standard we should specifically answer a failed nomination: // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 @@ -363,11 +394,11 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot // request with an appropriate error code response (e.g., 400) // [RFC5389]. - ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) + ok, pendingRequest, rtt := s.agent.consumePendingBindingRequest(m.TransactionID, remoteAddr) if !ok { s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) - return + return false } transactionAddr := pendingRequest.destination @@ -381,7 +412,7 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot remote, ) - return + return false } s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local) @@ -391,7 +422,7 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot // This shouldn't happen s.log.Error("Success response from invalid candidate pair") - return + return false } pair.state = CandidatePairStateSucceeded @@ -407,14 +438,29 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot } pair.UpdateRoundTripTime(rtt) + + return true } func (s *controlledSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop pair := s.agent.findPair(local, remote) + if pair != nil { + pair.UpdateRequestReceived() + } + + if s.agent.userBindingReqErrorRespHandler != nil { + bindingErrorResponse := s.agent.userBindingReqErrorRespHandler(message, local, remote, pair) + if bindingErrorResponse != nil { + s.agent.sendBindingError(message, local, remote, *bindingErrorResponse) + + return + } + } + if pair == nil { pair = s.agent.addPair(local, remote) + pair.UpdateRequestReceived() } - pair.UpdateRequestReceived() if message.Contains(stun.AttrUseCandidate) || message.Contains(s.agent.nominationAttribute) { //nolint:nestif // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 diff --git a/selection_test.go b/selection_test.go index bfc294f2..2dce6458 100644 --- a/selection_test.go +++ b/selection_test.go @@ -29,8 +29,17 @@ const ( selectionTestPassword = "pwd" selectionTestRemoteUfrag = "remote" selectionTestLocalUfrag = "local" + selectionTestCustomAttr = stun.AttrType(0xC001) ) +type selectionTestCustomErrorAttributeSetter struct{} + +func (selectionTestCustomErrorAttributeSetter) AddTo(m *stun.Message) error { + m.Add(selectionTestCustomAttr, []byte("custom-attr")) + + return nil +} + func sendUntilDone(t *testing.T, writingConn, readingConn net.Conn, maxAttempts int) bool { t.Helper() @@ -168,6 +177,86 @@ func TestBindingRequestHandler(t *testing.T) { closePipe(t, controllingConn, controlledConn) } +func TestBindingRequestErrorResponseHandler(t *testing.T) { + testCases := []struct { + name string + buildSelector func(*Agent) pairCandidateSelector + }{ + { + name: "controlledSelector", + buildSelector: func(agent *Agent) pairCandidateSelector { + return &controlledSelector{agent: agent, log: agent.log} + }, + }, + { + name: "controllingSelector", + buildSelector: func(agent *Agent) pairCandidateSelector { + return &controllingSelector{agent: agent, log: agent.log} + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + agent, err := NewAgent(&AgentConfig{ + BindingRequestErrorResponseHandler: func( + _ *stun.Message, _, _ Candidate, _ *CandidatePair, + ) *BindingRequestErrorResponse { + return &BindingRequestErrorResponse{ + ErrorCodeAttribute: stun.ErrorCodeAttribute{ + Code: stun.CodeForbidden, + Reason: []byte("Forbidden"), + }, + ExtraAttributes: []stun.Setter{selectionTestCustomErrorAttributeSetter{}}, + } + }, + }) + require.NoError(t, err) + defer func() { require.NoError(t, agent.Close()) }() + + local, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "127.0.0.1", + Port: 12345, + Component: ComponentRTP, + }) + require.NoError(t, err) + + remote, err := NewCandidateHost(&CandidateHostConfig{ + Network: "udp", + Address: "127.0.0.1", + Port: 54321, + Component: ComponentRTP, + }) + require.NoError(t, err) + + mockConn := &mockPacketConnWithCapture{} + local.conn = mockConn + + selector := tc.buildSelector(agent) + selector.Start() + + msg, err := stun.Build(stun.BindingRequest, stun.TransactionID) + require.NoError(t, err) + + selector.HandleBindingRequest(msg, local, remote) + + require.Len(t, mockConn.sentPackets, 1) + + out := &stun.Message{Raw: mockConn.sentPackets[0]} + require.NoError(t, out.Decode()) + require.Equal(t, stun.MethodBinding, out.Type.Method) + require.Equal(t, stun.ClassErrorResponse, out.Type.Class) + + errorCode := stun.ErrorCodeAttribute{} + require.NoError(t, errorCode.GetFrom(out)) + require.Equal(t, stun.CodeForbidden, errorCode.Code) + require.True(t, out.Contains(selectionTestCustomAttr)) + require.Empty(t, agent.checklist) + }) + } +} + // copied from pion/webrtc's peerconnection_go_test.go. type testICELogger struct { lastErrorMessage string @@ -266,15 +355,18 @@ func bareAgentForPing() *Agent { connectionStateNotifier: &handlerNotifier{ done: make(chan struct{}), - connectionStateFunc: func(ConnectionState) {}}, //nolint formatting + connectionStateFunc: func(ConnectionState) {}, + }, //nolint formatting candidateNotifier: &handlerNotifier{ done: make(chan struct{}), - candidateFunc: func(Candidate) {}}, //nolint formatting + candidateFunc: func(Candidate) {}, + }, //nolint formatting selectedCandidatePairNotifier: &handlerNotifier{ done: make(chan struct{}), - candidatePairFunc: func(*CandidatePair) {}}, //nolint formatting + candidatePairFunc: func(*CandidatePair) {}, + }, //nolint formatting } } @@ -1304,7 +1396,7 @@ func TestControllingSideRenomination(t *testing.T) { require.NoError(t, err) // Handle the success response - this should switch to pair2 - selector.HandleSuccessResponse(successMsg, local2, remote, remote.addr()) + selector.handleSuccessResponse(successMsg, local2, remote, remote.addr()) // The controlling agent should have switched to pair2 selectedPair := agent.getSelectedPair() @@ -1384,7 +1476,7 @@ func TestControllingSideRenomination(t *testing.T) { // Handle the success response - this should NOT switch since it's standard nomination // and a pair is already selected - selector.HandleSuccessResponse(successMsg, local2, remote, remote.addr()) + selector.handleSuccessResponse(successMsg, local2, remote, remote.addr()) // The controlling agent should remain with pair1 selectedPair := agent.getSelectedPair()