diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index d667866a87..5af4434870 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -14,6 +14,7 @@ import ( "time" "github.com/NethermindEth/juno/blockchain/networks" + "github.com/NethermindEth/juno/cmd/juno/verify" _ "github.com/NethermindEth/juno/encoder/registry" _ "github.com/NethermindEth/juno/jemalloc" "github.com/NethermindEth/juno/node" @@ -485,7 +486,12 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr junoCmd.Flags().Bool( disableReceivedTxnStreamF, defaultDisableReceivedTxnStream, disableReceivedTxnStreamUsage, ) - junoCmd.AddCommand(GenP2PKeyPair(), DBCmd(defaultDBPath), CompileSierraCmd()) + junoCmd.AddCommand( + GenP2PKeyPair(), + DBCmd(defaultDBPath), + CompileSierraCmd(), + verify.VerifyCmd(defaultDBPath), + ) return junoCmd } diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go new file mode 100644 index 0000000000..8dc125cc2b --- /dev/null +++ b/cmd/juno/verify/trie.go @@ -0,0 +1,102 @@ +package verify + +import ( + "fmt" + "slices" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/utils/log" + verifytrie "github.com/NethermindEth/juno/verify/trie" + "github.com/spf13/cobra" +) + +const ( + verifyTrieType = "type" + verifyContractAddr = "address" +) + +func verifyTrieCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "trie", + Short: "Verify trie integrity", + Long: `Verify trie integrity by rebuilding tries and comparing root hashes.`, + RunE: runTrieVerify, + SilenceUsage: true, + SilenceErrors: true, + } + + cmd.Flags().StringSlice( + verifyTrieType, + nil, + "Trie types to verify (contract, class, contract-storage)."+ + "If not specified, all trie types are verified.", + ) + + cmd.Flags().String( + verifyContractAddr, + "", + "Contract address to verify (only used with --type contract-storage). "+ + "If not specified, all contract storage tries are verified.", + ) + + return cmd +} + +func runTrieVerify(cmd *cobra.Command, args []string) error { + dbPath, err := cmd.Flags().GetString(verifyDBPathF) + if err != nil { + return err + } + + trieTypes, err := cmd.Flags().GetStringSlice(verifyTrieType) + if err != nil { + return err + } + + contractAddrStr, err := cmd.Flags().GetString(verifyContractAddr) + if err != nil { + return err + } + + var tries []verifytrie.TrieType + if len(trieTypes) > 0 { + tries = make([]verifytrie.TrieType, 0, len(trieTypes)) + for _, t := range trieTypes { + tt := verifytrie.TrieType(t) + if !tt.IsValid() { + return fmt.Errorf("invalid trie type %q (allowed: contract, class, contract-storage)", t) + } + tries = append(tries, tt) + } + } + + var contractAddr *felt.Felt + if contractAddrStr != "" { + hasContractStorage := len(tries) == 0 || slices.Contains(tries, verifytrie.ContractStorageTrie) + if !hasContractStorage { + return fmt.Errorf("--address flag can only be used with --type contract-storage") + } + + var addr felt.Felt + _, err = addr.SetString(contractAddrStr) + if err != nil { + return fmt.Errorf("invalid contract address %s: %w", contractAddrStr, err) + } + contractAddr = &addr + } + + database, err := openDB(dbPath) + if err != nil { + return err + } + defer database.Close() + + logger, err := log.NewZapLogger(log.NewLevel(log.INFO)) + if err != nil { + return fmt.Errorf("failed to create logger: %w", err) + } + + verifier := verifytrie.NewTrieVerifier(database, logger, tries, contractAddr) + ctx := cmd.Context() + return verifier.Run(ctx) +} diff --git a/cmd/juno/verify/trie_test.go b/cmd/juno/verify/trie_test.go new file mode 100644 index 0000000000..b5c238635e --- /dev/null +++ b/cmd/juno/verify/trie_test.go @@ -0,0 +1,96 @@ +package verify + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRunTrieVerify_AddressFlagValidation(t *testing.T) { + tests := []struct { + name string + trieTypes []string + address string + expectError bool + expectedErrMsg string + }{ + { + name: "address with contract-storage type should succeed", + trieTypes: []string{"contract-storage"}, + address: "0x123", + expectError: false, + }, + { + name: "address with contract and class types should fail", + trieTypes: []string{"contract", "class"}, + address: "0x123", + expectError: true, + expectedErrMsg: "--address flag can only be used with --type contract-storage", + }, + { + name: "address with contract type only should fail", + trieTypes: []string{"contract"}, + address: "0x1", + expectError: true, + expectedErrMsg: "--address flag can only be used with --type contract-storage", + }, + { + name: "address with no type specified should succeed (default includes contract-storage)", + trieTypes: []string{}, + address: "0x123", + expectError: false, + }, + { + name: "invalid type should fail", + trieTypes: []string{"invalid-type"}, + address: "", + expectError: true, + expectedErrMsg: "invalid trie type", + }, + { + name: "invalid address format should fail", + trieTypes: []string{"contract-storage"}, + address: "not-a-hex", + expectError: true, + expectedErrMsg: "invalid contract address", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parentCmd := VerifyCmd("") + args := []string{"--db-path", "ignored", "trie"} + + for _, trieType := range tt.trieTypes { + args = append(args, "--type", trieType) + } + + if tt.address != "" { + args = append(args, "--address", tt.address) + } + + parentCmd.SetArgs(args) + parentCmd.SetOut(os.Stderr) + parentCmd.SetErr(os.Stderr) + + err := parentCmd.ExecuteContext(context.Background()) + + if tt.expectError { + require.Error(t, err) + if tt.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrMsg) + } + } else if err != nil { + // For "success" cases, we're testing flag validation, not full execution. + // The command may fail downstream (empty DB, no data) - that's expected. + // We only verify that the specific flag validation error we're testing didn't occur. + addrFlagErr := "--address flag can only be used with --type contract-storage" + assert.NotContains(t, err.Error(), addrFlagErr, + "flag validation should pass; downstream errors are acceptable") + } + }) + } +} diff --git a/cmd/juno/verify/verify.go b/cmd/juno/verify/verify.go new file mode 100644 index 0000000000..3611ed2789 --- /dev/null +++ b/cmd/juno/verify/verify.go @@ -0,0 +1,46 @@ +package verify + +import ( + "context" + "errors" + "fmt" + "os" + + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/pebblev2" + "github.com/spf13/cobra" +) + +type Verifier interface { + Name() string + Run(ctx context.Context) error +} + +const verifyDBPathF = "db-path" + +func VerifyCmd(defaultDBPath string) *cobra.Command { + verifyCmd := &cobra.Command{ + Use: "verify", + Short: "Verify database integrity", + Long: `Verify database integrity using various verification methods.`, + } + + verifyCmd.PersistentFlags().String(verifyDBPathF, defaultDBPath, "Path to the database") + verifyCmd.AddCommand(verifyTrieCmd()) + + return verifyCmd +} + +func openDB(path string) (db.KeyValueStore, error) { + _, err := os.Stat(path) + if os.IsNotExist(err) { + return nil, errors.New("database path does not exist") + } + + database, err := pebblev2.New(path) + if err != nil { + return nil, fmt.Errorf("failed to open db: %w", err) + } + + return database, nil +} diff --git a/verify/trie/traversal.go b/verify/trie/traversal.go new file mode 100644 index 0000000000..f4323780ca --- /dev/null +++ b/verify/trie/traversal.go @@ -0,0 +1,56 @@ +package trie + +import ( + "context" + + "golang.org/x/sync/errgroup" +) + +func TraverseBinary[T any]( + ctx context.Context, + depth uint8, + maxConcurrentDepth uint8, + leftFn func(ctx context.Context) (T, error), + rightFn func(ctx context.Context) (T, error), +) (left, right T, err error) { + if depth <= maxConcurrentDepth { + return traverseConcurrently(ctx, leftFn, rightFn) + } + return traverseSequentially(ctx, leftFn, rightFn) +} + +func traverseConcurrently[T any]( + ctx context.Context, + leftFn func(ctx context.Context) (T, error), + rightFn func(ctx context.Context) (T, error), +) (left, right T, err error) { + eg, gCtx := errgroup.WithContext(ctx) + + eg.Go(func() error { + var err error + left, err = leftFn(gCtx) + return err + }) + + eg.Go(func() error { + var err error + right, err = rightFn(gCtx) + return err + }) + + err = eg.Wait() + return left, right, err +} + +func traverseSequentially[T any]( + ctx context.Context, + leftFn func(ctx context.Context) (T, error), + rightFn func(ctx context.Context) (T, error), +) (left, right T, err error) { + left, err = leftFn(ctx) + if err != nil { + return left, right, err + } + right, err = rightFn(ctx) + return left, right, err +} diff --git a/verify/trie/trie_core.go b/verify/trie/trie_core.go new file mode 100644 index 0000000000..1a89b11e8b --- /dev/null +++ b/verify/trie/trie_core.go @@ -0,0 +1,107 @@ +package trie + +import ( + "context" + "errors" + "fmt" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" +) + +var ErrCorruptionDetected = errors.New("corruption detected") + +func VerifyTrie( + ctx context.Context, + reader *trie.ReadStorage, + height uint8, + hashFn crypto.HashFn, + expectedRoot *felt.Felt, +) error { + rootKey, err := reader.RootKey() + if err != nil { + return fmt.Errorf("failed to get root key: %w", err) + } + + if rootKey == nil { + return nil + } + + rootHash, err := verifyNode(ctx, reader, rootKey, nil, height, hashFn) + if err != nil { + return err + } + + if rootHash.Cmp(expectedRoot) != 0 { + return fmt.Errorf( + "%w: root hash mismatch, expected %s, got %s", + ErrCorruptionDetected, expectedRoot.String(), rootHash.String(), + ) + } + + return nil +} + +func verifyNode( + ctx context.Context, + reader *trie.ReadStorage, + key *trie.BitArray, + parentKey *trie.BitArray, + height uint8, + hashFn crypto.HashFn, +) (felt.Felt, error) { + if err := ctx.Err(); err != nil { + return felt.Zero, err + } + + node, err := reader.Get(key) + if err != nil { + return felt.Zero, fmt.Errorf("failed to get node at key %s: %w", key.String(), err) + } + + p := path(key, parentKey) + + if key.Len() == height { + return node.Hash(&p, hashFn), nil + } + + leftFn := func(ctx context.Context) (felt.Felt, error) { + if node.Left.IsEmpty() { + return felt.Zero, nil + } + return verifyNode(ctx, reader, node.Left, key, height, hashFn) + } + + rightFn := func(ctx context.Context) (felt.Felt, error) { + if node.Right.IsEmpty() { + return felt.Zero, nil + } + return verifyNode(ctx, reader, node.Right, key, height, hashFn) + } + + leftHash, rightHash, err := TraverseBinary(ctx, key.Len(), ConcurrencyMaxDepth, leftFn, rightFn) + if err != nil { + return felt.Zero, err + } + + recomputed := hashFn(&leftHash, &rightHash) + if recomputed.Cmp(node.Value) != 0 { + return felt.Zero, fmt.Errorf( + "%w: node at key %s, stored hash=%s, recomputed hash=%s", + ErrCorruptionDetected, key.String(), node.Value.String(), recomputed.String(), + ) + } + + return node.Hash(&p, hashFn), nil +} + +func path(key, parentKey *trie.BitArray) trie.BitArray { + if parentKey == nil { + return key.Copy() + } + + var pathKey trie.BitArray + pathKey.LSBs(key, parentKey.Len()+1) + return pathKey +} diff --git a/verify/trie/trie_verifier.go b/verify/trie/trie_verifier.go new file mode 100644 index 0000000000..dff5e66012 --- /dev/null +++ b/verify/trie/trie_verifier.go @@ -0,0 +1,245 @@ +package trie + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/utils/log" + "go.uber.org/zap" +) + +var ( + stateTrieInfo = TrieInfo{ + Name: "ContractsTrie", + Prefix: db.StateTrie.Key(), + HashFn: crypto.Pedersen, + ReaderFunc: func(r db.KeyValueReader, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPedersen(r, db.StateTrie.Key(), height) + }, + Height: StarknetTrieHeight, + } + + classTrieInfo = TrieInfo{ + Name: "ClassesTrie", + Prefix: db.ClassesTrie.Key(), + HashFn: crypto.Poseidon, + ReaderFunc: func(r db.KeyValueReader, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPoseidon(r, db.ClassesTrie.Key(), height) + }, + Height: StarknetTrieHeight, + } +) + +type TrieVerifier struct { + database db.KeyValueStore + logger log.StructuredLogger + trieTypes []TrieType + contractAddress *felt.Felt +} + +func NewTrieVerifier( + database db.KeyValueStore, + logger log.StructuredLogger, + trieTypes []TrieType, + contractAddress *felt.Felt, +) *TrieVerifier { + if len(trieTypes) == 0 { + trieTypes = allTrieTypes + } + return &TrieVerifier{ + database: database, + logger: logger, + trieTypes: trieTypes, + contractAddress: contractAddress, + } +} + +func (v *TrieVerifier) Name() string { + return "trie" +} + +func (v *TrieVerifier) Run(ctx context.Context) error { + startTime := time.Now() + defer func() { + v.logger.Info("Trie verification finished", + zap.Duration("total_elapsed", time.Since(startTime).Round(time.Second))) + }() + + err := v.database.View(func(snap db.Snapshot) error { + return v.verifyAll(ctx, snap) + }) + if errors.Is(err, context.Canceled) { + return nil + } + if err != nil { + return err + } + + v.logger.Info("Trie verification completed successfully") + return nil +} + +func (v *TrieVerifier) verifyAll(ctx context.Context, snap db.Snapshot) error { + for _, t := range v.trieTypes { + var ( + name string + err error + ) + switch t { + case ContractTrie: + name, err = stateTrieInfo.Name, v.verifyTrie(ctx, snap, stateTrieInfo) + case ClassTrie: + name, err = classTrieInfo.Name, v.verifyTrie(ctx, snap, classTrieInfo) + case ContractStorageTrie: + name, err = "ContractStorageTries", v.verifyContractStorageTries(ctx, snap, v.contractAddress) + } + if err != nil { + v.logResult(err, name) + return err + } + } + return nil +} + +func (v *TrieVerifier) logResult(err error, trieName string) { + switch { + case errors.Is(err, context.Canceled): + v.logger.Info("Verification stopped", zap.String("trie", trieName)) + case errors.Is(err, ErrCorruptionDetected): + v.logger.Error("Corruption detected", + zap.String("trie", trieName), + zap.String("details", err.Error())) + default: + v.logger.Error("Verification error", + zap.String("trie", trieName), + zap.Error(err)) + } +} + +func contractStorageTrieInfo(addr *felt.Felt) TrieInfo { + prefix := db.ContractStorage.Key(addr.Marshal()) + return TrieInfo{ + Name: fmt.Sprintf("ContractStorage[%s]", addr.String()), + Prefix: prefix, + HashFn: crypto.Pedersen, + ReaderFunc: func(r db.KeyValueReader, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPedersen(r, prefix, height) + }, + Height: StarknetTrieHeight, + } +} + +func (v *TrieVerifier) verifyTrie(ctx context.Context, snap db.Snapshot, trieInfo TrieInfo) error { + v.logger.Info("Starting trie verification", zap.String("trie", trieInfo.Name)) + + reader, err := trieInfo.ReaderFunc(snap, trieInfo.Height) + if err != nil { + return fmt.Errorf("failed to open reader for %s: %w", trieInfo.Name, err) + } + if reader.RootKey() == nil { + v.logger.Info("Trie is empty", zap.String("trie", trieInfo.Name)) + return nil + } + + expectedRoot, err := reader.Hash() + if err != nil { + return fmt.Errorf("failed to get stored root hash for %s: %w", trieInfo.Name, err) + } + + v.logger.Info("Verifying trie", + zap.String("trie", trieInfo.Name), + zap.String("expectedRoot", expectedRoot.String())) + + storageReader := trie.NewReadStorage(snap, trieInfo.Prefix) + if err := VerifyTrie( + ctx, + storageReader, + trieInfo.Height, + trieInfo.HashFn, + &expectedRoot, + ); err != nil { + return err + } + + v.logger.Info("Trie verification successful", + zap.String("trie", trieInfo.Name), + zap.String("root", expectedRoot.String())) + return nil +} + +func (v *TrieVerifier) verifyContractStorageTries( + ctx context.Context, snap db.Snapshot, filterAddress *felt.Felt, +) error { + if filterAddress != nil { + return v.verifyTrie(ctx, snap, contractStorageTrieInfo(filterAddress)) + } + + bucketPrefix := db.ContractStorage.Key() + it, err := snap.NewIterator(bucketPrefix, true) + if err != nil { + return fmt.Errorf("failed to open contract storage iterator: %w", err) + } + defer it.Close() + + v.logger.Info("Starting contract storage tries verification") + + count := 0 + addrStart := len(bucketPrefix) + addrEnd := addrStart + felt.Bytes + + for ok := it.First(); ok; { + if err := ctx.Err(); err != nil { + return err + } + + key := it.Key() + if len(key) < addrEnd { + // Unexpected short key in bucket — skip it. + ok = it.Next() + continue + } + addrBytes := key[addrStart:addrEnd] + + var addr felt.Felt + addr.SetBytes(addrBytes) + + count++ + v.logger.Info("Verifying contract storage", + zap.String("contract", addr.String()), + zap.Int("index", count)) + + if err := v.verifyTrie(ctx, snap, contractStorageTrieInfo(&addr)); err != nil { + return err + } + + nextAddr := nextLexAddr(addrBytes) + if nextAddr == nil { + break + } + ok = it.Seek(db.ContractStorage.Key(nextAddr)) + } + + v.logger.Info("All contract storage tries verified successfully", + zap.Int("count", count)) + return nil +} + +// nextLexAddr returns the lexicographically next 32-byte address, or nil if +// addr is the maximum value (all 0xff). +func nextLexAddr(addr []byte) []byte { + out := make([]byte, len(addr)) + copy(out, addr) + for i := len(out) - 1; i >= 0; i-- { + out[i]++ + if out[i] != 0 { + return out + } + } + return nil +} diff --git a/verify/trie/trie_verifier_test.go b/verify/trie/trie_verifier_test.go new file mode 100644 index 0000000000..d083ffc314 --- /dev/null +++ b/verify/trie/trie_verifier_test.go @@ -0,0 +1,490 @@ +package trie + +import ( + "context" + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/NethermindEth/juno/utils/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTrieVerifier_Run_ValidStateTrie(t *testing.T) { + logger := log.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = testTrie.Put(key2, value2) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_ValidClassTrie(t *testing.T) { + logger := log.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.ClassesTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePoseidon(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](10) + value1 := felt.NewFromUint64[felt.Felt](1000) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ClassTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_CorruptedTrie(t *testing.T) { + logger := log.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = testTrie.Put(key2, value2) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + var nodeKey trie.BitArray + nodeKey.SetFelt(StarknetTrieHeight, key1) + + node, err := trieStorage.Get(&nodeKey) + require.NoError(t, err) + require.NotNil(t, node) + require.NotNil(t, node.Value) + + assert.True(t, node.Value.Equal(value1), "Expected value1 but got different value") + + corruptedValue := felt.NewFromUint64[felt.Felt](999999) + node.Value = corruptedValue + + err = trieStorage.Put(&nodeKey, node) + require.NoError(t, err) + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrCorruptionDetected) +} + +func TestTrieVerifier_Run_MultipleTrieTypes(t *testing.T) { + logger := log.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + statePrefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + stateTrie, err := trie.NewTriePedersen(txn, statePrefix, StarknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = stateTrie.Put(key1, value1) + require.NoError(t, err) + + err = stateTrie.Commit() + require.NoError(t, err) + + stateStorage := trie.NewStorage(txn, statePrefix) + if stateTrie.RootKey() != nil { + err = stateStorage.PutRootKey(stateTrie.RootKey()) + require.NoError(t, err) + } + + classPrefix := db.ClassesTrie.Key() + classTrie, err := trie.NewTriePoseidon(txn, classPrefix, StarknetTrieHeight) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = classTrie.Put(key2, value2) + require.NoError(t, err) + + err = classTrie.Commit() + require.NoError(t, err) + + classStorage := trie.NewStorage(txn, classPrefix) + if classTrie.RootKey() != nil { + err = classStorage.PutRootKey(classTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie, ClassTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_EmptyTrie(t *testing.T) { + logger := log.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + + testTrie, err := trie.NewTriePedersen(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + // Empty trie has nil RootKey, so no need to store it + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_ValidContractStorageTrie(t *testing.T) { + logger := log.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + txn := testDB.NewIndexedBatch() + + // Create a state trie with a contract address + // The key at height 251 represents a contract address + statePrefix := db.StateTrie.Key() + stateTrie, err := trie.NewTriePedersen(txn, statePrefix, StarknetTrieHeight) + require.NoError(t, err) + + contractAddr := felt.NewFromUint64[felt.Felt](0x1234) + contractValue := felt.NewFromUint64[felt.Felt](1) + _, err = stateTrie.Put(contractAddr, contractValue) + require.NoError(t, err) + + err = stateTrie.Commit() + require.NoError(t, err) + + stateStorage := trie.NewStorage(txn, statePrefix) + if stateTrie.RootKey() != nil { + err = stateStorage.PutRootKey(stateTrie.RootKey()) + require.NoError(t, err) + } + + // Create a contract storage trie for this contract + storagePrefix := db.ContractStorage.Key(contractAddr.Marshal()) + storageTrie, err := trie.NewTriePedersen(txn, storagePrefix, StarknetTrieHeight) + require.NoError(t, err) + + storageKey := felt.NewFromUint64[felt.Felt](1) + storageValue := felt.NewFromUint64[felt.Felt](100) + _, err = storageTrie.Put(storageKey, storageValue) + require.NoError(t, err) + + err = storageTrie.Commit() + require.NoError(t, err) + + storageTrieStorage := trie.NewStorage(txn, storagePrefix) + if storageTrie.RootKey() != nil { + err = storageTrieStorage.PutRootKey(storageTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractStorageTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_ContractStorageCorruption(t *testing.T) { + logger := log.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + txn := testDB.NewIndexedBatch() + + statePrefix := db.StateTrie.Key() + stateTrie, err := trie.NewTriePedersen(txn, statePrefix, StarknetTrieHeight) + require.NoError(t, err) + + contractAddr := felt.NewFromUint64[felt.Felt](0x1234) + _, err = stateTrie.Put(contractAddr, felt.NewFromUint64[felt.Felt](1)) + require.NoError(t, err) + require.NoError(t, stateTrie.Commit()) + + stateStorage := trie.NewStorage(txn, statePrefix) + if stateTrie.RootKey() != nil { + require.NoError(t, stateStorage.PutRootKey(stateTrie.RootKey())) + } + + storagePrefix := db.ContractStorage.Key(contractAddr.Marshal()) + storageTrie, err := trie.NewTriePedersen(txn, storagePrefix, StarknetTrieHeight) + require.NoError(t, err) + + storageKey1 := felt.NewFromUint64[felt.Felt](1) + storageKey2 := felt.NewFromUint64[felt.Felt](2) + _, err = storageTrie.Put(storageKey1, felt.NewFromUint64[felt.Felt](100)) + require.NoError(t, err) + _, err = storageTrie.Put(storageKey2, felt.NewFromUint64[felt.Felt](200)) + require.NoError(t, err) + require.NoError(t, storageTrie.Commit()) + + storageTrieStorage := trie.NewStorage(txn, storagePrefix) + if storageTrie.RootKey() != nil { + require.NoError(t, storageTrieStorage.PutRootKey(storageTrie.RootKey())) + } + + var leafKey trie.BitArray + leafKey.SetFelt(StarknetTrieHeight, storageKey1) + leafNode, err := storageTrieStorage.Get(&leafKey) + require.NoError(t, err) + require.NotNil(t, leafNode) + + leafNode.Value = felt.NewFromUint64[felt.Felt](999999) + require.NoError(t, storageTrieStorage.Put(&leafKey, leafNode)) + + require.NoError(t, txn.Write()) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractStorageTrie}, contractAddr) + + err = verifier.Run(context.Background()) + require.Error(t, err) + assert.ErrorIs(t, err, ErrCorruptionDetected) +} + +func TestTrieVerifier_Run_ContractStorageWithFilter(t *testing.T) { + logger := log.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + txn := testDB.NewIndexedBatch() + + // Create state trie with two contracts + statePrefix := db.StateTrie.Key() + stateTrie, err := trie.NewTriePedersen(txn, statePrefix, StarknetTrieHeight) + require.NoError(t, err) + + contractAddr1 := felt.NewFromUint64[felt.Felt](0x1111) + contractAddr2 := felt.NewFromUint64[felt.Felt](0x2222) + + _, err = stateTrie.Put(contractAddr1, felt.NewFromUint64[felt.Felt](1)) + require.NoError(t, err) + _, err = stateTrie.Put(contractAddr2, felt.NewFromUint64[felt.Felt](2)) + require.NoError(t, err) + + err = stateTrie.Commit() + require.NoError(t, err) + + stateStorage := trie.NewStorage(txn, statePrefix) + if stateTrie.RootKey() != nil { + err = stateStorage.PutRootKey(stateTrie.RootKey()) + require.NoError(t, err) + } + + // Create contract storage trie only for contract 1 + storagePrefix1 := db.ContractStorage.Key(contractAddr1.Marshal()) + storageTrie1, err := trie.NewTriePedersen(txn, storagePrefix1, StarknetTrieHeight) + require.NoError(t, err) + + _, err = storageTrie1.Put( + felt.NewFromUint64[felt.Felt](1), + felt.NewFromUint64[felt.Felt](100), + ) + require.NoError(t, err) + + err = storageTrie1.Commit() + require.NoError(t, err) + + storage1 := trie.NewStorage(txn, storagePrefix1) + if storageTrie1.RootKey() != nil { + err = storage1.PutRootKey(storageTrie1.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + // Verify only contract 1 using the filter - should succeed + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractStorageTrie}, contractAddr1) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_ContextCancellation(t *testing.T) { + logger := log.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie}, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_RootHashMismatch(t *testing.T) { + logger := log.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = testTrie.Put(key2, value2) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + // Get the root key and corrupt the stored root hash (not node values) + rootKey := testTrie.RootKey() + require.NotNil(t, rootKey) + + rootNode, err := trieStorage.Get(rootKey) + require.NoError(t, err) + require.NotNil(t, rootNode) + + // Corrupt the root node's value (stored hash) while keeping node structure intact + // This will cause root hash mismatch, not node corruption during traversal + corruptedHash := felt.NewFromUint64[felt.Felt](999999) + rootNode.Value = corruptedHash + + err = trieStorage.Put(rootKey, rootNode) + require.NoError(t, err) + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrCorruptionDetected) +} diff --git a/verify/trie/types.go b/verify/trie/types.go new file mode 100644 index 0000000000..859091aaa8 --- /dev/null +++ b/verify/trie/types.go @@ -0,0 +1,36 @@ +package trie + +import ( + "slices" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" +) + +const ( + StarknetTrieHeight uint8 = 251 + ConcurrencyMaxDepth uint8 = 8 +) + +type TrieType string + +const ( + ContractTrie TrieType = "contract" + ClassTrie TrieType = "class" + ContractStorageTrie TrieType = "contract-storage" +) + +var allTrieTypes = []TrieType{ContractTrie, ClassTrie, ContractStorageTrie} + +func (t TrieType) IsValid() bool { + return slices.Contains(allTrieTypes, t) +} + +type TrieInfo struct { + Name string + Prefix []byte + HashFn crypto.HashFn + ReaderFunc func(db.KeyValueReader, uint8) (trie.TrieReader, error) + Height uint8 +}