diff --git a/simplex/epoch.go b/simplex/epoch.go index 543359b1..5d24c7e2 100644 --- a/simplex/epoch.go +++ b/simplex/epoch.go @@ -93,8 +93,8 @@ type Epoch struct { finishFn context.CancelFunc blockBuilderCtx context.Context blockBuilderCancelFunc context.CancelFunc - nodeIDs common.NodeIDs - nodes common.Nodes + validatorNodeIDs common.NodeIDs + validators common.Nodes validatorsToPKs map[string][]byte rounds map[uint64]*Round emptyVotes map[uint64]*EmptyVoteSet @@ -151,11 +151,16 @@ func (e *Epoch) HandleMessage(msg *common.Message, from common.NodeID) error { return nil } - // Guard against receiving messages from unknown nodes _, known := e.validatorsToPKs[string(from)] if !known { - e.Logger.Debug("Received message from an unknown node", zap.Stringer("nodeID", from)) - return nil + e.Logger.Debug("Received message from a non-validator node", zap.Stringer("nodeID", from)) + switch { + case msg.ReplicationRequest != nil && e.ReplicationEnabled: + return e.handleReplicationRequest(msg.ReplicationRequest, from) + default: + e.Logger.Debug("Invalid message type", zap.Stringer("from", from)) + return nil + } } switch { @@ -199,22 +204,22 @@ func (e *Epoch) init() error { e.finishCtx, e.finishFn = context.WithCancel(context.Background()) e.blockBuilderCtx = context.Background() e.blockBuilderCancelFunc = func() {} - e.nodes = e.Comm.Validators() - common.SortNodes(e.nodes) - e.nodeIDs = e.nodes.NodeIDs() - e.timedOutRounds = make(map[uint16]uint64, len(e.nodeIDs)) - e.redeemedRounds = make(map[uint16]uint64, len(e.nodeIDs)) + e.validators = e.Comm.Validators() + common.SortNodes(e.validators) + e.validatorNodeIDs = e.validators.NodeIDs() + e.timedOutRounds = make(map[uint16]uint64, len(e.validatorNodeIDs)) + e.redeemedRounds = make(map[uint16]uint64, len(e.validatorNodeIDs)) e.rounds = make(map[uint64]*Round) e.emptyVotes = make(map[uint64]*EmptyVoteSet) - e.futureMessages = make(messagesFromNode, len(e.nodeIDs)) + e.futureMessages = make(messagesFromNode, len(e.validatorNodeIDs)) e.replicationState = NewReplicationState(e.Logger, e.Comm, e.ID, e.MaxRoundWindow, e.ReplicationEnabled, e.StartTime, &e.lock, e.RandomSource) e.timeoutHandler = common.NewTimeoutHandler(e.Logger, "emptyVoteRebroadcast", e.StartTime, e.MaxRebroadcastWait, e.emptyVoteTimeoutTaskRunner) - e.signatureAggregator = e.SignatureAggregatorCreator(e.nodes) - e.validatorsToPKs = make(map[string][]byte, len(e.nodeIDs)) - for _, node := range e.nodes { + e.signatureAggregator = e.SignatureAggregatorCreator(e.validators) + e.validatorsToPKs = make(map[string][]byte, len(e.validatorNodeIDs)) + for _, node := range e.validators { e.validatorsToPKs[string(node.Id)] = node.PK } - for _, node := range e.nodeIDs { + for _, node := range e.validatorNodeIDs { e.futureMessages[string(node)] = make(map[uint64]*messagesForRound) } err := e.loadLastBlock() @@ -222,7 +227,7 @@ func (e *Epoch) init() error { return err } - e.Logger.Info("Starting Simplex Epoch", zap.String("ID", e.ID.String()), zap.Stringer("nodes", e.nodeIDs)) + e.Logger.Info("Starting Simplex Epoch", zap.String("ID", e.ID.String()), zap.Stringer("nodes", e.validatorNodeIDs)) return e.setMetadataFromStorage() } @@ -574,7 +579,7 @@ func (e *Epoch) resumeFromWal(highestRoundRecord *walRound) error { return fmt.Errorf("could not find round %d for block", block.BlockHeader().Round) } - if e.ID.Equals(LeaderForRound(e.nodeIDs, block.BlockHeader().Round)) { + if e.ID.Equals(LeaderForRound(e.validatorNodeIDs, block.BlockHeader().Round)) { vote, err := e.voteOnBlock(round.block) if err != nil { return err @@ -741,7 +746,7 @@ func (e *Epoch) handleFinalizationMessage(message *common.Finalization, from com return nil } - if err := VerifyQC(message.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, message, e.nodes); err != nil { + if err := VerifyQC(message.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, message, e.validators); err != nil { e.Logger.Debug(fmt.Sprintf("Finalization %s", err), zap.Int("round", int(message.Finalization.Round)), zap.Stringer("NodeID", from)) @@ -779,7 +784,7 @@ func (e *Epoch) handleFinalizationForPendingOrFutureRound(message *common.Finali // TODO: delay requesting future finalizations and blocks, since blocks could be in transit e.Logger.Debug("Received finalization for a pending or future round, and we don't have the block", zap.Uint64("round", round), zap.Uint64("our round", e.round)) - if LeaderForRound(e.nodeIDs, e.round).Equals(e.ID) { + if LeaderForRound(e.validatorNodeIDs, e.round).Equals(e.ID) { e.Logger.Debug("We are the leader of this round, but a higher round has been finalized. Aborting block building.") e.blockBuilderCancelFunc() } @@ -1491,7 +1496,7 @@ func (e *Epoch) persistEmptyNotarization(emptyVotes *EmptyVoteSet, shouldBroadca } func (e *Epoch) maybeMarkLeaderAsTimedOutForFutureBlacklisting(emptyNotarization *common.EmptyNotarization) error { - e.Logger.Debug("Marking the leader as timed out", zap.Uint64("round", emptyNotarization.Vote.Round), zap.Stringer("leader", LeaderForRound(e.nodeIDs, emptyNotarization.Vote.Round))) + e.Logger.Debug("Marking the leader as timed out", zap.Uint64("round", emptyNotarization.Vote.Round), zap.Stringer("leader", LeaderForRound(e.validatorNodeIDs, emptyNotarization.Vote.Round))) var blacklist common.Blacklist if e.lastBlock != nil { if e.lastBlock.VerifiedBlock == nil { @@ -1501,7 +1506,7 @@ func (e *Epoch) maybeMarkLeaderAsTimedOutForFutureBlacklisting(emptyNotarization blacklist = e.lastBlock.VerifiedBlock.Blacklist() } round := emptyNotarization.Vote.Round - leaderIndex := round % uint64(len(e.nodeIDs)) + leaderIndex := round % uint64(len(e.validatorNodeIDs)) if !blacklist.IsNodeSuspected(uint16(leaderIndex)) { e.timedOutRounds[uint16(leaderIndex)] = round } @@ -1573,7 +1578,7 @@ func (e *Epoch) persistNotarization(notarization common.Notarization) error { round := notarization.Vote.Round for _, signer := range notarization.QC.Signers() { - if signerIndex := e.nodeIDs.IndexOf(signer); signerIndex != -1 { + if signerIndex := e.validatorNodeIDs.IndexOf(signer); signerIndex != -1 { e.Logger.Debug("Potentially redeeming node", zap.Stringer("signer", signer), zap.Uint64("round", round)) e.redeemedRounds[uint16(signerIndex)] = round } else { @@ -1618,7 +1623,7 @@ func (e *Epoch) handleEmptyNotarizationMessage(emptyNotarization *common.EmptyNo } // Otherwise, this round is not notarized or finalized yet, so verify the empty notarization and store it. - if err := VerifyQC(emptyNotarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, emptyNotarization, e.nodes); err != nil { + if err := VerifyQC(emptyNotarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, emptyNotarization, e.validators); err != nil { e.Logger.Debug(fmt.Sprintf("Empty notarization %s", err), zap.Stringer("NodeID", from)) return nil @@ -1676,7 +1681,7 @@ func (e *Epoch) handleNotarizationMessage(message *common.Notarization, from com return nil } - if err := VerifyQC(message.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, message, e.nodes); err != nil { + if err := VerifyQC(message.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, message, e.validators); err != nil { e.Logger.Debug(fmt.Sprintf("Notarization %s", err), zap.Stringer("NodeID", from)) return nil @@ -1758,7 +1763,7 @@ func (e *Epoch) handleBlockMessage(message *common.BlockMessage, from common.Nod } // Check that the node is a leader for the round corresponding to the block. - if !LeaderForRound(e.nodeIDs, md.Round).Equals(from) { + if !LeaderForRound(e.validatorNodeIDs, md.Round).Equals(from) { // The block is associated with a round in which the sender is not the leader, // it should not be sending us any block at all. e.Logger.Debug("Got block from a block proposer that is not the leader of the round", zap.Stringer("NodeID", from), zap.Uint64("round", md.Round)) @@ -2017,7 +2022,7 @@ func (e *Epoch) createBlockVerificationTask(block common.Block, from common.Node defer e.lock.Unlock() if err != nil { - leader := LeaderForRound(e.nodeIDs, md.Round) + leader := LeaderForRound(e.validatorNodeIDs, md.Round) e.Logger.Info("Triggering empty block agreement", zap.String("reason", "Failed verifying block"), zap.Uint64("round", md.Round), @@ -2253,7 +2258,7 @@ func (e *Epoch) verifyProposalMetadataAndBlacklist(block common.Block) bool { // Else, either it's not the first block, or we haven't committed the first block, and it is the first block. // If it's the latter we have nothing else to do. // If it's the former, we need to find the parent of the block and ensure it is correct. - prevBlacklist := common.NewBlacklist(uint16(len(e.nodeIDs))) + prevBlacklist := common.NewBlacklist(uint16(len(e.validatorNodeIDs))) if bh.Seq > 0 { prevBlock, _, found := e.locateBlock(bh.Seq-1, bh.Prev[:]) if !found { @@ -2270,7 +2275,7 @@ func (e *Epoch) verifyProposalMetadataAndBlacklist(block common.Block) bool { prevBlacklist = prevBlock.Blacklist() if prevBlacklist.IsEmpty() { - prevBlacklist = common.NewBlacklist(uint16(len(e.nodeIDs))) + prevBlacklist = common.NewBlacklist(uint16(len(e.validatorNodeIDs))) } } @@ -2370,7 +2375,7 @@ func (e *Epoch) buildBlock() { } // If I'm blacklisted, I cannot propose a block. - if prevBlacklist.IsNodeSuspected(uint16(e.nodeIDs.IndexOf(e.ID))) { + if prevBlacklist.IsNodeSuspected(uint16(e.validatorNodeIDs.IndexOf(e.ID))) { e.Logger.Debug("I'm blacklisted, cannot propose a block", zap.Uint64("round", metadata.Round), zap.Stringer("blacklist", &prevBlacklist)) e.triggerEmptyBlockNotarization(metadata.Round) return @@ -2384,7 +2389,7 @@ func (e *Epoch) buildBlock() { e.Logger.Debug("Computing blacklist updates", zap.String("timedOutRounds", fmt.Sprintf("%v", e.timedOutRounds)), zap.String("redeemedRounds", fmt.Sprintf("%v", e.redeemedRounds))) - updates := prevBlacklist.ComputeBlacklistUpdates(metadata.Round, uint16(len(e.nodeIDs)), e.timedOutRounds, e.redeemedRounds) + updates := prevBlacklist.ComputeBlacklistUpdates(metadata.Round, uint16(len(e.validatorNodeIDs)), e.timedOutRounds, e.redeemedRounds) // 3) Apply the updates to the blacklist. nextBlacklist := prevBlacklist.ApplyUpdates(updates, metadata.Round) @@ -2426,7 +2431,7 @@ func (e *Epoch) retrieveBlacklistOfParentBlock(metadata common.ProtocolMetadata) } if blacklist.IsEmpty() { - blacklist = common.NewBlacklist(uint16(len(e.nodeIDs))) + blacklist = common.NewBlacklist(uint16(len(e.validatorNodeIDs))) } return blacklist, true @@ -2626,7 +2631,7 @@ func (e *Epoch) monitorProgress(round uint64) { noop := func() {} - leader := LeaderForRound(e.nodeIDs, round) + leader := LeaderForRound(e.validatorNodeIDs, round) // If we have a task pending to be executed, remove it from execution because we're about to schedule // a task for a higher round. @@ -2646,7 +2651,7 @@ func (e *Epoch) monitorProgress(round uint64) { return } - leader := LeaderForRound(e.nodeIDs, round) + leader := LeaderForRound(e.validatorNodeIDs, round) e.Logger.Debug("Triggering empty block agreement", zap.String("reason", "Timed out on block agreement"), zap.Uint64("round", round), @@ -2666,7 +2671,7 @@ func (e *Epoch) monitorProgress(round uint64) { // If the current leader is blacklisted, we should not wait for it to propose a block. // Instead, we should immediately trigger the empty block agreement. - leaderIndex := e.nodeIDs.IndexOf(leader) + leaderIndex := e.validatorNodeIDs.IndexOf(leader) if leaderIndex >= 0 && blacklist.IsNodeSuspected(uint16(leaderIndex)) { e.Logger.Debug("Leader is blacklisted, will not wait for it to propose a block", zap.Uint64("round", round), zap.Stringer("leader", leader)) @@ -2749,7 +2754,7 @@ func (e *Epoch) startRound() error { return err } - leaderForCurrentRound := LeaderForRound(e.nodeIDs, e.round) + leaderForCurrentRound := LeaderForRound(e.validatorNodeIDs, e.round) if e.ID.Equals(leaderForCurrentRound) { e.buildBlock() @@ -2829,8 +2834,8 @@ func (e *Epoch) increaseRound() { // remove the rebroadcast empty vote task e.timeoutHandler.RemoveTask(EmptyVoteTimeoutID) - prevLeader := LeaderForRound(e.nodeIDs, e.round) - nextLeader := LeaderForRound(e.nodeIDs, e.round+1) + prevLeader := LeaderForRound(e.validatorNodeIDs, e.round) + nextLeader := LeaderForRound(e.validatorNodeIDs, e.round+1) e.Logger.Info("Moving to a new round", zap.Uint64("prev round", e.round), @@ -3236,19 +3241,19 @@ func (e *Epoch) verifyQuorumRound(q common.QuorumRound) error { if q.Finalization != nil { // extra check needed if we have a finalized block - if err := VerifyQC(q.Finalization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.Finalization, e.nodes); err != nil { + if err := VerifyQC(q.Finalization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.Finalization, e.validators); err != nil { return fmt.Errorf("invalid finalization: %w", err) } } if q.Notarization != nil { - if err := VerifyQC(q.Notarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.Notarization, e.nodes); err != nil { + if err := VerifyQC(q.Notarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.Notarization, e.validators); err != nil { return fmt.Errorf("invalid notarization: %w", err) } } if q.EmptyNotarization != nil { - if err := VerifyQC(q.EmptyNotarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.EmptyNotarization, e.nodes); err != nil { + if err := VerifyQC(q.EmptyNotarization.QC, e.signatureAggregator.IsQuorum, e.validatorsToPKs, q.EmptyNotarization, e.validators); err != nil { return fmt.Errorf("invalid empty notarization QC: %w", err) } } diff --git a/simplex/replication_request_test.go b/simplex/replication_request_test.go index 54ac1765..f1f34d82 100644 --- a/simplex/replication_request_test.go +++ b/simplex/replication_request_test.go @@ -302,6 +302,78 @@ func TestReplicationRequestUnknownSeqsAndRounds(t *testing.T) { require.Never(t, func() bool { return len(comm.in) > 0 }, 5*time.Second, 100*time.Millisecond) } +// TestNonValidatorReplicationRequestServed ensures that a replication request +// from a non-validator node is still served when replication is enabled. +func TestNonValidatorReplicationRequestServed(t *testing.T) { + bb := testutil.NewTestBlockBuilder() + nodes := []common.NodeID{{1}, {2}, {3}, {4}} + comm := NewListenerComm(nodes) + ctx := context.Background() + conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], comm, bb) + conf.ReplicationEnabled = true + + numBlocks := uint64(10) + seqs := createBlocks(t, nodes, numBlocks) + for _, data := range seqs { + err := conf.Storage.Index(ctx, data.VerifiedBlock, data.Finalization) + require.NoError(t, err) + } + e, err := simplex.NewEpoch(conf) + require.NoError(t, err) + t.Cleanup(e.Stop) + require.NoError(t, e.Start()) + + // a node that is not part of the validator set + nonValidator := common.NodeID{5} + + sequences := []uint64{0, 1, 2, 3} + req := &common.Message{ + ReplicationRequest: &common.ReplicationRequest{ + Seqs: sequences, + LatestRound: numBlocks, + }, + } + + err = e.HandleMessage(req, nonValidator) + require.NoError(t, err) + + msg := <-comm.in + resp := msg.VerifiedReplicationResponse + require.Equal(t, len(sequences), len(resp.Data)) + for i, data := range resp.Data { + require.Equal(t, seqs[i].Finalization, *data.Finalization) + require.Equal(t, seqs[i].VerifiedBlock, data.VerifiedBlock) + } +} + +// TestNonValidatorNonReplicationMessageDropped ensures that non-replication +// messages from a non-validator node are dropped and have no effect on the +// epoch's state. +func TestNonValidatorNonReplicationMessageDropped(t *testing.T) { + bb := testutil.NewTestBlockBuilder() + nodes := []common.NodeID{{1}, {2}, {3}, {4}} + comm := NewListenerComm(nodes) + conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], comm, bb) + conf.ReplicationEnabled = true + + e, err := simplex.NewEpoch(conf) + require.NoError(t, err) + t.Cleanup(e.Stop) + require.NoError(t, e.Start()) + + nonValidator := common.NodeID{5} + + // an empty notarization from a validator would advance the round; from a + // non-validator it must be dropped, leaving the round unchanged. + emptyNotarization := testutil.NewEmptyNotarization(nodes, 0) + err = e.HandleMessage(&common.Message{ + EmptyNotarization: emptyNotarization, + }, nonValidator) + require.NoError(t, err) + + require.Equal(t, uint64(0), e.Metadata().Round) +} + func TestNilReplicationResponse(t *testing.T) { nodes := []common.NodeID{{1}, {2}, {3}, {4}} net := testutil.NewControlledNetwork(t, nodes)