From 46bbe6a5250f3d1e0ad88239609949c3d8968833 Mon Sep 17 00:00:00 2001 From: samliok Date: Wed, 24 Jun 2026 15:20:19 -0400 Subject: [PATCH] update tests --- nonvalidator/chain_helpers_test.go | 5 +- nonvalidator/comm_test.go | 211 ++++++++++++++++++---------- nonvalidator/non_validator_test.go | 112 ++++++++++----- simplex/epoch.go | 90 ++++++------ simplex/replication_request_test.go | 72 ++++++++++ 5 files changed, 335 insertions(+), 155 deletions(-) diff --git a/nonvalidator/chain_helpers_test.go b/nonvalidator/chain_helpers_test.go index cbe027b4..651bbbe0 100644 --- a/nonvalidator/chain_helpers_test.go +++ b/nonvalidator/chain_helpers_test.go @@ -29,6 +29,7 @@ var genesis = testutil.NewTestBlock(common.ProtocolMetadata{ type messageInfo struct { msg *common.Message from common.NodeID + to common.NodeID } // testChain is a helper that book-keeps the current chain-tip, alongside any @@ -186,8 +187,8 @@ func (tc *testChain) signatureAggregatorCreator(nodes []common.Node) common.Sign } } -// addEpochs adds sealing blocks at epochs, and normal blocks in between -func (tc *testChain) addEpochs(epochs ...uint64) { +// indexEpochs indexes sealing blocks at epochs, and normal blocks in between +func (tc *testChain) indexEpochs(epochs ...uint64) { // ensure that the new epoch we are adding is not already indexed require.Greater(tc.t, epochs[0], tc.seq) diff --git a/nonvalidator/comm_test.go b/nonvalidator/comm_test.go index 02888f7f..88864180 100644 --- a/nonvalidator/comm_test.go +++ b/nonvalidator/comm_test.go @@ -9,111 +9,180 @@ import ( "testing" "github.com/ava-labs/simplex/common" + "github.com/ava-labs/simplex/simplex" + "github.com/ava-labs/simplex/testutil" + "github.com/stretchr/testify/require" ) -// nonValidatorResponderComm implements common.Communication and is used during tests -// to create responses to any requests a node sends or broadcasts. -type nonValidatorResponderComm struct { - t *testing.T +// messageQueue keeps a queue of messages +type messageQueue struct { + responsesLock sync.Mutex + responses []*messageInfo +} + +func (m *messageQueue) clearResponses() { + m.responsesLock.Lock() + defer m.responsesLock.Unlock() + m.responses = []*messageInfo{} +} + +func (m *messageQueue) enqueue(mi *messageInfo) { + m.responsesLock.Lock() + defer m.responsesLock.Unlock() + m.responses = append(m.responses, mi) +} - // storage is used to create the responses - storage common.Storage +func (m *messageQueue) popResponse() (*messageInfo, bool) { + m.responsesLock.Lock() + defer m.responsesLock.Unlock() + if len(m.responses) == 0 { + return nil, false + } + msg := m.responses[0] + m.responses = m.responses[1:] + return msg, true +} + +// routerComm appends messages being sent or broadcast to the message queue. +type routerComm struct { + t *testing.T - // nodes should contain the validator set of the highest epoch nodes common.Nodes - // ID is the NodeID of the non-validator using this comm. Broadcasts - // pick the first node in `nodes` that is not equal to ID as the - // simulated responder. + // ID is the sender of messages and is the Node using the struct. ID common.NodeID - // responses is the queue of synthesized responses. Tests will pop messages - // and feed entries into NonValidator.HandleMessage. - responsesLock sync.Mutex - responses []*messageInfo + messageQueue *messageQueue } -func newTestResponder(t *testing.T, myNodeID common.NodeID, tc *testChain) *nonValidatorResponderComm { - return &nonValidatorResponderComm{ - nodes: tc.nodes(), - t: t, - ID: myNodeID, - storage: tc, - } +type testEpochs struct { + t *testing.T + epochs []*simplex.Epoch } -func (r *nonValidatorResponderComm) Validators() common.Nodes { return r.nodes } +func newTestEpochs(tc *testChain, msgQueue *messageQueue, maxSeqWindow uint64) *testEpochs { + nodes := tc.nodes() + epochs := make([]*simplex.Epoch, 0, len(nodes)) -// Enqueues a response coming from `destination`. -func (r *nonValidatorResponderComm) Send(msg *common.Message, destination common.NodeID) { - r.handle(msg, destination) -} + for _, node := range nodes { + epochNodeID := node.Id -// Enqueues responses coming from all other nodes in the network. -func (r *nonValidatorResponderComm) Broadcast(msg *common.Message) { - - for _, n := range r.nodes { - if bytes.Equal(n.Id, r.ID) { - continue + comm := &routerComm{ + nodes: nodes, + t: tc.t, + ID: epochNodeID, + messageQueue: msgQueue, } - r.handle(msg, n.Id) + conf, _, _ := testutil.DefaultTestNodeEpochConfig(tc.t, epochNodeID, comm, testutil.NewTestBlockBuilder()) + conf.MaxRoundWindow = maxSeqWindow + conf.Storage = tc + conf.SignatureAggregatorCreator = tc.signatureAggregatorCreator + conf.ReplicationEnabled = true + + epoch, err := simplex.NewEpoch(conf) + require.NoError(tc.t, err) + + epochs = append(epochs, epoch) + } + + return &testEpochs{ + t: tc.t, + epochs: epochs, } } -func (r *nonValidatorResponderComm) handle(msg *common.Message, from common.NodeID) { - switch { - case msg.ReplicationRequest != nil: - r.respondToReplicationRequest(msg.ReplicationRequest, from) +// start starts every epoch in the network, failing the test if any epoch +// fails to start. +func (e *testEpochs) start() { + for _, epoch := range e.epochs { + require.NoError(e.t, epoch.Start()) } } -func (r *nonValidatorResponderComm) clearResponses() { - r.responsesLock.Lock() - defer r.responsesLock.Unlock() - r.responses = []*messageInfo{} +// stop stops every epoch in the network. +func (e *testEpochs) stop() { + for _, epoch := range e.epochs { + epoch.Stop() + } } -func (r *nonValidatorResponderComm) respondToReplicationRequest(req *common.ReplicationRequest, from common.NodeID) { - resp := &common.ReplicationResponse{} +// handleMessage routes messages between the non-validator and its epochs: a request +// originating from the non-validator is delivered to the addressed epoch, while +// a response from an epoch is delivered back to the non-validator. +func handleMessage(epochs *testEpochs, nv *NonValidator, mi *messageInfo) { + if !bytes.Equal(mi.from, nv.ID) { + // a response from an epoch: deliver it to the non-validator. + require.NoError(epochs.t, nv.HandleMessage(mi.msg, mi.from)) + return + } - for _, seq := range req.Seqs { - block, fin, err := r.storage.Retrieve(seq) - if err == nil { - resp.Data = append(resp.Data, common.QuorumRound{Block: block.(common.Block), Finalization: &fin}) + // a request from the non-validator: route it to the addressed epoch. + for _, epoch := range epochs.epochs { + if bytes.Equal(epoch.ID, mi.to) { + require.NoError(epochs.t, epoch.HandleMessage(mi.msg, mi.from)) + return } } + require.Failf(epochs.t, "no epoch for destination", "destination %x", mi.to) +} + +func (r *routerComm) Validators() common.Nodes { return r.nodes } + +// Enqueues a message sent from this node to `destination`. +func (r *routerComm) Send(msg *common.Message, destination common.NodeID) { + r.handle(msg, destination) +} - if req.LatestFinalizedSeq > 0 && r.storage.NumBlocks() > 0 { - numBlocks := r.storage.NumBlocks() - if req.LatestFinalizedSeq < numBlocks-1 { - block, fin, err := r.storage.Retrieve(numBlocks - 1) - if err == nil { - resp.LatestSeq = &common.QuorumRound{Block: block.(common.Block), Finalization: &fin} - } +// Enqueues a copy of the message addressed to every other node in the network. +func (r *routerComm) Broadcast(msg *common.Message) { + for _, n := range r.nodes { + if bytes.Equal(n.Id, r.ID) { + continue } - } - if len(resp.Data) == 0 && resp.LatestSeq == nil { - return + r.handle(msg, n.Id) } - - r.enqueue(&messageInfo{msg: &common.Message{ReplicationResponse: resp}, from: from}) } -func (r *nonValidatorResponderComm) enqueue(m *messageInfo) { - r.responsesLock.Lock() - defer r.responsesLock.Unlock() - r.responses = append(r.responses, m) +func (r *routerComm) handle(msg *common.Message, to common.NodeID) { + switch { + case msg.VerifiedReplicationResponse != nil: + // Outgoing responses are of the verified type, but incoming responses + // are of the unverified type, so we translate before enqueuing. + vrr := msg.VerifiedReplicationResponse + data := make([]common.QuorumRound, 0, len(vrr.Data)) + for _, vqr := range vrr.Data { + data = append(data, *verifiedQRtoQR(&vqr)) + } + + msg = &common.Message{ + ReplicationResponse: &common.ReplicationResponse{ + Data: data, + LatestRound: verifiedQRtoQR(vrr.LatestRound), + LatestSeq: verifiedQRtoQR(vrr.LatestFinalizedSeq), + }, + } + r.messageQueue.enqueue(&messageInfo{msg: msg, from: r.ID, to: to}) + default: + r.messageQueue.enqueue(&messageInfo{msg: msg, from: r.ID, to: to}) + } } -func (r *nonValidatorResponderComm) popResponse() (*messageInfo, bool) { - r.responsesLock.Lock() - defer r.responsesLock.Unlock() - if len(r.responses) == 0 { - return nil, false +func verifiedQRtoQR(vqr *common.VerifiedQuorumRound) *common.QuorumRound { + if vqr == nil { + return nil + } + + qr := &common.QuorumRound{ + Notarization: vqr.Notarization, + Finalization: vqr.Finalization, + EmptyNotarization: vqr.EmptyNotarization, } - m := r.responses[0] - r.responses = r.responses[1:] - return m, true + + if vqr.VerifiedBlock != nil { + qr.Block = vqr.VerifiedBlock.(common.Block) + } + + return qr } diff --git a/nonvalidator/non_validator_test.go b/nonvalidator/non_validator_test.go index 6e5dd9d0..52239bef 100644 --- a/nonvalidator/non_validator_test.go +++ b/nonvalidator/non_validator_test.go @@ -322,18 +322,24 @@ func TestHandleMessages_DuplicateBlock(t *testing.T) { // epoch on startup. func TestNonValidator_RequestHighestEpochOnStart(t *testing.T) { tc := newSeededChain(t, testNodes, 2) - tc.addEpochs(4, 8) - responder := newTestResponder(t, testNodes.NodeIDs()[0], tc) + tc.indexEpochs(4, 8) + + myNodeID := common.NodeID{100} + comm := &routerComm{ + messageQueue: &messageQueue{}, + ID: myNodeID, + nodes: tc.nodes(), + } nvStorage := tc.CloneUntil(2) nv, err := NewNonValidator( Config{ Storage: nvStorage, - Comm: responder, + Comm: comm, Logger: testutil.MakeLogger(t, 1), SignatureAggregatorCreator: tc.signatureAggregatorCreator, MaxSequenceWindow: simplex.DefaultMaxRoundWindow, - ID: responder.ID, + ID: myNodeID, }, ) @@ -342,10 +348,10 @@ func TestNonValidator_RequestHighestEpochOnStart(t *testing.T) { nv.Start() defer nv.Stop() - msg, ok := responder.popResponse() + msg, ok := comm.messageQueue.popResponse() require.True(t, ok) - require.NotNil(t, msg.msg.ReplicationResponse) - require.NotNil(t, msg.msg.ReplicationResponse.LatestSeq) + require.NotNil(t, msg.msg.ReplicationRequest) + require.Equal(t, uint64(1), msg.msg.ReplicationRequest.LatestFinalizedSeq) } // TestNonValidator_Bootstrap ensures a non-validator can replicate sequences given different states of the chain. @@ -365,7 +371,7 @@ func TestNonValidator_Bootstrap(t *testing.T) { name: "replicates epochs", setup: func(t *testing.T) *testChain { tc := newSeededChain(t, testNodes, 2) - tc.addEpochs(5, 10, 20, 30, 40) + tc.indexEpochs(5, 10, 20, 30, 40) return tc }, maxSequenceWindow: 50, @@ -377,7 +383,7 @@ func TestNonValidator_Bootstrap(t *testing.T) { name: "past max round window", setup: func(t *testing.T) *testChain { tc := newSeededChain(t, testNodes, 2) - tc.addEpochs(5, 10, 20, 30, 40, 50, 60, 80, 100) + tc.indexEpochs(5, 10, 20, 30, 40, 50, 60, 80, 100) return tc }, maxSequenceWindow: 5, // significantly lower @@ -389,7 +395,7 @@ func TestNonValidator_Bootstrap(t *testing.T) { name: "from genesis", setup: func(t *testing.T) *testChain { tc := newSeededChain(t, testNodes, 2) - tc.addEpochs(5, 10, 20) + tc.indexEpochs(5, 10, 20) return tc }, maxSequenceWindow: 50, @@ -416,7 +422,7 @@ func TestNonValidator_Bootstrap(t *testing.T) { tc := newSnowToSimplexChain(t, 10) firstBlock := tc.appendFirstSimplexAfterGenesis(testNodes) tc.Index(context.Background(), firstBlock, tc.newFinalization(firstBlock)) - tc.addEpochs(20, 30) + tc.indexEpochs(20, 30) return tc }, maxSequenceWindow: 50, @@ -428,18 +434,29 @@ func TestNonValidator_Bootstrap(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tc := tt.setup(t) - responder := newTestResponder(t, testNodes.NodeIDs()[0], tc) + myNodeID := common.NodeID{16} + msgQueue := &messageQueue{} + nonValidatorComm := &routerComm{ + nodes: tc.nodes(), + t: tc.t, + ID: myNodeID, + messageQueue: msgQueue, + } + epochs := newTestEpochs(tc, msgQueue, tt.maxSequenceWindow) + epochs.start() + defer epochs.stop() + nvStorage := tc.CloneUntil(tt.initialHeight) require.Equal(t, tt.initialHeight, nvStorage.NumBlocks()) nv, err := NewNonValidator( Config{ Storage: nvStorage, - Comm: responder, + Comm: nonValidatorComm, Logger: testutil.MakeLogger(t, 1), SignatureAggregatorCreator: tc.signatureAggregatorCreator, MaxSequenceWindow: tt.maxSequenceWindow, - ID: responder.ID, + ID: myNodeID, StartTime: time.Now(), }, ) @@ -448,7 +465,7 @@ func TestNonValidator_Bootstrap(t *testing.T) { nv.Start() defer nv.Stop() - advanceUntil(nv, responder, tt.lastSeq) + advanceUntil(nv, epochs, msgQueue, tt.lastSeq) }) } } @@ -457,18 +474,30 @@ func TestNonValidator_ReplicationRequests(t *testing.T) { tc := newSeededChain(t, testNodes, 2) lastSeq := uint64(40) initialHeight := uint64(2) - tc.addEpochs(5, 10, 20, 30, lastSeq) - responder := newTestResponder(t, testNodes.NodeIDs()[0], tc) + tc.indexEpochs(5, 10, 20, 30, lastSeq) + maxSeqWindow := uint64(50) + myNodeID := common.NodeID{255} + msgQueue := &messageQueue{} + nonValidatorComm := &routerComm{ + nodes: tc.nodes(), + t: tc.t, + ID: myNodeID, + messageQueue: msgQueue, + } + epochs := newTestEpochs(tc, msgQueue, maxSeqWindow) + epochs.start() + defer epochs.stop() + nvStorage := tc.CloneUntil(initialHeight) startTime := time.Now() nv, err := NewNonValidator( Config{ Storage: nvStorage, - Comm: responder, + Comm: nonValidatorComm, Logger: testutil.MakeLogger(t, 1), SignatureAggregatorCreator: tc.signatureAggregatorCreator, - MaxSequenceWindow: 50, - ID: responder.ID, + MaxSequenceWindow: maxSeqWindow, + ID: myNodeID, StartTime: startTime, }, ) @@ -479,43 +508,48 @@ func TestNonValidator_ReplicationRequests(t *testing.T) { count := 0 for { - // Send any requests as responses back to the node - for msg, ok := responder.popResponse(); ok; { - // drop every other message - if count%2 == 0 { - require.NoError(t, nv.HandleMessage(msg.msg, msg.from)) + for msg, ok := msgQueue.popResponse(); ok; { + // drops 25% of messages + // TODO: we can handle a higher threshold once we implement https://github.com/ava-labs/Simplex/issues/425 + if count%4 != 0 { + handleMessage(epochs, nv, msg) } - count += 1 - msg, ok = responder.popResponse() + count++ + msg, ok = msgQueue.popResponse() } // check if storage has indexed all - if lastSeq == nvStorage.NumBlocks()-1 { + if lastSeq == nv.Storage.NumBlocks()-1 { break } - // update the time + // update the time so pending replication requests time out and re-fire startTime = startTime.Add(simplex.DefaultReplicationRequestTimeout) + time.Sleep(50 * time.Millisecond) nv.AdvanceTime(startTime) } - // clear in flight responses + // clear in flight messages startTime = startTime.Add(simplex.DefaultReplicationRequestTimeout) nv.AdvanceTime(startTime) - responder.clearResponses() + time.Sleep(50 * time.Millisecond) + msgQueue.clearResponses() - // ensure all timeout tasks were removed + // ensure all timeout tasks were removed: advancing time should no longer + // cause the non-validator to emit any further requests. count = 0 for { startTime = startTime.Add(simplex.DefaultReplicationRequestTimeout) nv.AdvanceTime(startTime) - msg, ok := responder.popResponse() + time.Sleep(50 * time.Millisecond) + + msg, ok := msgQueue.popResponse() require.False(t, ok, fmt.Sprintf("all replication request tasks should be finished %v", msg)) if count > 3 { break } - count += 1 + count++ } } @@ -618,14 +652,13 @@ func TestNonValidator_VerifiesFinalizationDuringReplication(t *testing.T) { ) } -func advanceUntil(nv *NonValidator, responder *nonValidatorResponderComm, seq uint64) { +func advanceUntil(nv *NonValidator, epochs *testEpochs, msgQueue *messageQueue, seq uint64) { startTime := nv.StartTime for { // Send any requests as responses back to the node - for msg, ok := responder.popResponse(); ok; { - // drop every other message - require.NoError(responder.t, nv.HandleMessage(msg.msg, msg.from)) - msg, ok = responder.popResponse() + for msg, ok := msgQueue.popResponse(); ok; { + handleMessage(epochs, nv, msg) + msg, ok = msgQueue.popResponse() } // check if storage has indexed all @@ -635,6 +668,7 @@ func advanceUntil(nv *NonValidator, responder *nonValidatorResponderComm, seq ui // update the time startTime = startTime.Add(simplex.DefaultReplicationRequestTimeout) + time.Sleep(50 * time.Millisecond) nv.AdvanceTime(startTime) } } diff --git a/simplex/epoch.go b/simplex/epoch.go index 543359b1..9cf3c943 100644 --- a/simplex/epoch.go +++ b/simplex/epoch.go @@ -93,8 +93,8 @@ type Epoch struct { finishFn context.CancelFunc blockBuilderCtx context.Context blockBuilderCancelFunc context.CancelFunc - nodeIDs common.NodeIDs - nodes common.Nodes + validatorNodeIDs common.NodeIDs + validators common.Nodes validatorsToPKs map[string][]byte rounds map[uint64]*Round emptyVotes map[uint64]*EmptyVoteSet @@ -151,13 +151,17 @@ func (e *Epoch) HandleMessage(msg *common.Message, from common.NodeID) error { return nil } - // Guard against receiving messages from unknown nodes _, known := e.validatorsToPKs[string(from)] if !known { - e.Logger.Debug("Received message from an unknown node", zap.Stringer("nodeID", from)) - return nil + e.Logger.Debug("Received message from a non-validator node", zap.Stringer("nodeID", from)) + switch { + case msg.ReplicationRequest != nil && e.ReplicationEnabled: + return e.handleReplicationRequest(msg.ReplicationRequest, from) + default: + e.Logger.Debug("Invalid message type", zap.Stringer("from", from)) + return nil + } } - switch { case msg.BlockMessage != nil: return e.handleBlockMessage(msg.BlockMessage, from) @@ -199,22 +203,22 @@ func (e *Epoch) init() error { e.finishCtx, e.finishFn = context.WithCancel(context.Background()) e.blockBuilderCtx = context.Background() e.blockBuilderCancelFunc = func() {} - e.nodes = e.Comm.Validators() - common.SortNodes(e.nodes) - e.nodeIDs = e.nodes.NodeIDs() - e.timedOutRounds = make(map[uint16]uint64, len(e.nodeIDs)) - e.redeemedRounds = make(map[uint16]uint64, len(e.nodeIDs)) + e.validators = e.Comm.Validators() + common.SortNodes(e.validators) + e.validatorNodeIDs = e.validators.NodeIDs() + e.timedOutRounds = make(map[uint16]uint64, len(e.validatorNodeIDs)) + e.redeemedRounds = make(map[uint16]uint64, len(e.validatorNodeIDs)) e.rounds = make(map[uint64]*Round) e.emptyVotes = make(map[uint64]*EmptyVoteSet) - e.futureMessages = make(messagesFromNode, len(e.nodeIDs)) + e.futureMessages = make(messagesFromNode, len(e.validatorNodeIDs)) e.replicationState = NewReplicationState(e.Logger, e.Comm, e.ID, e.MaxRoundWindow, e.ReplicationEnabled, e.StartTime, &e.lock, e.RandomSource) e.timeoutHandler = common.NewTimeoutHandler(e.Logger, "emptyVoteRebroadcast", e.StartTime, e.MaxRebroadcastWait, e.emptyVoteTimeoutTaskRunner) - e.signatureAggregator = e.SignatureAggregatorCreator(e.nodes) - e.validatorsToPKs = make(map[string][]byte, len(e.nodeIDs)) - for _, node := range e.nodes { + e.signatureAggregator = e.SignatureAggregatorCreator(e.validators) + e.validatorsToPKs = make(map[string][]byte, len(e.validatorNodeIDs)) + for _, node := range e.validators { e.validatorsToPKs[string(node.Id)] = node.PK } - for _, node := range e.nodeIDs { + for _, node := range e.validatorNodeIDs { e.futureMessages[string(node)] = make(map[uint64]*messagesForRound) } err := e.loadLastBlock() @@ -222,7 +226,7 @@ func (e *Epoch) init() error { return err } - e.Logger.Info("Starting Simplex Epoch", zap.String("ID", e.ID.String()), zap.Stringer("nodes", e.nodeIDs)) + e.Logger.Info("Starting Simplex Epoch", zap.String("ID", e.ID.String()), zap.Stringer("nodes", e.validatorNodeIDs)) return e.setMetadataFromStorage() } @@ -280,7 +284,7 @@ func (e *Epoch) Start() error { } // Only init receiving messages once you have initialized the data structures required for it. - e.Logger.Debug("Epoch is ready to receive messages") + e.Logger.Debug("Epoch is ready to receive messages", zap.Uint64("epoch", e.Epoch)) e.canReceiveMessages.Store(true) e.broadcastReplicationSync() @@ -574,7 +578,7 @@ func (e *Epoch) resumeFromWal(highestRoundRecord *walRound) error { return fmt.Errorf("could not find round %d for block", block.BlockHeader().Round) } - if e.ID.Equals(LeaderForRound(e.nodeIDs, block.BlockHeader().Round)) { + if e.ID.Equals(LeaderForRound(e.validatorNodeIDs, block.BlockHeader().Round)) { vote, err := e.voteOnBlock(round.block) if err != nil { return err @@ -741,7 +745,7 @@ func (e *Epoch) handleFinalizationMessage(message *common.Finalization, from com return nil } - if err := VerifyQC(message.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, message, e.nodes); err != nil { + if err := VerifyQC(message.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, message, e.validators); err != nil { e.Logger.Debug(fmt.Sprintf("Finalization %s", err), zap.Int("round", int(message.Finalization.Round)), zap.Stringer("NodeID", from)) @@ -779,7 +783,7 @@ func (e *Epoch) handleFinalizationForPendingOrFutureRound(message *common.Finali // TODO: delay requesting future finalizations and blocks, since blocks could be in transit e.Logger.Debug("Received finalization for a pending or future round, and we don't have the block", zap.Uint64("round", round), zap.Uint64("our round", e.round)) - if LeaderForRound(e.nodeIDs, e.round).Equals(e.ID) { + if LeaderForRound(e.validatorNodeIDs, e.round).Equals(e.ID) { e.Logger.Debug("We are the leader of this round, but a higher round has been finalized. Aborting block building.") e.blockBuilderCancelFunc() } @@ -1491,7 +1495,7 @@ func (e *Epoch) persistEmptyNotarization(emptyVotes *EmptyVoteSet, shouldBroadca } func (e *Epoch) maybeMarkLeaderAsTimedOutForFutureBlacklisting(emptyNotarization *common.EmptyNotarization) error { - e.Logger.Debug("Marking the leader as timed out", zap.Uint64("round", emptyNotarization.Vote.Round), zap.Stringer("leader", LeaderForRound(e.nodeIDs, emptyNotarization.Vote.Round))) + e.Logger.Debug("Marking the leader as timed out", zap.Uint64("round", emptyNotarization.Vote.Round), zap.Stringer("leader", LeaderForRound(e.validatorNodeIDs, emptyNotarization.Vote.Round))) var blacklist common.Blacklist if e.lastBlock != nil { if e.lastBlock.VerifiedBlock == nil { @@ -1501,7 +1505,7 @@ func (e *Epoch) maybeMarkLeaderAsTimedOutForFutureBlacklisting(emptyNotarization blacklist = e.lastBlock.VerifiedBlock.Blacklist() } round := emptyNotarization.Vote.Round - leaderIndex := round % uint64(len(e.nodeIDs)) + leaderIndex := round % uint64(len(e.validatorNodeIDs)) if !blacklist.IsNodeSuspected(uint16(leaderIndex)) { e.timedOutRounds[uint16(leaderIndex)] = round } @@ -1573,7 +1577,7 @@ func (e *Epoch) persistNotarization(notarization common.Notarization) error { round := notarization.Vote.Round for _, signer := range notarization.QC.Signers() { - if signerIndex := e.nodeIDs.IndexOf(signer); signerIndex != -1 { + if signerIndex := e.validatorNodeIDs.IndexOf(signer); signerIndex != -1 { e.Logger.Debug("Potentially redeeming node", zap.Stringer("signer", signer), zap.Uint64("round", round)) e.redeemedRounds[uint16(signerIndex)] = round } else { @@ -1618,7 +1622,7 @@ func (e *Epoch) handleEmptyNotarizationMessage(emptyNotarization *common.EmptyNo } // Otherwise, this round is not notarized or finalized yet, so verify the empty notarization and store it. - if err := VerifyQC(emptyNotarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, emptyNotarization, e.nodes); err != nil { + if err := VerifyQC(emptyNotarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, emptyNotarization, e.validators); err != nil { e.Logger.Debug(fmt.Sprintf("Empty notarization %s", err), zap.Stringer("NodeID", from)) return nil @@ -1676,7 +1680,7 @@ func (e *Epoch) handleNotarizationMessage(message *common.Notarization, from com return nil } - if err := VerifyQC(message.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, message, e.nodes); err != nil { + if err := VerifyQC(message.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, message, e.validators); err != nil { e.Logger.Debug(fmt.Sprintf("Notarization %s", err), zap.Stringer("NodeID", from)) return nil @@ -1758,7 +1762,7 @@ func (e *Epoch) handleBlockMessage(message *common.BlockMessage, from common.Nod } // Check that the node is a leader for the round corresponding to the block. - if !LeaderForRound(e.nodeIDs, md.Round).Equals(from) { + if !LeaderForRound(e.validatorNodeIDs, md.Round).Equals(from) { // The block is associated with a round in which the sender is not the leader, // it should not be sending us any block at all. e.Logger.Debug("Got block from a block proposer that is not the leader of the round", zap.Stringer("NodeID", from), zap.Uint64("round", md.Round)) @@ -2017,7 +2021,7 @@ func (e *Epoch) createBlockVerificationTask(block common.Block, from common.Node defer e.lock.Unlock() if err != nil { - leader := LeaderForRound(e.nodeIDs, md.Round) + leader := LeaderForRound(e.validatorNodeIDs, md.Round) e.Logger.Info("Triggering empty block agreement", zap.String("reason", "Failed verifying block"), zap.Uint64("round", md.Round), @@ -2253,7 +2257,7 @@ func (e *Epoch) verifyProposalMetadataAndBlacklist(block common.Block) bool { // Else, either it's not the first block, or we haven't committed the first block, and it is the first block. // If it's the latter we have nothing else to do. // If it's the former, we need to find the parent of the block and ensure it is correct. - prevBlacklist := common.NewBlacklist(uint16(len(e.nodeIDs))) + prevBlacklist := common.NewBlacklist(uint16(len(e.validatorNodeIDs))) if bh.Seq > 0 { prevBlock, _, found := e.locateBlock(bh.Seq-1, bh.Prev[:]) if !found { @@ -2270,7 +2274,7 @@ func (e *Epoch) verifyProposalMetadataAndBlacklist(block common.Block) bool { prevBlacklist = prevBlock.Blacklist() if prevBlacklist.IsEmpty() { - prevBlacklist = common.NewBlacklist(uint16(len(e.nodeIDs))) + prevBlacklist = common.NewBlacklist(uint16(len(e.validatorNodeIDs))) } } @@ -2370,7 +2374,7 @@ func (e *Epoch) buildBlock() { } // If I'm blacklisted, I cannot propose a block. - if prevBlacklist.IsNodeSuspected(uint16(e.nodeIDs.IndexOf(e.ID))) { + if prevBlacklist.IsNodeSuspected(uint16(e.validatorNodeIDs.IndexOf(e.ID))) { e.Logger.Debug("I'm blacklisted, cannot propose a block", zap.Uint64("round", metadata.Round), zap.Stringer("blacklist", &prevBlacklist)) e.triggerEmptyBlockNotarization(metadata.Round) return @@ -2384,7 +2388,7 @@ func (e *Epoch) buildBlock() { e.Logger.Debug("Computing blacklist updates", zap.String("timedOutRounds", fmt.Sprintf("%v", e.timedOutRounds)), zap.String("redeemedRounds", fmt.Sprintf("%v", e.redeemedRounds))) - updates := prevBlacklist.ComputeBlacklistUpdates(metadata.Round, uint16(len(e.nodeIDs)), e.timedOutRounds, e.redeemedRounds) + updates := prevBlacklist.ComputeBlacklistUpdates(metadata.Round, uint16(len(e.validatorNodeIDs)), e.timedOutRounds, e.redeemedRounds) // 3) Apply the updates to the blacklist. nextBlacklist := prevBlacklist.ApplyUpdates(updates, metadata.Round) @@ -2426,7 +2430,7 @@ func (e *Epoch) retrieveBlacklistOfParentBlock(metadata common.ProtocolMetadata) } if blacklist.IsEmpty() { - blacklist = common.NewBlacklist(uint16(len(e.nodeIDs))) + blacklist = common.NewBlacklist(uint16(len(e.validatorNodeIDs))) } return blacklist, true @@ -2626,7 +2630,7 @@ func (e *Epoch) monitorProgress(round uint64) { noop := func() {} - leader := LeaderForRound(e.nodeIDs, round) + leader := LeaderForRound(e.validatorNodeIDs, round) // If we have a task pending to be executed, remove it from execution because we're about to schedule // a task for a higher round. @@ -2646,7 +2650,7 @@ func (e *Epoch) monitorProgress(round uint64) { return } - leader := LeaderForRound(e.nodeIDs, round) + leader := LeaderForRound(e.validatorNodeIDs, round) e.Logger.Debug("Triggering empty block agreement", zap.String("reason", "Timed out on block agreement"), zap.Uint64("round", round), @@ -2666,7 +2670,7 @@ func (e *Epoch) monitorProgress(round uint64) { // If the current leader is blacklisted, we should not wait for it to propose a block. // Instead, we should immediately trigger the empty block agreement. - leaderIndex := e.nodeIDs.IndexOf(leader) + leaderIndex := e.validatorNodeIDs.IndexOf(leader) if leaderIndex >= 0 && blacklist.IsNodeSuspected(uint16(leaderIndex)) { e.Logger.Debug("Leader is blacklisted, will not wait for it to propose a block", zap.Uint64("round", round), zap.Stringer("leader", leader)) @@ -2749,7 +2753,7 @@ func (e *Epoch) startRound() error { return err } - leaderForCurrentRound := LeaderForRound(e.nodeIDs, e.round) + leaderForCurrentRound := LeaderForRound(e.validatorNodeIDs, e.round) if e.ID.Equals(leaderForCurrentRound) { e.buildBlock() @@ -2829,8 +2833,8 @@ func (e *Epoch) increaseRound() { // remove the rebroadcast empty vote task e.timeoutHandler.RemoveTask(EmptyVoteTimeoutID) - prevLeader := LeaderForRound(e.nodeIDs, e.round) - nextLeader := LeaderForRound(e.nodeIDs, e.round+1) + prevLeader := LeaderForRound(e.validatorNodeIDs, e.round) + nextLeader := LeaderForRound(e.validatorNodeIDs, e.round+1) e.Logger.Info("Moving to a new round", zap.Uint64("prev round", e.round), @@ -2989,7 +2993,7 @@ func (e *Epoch) storeProposal(block common.VerifiedBlock) bool { // HandleRequest processes a request and returns a response. It also sends a response to the sender. func (e *Epoch) handleReplicationRequest(req *common.ReplicationRequest, from common.NodeID) error { - e.Logger.Debug("Received replication request", zap.Stringer("from", from), zap.Int("num seqs", len(req.Seqs)), zap.Int("num rounds", len(req.Rounds)), zap.Uint64("latest round", req.LatestRound)) + e.Logger.Debug("Received replication request", zap.Stringer("from", from), zap.Uint64s("seqs", req.Seqs), zap.Int("num rounds", len(req.Rounds)), zap.Uint64("latest round", req.LatestRound), zap.Uint64("latest seq", req.LatestFinalizedSeq)) if !e.ReplicationEnabled { return nil } @@ -3236,19 +3240,19 @@ func (e *Epoch) verifyQuorumRound(q common.QuorumRound) error { if q.Finalization != nil { // extra check needed if we have a finalized block - if err := VerifyQC(q.Finalization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.Finalization, e.nodes); err != nil { + if err := VerifyQC(q.Finalization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.Finalization, e.validators); err != nil { return fmt.Errorf("invalid finalization: %w", err) } } if q.Notarization != nil { - if err := VerifyQC(q.Notarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.Notarization, e.nodes); err != nil { + if err := VerifyQC(q.Notarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.Notarization, e.validators); err != nil { return fmt.Errorf("invalid notarization: %w", err) } } if q.EmptyNotarization != nil { - if err := VerifyQC(q.EmptyNotarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.EmptyNotarization, e.nodes); err != nil { + if err := VerifyQC(q.EmptyNotarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.EmptyNotarization, e.validators); err != nil { return fmt.Errorf("invalid empty notarization QC: %w", err) } } diff --git a/simplex/replication_request_test.go b/simplex/replication_request_test.go index 54ac1765..f1f34d82 100644 --- a/simplex/replication_request_test.go +++ b/simplex/replication_request_test.go @@ -302,6 +302,78 @@ func TestReplicationRequestUnknownSeqsAndRounds(t *testing.T) { require.Never(t, func() bool { return len(comm.in) > 0 }, 5*time.Second, 100*time.Millisecond) } +// TestNonValidatorReplicationRequestServed ensures that a replication request +// from a non-validator node is still served when replication is enabled. +func TestNonValidatorReplicationRequestServed(t *testing.T) { + bb := testutil.NewTestBlockBuilder() + nodes := []common.NodeID{{1}, {2}, {3}, {4}} + comm := NewListenerComm(nodes) + ctx := context.Background() + conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], comm, bb) + conf.ReplicationEnabled = true + + numBlocks := uint64(10) + seqs := createBlocks(t, nodes, numBlocks) + for _, data := range seqs { + err := conf.Storage.Index(ctx, data.VerifiedBlock, data.Finalization) + require.NoError(t, err) + } + e, err := simplex.NewEpoch(conf) + require.NoError(t, err) + t.Cleanup(e.Stop) + require.NoError(t, e.Start()) + + // a node that is not part of the validator set + nonValidator := common.NodeID{5} + + sequences := []uint64{0, 1, 2, 3} + req := &common.Message{ + ReplicationRequest: &common.ReplicationRequest{ + Seqs: sequences, + LatestRound: numBlocks, + }, + } + + err = e.HandleMessage(req, nonValidator) + require.NoError(t, err) + + msg := <-comm.in + resp := msg.VerifiedReplicationResponse + require.Equal(t, len(sequences), len(resp.Data)) + for i, data := range resp.Data { + require.Equal(t, seqs[i].Finalization, *data.Finalization) + require.Equal(t, seqs[i].VerifiedBlock, data.VerifiedBlock) + } +} + +// TestNonValidatorNonReplicationMessageDropped ensures that non-replication +// messages from a non-validator node are dropped and have no effect on the +// epoch's state. +func TestNonValidatorNonReplicationMessageDropped(t *testing.T) { + bb := testutil.NewTestBlockBuilder() + nodes := []common.NodeID{{1}, {2}, {3}, {4}} + comm := NewListenerComm(nodes) + conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], comm, bb) + conf.ReplicationEnabled = true + + e, err := simplex.NewEpoch(conf) + require.NoError(t, err) + t.Cleanup(e.Stop) + require.NoError(t, e.Start()) + + nonValidator := common.NodeID{5} + + // an empty notarization from a validator would advance the round; from a + // non-validator it must be dropped, leaving the round unchanged. + emptyNotarization := testutil.NewEmptyNotarization(nodes, 0) + err = e.HandleMessage(&common.Message{ + EmptyNotarization: emptyNotarization, + }, nonValidator) + require.NoError(t, err) + + require.Equal(t, uint64(0), e.Metadata().Round) +} + func TestNilReplicationResponse(t *testing.T) { nodes := []common.NodeID{{1}, {2}, {3}, {4}} net := testutil.NewControlledNetwork(t, nodes)