diff --git a/nonvalidator/chain_helpers_test.go b/nonvalidator/chain_helpers_test.go index cbe027b4..651bbbe0 100644 --- a/nonvalidator/chain_helpers_test.go +++ b/nonvalidator/chain_helpers_test.go @@ -29,6 +29,7 @@ var genesis = testutil.NewTestBlock(common.ProtocolMetadata{ type messageInfo struct { msg *common.Message from common.NodeID + to common.NodeID } // testChain is a helper that book-keeps the current chain-tip, alongside any @@ -186,8 +187,8 @@ func (tc *testChain) signatureAggregatorCreator(nodes []common.Node) common.Sign } } -// addEpochs adds sealing blocks at epochs, and normal blocks in between -func (tc *testChain) addEpochs(epochs ...uint64) { +// indexEpochs indexes sealing blocks at epochs, and normal blocks in between +func (tc *testChain) indexEpochs(epochs ...uint64) { // ensure that the new epoch we are adding is not already indexed require.Greater(tc.t, epochs[0], tc.seq) diff --git a/nonvalidator/comm_test.go b/nonvalidator/comm_test.go index 02888f7f..88864180 100644 --- a/nonvalidator/comm_test.go +++ b/nonvalidator/comm_test.go @@ -9,111 +9,180 @@ import ( "testing" "github.com/ava-labs/simplex/common" + "github.com/ava-labs/simplex/simplex" + "github.com/ava-labs/simplex/testutil" + "github.com/stretchr/testify/require" ) -// nonValidatorResponderComm implements common.Communication and is used during tests -// to create responses to any requests a node sends or broadcasts. -type nonValidatorResponderComm struct { - t *testing.T +// messageQueue keeps a queue of messages +type messageQueue struct { + responsesLock sync.Mutex + responses []*messageInfo +} + +func (m *messageQueue) clearResponses() { + m.responsesLock.Lock() + defer m.responsesLock.Unlock() + m.responses = []*messageInfo{} +} + +func (m *messageQueue) enqueue(mi *messageInfo) { + m.responsesLock.Lock() + defer m.responsesLock.Unlock() + m.responses = append(m.responses, mi) +} - // storage is used to create the responses - storage common.Storage +func (m *messageQueue) popResponse() (*messageInfo, bool) { + m.responsesLock.Lock() + defer m.responsesLock.Unlock() + if len(m.responses) == 0 { + return nil, false + } + msg := m.responses[0] + m.responses = m.responses[1:] + return msg, true +} + +// routerComm appends messages being sent or broadcast to the message queue. +type routerComm struct { + t *testing.T - // nodes should contain the validator set of the highest epoch nodes common.Nodes - // ID is the NodeID of the non-validator using this comm. Broadcasts - // pick the first node in `nodes` that is not equal to ID as the - // simulated responder. + // ID is the sender of messages and is the Node using the struct. ID common.NodeID - // responses is the queue of synthesized responses. Tests will pop messages - // and feed entries into NonValidator.HandleMessage. - responsesLock sync.Mutex - responses []*messageInfo + messageQueue *messageQueue } -func newTestResponder(t *testing.T, myNodeID common.NodeID, tc *testChain) *nonValidatorResponderComm { - return &nonValidatorResponderComm{ - nodes: tc.nodes(), - t: t, - ID: myNodeID, - storage: tc, - } +type testEpochs struct { + t *testing.T + epochs []*simplex.Epoch } -func (r *nonValidatorResponderComm) Validators() common.Nodes { return r.nodes } +func newTestEpochs(tc *testChain, msgQueue *messageQueue, maxSeqWindow uint64) *testEpochs { + nodes := tc.nodes() + epochs := make([]*simplex.Epoch, 0, len(nodes)) -// Enqueues a response coming from `destination`. -func (r *nonValidatorResponderComm) Send(msg *common.Message, destination common.NodeID) { - r.handle(msg, destination) -} + for _, node := range nodes { + epochNodeID := node.Id -// Enqueues responses coming from all other nodes in the network. -func (r *nonValidatorResponderComm) Broadcast(msg *common.Message) { - - for _, n := range r.nodes { - if bytes.Equal(n.Id, r.ID) { - continue + comm := &routerComm{ + nodes: nodes, + t: tc.t, + ID: epochNodeID, + messageQueue: msgQueue, } - r.handle(msg, n.Id) + conf, _, _ := testutil.DefaultTestNodeEpochConfig(tc.t, epochNodeID, comm, testutil.NewTestBlockBuilder()) + conf.MaxRoundWindow = maxSeqWindow + conf.Storage = tc + conf.SignatureAggregatorCreator = tc.signatureAggregatorCreator + conf.ReplicationEnabled = true + + epoch, err := simplex.NewEpoch(conf) + require.NoError(tc.t, err) + + epochs = append(epochs, epoch) + } + + return &testEpochs{ + t: tc.t, + epochs: epochs, } } -func (r *nonValidatorResponderComm) handle(msg *common.Message, from common.NodeID) { - switch { - case msg.ReplicationRequest != nil: - r.respondToReplicationRequest(msg.ReplicationRequest, from) +// start starts every epoch in the network, failing the test if any epoch +// fails to start. +func (e *testEpochs) start() { + for _, epoch := range e.epochs { + require.NoError(e.t, epoch.Start()) } } -func (r *nonValidatorResponderComm) clearResponses() { - r.responsesLock.Lock() - defer r.responsesLock.Unlock() - r.responses = []*messageInfo{} +// stop stops every epoch in the network. +func (e *testEpochs) stop() { + for _, epoch := range e.epochs { + epoch.Stop() + } } -func (r *nonValidatorResponderComm) respondToReplicationRequest(req *common.ReplicationRequest, from common.NodeID) { - resp := &common.ReplicationResponse{} +// handleMessage routes messages between the non-validator and its epochs: a request +// originating from the non-validator is delivered to the addressed epoch, while +// a response from an epoch is delivered back to the non-validator. +func handleMessage(epochs *testEpochs, nv *NonValidator, mi *messageInfo) { + if !bytes.Equal(mi.from, nv.ID) { + // a response from an epoch: deliver it to the non-validator. + require.NoError(epochs.t, nv.HandleMessage(mi.msg, mi.from)) + return + } - for _, seq := range req.Seqs { - block, fin, err := r.storage.Retrieve(seq) - if err == nil { - resp.Data = append(resp.Data, common.QuorumRound{Block: block.(common.Block), Finalization: &fin}) + // a request from the non-validator: route it to the addressed epoch. + for _, epoch := range epochs.epochs { + if bytes.Equal(epoch.ID, mi.to) { + require.NoError(epochs.t, epoch.HandleMessage(mi.msg, mi.from)) + return } } + require.Failf(epochs.t, "no epoch for destination", "destination %x", mi.to) +} + +func (r *routerComm) Validators() common.Nodes { return r.nodes } + +// Enqueues a message sent from this node to `destination`. +func (r *routerComm) Send(msg *common.Message, destination common.NodeID) { + r.handle(msg, destination) +} - if req.LatestFinalizedSeq > 0 && r.storage.NumBlocks() > 0 { - numBlocks := r.storage.NumBlocks() - if req.LatestFinalizedSeq < numBlocks-1 { - block, fin, err := r.storage.Retrieve(numBlocks - 1) - if err == nil { - resp.LatestSeq = &common.QuorumRound{Block: block.(common.Block), Finalization: &fin} - } +// Enqueues a copy of the message addressed to every other node in the network. +func (r *routerComm) Broadcast(msg *common.Message) { + for _, n := range r.nodes { + if bytes.Equal(n.Id, r.ID) { + continue } - } - if len(resp.Data) == 0 && resp.LatestSeq == nil { - return + r.handle(msg, n.Id) } - - r.enqueue(&messageInfo{msg: &common.Message{ReplicationResponse: resp}, from: from}) } -func (r *nonValidatorResponderComm) enqueue(m *messageInfo) { - r.responsesLock.Lock() - defer r.responsesLock.Unlock() - r.responses = append(r.responses, m) +func (r *routerComm) handle(msg *common.Message, to common.NodeID) { + switch { + case msg.VerifiedReplicationResponse != nil: + // Outgoing responses are of the verified type, but incoming responses + // are of the unverified type, so we translate before enqueuing. + vrr := msg.VerifiedReplicationResponse + data := make([]common.QuorumRound, 0, len(vrr.Data)) + for _, vqr := range vrr.Data { + data = append(data, *verifiedQRtoQR(&vqr)) + } + + msg = &common.Message{ + ReplicationResponse: &common.ReplicationResponse{ + Data: data, + LatestRound: verifiedQRtoQR(vrr.LatestRound), + LatestSeq: verifiedQRtoQR(vrr.LatestFinalizedSeq), + }, + } + r.messageQueue.enqueue(&messageInfo{msg: msg, from: r.ID, to: to}) + default: + r.messageQueue.enqueue(&messageInfo{msg: msg, from: r.ID, to: to}) + } } -func (r *nonValidatorResponderComm) popResponse() (*messageInfo, bool) { - r.responsesLock.Lock() - defer r.responsesLock.Unlock() - if len(r.responses) == 0 { - return nil, false +func verifiedQRtoQR(vqr *common.VerifiedQuorumRound) *common.QuorumRound { + if vqr == nil { + return nil + } + + qr := &common.QuorumRound{ + Notarization: vqr.Notarization, + Finalization: vqr.Finalization, + EmptyNotarization: vqr.EmptyNotarization, } - m := r.responses[0] - r.responses = r.responses[1:] - return m, true + + if vqr.VerifiedBlock != nil { + qr.Block = vqr.VerifiedBlock.(common.Block) + } + + return qr } diff --git a/nonvalidator/non_validator_test.go b/nonvalidator/non_validator_test.go index 6e5dd9d0..52239bef 100644 --- a/nonvalidator/non_validator_test.go +++ b/nonvalidator/non_validator_test.go @@ -322,18 +322,24 @@ func TestHandleMessages_DuplicateBlock(t *testing.T) { // epoch on startup. func TestNonValidator_RequestHighestEpochOnStart(t *testing.T) { tc := newSeededChain(t, testNodes, 2) - tc.addEpochs(4, 8) - responder := newTestResponder(t, testNodes.NodeIDs()[0], tc) + tc.indexEpochs(4, 8) + + myNodeID := common.NodeID{100} + comm := &routerComm{ + messageQueue: &messageQueue{}, + ID: myNodeID, + nodes: tc.nodes(), + } nvStorage := tc.CloneUntil(2) nv, err := NewNonValidator( Config{ Storage: nvStorage, - Comm: responder, + Comm: comm, Logger: testutil.MakeLogger(t, 1), SignatureAggregatorCreator: tc.signatureAggregatorCreator, MaxSequenceWindow: simplex.DefaultMaxRoundWindow, - ID: responder.ID, + ID: myNodeID, }, ) @@ -342,10 +348,10 @@ func TestNonValidator_RequestHighestEpochOnStart(t *testing.T) { nv.Start() defer nv.Stop() - msg, ok := responder.popResponse() + msg, ok := comm.messageQueue.popResponse() require.True(t, ok) - require.NotNil(t, msg.msg.ReplicationResponse) - require.NotNil(t, msg.msg.ReplicationResponse.LatestSeq) + require.NotNil(t, msg.msg.ReplicationRequest) + require.Equal(t, uint64(1), msg.msg.ReplicationRequest.LatestFinalizedSeq) } // TestNonValidator_Bootstrap ensures a non-validator can replicate sequences given different states of the chain. @@ -365,7 +371,7 @@ func TestNonValidator_Bootstrap(t *testing.T) { name: "replicates epochs", setup: func(t *testing.T) *testChain { tc := newSeededChain(t, testNodes, 2) - tc.addEpochs(5, 10, 20, 30, 40) + tc.indexEpochs(5, 10, 20, 30, 40) return tc }, maxSequenceWindow: 50, @@ -377,7 +383,7 @@ func TestNonValidator_Bootstrap(t *testing.T) { name: "past max round window", setup: func(t *testing.T) *testChain { tc := newSeededChain(t, testNodes, 2) - tc.addEpochs(5, 10, 20, 30, 40, 50, 60, 80, 100) + tc.indexEpochs(5, 10, 20, 30, 40, 50, 60, 80, 100) return tc }, maxSequenceWindow: 5, // significantly lower @@ -389,7 +395,7 @@ func TestNonValidator_Bootstrap(t *testing.T) { name: "from genesis", setup: func(t *testing.T) *testChain { tc := newSeededChain(t, testNodes, 2) - tc.addEpochs(5, 10, 20) + tc.indexEpochs(5, 10, 20) return tc }, maxSequenceWindow: 50, @@ -416,7 +422,7 @@ func TestNonValidator_Bootstrap(t *testing.T) { tc := newSnowToSimplexChain(t, 10) firstBlock := tc.appendFirstSimplexAfterGenesis(testNodes) tc.Index(context.Background(), firstBlock, tc.newFinalization(firstBlock)) - tc.addEpochs(20, 30) + tc.indexEpochs(20, 30) return tc }, maxSequenceWindow: 50, @@ -428,18 +434,29 @@ func TestNonValidator_Bootstrap(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tc := tt.setup(t) - responder := newTestResponder(t, testNodes.NodeIDs()[0], tc) + myNodeID := common.NodeID{16} + msgQueue := &messageQueue{} + nonValidatorComm := &routerComm{ + nodes: tc.nodes(), + t: tc.t, + ID: myNodeID, + messageQueue: msgQueue, + } + epochs := newTestEpochs(tc, msgQueue, tt.maxSequenceWindow) + epochs.start() + defer epochs.stop() + nvStorage := tc.CloneUntil(tt.initialHeight) require.Equal(t, tt.initialHeight, nvStorage.NumBlocks()) nv, err := NewNonValidator( Config{ Storage: nvStorage, - Comm: responder, + Comm: nonValidatorComm, Logger: testutil.MakeLogger(t, 1), SignatureAggregatorCreator: tc.signatureAggregatorCreator, MaxSequenceWindow: tt.maxSequenceWindow, - ID: responder.ID, + ID: myNodeID, StartTime: time.Now(), }, ) @@ -448,7 +465,7 @@ func TestNonValidator_Bootstrap(t *testing.T) { nv.Start() defer nv.Stop() - advanceUntil(nv, responder, tt.lastSeq) + advanceUntil(nv, epochs, msgQueue, tt.lastSeq) }) } } @@ -457,18 +474,30 @@ func TestNonValidator_ReplicationRequests(t *testing.T) { tc := newSeededChain(t, testNodes, 2) lastSeq := uint64(40) initialHeight := uint64(2) - tc.addEpochs(5, 10, 20, 30, lastSeq) - responder := newTestResponder(t, testNodes.NodeIDs()[0], tc) + tc.indexEpochs(5, 10, 20, 30, lastSeq) + maxSeqWindow := uint64(50) + myNodeID := common.NodeID{255} + msgQueue := &messageQueue{} + nonValidatorComm := &routerComm{ + nodes: tc.nodes(), + t: tc.t, + ID: myNodeID, + messageQueue: msgQueue, + } + epochs := newTestEpochs(tc, msgQueue, maxSeqWindow) + epochs.start() + defer epochs.stop() + nvStorage := tc.CloneUntil(initialHeight) startTime := time.Now() nv, err := NewNonValidator( Config{ Storage: nvStorage, - Comm: responder, + Comm: nonValidatorComm, Logger: testutil.MakeLogger(t, 1), SignatureAggregatorCreator: tc.signatureAggregatorCreator, - MaxSequenceWindow: 50, - ID: responder.ID, + MaxSequenceWindow: maxSeqWindow, + ID: myNodeID, StartTime: startTime, }, ) @@ -479,43 +508,48 @@ func TestNonValidator_ReplicationRequests(t *testing.T) { count := 0 for { - // Send any requests as responses back to the node - for msg, ok := responder.popResponse(); ok; { - // drop every other message - if count%2 == 0 { - require.NoError(t, nv.HandleMessage(msg.msg, msg.from)) + for msg, ok := msgQueue.popResponse(); ok; { + // drops 25% of messages + // TODO: we can handle a higher threshold once we implement https://github.com/ava-labs/Simplex/issues/425 + if count%4 != 0 { + handleMessage(epochs, nv, msg) } - count += 1 - msg, ok = responder.popResponse() + count++ + msg, ok = msgQueue.popResponse() } // check if storage has indexed all - if lastSeq == nvStorage.NumBlocks()-1 { + if lastSeq == nv.Storage.NumBlocks()-1 { break } - // update the time + // update the time so pending replication requests time out and re-fire startTime = startTime.Add(simplex.DefaultReplicationRequestTimeout) + time.Sleep(50 * time.Millisecond) nv.AdvanceTime(startTime) } - // clear in flight responses + // clear in flight messages startTime = startTime.Add(simplex.DefaultReplicationRequestTimeout) nv.AdvanceTime(startTime) - responder.clearResponses() + time.Sleep(50 * time.Millisecond) + msgQueue.clearResponses() - // ensure all timeout tasks were removed + // ensure all timeout tasks were removed: advancing time should no longer + // cause the non-validator to emit any further requests. count = 0 for { startTime = startTime.Add(simplex.DefaultReplicationRequestTimeout) nv.AdvanceTime(startTime) - msg, ok := responder.popResponse() + time.Sleep(50 * time.Millisecond) + + msg, ok := msgQueue.popResponse() require.False(t, ok, fmt.Sprintf("all replication request tasks should be finished %v", msg)) if count > 3 { break } - count += 1 + count++ } } @@ -618,14 +652,13 @@ func TestNonValidator_VerifiesFinalizationDuringReplication(t *testing.T) { ) } -func advanceUntil(nv *NonValidator, responder *nonValidatorResponderComm, seq uint64) { +func advanceUntil(nv *NonValidator, epochs *testEpochs, msgQueue *messageQueue, seq uint64) { startTime := nv.StartTime for { // Send any requests as responses back to the node - for msg, ok := responder.popResponse(); ok; { - // drop every other message - require.NoError(responder.t, nv.HandleMessage(msg.msg, msg.from)) - msg, ok = responder.popResponse() + for msg, ok := msgQueue.popResponse(); ok; { + handleMessage(epochs, nv, msg) + msg, ok = msgQueue.popResponse() } // check if storage has indexed all @@ -635,6 +668,7 @@ func advanceUntil(nv *NonValidator, responder *nonValidatorResponderComm, seq ui // update the time startTime = startTime.Add(simplex.DefaultReplicationRequestTimeout) + time.Sleep(50 * time.Millisecond) nv.AdvanceTime(startTime) } } diff --git a/simplex/epoch.go b/simplex/epoch.go index 5d24c7e2..9cf3c943 100644 --- a/simplex/epoch.go +++ b/simplex/epoch.go @@ -162,7 +162,6 @@ func (e *Epoch) HandleMessage(msg *common.Message, from common.NodeID) error { return nil } } - switch { case msg.BlockMessage != nil: return e.handleBlockMessage(msg.BlockMessage, from) @@ -285,7 +284,7 @@ func (e *Epoch) Start() error { } // Only init receiving messages once you have initialized the data structures required for it. - e.Logger.Debug("Epoch is ready to receive messages") + e.Logger.Debug("Epoch is ready to receive messages", zap.Uint64("epoch", e.Epoch)) e.canReceiveMessages.Store(true) e.broadcastReplicationSync() @@ -2994,7 +2993,7 @@ func (e *Epoch) storeProposal(block common.VerifiedBlock) bool { // HandleRequest processes a request and returns a response. It also sends a response to the sender. func (e *Epoch) handleReplicationRequest(req *common.ReplicationRequest, from common.NodeID) error { - e.Logger.Debug("Received replication request", zap.Stringer("from", from), zap.Int("num seqs", len(req.Seqs)), zap.Int("num rounds", len(req.Rounds)), zap.Uint64("latest round", req.LatestRound)) + e.Logger.Debug("Received replication request", zap.Stringer("from", from), zap.Uint64s("seqs", req.Seqs), zap.Int("num rounds", len(req.Rounds)), zap.Uint64("latest round", req.LatestRound), zap.Uint64("latest seq", req.LatestFinalizedSeq)) if !e.ReplicationEnabled { return nil }