diff --git a/nonvalidator/non_validator.go b/nonvalidator/non_validator.go index e890710e..8726ec7b 100644 --- a/nonvalidator/non_validator.go +++ b/nonvalidator/non_validator.go @@ -36,6 +36,12 @@ func (f *finalizedSeq) String() string { return fmt.Sprintf("FinalizedSeq {BlockDigest: %s, Seq: %d, BlockExists %t, FinalizationExists %t}", digest, seq, f.block != nil, f.finalization != nil) } +var ( + maxRebroadcastAttempts = uint64(5) + // How often we send broadcast requests until we validate the latest epoch + defaultRebroadcastTimeout = 5 * time.Second +) + type Config struct { Storage common.Storage Comm common.Communication @@ -81,6 +87,10 @@ type NonValidator struct { epochs epochs verifier *common.BlockDependencyManager + + // bootstrapRebroadcastHandler handles rebroadcasting our latest epoch to catch up with the current tip + // in the case that our original broadcast requests failed. + bootstrapRebroadcastHandler *common.TimeoutHandler[uint64] } // NewNonValidator creates a NonValidator with the given `config`. @@ -105,7 +115,7 @@ func NewNonValidator(config Config) (*NonValidator, error) { replicator := simplex.NewReplicationState(config.Logger, config.Comm, config.ID, config.MaxSequenceWindow, true, config.StartTime, lock, randomSource) - return &NonValidator{ + n := &NonValidator{ Config: config, incompleteSequences: make(map[uint64]*finalizedSeq), ctx: ctx, @@ -116,7 +126,10 @@ func NewNonValidator(config Config) (*NonValidator, error) { highestEpochCollector: newEpochReplicator(config.Logger, config.Comm), oneTimeVerifier: simplex.NewOneTimeVerifier(config.Logger), sequenceReplicator: replicator, - }, nil + } + + n.bootstrapRebroadcastHandler = common.NewTimeoutHandler(config.Logger, "NonValidator TimeoutHandler", config.StartTime, defaultRebroadcastTimeout, n.bootstrapRunner) + return n, nil } func (n *NonValidator) Start() { @@ -294,6 +307,11 @@ func (n *NonValidator) maybeValidateNextEpoch(block common.Block) { n.Logger.Info("We have a valid sealing block, messages for that epoch can be processed.", zap.Uint64("Epoch", nextEpoch)) n.epochs[nextEpoch] = newEpochMetadata(nextEpoch, sealingInfo, n.SignatureAggregatorCreator) + + // remove all the rebroadcast tasks once we advanced to the next epoch + n.bootstrapRebroadcastHandler.RemoveOldTasks(func(_ uint64, _ struct{}) bool { + return true + }) } func (n *NonValidator) removeOldSequencesAndEpochs(lastCommittedSeq, minEpochToKeep uint64) { @@ -569,3 +587,23 @@ func (n *NonValidator) sendRequest(seq uint64, to common.NodeID) { func (n *NonValidator) nextSeqToCommit() uint64 { return n.Storage.NumBlocks() } + +func (n *NonValidator) bootstrapRunner(taskIds []uint64) { + if len(taskIds) != 1 { + return + } + + // drop the task we just processed; we reschedule the next attempt below. + n.bootstrapRebroadcastHandler.RemoveTask(taskIds[0]) + + attempt := taskIds[0] + 1 + + // too many attempts, don't rebroadcast + if attempt > maxRebroadcastAttempts { + return + } + + n.Logger.Debug("Rebroadcasting latest epoch", zap.Uint64("Next to commit", n.nextSeqToCommit())) + n.broadcastLatestEpoch() + n.bootstrapRebroadcastHandler.AddTask(attempt) +} diff --git a/nonvalidator/non_validator_test.go b/nonvalidator/non_validator_test.go index 52239bef..eb7cfd9b 100644 --- a/nonvalidator/non_validator_test.go +++ b/nonvalidator/non_validator_test.go @@ -652,6 +652,82 @@ func TestNonValidator_VerifiesFinalizationDuringReplication(t *testing.T) { ) } +// TestNonValidator_RebroadcastsUntilMaxAttempts verifies that a pending +// rebroadcast task re-fires on every interval until it reaches +// maxRebroadcastAttempts, then stops on its own. +func TestNonValidator_RebroadcastsUntilMaxAttempts(t *testing.T) { + tc := newSeededChain(t, testNodes, 2) + myNodeID := common.NodeID{100} + msgQueue := &messageQueue{} + comm := &routerComm{nodes: tc.nodes(), t: t, ID: myNodeID, messageQueue: msgQueue} + + start := time.Now() + nv, err := NewNonValidator(Config{ + Storage: tc, + Comm: comm, + Logger: testutil.MakeLogger(t, 1), + SignatureAggregatorCreator: tc.signatureAggregatorCreator, + MaxSequenceWindow: simplex.DefaultMaxRoundWindow, + ID: myNodeID, + StartTime: start, + }) + require.NoError(t, err) + defer nv.Stop() + + // Seed the first attempt, then tick a few intervals past the cap. Each tick + // fires one rebroadcast and schedules the next attempt; ticks past the cap + // should be no-ops. + nv.bootstrapRebroadcastHandler.AddTask(0) + for i := 1; i <= int(maxRebroadcastAttempts)+3; i++ { + nv.bootstrapRebroadcastHandler.Tick(start.Add(time.Duration(i) * defaultRebroadcastTimeout)) + time.Sleep(50 * time.Millisecond) + } + + // broadcastLatestEpoch fans a ReplicationRequest out to every peer, so one + // rebroadcast equals len(testNodes) queued messages. + msgs := 0 + for msg, ok := msgQueue.popResponse(); ok; msg, ok = msgQueue.popResponse() { + require.NotNil(t, msg.msg.ReplicationRequest) + msgs++ + } + require.Equal(t, int(maxRebroadcastAttempts)*len(testNodes), msgs) +} + +// TestNonValidator_CancelsRebroadcastOnEpochValidation verifies that once the +// non-validator validates a new epoch, it cancels the pending rebroadcast +// attempts and stops sending. +func TestNonValidator_CancelsRebroadcastOnEpochValidation(t *testing.T) { + tc := newSeededChain(t, testNodes, 2) + myNodeID := common.NodeID{100} + msgQueue := &messageQueue{} + comm := &routerComm{nodes: tc.nodes(), t: t, ID: myNodeID, messageQueue: msgQueue} + + start := time.Now() + nv, err := NewNonValidator(Config{ + Storage: tc, + Comm: comm, + Logger: testutil.MakeLogger(t, 1), + SignatureAggregatorCreator: tc.signatureAggregatorCreator, + MaxSequenceWindow: simplex.DefaultMaxRoundWindow, + ID: myNodeID, + StartTime: start, + }) + require.NoError(t, err) + defer nv.Stop() + + // Validating a new epoch should drop every pending rebroadcast task. + nv.bootstrapRebroadcastHandler.AddTask(0) + nv.maybeValidateNextEpoch(tc.appendSealing(testNodes)) + + // Ticking past the interval no longer rebroadcasts. + nv.bootstrapRebroadcastHandler.Tick(start.Add(defaultRebroadcastTimeout)) + + require.Never(t, func() bool { + _, ok := msgQueue.popResponse() + return ok + }, 2*time.Second, 50*time.Millisecond) +} + func advanceUntil(nv *NonValidator, epochs *testEpochs, msgQueue *messageQueue, seq uint64) { startTime := nv.StartTime for {