From 093be27c46c58aae81b448231154bfe03ce383e8 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Wed, 17 Dec 2025 19:11:25 +0100 Subject: [PATCH 01/15] first commit --- cmd/juno/juno.go | 2 + cmd/juno/verify/trie.go | 95 +++++++++++++++++++++++++++++++++++++ cmd/juno/verify/verifier.go | 43 +++++++++++++++++ cmd/juno/verify/verify.go | 94 ++++++++++++++++++++++++++++++++++++ 4 files changed, 234 insertions(+) create mode 100644 cmd/juno/verify/trie.go create mode 100644 cmd/juno/verify/verifier.go create mode 100644 cmd/juno/verify/verify.go diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index d667866a87..cd1a85176c 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -14,6 +14,7 @@ import ( "time" "github.com/NethermindEth/juno/blockchain/networks" + "github.com/NethermindEth/juno/cmd/juno/verify" _ "github.com/NethermindEth/juno/encoder/registry" _ "github.com/NethermindEth/juno/jemalloc" "github.com/NethermindEth/juno/node" @@ -486,6 +487,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr disableReceivedTxnStreamF, defaultDisableReceivedTxnStream, disableReceivedTxnStreamUsage, ) junoCmd.AddCommand(GenP2PKeyPair(), DBCmd(defaultDBPath), CompileSierraCmd()) + junoCmd.AddCommand(GenP2PKeyPair(), DBCmd(defaultDBPath), verify.VerifyCmd(defaultDBPath)) return junoCmd } diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go new file mode 100644 index 0000000000..52a2eacf35 --- /dev/null +++ b/cmd/juno/verify/trie.go @@ -0,0 +1,95 @@ +package verify + +import ( + "context" + "fmt" + + "github.com/NethermindEth/juno/db" +) + +type TrieType string + +const ( + TrieTypeState TrieType = "state" + TrieTypeClass TrieType = "class" + TrieTypeContract TrieType = "contract" +) + +type TrieConfig struct { + Tries []TrieType +} + +type TrieVerifier struct { + database db.KeyValueStore +} + +func NewTrieVerifier(database db.KeyValueStore) *TrieVerifier { + return &TrieVerifier{ + database: database, + } +} + +// Name returns the name of this verifier. +func (v *TrieVerifier) Name() string { + return "trie" +} + +// DefaultConfig returns the default configuration (verify all tries). +func (v *TrieVerifier) DefaultConfig() Config { + return &TrieConfig{ + Tries: nil, // nil = all tries + } +} + +// Run executes the trie verification. +func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { + trieCfg, ok := cfg.(*TrieConfig) + if !ok { + return fmt.Errorf("invalid config type for trie verifier: expected *TrieConfig") + } + + // If Tries is empty, verify all tries + typesToVerify := trieCfg.Tries + if len(typesToVerify) == 0 { + typesToVerify = []TrieType{TrieTypeState, TrieTypeClass, TrieTypeContract} + } + + for _, trieType := range typesToVerify { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if err := v.verifyTrie(ctx, trieType); err != nil { + return fmt.Errorf("trie type %s: %w", trieType, err) + } + } + + return nil +} + +// verifyTrie verifies a specific trie type. +func (v *TrieVerifier) verifyTrie(ctx context.Context, trieType TrieType) error { + // TODO: Implement actual trie verification logic + // This is a stub implementation that demonstrates the structure. + // The actual implementation should: + // 1. Collect leaf nodes from the specified trie + // 2. Rebuild the trie in memory from leaves + // 3. Calculate the root hash + // 4. Compare with stored root hash + + switch trieType { + case TrieTypeState: + // Verify StateTrie + return nil + case TrieTypeClass: + // Verify ClassesTrie + return nil + case TrieTypeContract: + // Verify all contract storage tries + return nil + default: + return fmt.Errorf("unknown trie type: %s", trieType) + } +} diff --git a/cmd/juno/verify/verifier.go b/cmd/juno/verify/verifier.go new file mode 100644 index 0000000000..fac5d5b3b4 --- /dev/null +++ b/cmd/juno/verify/verifier.go @@ -0,0 +1,43 @@ +package verify + +import ( + "context" + "fmt" +) + +// Config is the base configuration interface that all verifier configs must implement. +// Empty values in config structs imply default behavior. +type Config interface{} + +// Verifier defines the interface for all database verifiers. +// Each verifier is atomic, has its own strongly-typed config, and contains no CLI logic. +type Verifier interface { + // Name returns the name of the verifier (e.g., "trie", "tx"). + Name() string + // DefaultConfig returns the default configuration for this verifier. + DefaultConfig() Config + // Run executes the verification with the given config and context. + // Returns an error if verification fails. + Run(ctx context.Context, cfg Config) error +} + +// VerifyRunner orchestrates the execution of multiple verifiers. +type VerifyRunner struct { + Verifiers []Verifier +} + +// Run executes all registered verifiers sequentially. +// Stops on the first error and wraps it with the verifier name. +func (r *VerifyRunner) Run(ctx context.Context, configs map[string]Config) error { + for _, verifier := range r.Verifiers { + cfg, ok := configs[verifier.Name()] + if !ok { + cfg = verifier.DefaultConfig() + } + + if err := verifier.Run(ctx, cfg); err != nil { + return fmt.Errorf("%s: %w", verifier.Name(), err) + } + } + return nil +} diff --git a/cmd/juno/verify/verify.go b/cmd/juno/verify/verify.go new file mode 100644 index 0000000000..791a3b067e --- /dev/null +++ b/cmd/juno/verify/verify.go @@ -0,0 +1,94 @@ +package verify + +import ( + "github.com/spf13/cobra" +) + +const ( + verifyDBPathF = "db-path" + verifyTrieType = "type" +) + +func VerifyCmd(defaultDBPath string) *cobra.Command { + verifyCmd := &cobra.Command{ + Use: "verify", + Short: "Verify database integrity", + Long: `Verify database integrity using various verification methods.`, + } + + verifyCmd.PersistentFlags().String(verifyDBPathF, defaultDBPath, "Path to the database") + verifyCmd.AddCommand(verifyTrieCmd()) + verifyCmd.RunE = verifyAll + + return verifyCmd +} + +// verifyAll runs all verifiers with default scope when no subcommand is specified. +func verifyAll(cmd *cobra.Command, args []string) error { + dbPath, err := cmd.Flags().GetString(verifyDBPathF) + if err != nil { + return err + } + + database, err := openDB(dbPath) + if err != nil { + return err + } + defer database.Close() + + trieVerifier := NewTrieVerifier(database) + + runner := &VerifyRunner{ + Verifiers: []Verifier{ + trieVerifier, + }, + } + + configs := make(map[string]Config) + ctx := cmd.Context() + return runner.Run(ctx, configs) +} + +func verifyTrieCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "trie", + Short: "Verify trie integrity", + Long: `Verify trie integrity by rebuilding tries from leaf nodes and comparing root hashes.`, + RunE: runTrieVerify, + } + + cmd.Flags().StringSlice(verifyTrieType, nil, "Trie types to verify (state, class, contract). Can be specified multiple times. Empty = all.") + + return cmd +} + +func runTrieVerify(cmd *cobra.Command, args []string) error { + dbPath, err := cmd.Flags().GetString(verifyDBPathF) + if err != nil { + return err + } + + database, err := openDB(dbPath) + if err != nil { + return err + } + defer database.Close() + + trieTypes, err := cmd.Flags().GetStringSlice(verifyTrieType) + if err != nil { + return err + } + + cfg := &TrieConfig{} + + if len(trieTypes) > 0 { + cfg.Tries = make([]TrieType, len(trieTypes)) + for i, t := range trieTypes { + cfg.Tries[i] = TrieType(t) + } + } + + verifier := NewTrieVerifier(database) + ctx := cmd.Context() + return verifier.Run(ctx, cfg) +} From 39665546426dba3d423df0a56b1c27559ee2249a Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Thu, 18 Dec 2025 00:05:53 +0100 Subject: [PATCH 02/15] add trie verifier --- cmd/juno/verify/trie.go | 345 +++++++++++++++++++++++++++++++++--- cmd/juno/verify/verifier.go | 11 -- cmd/juno/verify/verify.go | 37 +++- 3 files changed, 352 insertions(+), 41 deletions(-) diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go index 52a2eacf35..20ac72f1a0 100644 --- a/cmd/juno/verify/trie.go +++ b/cmd/juno/verify/trie.go @@ -1,10 +1,19 @@ package verify import ( + "bytes" "context" "fmt" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/NethermindEth/juno/utils" +) + +const ( + starknetTrieHeight = 251 ) type TrieType string @@ -21,75 +30,355 @@ type TrieConfig struct { type TrieVerifier struct { database db.KeyValueStore + logger utils.SimpleLogger } -func NewTrieVerifier(database db.KeyValueStore) *TrieVerifier { +func NewTrieVerifier(database db.KeyValueStore, logger utils.SimpleLogger) *TrieVerifier { return &TrieVerifier{ database: database, + logger: logger, } } -// Name returns the name of this verifier. func (v *TrieVerifier) Name() string { return "trie" } -// DefaultConfig returns the default configuration (verify all tries). func (v *TrieVerifier) DefaultConfig() Config { return &TrieConfig{ - Tries: nil, // nil = all tries + Tries: nil, } } -// Run executes the trie verification. func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { trieCfg, ok := cfg.(*TrieConfig) if !ok { return fmt.Errorf("invalid config type for trie verifier: expected *TrieConfig") } - // If Tries is empty, verify all tries typesToVerify := trieCfg.Tries if len(typesToVerify) == 0 { typesToVerify = []TrieType{TrieTypeState, TrieTypeClass, TrieTypeContract} } - for _, trieType := range typesToVerify { + typeSet := make(map[TrieType]bool) + for _, t := range typesToVerify { + typeSet[t] = true + } + + var allErrors []error + + if typeSet[TrieTypeState] { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + stateTrieInfo := TrieInfo{ + Name: "StateTrie", + Bucket: db.StateTrie, + HashFunc: trie.NewTriePedersen, + ReaderFunc: func(r db.KeyValueReader, prefix []byte, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPedersen(r, prefix, height) + }, + Height: starknetTrieHeight, + } + + v.logger.Infow("=== Scanning StateTrie ===") + if err := v.scanTrie(v.database, stateTrieInfo); err != nil { + v.logger.Errorw("Error scanning StateTrie", "error", err) + allErrors = append(allErrors, fmt.Errorf("StateTrie: %w", err)) + } + } + + if typeSet[TrieTypeClass] { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + classesTrieInfo := TrieInfo{ + Name: "ClassesTrie", + Bucket: db.ClassesTrie, + HashFunc: trie.NewTriePoseidon, + ReaderFunc: func(r db.KeyValueReader, prefix []byte, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPoseidon(r, prefix, height) + }, + Height: starknetTrieHeight, + } + + v.logger.Infow("=== Scanning ClassesTrie ===") + if err := v.scanTrie(v.database, classesTrieInfo); err != nil { + v.logger.Errorw("Error scanning ClassesTrie", "error", err) + allErrors = append(allErrors, fmt.Errorf("ClassesTrie: %w", err)) + } + } + + if typeSet[TrieTypeContract] { select { case <-ctx.Done(): return ctx.Err() default: } - if err := v.verifyTrie(ctx, trieType); err != nil { - return fmt.Errorf("trie type %s: %w", trieType, err) + contractAddresses := v.collectContractAddresses() + if len(contractAddresses) > 0 { + v.logger.Infow("=== Scanning Contract Storage Tries ===") + v.logger.Infow("Found contracts to scan", "count", len(contractAddresses)) + + contractErrors := v.scanContractStorageTries(v.database, contractAddresses) + if len(contractErrors) > 0 { + v.logger.Errorw("Errors in contracts", "errorCount", len(contractErrors), "totalCount", len(contractAddresses)) + allErrors = append(allErrors, fmt.Errorf("contract storage tries: %d errors", len(contractErrors))) + for _, err := range contractErrors { + v.logger.Errorw("Contract error", "error", err) + } + } } } + if len(allErrors) > 0 { + return fmt.Errorf("trie verification completed with errors: %v", allErrors) + } + + v.logger.Infow("=== Trie verification completed successfully ===") return nil } -// verifyTrie verifies a specific trie type. -func (v *TrieVerifier) verifyTrie(ctx context.Context, trieType TrieType) error { - // TODO: Implement actual trie verification logic - // This is a stub implementation that demonstrates the structure. - // The actual implementation should: - // 1. Collect leaf nodes from the specified trie - // 2. Rebuild the trie in memory from leaves - // 3. Calculate the root hash - // 4. Compare with stored root hash - - switch trieType { - case TrieTypeState: - // Verify StateTrie +type TrieInfo struct { + Name string + Bucket db.Bucket + HashFunc trie.NewTrieFunc + ReaderFunc func(db.KeyValueReader, []byte, uint8) (trie.TrieReader, error) + Height uint8 +} + +func (v *TrieVerifier) scanTrie(database db.KeyValueStore, trieInfo TrieInfo) error { + prefix := trieInfo.Bucket.Key() + return v.scanTrieWithPrefix(database, trieInfo, prefix) +} + +func (v *TrieVerifier) collectContractAddresses() []felt.Felt { + contractAddresses := make([]felt.Felt, 0) + stateTriePrefix := db.StateTrie.Key() + + err := v.database.View(func(snap db.Snapshot) error { + it, err := snap.NewIterator(stateTriePrefix, true) + if err != nil { + return err + } + defer it.Close() + + for it.First(); it.Valid(); it.Next() { + keyBytes := it.Key() + if bytes.Equal(keyBytes, stateTriePrefix) { + continue + } + + if !bytes.HasPrefix(keyBytes, stateTriePrefix) { + continue + } + nodeKeyBytes := keyBytes[len(stateTriePrefix):] + + var nodeKey trie.BitArray + if err := nodeKey.UnmarshalBinary(nodeKeyBytes); err != nil { + continue + } + + if nodeKey.Len() == starknetTrieHeight { + contractAddr := nodeKey.Felt() + contractAddresses = append(contractAddresses, contractAddr) + } + } return nil - case TrieTypeClass: - // Verify ClassesTrie + }) + if err != nil { return nil - case TrieTypeContract: - // Verify all contract storage tries + } + + return contractAddresses +} + +func (v *TrieVerifier) scanContractStorageTries(database db.KeyValueStore, contractAddresses []felt.Felt) []error { + var allErrors []error + + for i, contractAddr := range contractAddresses { + v.logger.Infow("Scanning contract", "current", i+1, "total", len(contractAddresses)) + + addrBytes := contractAddr.Marshal() + prefix := db.ContractStorage.Key(addrBytes) + + trieInfo := TrieInfo{ + Name: fmt.Sprintf("ContractStorage[%s]", contractAddr.String()), + Bucket: db.ContractStorage, + HashFunc: trie.NewTriePedersen, + ReaderFunc: func(r db.KeyValueReader, _ []byte, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPedersen(r, prefix, height) + }, + Height: starknetTrieHeight, + } + + err := v.scanTrieWithPrefix(database, trieInfo, prefix) + if err != nil { + allErrors = append(allErrors, fmt.Errorf("contract %s: %w", contractAddr.String(), err)) + } + } + + v.logger.Infow("Scanned all contracts", "count", len(contractAddresses)) + + return allErrors +} + +func (v *TrieVerifier) scanTrieWithPrefix(database db.KeyValueStore, trieInfo TrieInfo, prefix []byte) error { + var storedRootHash felt.Felt + var hasRootKey bool + err := database.View(func(snap db.Snapshot) error { + reader, err := trieInfo.ReaderFunc(snap, prefix, trieInfo.Height) + if err != nil { + return err + } + if reader.RootKey() == nil { + storedRootHash = felt.Zero + hasRootKey = false + return nil + } + hasRootKey = true + storedRootHash, err = reader.Hash() + return err + }) + if err != nil { + if trieInfo.Bucket == db.ContractStorage { + return nil + } + return fmt.Errorf("failed to get stored root hash: %w", err) + } + + isContractStorage := trieInfo.Bucket == db.ContractStorage + if !isContractStorage { + v.logger.Infow("Stored root hash", "hash", storedRootHash.String()) + v.logger.Infow("Scanning nodes...") + } + + leaves := make(map[felt.Felt]felt.Felt) + totalNodes := 0 + leafCount := 0 + + err = database.View(func(snap db.Snapshot) error { + it, err := snap.NewIterator(prefix, true) + if err != nil { + return err + } + defer it.Close() + + for it.First(); it.Valid(); it.Next() { + keyBytes := it.Key() + if bytes.Equal(keyBytes, prefix) { + continue + } + + if !bytes.HasPrefix(keyBytes, prefix) { + continue + } + nodeKeyBytes := keyBytes[len(prefix):] + + var nodeKey trie.BitArray + if err := nodeKey.UnmarshalBinary(nodeKeyBytes); err != nil { + continue + } + + totalNodes++ + + valBytes, err := it.Value() + if err != nil { + return fmt.Errorf("failed to read value: %w", err) + } + + var node trie.Node + if err := node.UnmarshalBinary(valBytes); err != nil { + return fmt.Errorf("failed to unmarshal node at key %s: %w", nodeKey.String(), err) + } + + isLeaf := nodeKey.Len() == trieInfo.Height + + if isLeaf && node.Value != nil { + leafCount++ + keyFelt := nodeKey.Felt() + leaves[keyFelt] = *node.Value + } + } return nil - default: - return fmt.Errorf("unknown trie type: %s", trieType) + }) + if err != nil { + return fmt.Errorf("failed to iterate nodes: %w", err) } + + if len(leaves) == 0 { + v.logger.Infow("No leaves found, calculating root from empty trie") + if isContractStorage && !hasRootKey { + return nil + } + if isContractStorage && hasRootKey { + return nil + } + } else { + v.logger.Infow("Rebuilding trie from leaves", "leafCount", len(leaves)) + } + + calculatedRoot, err := v.rebuildTrieFromLeaves(leaves, trieInfo) + if err != nil { + return fmt.Errorf("failed to rebuild trie: %w", err) + } + + if calculatedRoot.Equal(&storedRootHash) { + v.logger.Infow("Root hashes match - no corruption detected", + "calculated", calculatedRoot.String(), + "stored", storedRootHash.String()) + return nil + } + + v.logger.Errorw("ROOT MISMATCH DETECTED!", + "calculated", calculatedRoot.String(), + "stored", storedRootHash.String()) + return fmt.Errorf("root hash mismatch for %s", trieInfo.Name) +} + +func (v *TrieVerifier) rebuildTrieFromLeaves(leaves map[felt.Felt]felt.Felt, trieInfo TrieInfo) (felt.Felt, error) { + memoryDB := memory.New() + txn := memoryDB.NewIndexedBatch() + defer func() { + _ = memoryDB.Close() + }() + + t, err := trieInfo.HashFunc(txn, nil, trieInfo.Height) + if err != nil { + return felt.Zero, fmt.Errorf("failed to create in-memory trie: %w", err) + } + + totalLeaves := len(leaves) + const rebuildProgressInterval = 10000 + inserted := 0 + + for key, value := range leaves { + keyCopy := key + valueCopy := value + if _, err := t.Put(&keyCopy, &valueCopy); err != nil { + return felt.Zero, fmt.Errorf("failed to insert leaf %s: %w", key.String(), err) + } + inserted++ + } + + if totalLeaves > 0 && inserted%rebuildProgressInterval != 0 { + v.logger.Infow("Inserted all leaves", "count", inserted) + } + + v.logger.Infow("Calculating root hash...") + + rootHash, err := t.Hash() + if err != nil { + return felt.Zero, fmt.Errorf("failed to calculate root hash: %w", err) + } + + return rootHash, nil } diff --git a/cmd/juno/verify/verifier.go b/cmd/juno/verify/verifier.go index fac5d5b3b4..6abe284de4 100644 --- a/cmd/juno/verify/verifier.go +++ b/cmd/juno/verify/verifier.go @@ -5,29 +5,18 @@ import ( "fmt" ) -// Config is the base configuration interface that all verifier configs must implement. -// Empty values in config structs imply default behavior. type Config interface{} -// Verifier defines the interface for all database verifiers. -// Each verifier is atomic, has its own strongly-typed config, and contains no CLI logic. type Verifier interface { - // Name returns the name of the verifier (e.g., "trie", "tx"). Name() string - // DefaultConfig returns the default configuration for this verifier. DefaultConfig() Config - // Run executes the verification with the given config and context. - // Returns an error if verification fails. Run(ctx context.Context, cfg Config) error } -// VerifyRunner orchestrates the execution of multiple verifiers. type VerifyRunner struct { Verifiers []Verifier } -// Run executes all registered verifiers sequentially. -// Stops on the first error and wraps it with the verifier name. func (r *VerifyRunner) Run(ctx context.Context, configs map[string]Config) error { for _, verifier := range r.Verifiers { cfg, ok := configs[verifier.Name()] diff --git a/cmd/juno/verify/verify.go b/cmd/juno/verify/verify.go index 791a3b067e..2dd28d5085 100644 --- a/cmd/juno/verify/verify.go +++ b/cmd/juno/verify/verify.go @@ -1,6 +1,13 @@ package verify import ( + "errors" + "fmt" + "os" + + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/pebblev2" + "github.com/NethermindEth/juno/utils" "github.com/spf13/cobra" ) @@ -36,7 +43,13 @@ func verifyAll(cmd *cobra.Command, args []string) error { } defer database.Close() - trieVerifier := NewTrieVerifier(database) + logLevel := utils.NewLogLevel(utils.INFO) + logger, err := utils.NewZapLogger(logLevel, true) + if err != nil { + return fmt.Errorf("failed to create logger: %w", err) + } + + trieVerifier := NewTrieVerifier(database, logger) runner := &VerifyRunner{ Verifiers: []Verifier{ @@ -88,7 +101,27 @@ func runTrieVerify(cmd *cobra.Command, args []string) error { } } - verifier := NewTrieVerifier(database) + logLevel := utils.NewLogLevel(utils.INFO) + logger, err := utils.NewZapLogger(logLevel, true) + if err != nil { + return fmt.Errorf("failed to create logger: %w", err) + } + + verifier := NewTrieVerifier(database, logger) ctx := cmd.Context() return verifier.Run(ctx, cfg) } + +func openDB(path string) (db.KeyValueStore, error) { + _, err := os.Stat(path) + if os.IsNotExist(err) { + return nil, errors.New("database path does not exist") + } + + database, err := pebblev2.New(path) + if err != nil { + return nil, fmt.Errorf("failed to open db: %w", err) + } + + return database, nil +} From bf9f07cd3848953750a030e866ffc7c6bffe9881 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Thu, 18 Dec 2025 17:37:29 +0100 Subject: [PATCH 03/15] add unit tests --- cmd/juno/verify/trie.go | 41 +++--- cmd/juno/verify/trie_test.go | 252 +++++++++++++++++++++++++++++++++++ 2 files changed, 270 insertions(+), 23 deletions(-) create mode 100644 cmd/juno/verify/trie_test.go diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go index 20ac72f1a0..8bb7ca6bc3 100644 --- a/cmd/juno/verify/trie.go +++ b/cmd/juno/verify/trie.go @@ -19,9 +19,9 @@ const ( type TrieType string const ( - TrieTypeState TrieType = "state" - TrieTypeClass TrieType = "class" - TrieTypeContract TrieType = "contract" + ContractTrieType TrieType = "contract" + ClassTrieType TrieType = "class" + ContractStorageTrieType TrieType = "contract-storage" ) type TrieConfig struct { @@ -58,7 +58,7 @@ func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { typesToVerify := trieCfg.Tries if len(typesToVerify) == 0 { - typesToVerify = []TrieType{TrieTypeState, TrieTypeClass, TrieTypeContract} + typesToVerify = []TrieType{ContractTrieType, ClassTrieType, ContractStorageTrieType} } typeSet := make(map[TrieType]bool) @@ -68,7 +68,7 @@ func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { var allErrors []error - if typeSet[TrieTypeState] { + if typeSet[ContractTrieType] { select { case <-ctx.Done(): return ctx.Err() @@ -76,7 +76,7 @@ func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { } stateTrieInfo := TrieInfo{ - Name: "StateTrie", + Name: "ContractTrie", Bucket: db.StateTrie, HashFunc: trie.NewTriePedersen, ReaderFunc: func(r db.KeyValueReader, prefix []byte, height uint8) (trie.TrieReader, error) { @@ -85,21 +85,21 @@ func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { Height: starknetTrieHeight, } - v.logger.Infow("=== Scanning StateTrie ===") + v.logger.Infow("=== Scanning ContractTrie ===") if err := v.scanTrie(v.database, stateTrieInfo); err != nil { - v.logger.Errorw("Error scanning StateTrie", "error", err) - allErrors = append(allErrors, fmt.Errorf("StateTrie: %w", err)) + v.logger.Errorw("Error scanning ContractTrie", "error", err) + allErrors = append(allErrors, fmt.Errorf("ContractTrie: %w", err)) } } - if typeSet[TrieTypeClass] { + if typeSet[ClassTrieType] { select { case <-ctx.Done(): return ctx.Err() default: } - classesTrieInfo := TrieInfo{ + classTrieInfo := TrieInfo{ Name: "ClassesTrie", Bucket: db.ClassesTrie, HashFunc: trie.NewTriePoseidon, @@ -109,14 +109,14 @@ func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { Height: starknetTrieHeight, } - v.logger.Infow("=== Scanning ClassesTrie ===") - if err := v.scanTrie(v.database, classesTrieInfo); err != nil { - v.logger.Errorw("Error scanning ClassesTrie", "error", err) - allErrors = append(allErrors, fmt.Errorf("ClassesTrie: %w", err)) + v.logger.Infow("=== Scanning ClassTrie ===") + if err := v.scanTrie(v.database, classTrieInfo); err != nil { + v.logger.Errorw("Error scanning ClassTrie", "error", err) + allErrors = append(allErrors, fmt.Errorf("ClassTrie: %w", err)) } } - if typeSet[TrieTypeContract] { + if typeSet[ContractStorageTrieType] { select { case <-ctx.Done(): return ctx.Err() @@ -341,7 +341,7 @@ func (v *TrieVerifier) scanTrieWithPrefix(database db.KeyValueStore, trieInfo Tr v.logger.Errorw("ROOT MISMATCH DETECTED!", "calculated", calculatedRoot.String(), "stored", storedRootHash.String()) - return fmt.Errorf("root hash mismatch for %s", trieInfo.Name) + return fmt.Errorf("root hash mismatch for %s, expected %s, got %s", trieInfo.Name, calculatedRoot.String(), storedRootHash.String()) } func (v *TrieVerifier) rebuildTrieFromLeaves(leaves map[felt.Felt]felt.Felt, trieInfo TrieInfo) (felt.Felt, error) { @@ -356,8 +356,6 @@ func (v *TrieVerifier) rebuildTrieFromLeaves(leaves map[felt.Felt]felt.Felt, tri return felt.Zero, fmt.Errorf("failed to create in-memory trie: %w", err) } - totalLeaves := len(leaves) - const rebuildProgressInterval = 10000 inserted := 0 for key, value := range leaves { @@ -369,10 +367,7 @@ func (v *TrieVerifier) rebuildTrieFromLeaves(leaves map[felt.Felt]felt.Felt, tri inserted++ } - if totalLeaves > 0 && inserted%rebuildProgressInterval != 0 { - v.logger.Infow("Inserted all leaves", "count", inserted) - } - + v.logger.Infow("Inserted all leaves", "count", inserted) v.logger.Infow("Calculating root hash...") rootHash, err := t.Hash() diff --git a/cmd/juno/verify/trie_test.go b/cmd/juno/verify/trie_test.go new file mode 100644 index 0000000000..faa331eeca --- /dev/null +++ b/cmd/juno/verify/trie_test.go @@ -0,0 +1,252 @@ +package verify + +import ( + "context" + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTrieVerifier_Run_ValidStateTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = testTrie.Put(key2, value2) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + rootHash, err := testTrie.Hash() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger) + cfg := &TrieConfig{ + Tries: []TrieType{ContractTrieType}, + } + + ctx := context.Background() + err = verifier.Run(ctx, cfg) + assert.NoError(t, err) + + reader, err := trie.NewTrieReaderPedersen(testDB, prefix, starknetTrieHeight) + require.NoError(t, err) + storedHash, err := reader.Hash() + require.NoError(t, err) + assert.True(t, rootHash.Equal(&storedHash)) +} + +func TestTrieVerifier_Run_ValidClassTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.ClassesTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePoseidon(txn, prefix, starknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](10) + value1 := felt.NewFromUint64[felt.Felt](1000) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger) + cfg := &TrieConfig{ + Tries: []TrieType{ClassTrieType}, + } + + ctx := context.Background() + err = verifier.Run(ctx, cfg) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_CorruptedTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = testTrie.Put(key2, value2) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + var nodeKey trie.BitArray + nodeKey.SetFelt(starknetTrieHeight, key1) + + node, err := trieStorage.Get(&nodeKey) + require.NoError(t, err) + require.NotNil(t, node) + require.NotNil(t, node.Value) + + assert.True(t, node.Value.Equal(value1), "Expected value1 but got different value") + + corruptedValue := felt.NewFromUint64[felt.Felt](999999) + node.Value = corruptedValue + + err = trieStorage.Put(&nodeKey, node) + require.NoError(t, err) + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger) + cfg := &TrieConfig{ + Tries: []TrieType{ContractTrieType}, + } + + ctx := context.Background() + err = verifier.Run(ctx, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "root hash mismatch") +} + +func TestTrieVerifier_Run_MultipleTrieTypes(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + statePrefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + stateTrie, err := trie.NewTriePedersen(txn, statePrefix, starknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = stateTrie.Put(key1, value1) + require.NoError(t, err) + + err = stateTrie.Commit() + require.NoError(t, err) + + stateStorage := trie.NewStorage(txn, statePrefix) + if stateTrie.RootKey() != nil { + err = stateStorage.PutRootKey(stateTrie.RootKey()) + require.NoError(t, err) + } + + classPrefix := db.ClassesTrie.Key() + classTrie, err := trie.NewTriePoseidon(txn, classPrefix, starknetTrieHeight) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = classTrie.Put(key2, value2) + require.NoError(t, err) + + err = classTrie.Commit() + require.NoError(t, err) + + classStorage := trie.NewStorage(txn, classPrefix) + if classTrie.RootKey() != nil { + err = classStorage.PutRootKey(classTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger) + cfg := &TrieConfig{ + Tries: []TrieType{ContractTrieType, ClassTrieType}, + } + + ctx := context.Background() + err = verifier.Run(ctx, cfg) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_EmptyTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger) + cfg := &TrieConfig{ + Tries: []TrieType{ContractTrieType}, + } + + ctx := context.Background() + err = verifier.Run(ctx, cfg) + assert.NoError(t, err) +} From b616ec571a4c38d633e21156aba6a752eaafe6b0 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Sun, 21 Dec 2025 12:58:39 +0100 Subject: [PATCH 04/15] trie verification reworked --- cmd/juno/verify/trie.go | 441 +++++++++++++++++++++----------------- cmd/juno/verify/verify.go | 6 +- 2 files changed, 248 insertions(+), 199 deletions(-) diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go index 8bb7ca6bc3..c2a8db15ac 100644 --- a/cmd/juno/verify/trie.go +++ b/cmd/juno/verify/trie.go @@ -3,17 +3,21 @@ package verify import ( "bytes" "context" + "errors" "fmt" + "sync" + "time" + "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/utils" ) const ( - starknetTrieHeight = 251 + starknetTrieHeight = 251 + concurrencyMaxDepth = 8 ) type TrieType string @@ -50,7 +54,22 @@ func (v *TrieVerifier) DefaultConfig() Config { } } +type TrieInfo struct { + Name string + prefix []byte + HashFunc trie.NewTrieFunc + HashFn crypto.HashFn + ReaderFunc func(db.KeyValueReader, []byte, uint8) (trie.TrieReader, error) + Height uint8 +} + func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { + startTime := time.Now() + defer func() { + elapsed := time.Since(startTime) + v.logger.Infow("=== Trie verification finished ===", "total_elapsed", elapsed.Round(time.Second)) + }() + trieCfg, ok := cfg.(*TrieConfig) if !ok { return fmt.Errorf("invalid config type for trie verifier: expected *TrieConfig") @@ -66,100 +85,53 @@ func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { typeSet[t] = true } - var allErrors []error - if typeSet[ContractTrieType] { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - stateTrieInfo := TrieInfo{ - Name: "ContractTrie", - Bucket: db.StateTrie, - HashFunc: trie.NewTriePedersen, - ReaderFunc: func(r db.KeyValueReader, prefix []byte, height uint8) (trie.TrieReader, error) { - return trie.NewTrieReaderPedersen(r, prefix, height) - }, - Height: starknetTrieHeight, + Name: "ContractsTrie", + prefix: db.StateTrie.Key(), + HashFunc: trie.NewTriePedersen, + HashFn: crypto.Pedersen, + ReaderFunc: trie.NewTrieReaderPedersen, + Height: starknetTrieHeight, } - - v.logger.Infow("=== Scanning ContractTrie ===") - if err := v.scanTrie(v.database, stateTrieInfo); err != nil { - v.logger.Errorw("Error scanning ContractTrie", "error", err) - allErrors = append(allErrors, fmt.Errorf("ContractTrie: %w", err)) + if err := v.verifyTrieWithLogging(ctx, stateTrieInfo); err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err } } if typeSet[ClassTrieType] { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - classTrieInfo := TrieInfo{ - Name: "ClassesTrie", - Bucket: db.ClassesTrie, - HashFunc: trie.NewTriePoseidon, - ReaderFunc: func(r db.KeyValueReader, prefix []byte, height uint8) (trie.TrieReader, error) { - return trie.NewTrieReaderPoseidon(r, prefix, height) - }, - Height: starknetTrieHeight, + Name: "ClassesTrie", + prefix: db.ClassesTrie.Key(), + HashFunc: trie.NewTriePoseidon, + HashFn: crypto.Poseidon, + ReaderFunc: trie.NewTrieReaderPoseidon, + Height: starknetTrieHeight, } - - v.logger.Infow("=== Scanning ClassTrie ===") - if err := v.scanTrie(v.database, classTrieInfo); err != nil { - v.logger.Errorw("Error scanning ClassTrie", "error", err) - allErrors = append(allErrors, fmt.Errorf("ClassTrie: %w", err)) + if err := v.verifyTrieWithLogging(ctx, classTrieInfo); err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err } } if typeSet[ContractStorageTrieType] { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - contractAddresses := v.collectContractAddresses() - if len(contractAddresses) > 0 { - v.logger.Infow("=== Scanning Contract Storage Tries ===") - v.logger.Infow("Found contracts to scan", "count", len(contractAddresses)) - - contractErrors := v.scanContractStorageTries(v.database, contractAddresses) - if len(contractErrors) > 0 { - v.logger.Errorw("Errors in contracts", "errorCount", len(contractErrors), "totalCount", len(contractAddresses)) - allErrors = append(allErrors, fmt.Errorf("contract storage tries: %d errors", len(contractErrors))) - for _, err := range contractErrors { - v.logger.Errorw("Contract error", "error", err) - } + if err := v.verifyContractStorageTries(ctx); err != nil { + if errors.Is(err, context.Canceled) { + return nil } + return err } } - if len(allErrors) > 0 { - return fmt.Errorf("trie verification completed with errors: %v", allErrors) - } - v.logger.Infow("=== Trie verification completed successfully ===") return nil } -type TrieInfo struct { - Name string - Bucket db.Bucket - HashFunc trie.NewTrieFunc - ReaderFunc func(db.KeyValueReader, []byte, uint8) (trie.TrieReader, error) - Height uint8 -} - -func (v *TrieVerifier) scanTrie(database db.KeyValueStore, trieInfo TrieInfo) error { - prefix := trieInfo.Bucket.Key() - return v.scanTrieWithPrefix(database, trieInfo, prefix) -} - func (v *TrieVerifier) collectContractAddresses() []felt.Felt { contractAddresses := make([]felt.Felt, 0) stateTriePrefix := db.StateTrie.Key() @@ -201,179 +173,252 @@ func (v *TrieVerifier) collectContractAddresses() []felt.Felt { return contractAddresses } -func (v *TrieVerifier) scanContractStorageTries(database db.KeyValueStore, contractAddresses []felt.Felt) []error { - var allErrors []error - - for i, contractAddr := range contractAddresses { - v.logger.Infow("Scanning contract", "current", i+1, "total", len(contractAddresses)) - - addrBytes := contractAddr.Marshal() - prefix := db.ContractStorage.Key(addrBytes) - - trieInfo := TrieInfo{ - Name: fmt.Sprintf("ContractStorage[%s]", contractAddr.String()), - Bucket: db.ContractStorage, - HashFunc: trie.NewTriePedersen, - ReaderFunc: func(r db.KeyValueReader, _ []byte, height uint8) (trie.TrieReader, error) { - return trie.NewTrieReaderPedersen(r, prefix, height) - }, - Height: starknetTrieHeight, - } - - err := v.scanTrieWithPrefix(database, trieInfo, prefix) - if err != nil { - allErrors = append(allErrors, fmt.Errorf("contract %s: %w", contractAddr.String(), err)) +func (v *TrieVerifier) verifyTrieWithLogging(ctx context.Context, trieInfo TrieInfo) error { + v.logger.Infow(fmt.Sprintf("=== Starting %s verification ===", trieInfo.Name)) + err := v.verifyTrie(ctx, trieInfo) + if err != nil { + if errors.Is(err, context.Canceled) { + v.logger.Infow("Verification stopped", "trie", trieInfo.Name) + return err } + v.logger.Errorw(fmt.Sprintf("%s verification failed", trieInfo.Name), "error", err) + return fmt.Errorf("%s verification failed: %w", trieInfo.Name, err) } - - v.logger.Infow("Scanned all contracts", "count", len(contractAddresses)) - - return allErrors + v.logger.Infow(fmt.Sprintf("%s verification completed successfully", trieInfo.Name)) + return nil } -func (v *TrieVerifier) scanTrieWithPrefix(database db.KeyValueStore, trieInfo TrieInfo, prefix []byte) error { - var storedRootHash felt.Felt - var hasRootKey bool - err := database.View(func(snap db.Snapshot) error { - reader, err := trieInfo.ReaderFunc(snap, prefix, trieInfo.Height) +func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error { + expectedRoot := felt.Zero + err := v.database.View(func(snap db.Snapshot) error { + reader, err := trieInfo.ReaderFunc(snap, trieInfo.prefix, trieInfo.Height) if err != nil { return err } if reader.RootKey() == nil { - storedRootHash = felt.Zero - hasRootKey = false + expectedRoot = felt.Zero return nil } - hasRootKey = true - storedRootHash, err = reader.Hash() + expectedRoot, err = reader.Hash() return err }) if err != nil { - if trieInfo.Bucket == db.ContractStorage { - return nil - } - return fmt.Errorf("failed to get stored root hash: %w", err) + v.logger.Errorw("Failed to get stored root hash", "trie", trieInfo.Name, "error", err) + return fmt.Errorf("failed to get stored root hash for %s: %w", trieInfo.Name, err) } - isContractStorage := trieInfo.Bucket == db.ContractStorage - if !isContractStorage { - v.logger.Infow("Stored root hash", "hash", storedRootHash.String()) - v.logger.Infow("Scanning nodes...") + if expectedRoot.IsZero() { + v.logger.Infow("Trie is empty (zero root)", "trie", trieInfo.Name) + return nil } - leaves := make(map[felt.Felt]felt.Felt) - totalNodes := 0 - leafCount := 0 + v.logger.Infow("Starting verification", "trie", trieInfo.Name, "expectedRoot", expectedRoot.String()) + storageReader := trie.NewReadStorage(v.database, trieInfo.prefix) - err = database.View(func(snap db.Snapshot) error { - it, err := snap.NewIterator(prefix, true) - if err != nil { + err = verifyTrie(ctx, storageReader, starknetTrieHeight, trieInfo.HashFn, &expectedRoot) + if err != nil { + if errors.Is(err, context.Canceled) { return err } - defer it.Close() - - for it.First(); it.Valid(); it.Next() { - keyBytes := it.Key() - if bytes.Equal(keyBytes, prefix) { - continue - } + v.logger.Errorw("Trie verification failed", "trie", trieInfo.Name, "error", err) + return err + } - if !bytes.HasPrefix(keyBytes, prefix) { - continue - } - nodeKeyBytes := keyBytes[len(prefix):] + v.logger.Infow("Trie verification successful", "trie", trieInfo.Name, "root", expectedRoot.String()) + return nil +} - var nodeKey trie.BitArray - if err := nodeKey.UnmarshalBinary(nodeKeyBytes); err != nil { - continue - } +func (v *TrieVerifier) verifyContractStorageTries(ctx context.Context) error { + v.logger.Infow("=== Starting Contract Storage Tries verification ===") + contractAddresses := v.collectContractAddresses() - totalNodes++ + if len(contractAddresses) == 0 { + v.logger.Infow("No contract addresses found, skipping contract storage verification") + return nil + } - valBytes, err := it.Value() - if err != nil { - return fmt.Errorf("failed to read value: %w", err) - } + v.logger.Infow("Found contracts to verify", "count", len(contractAddresses)) - var node trie.Node - if err := node.UnmarshalBinary(valBytes); err != nil { - return fmt.Errorf("failed to unmarshal node at key %s: %w", nodeKey.String(), err) - } + for i, contractAddress := range contractAddresses { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + addrBytes := contractAddress.Marshal() + prefix := db.ContractStorage.Key(addrBytes) + trieInfo := TrieInfo{ + Name: fmt.Sprintf("ContractStorage[%s]", contractAddress.String()), + prefix: db.ContractStorage.Key(addrBytes), + HashFunc: trie.NewTriePedersen, + HashFn: crypto.Pedersen, + ReaderFunc: func(r db.KeyValueReader, _ []byte, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPedersen(r, prefix, height) + }, + Height: starknetTrieHeight, + } - isLeaf := nodeKey.Len() == trieInfo.Height + v.logger.Infow("Verifying contract storage", + "contract", contractAddress.String(), + "progress", fmt.Sprintf("%d/%d", i+1, len(contractAddresses))) - if isLeaf && node.Value != nil { - leafCount++ - keyFelt := nodeKey.Felt() - leaves[keyFelt] = *node.Value + err := v.verifyTrie(ctx, trieInfo) + if err != nil { + if errors.Is(err, context.Canceled) { + return err } + v.logger.Errorw("Contract storage verification failed", + "contract", contractAddress.String(), + "error", err) + return fmt.Errorf("contract storage verification failed for %s: %w", contractAddress.String(), err) } - return nil - }) + } + + v.logger.Infow("All contract storage tries verified successfully", "count", len(contractAddresses)) + return nil +} + +func verifyTrie( + ctx context.Context, + reader *trie.ReadStorage, + height uint8, + hashFn crypto.HashFn, + expectedRoot *felt.Felt, +) error { + rootKey, err := reader.RootKey() if err != nil { - return fmt.Errorf("failed to iterate nodes: %w", err) + return fmt.Errorf("failed to get root key: %w", err) } - if len(leaves) == 0 { - v.logger.Infow("No leaves found, calculating root from empty trie") - if isContractStorage && !hasRootKey { - return nil - } - if isContractStorage && hasRootKey { - return nil - } - } else { - v.logger.Infow("Rebuilding trie from leaves", "leafCount", len(leaves)) + if rootKey == nil { + return nil } - calculatedRoot, err := v.rebuildTrieFromLeaves(leaves, trieInfo) + startTime := time.Now() + rootHash, err := verifyNode(ctx, reader, rootKey, nil, height, hashFn) if err != nil { - return fmt.Errorf("failed to rebuild trie: %w", err) + return fmt.Errorf("node verification failed: %w", err) } - if calculatedRoot.Equal(&storedRootHash) { - v.logger.Infow("Root hashes match - no corruption detected", - "calculated", calculatedRoot.String(), - "stored", storedRootHash.String()) - return nil + elapsed := time.Since(startTime) + + if rootHash.Cmp(expectedRoot) != 0 { + return fmt.Errorf( + "root hash mismatch: expected %s, got %s (verification took %v)", + expectedRoot, rootHash, elapsed.Round(time.Second), + ) } - v.logger.Errorw("ROOT MISMATCH DETECTED!", - "calculated", calculatedRoot.String(), - "stored", storedRootHash.String()) - return fmt.Errorf("root hash mismatch for %s, expected %s, got %s", trieInfo.Name, calculatedRoot.String(), storedRootHash.String()) + return nil } -func (v *TrieVerifier) rebuildTrieFromLeaves(leaves map[felt.Felt]felt.Felt, trieInfo TrieInfo) (felt.Felt, error) { - memoryDB := memory.New() - txn := memoryDB.NewIndexedBatch() - defer func() { - _ = memoryDB.Close() - }() +func verifyNode( + ctx context.Context, + reader *trie.ReadStorage, + key *trie.BitArray, + parentKey *trie.BitArray, + height uint8, + hashFn crypto.HashFn, +) (*felt.Felt, error) { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("verification cancelled: %w", ctx.Err()) + default: + } - t, err := trieInfo.HashFunc(txn, nil, trieInfo.Height) + node, err := reader.Get(key) if err != nil { - return felt.Zero, fmt.Errorf("failed to create in-memory trie: %w", err) + return nil, fmt.Errorf("failed to get node at key %s: %w", key.String(), err) } - inserted := 0 + if key.Len() == height { + p := path(key, parentKey) + h := node.Hash(&p, hashFn) + return &h, nil + } - for key, value := range leaves { - keyCopy := key - valueCopy := value - if _, err := t.Put(&keyCopy, &valueCopy); err != nil { - return felt.Zero, fmt.Errorf("failed to insert leaf %s: %w", key.String(), err) + useConcurrency := key.Len() <= concurrencyMaxDepth + var leftHash, rightHash *felt.Felt + var leftErr, rightErr error + + if useConcurrency { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if node.Left.IsEmpty() { + zero := felt.Zero + leftHash = &zero + return + } + h, err := verifyNode(ctx, reader, node.Left, key, height, hashFn) + leftHash = h + leftErr = err + }() + + if node.Right.IsEmpty() { + zero := felt.Zero + rightHash = &zero + } else { + h, err := verifyNode(ctx, reader, node.Right, key, height, hashFn) + rightHash = h + rightErr = err + } + + wg.Wait() + + if leftErr != nil { + return nil, leftErr + } + if rightErr != nil { + return nil, rightErr + } + } else { + if node.Left.IsEmpty() { + zero := felt.Zero + leftHash = &zero + } else { + h, err := verifyNode(ctx, reader, node.Left, key, height, hashFn) + if err != nil { + return nil, err + } + leftHash = h + } + + if node.Right.IsEmpty() { + zero := felt.Zero + rightHash = &zero + } else { + h, err := verifyNode(ctx, reader, node.Right, key, height, hashFn) + if err != nil { + return nil, err + } + rightHash = h } - inserted++ } - v.logger.Infow("Inserted all leaves", "count", inserted) - v.logger.Infow("Calculating root hash...") + recomputed := hashFn(leftHash, rightHash) + if recomputed.Cmp(node.Value) != 0 { + return nil, fmt.Errorf( + "node corruption detected at key %s: stored hash=%s, recomputed hash=%s", + key.String(), node.Value.String(), recomputed.String(), + ) + } - rootHash, err := t.Hash() - if err != nil { - return felt.Zero, fmt.Errorf("failed to calculate root hash: %w", err) + tmp := *node + tmp.Value = &recomputed + + p := path(key, parentKey) + h := tmp.Hash(&p, hashFn) + return &h, nil +} + +func path(key, parentKey *trie.BitArray) trie.BitArray { + if parentKey == nil { + return key.Copy() } - return rootHash, nil + var pathKey trie.BitArray + pathKey.LSBs(key, parentKey.Len()+1) + return pathKey } diff --git a/cmd/juno/verify/verify.go b/cmd/juno/verify/verify.go index 2dd28d5085..ec4b4bb79e 100644 --- a/cmd/juno/verify/verify.go +++ b/cmd/juno/verify/verify.go @@ -70,7 +70,11 @@ func verifyTrieCmd() *cobra.Command { RunE: runTrieVerify, } - cmd.Flags().StringSlice(verifyTrieType, nil, "Trie types to verify (state, class, contract). Can be specified multiple times. Empty = all.") + cmd.Flags().StringSlice( + verifyTrieType, + nil, + "Trie types to verify (state, class, contract). Can be specified multiple times. Empty = all.", + ) return cmd } From 2d71a8ba4e848827b399fe0b429b16f4463de83a Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Sun, 21 Dec 2025 16:23:17 +0100 Subject: [PATCH 05/15] lint --- cmd/juno/verify/trie.go | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go index c2a8db15ac..b6ddf54816 100644 --- a/cmd/juno/verify/trie.go +++ b/cmd/juno/verify/trie.go @@ -212,7 +212,12 @@ func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error return nil } - v.logger.Infow("Starting verification", "trie", trieInfo.Name, "expectedRoot", expectedRoot.String()) + v.logger.Infow("Starting verification", + "trie", + trieInfo.Name, + "expectedRoot", + expectedRoot.String(), + ) storageReader := trie.NewReadStorage(v.database, trieInfo.prefix) err = verifyTrie(ctx, storageReader, starknetTrieHeight, trieInfo.HashFn, &expectedRoot) @@ -224,7 +229,10 @@ func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error return err } - v.logger.Infow("Trie verification successful", "trie", trieInfo.Name, "root", expectedRoot.String()) + v.logger.Infow("Trie verification successful", + "trie", trieInfo.Name, "root", + expectedRoot.String(), + ) return nil } @@ -258,19 +266,30 @@ func (v *TrieVerifier) verifyContractStorageTries(ctx context.Context) error { Height: starknetTrieHeight, } - v.logger.Infow("Verifying contract storage", - "contract", contractAddress.String(), - "progress", fmt.Sprintf("%d/%d", i+1, len(contractAddresses))) + v.logger.Infow( + "Verifying contract storage", + "contract", + contractAddress.String(), + "progress", + fmt.Sprintf("%d/%d", i+1, len(contractAddresses)), + ) err := v.verifyTrie(ctx, trieInfo) if err != nil { if errors.Is(err, context.Canceled) { return err } - v.logger.Errorw("Contract storage verification failed", - "contract", contractAddress.String(), - "error", err) - return fmt.Errorf("contract storage verification failed for %s: %w", contractAddress.String(), err) + v.logger.Errorw( + "Contract storage verification failed", + "contract", + contractAddress.String(), + "error", + err, + ) + return fmt.Errorf( + "contract storage verification failed for %s: %w", + contractAddress.String(), err, + ) } } From 20dc28700810193412d538a954692296d42015ac Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 22 Dec 2025 15:42:15 +0100 Subject: [PATCH 06/15] restructure --- cmd/juno/verify/trie.go | 55 +++++++++++++++++++++++++++++++++ cmd/juno/verify/trie_test.go | 2 +- cmd/juno/verify/verify.go | 59 ++---------------------------------- 3 files changed, 58 insertions(+), 58 deletions(-) diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go index b6ddf54816..4a6fdeff7b 100644 --- a/cmd/juno/verify/trie.go +++ b/cmd/juno/verify/trie.go @@ -13,6 +13,7 @@ import ( "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/utils" + "github.com/spf13/cobra" ) const ( @@ -28,6 +29,60 @@ const ( ContractStorageTrieType TrieType = "contract-storage" ) +func verifyTrieCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "trie", + Short: "Verify trie integrity", + Long: `Verify trie integrity by rebuilding tries from leaf nodes and comparing root hashes.`, + RunE: runTrieVerify, + } + + cmd.Flags().StringSlice( + verifyTrieType, + nil, + "Trie types to verify (state, class, contract). Can be specified multiple times. Empty = all.", + ) + + return cmd +} + +func runTrieVerify(cmd *cobra.Command, args []string) error { + dbPath, err := cmd.Flags().GetString(verifyDBPathF) + if err != nil { + return err + } + + database, err := openDB(dbPath) + if err != nil { + return err + } + defer database.Close() + + trieTypes, err := cmd.Flags().GetStringSlice(verifyTrieType) + if err != nil { + return err + } + + cfg := &TrieConfig{} + + if len(trieTypes) > 0 { + cfg.Tries = make([]TrieType, len(trieTypes)) + for i, t := range trieTypes { + cfg.Tries[i] = TrieType(t) + } + } + + logLevel := utils.NewLogLevel(utils.INFO) + logger, err := utils.NewZapLogger(logLevel, true) + if err != nil { + return fmt.Errorf("failed to create logger: %w", err) + } + + verifier := NewTrieVerifier(database, logger) + ctx := cmd.Context() + return verifier.Run(ctx, cfg) +} + type TrieConfig struct { Tries []TrieType } diff --git a/cmd/juno/verify/trie_test.go b/cmd/juno/verify/trie_test.go index faa331eeca..74b1769a55 100644 --- a/cmd/juno/verify/trie_test.go +++ b/cmd/juno/verify/trie_test.go @@ -160,7 +160,7 @@ func TestTrieVerifier_Run_CorruptedTrie(t *testing.T) { ctx := context.Background() err = verifier.Run(ctx, cfg) require.Error(t, err) - assert.Contains(t, err.Error(), "root hash mismatch") + assert.Contains(t, err.Error(), "node corruption detected") } func TestTrieVerifier_Run_MultipleTrieTypes(t *testing.T) { diff --git a/cmd/juno/verify/verify.go b/cmd/juno/verify/verify.go index ec4b4bb79e..c21004664b 100644 --- a/cmd/juno/verify/verify.go +++ b/cmd/juno/verify/verify.go @@ -51,69 +51,14 @@ func verifyAll(cmd *cobra.Command, args []string) error { trieVerifier := NewTrieVerifier(database, logger) - runner := &VerifyRunner{ + verifier := &VerifyRunner{ Verifiers: []Verifier{ trieVerifier, }, } - configs := make(map[string]Config) ctx := cmd.Context() - return runner.Run(ctx, configs) -} - -func verifyTrieCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "trie", - Short: "Verify trie integrity", - Long: `Verify trie integrity by rebuilding tries from leaf nodes and comparing root hashes.`, - RunE: runTrieVerify, - } - - cmd.Flags().StringSlice( - verifyTrieType, - nil, - "Trie types to verify (state, class, contract). Can be specified multiple times. Empty = all.", - ) - - return cmd -} - -func runTrieVerify(cmd *cobra.Command, args []string) error { - dbPath, err := cmd.Flags().GetString(verifyDBPathF) - if err != nil { - return err - } - - database, err := openDB(dbPath) - if err != nil { - return err - } - defer database.Close() - - trieTypes, err := cmd.Flags().GetStringSlice(verifyTrieType) - if err != nil { - return err - } - - cfg := &TrieConfig{} - - if len(trieTypes) > 0 { - cfg.Tries = make([]TrieType, len(trieTypes)) - for i, t := range trieTypes { - cfg.Tries[i] = TrieType(t) - } - } - - logLevel := utils.NewLogLevel(utils.INFO) - logger, err := utils.NewZapLogger(logLevel, true) - if err != nil { - return fmt.Errorf("failed to create logger: %w", err) - } - - verifier := NewTrieVerifier(database, logger) - ctx := cmd.Context() - return verifier.Run(ctx, cfg) + return verifier.Run(ctx, nil) } func openDB(path string) (db.KeyValueStore, error) { From 56a912697e3aea80aedb723abb5cd5f6099f8acd Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 22 Dec 2025 15:49:17 +0100 Subject: [PATCH 07/15] split pr into two --- cmd/juno/verify/trie.go | 498 ----------------------------------- cmd/juno/verify/trie_test.go | 252 ------------------ cmd/juno/verify/verify.go | 17 +- 3 files changed, 2 insertions(+), 765 deletions(-) delete mode 100644 cmd/juno/verify/trie.go delete mode 100644 cmd/juno/verify/trie_test.go diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go deleted file mode 100644 index 4a6fdeff7b..0000000000 --- a/cmd/juno/verify/trie.go +++ /dev/null @@ -1,498 +0,0 @@ -package verify - -import ( - "bytes" - "context" - "errors" - "fmt" - "sync" - "time" - - "github.com/NethermindEth/juno/core/crypto" - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/utils" - "github.com/spf13/cobra" -) - -const ( - starknetTrieHeight = 251 - concurrencyMaxDepth = 8 -) - -type TrieType string - -const ( - ContractTrieType TrieType = "contract" - ClassTrieType TrieType = "class" - ContractStorageTrieType TrieType = "contract-storage" -) - -func verifyTrieCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "trie", - Short: "Verify trie integrity", - Long: `Verify trie integrity by rebuilding tries from leaf nodes and comparing root hashes.`, - RunE: runTrieVerify, - } - - cmd.Flags().StringSlice( - verifyTrieType, - nil, - "Trie types to verify (state, class, contract). Can be specified multiple times. Empty = all.", - ) - - return cmd -} - -func runTrieVerify(cmd *cobra.Command, args []string) error { - dbPath, err := cmd.Flags().GetString(verifyDBPathF) - if err != nil { - return err - } - - database, err := openDB(dbPath) - if err != nil { - return err - } - defer database.Close() - - trieTypes, err := cmd.Flags().GetStringSlice(verifyTrieType) - if err != nil { - return err - } - - cfg := &TrieConfig{} - - if len(trieTypes) > 0 { - cfg.Tries = make([]TrieType, len(trieTypes)) - for i, t := range trieTypes { - cfg.Tries[i] = TrieType(t) - } - } - - logLevel := utils.NewLogLevel(utils.INFO) - logger, err := utils.NewZapLogger(logLevel, true) - if err != nil { - return fmt.Errorf("failed to create logger: %w", err) - } - - verifier := NewTrieVerifier(database, logger) - ctx := cmd.Context() - return verifier.Run(ctx, cfg) -} - -type TrieConfig struct { - Tries []TrieType -} - -type TrieVerifier struct { - database db.KeyValueStore - logger utils.SimpleLogger -} - -func NewTrieVerifier(database db.KeyValueStore, logger utils.SimpleLogger) *TrieVerifier { - return &TrieVerifier{ - database: database, - logger: logger, - } -} - -func (v *TrieVerifier) Name() string { - return "trie" -} - -func (v *TrieVerifier) DefaultConfig() Config { - return &TrieConfig{ - Tries: nil, - } -} - -type TrieInfo struct { - Name string - prefix []byte - HashFunc trie.NewTrieFunc - HashFn crypto.HashFn - ReaderFunc func(db.KeyValueReader, []byte, uint8) (trie.TrieReader, error) - Height uint8 -} - -func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { - startTime := time.Now() - defer func() { - elapsed := time.Since(startTime) - v.logger.Infow("=== Trie verification finished ===", "total_elapsed", elapsed.Round(time.Second)) - }() - - trieCfg, ok := cfg.(*TrieConfig) - if !ok { - return fmt.Errorf("invalid config type for trie verifier: expected *TrieConfig") - } - - typesToVerify := trieCfg.Tries - if len(typesToVerify) == 0 { - typesToVerify = []TrieType{ContractTrieType, ClassTrieType, ContractStorageTrieType} - } - - typeSet := make(map[TrieType]bool) - for _, t := range typesToVerify { - typeSet[t] = true - } - - if typeSet[ContractTrieType] { - stateTrieInfo := TrieInfo{ - Name: "ContractsTrie", - prefix: db.StateTrie.Key(), - HashFunc: trie.NewTriePedersen, - HashFn: crypto.Pedersen, - ReaderFunc: trie.NewTrieReaderPedersen, - Height: starknetTrieHeight, - } - if err := v.verifyTrieWithLogging(ctx, stateTrieInfo); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - return err - } - } - - if typeSet[ClassTrieType] { - classTrieInfo := TrieInfo{ - Name: "ClassesTrie", - prefix: db.ClassesTrie.Key(), - HashFunc: trie.NewTriePoseidon, - HashFn: crypto.Poseidon, - ReaderFunc: trie.NewTrieReaderPoseidon, - Height: starknetTrieHeight, - } - if err := v.verifyTrieWithLogging(ctx, classTrieInfo); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - return err - } - } - - if typeSet[ContractStorageTrieType] { - if err := v.verifyContractStorageTries(ctx); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - return err - } - } - - v.logger.Infow("=== Trie verification completed successfully ===") - return nil -} - -func (v *TrieVerifier) collectContractAddresses() []felt.Felt { - contractAddresses := make([]felt.Felt, 0) - stateTriePrefix := db.StateTrie.Key() - - err := v.database.View(func(snap db.Snapshot) error { - it, err := snap.NewIterator(stateTriePrefix, true) - if err != nil { - return err - } - defer it.Close() - - for it.First(); it.Valid(); it.Next() { - keyBytes := it.Key() - if bytes.Equal(keyBytes, stateTriePrefix) { - continue - } - - if !bytes.HasPrefix(keyBytes, stateTriePrefix) { - continue - } - nodeKeyBytes := keyBytes[len(stateTriePrefix):] - - var nodeKey trie.BitArray - if err := nodeKey.UnmarshalBinary(nodeKeyBytes); err != nil { - continue - } - - if nodeKey.Len() == starknetTrieHeight { - contractAddr := nodeKey.Felt() - contractAddresses = append(contractAddresses, contractAddr) - } - } - return nil - }) - if err != nil { - return nil - } - - return contractAddresses -} - -func (v *TrieVerifier) verifyTrieWithLogging(ctx context.Context, trieInfo TrieInfo) error { - v.logger.Infow(fmt.Sprintf("=== Starting %s verification ===", trieInfo.Name)) - err := v.verifyTrie(ctx, trieInfo) - if err != nil { - if errors.Is(err, context.Canceled) { - v.logger.Infow("Verification stopped", "trie", trieInfo.Name) - return err - } - v.logger.Errorw(fmt.Sprintf("%s verification failed", trieInfo.Name), "error", err) - return fmt.Errorf("%s verification failed: %w", trieInfo.Name, err) - } - v.logger.Infow(fmt.Sprintf("%s verification completed successfully", trieInfo.Name)) - return nil -} - -func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error { - expectedRoot := felt.Zero - err := v.database.View(func(snap db.Snapshot) error { - reader, err := trieInfo.ReaderFunc(snap, trieInfo.prefix, trieInfo.Height) - if err != nil { - return err - } - if reader.RootKey() == nil { - expectedRoot = felt.Zero - return nil - } - expectedRoot, err = reader.Hash() - return err - }) - if err != nil { - v.logger.Errorw("Failed to get stored root hash", "trie", trieInfo.Name, "error", err) - return fmt.Errorf("failed to get stored root hash for %s: %w", trieInfo.Name, err) - } - - if expectedRoot.IsZero() { - v.logger.Infow("Trie is empty (zero root)", "trie", trieInfo.Name) - return nil - } - - v.logger.Infow("Starting verification", - "trie", - trieInfo.Name, - "expectedRoot", - expectedRoot.String(), - ) - storageReader := trie.NewReadStorage(v.database, trieInfo.prefix) - - err = verifyTrie(ctx, storageReader, starknetTrieHeight, trieInfo.HashFn, &expectedRoot) - if err != nil { - if errors.Is(err, context.Canceled) { - return err - } - v.logger.Errorw("Trie verification failed", "trie", trieInfo.Name, "error", err) - return err - } - - v.logger.Infow("Trie verification successful", - "trie", trieInfo.Name, "root", - expectedRoot.String(), - ) - return nil -} - -func (v *TrieVerifier) verifyContractStorageTries(ctx context.Context) error { - v.logger.Infow("=== Starting Contract Storage Tries verification ===") - contractAddresses := v.collectContractAddresses() - - if len(contractAddresses) == 0 { - v.logger.Infow("No contract addresses found, skipping contract storage verification") - return nil - } - - v.logger.Infow("Found contracts to verify", "count", len(contractAddresses)) - - for i, contractAddress := range contractAddresses { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - addrBytes := contractAddress.Marshal() - prefix := db.ContractStorage.Key(addrBytes) - trieInfo := TrieInfo{ - Name: fmt.Sprintf("ContractStorage[%s]", contractAddress.String()), - prefix: db.ContractStorage.Key(addrBytes), - HashFunc: trie.NewTriePedersen, - HashFn: crypto.Pedersen, - ReaderFunc: func(r db.KeyValueReader, _ []byte, height uint8) (trie.TrieReader, error) { - return trie.NewTrieReaderPedersen(r, prefix, height) - }, - Height: starknetTrieHeight, - } - - v.logger.Infow( - "Verifying contract storage", - "contract", - contractAddress.String(), - "progress", - fmt.Sprintf("%d/%d", i+1, len(contractAddresses)), - ) - - err := v.verifyTrie(ctx, trieInfo) - if err != nil { - if errors.Is(err, context.Canceled) { - return err - } - v.logger.Errorw( - "Contract storage verification failed", - "contract", - contractAddress.String(), - "error", - err, - ) - return fmt.Errorf( - "contract storage verification failed for %s: %w", - contractAddress.String(), err, - ) - } - } - - v.logger.Infow("All contract storage tries verified successfully", "count", len(contractAddresses)) - return nil -} - -func verifyTrie( - ctx context.Context, - reader *trie.ReadStorage, - height uint8, - hashFn crypto.HashFn, - expectedRoot *felt.Felt, -) error { - rootKey, err := reader.RootKey() - if err != nil { - return fmt.Errorf("failed to get root key: %w", err) - } - - if rootKey == nil { - return nil - } - - startTime := time.Now() - rootHash, err := verifyNode(ctx, reader, rootKey, nil, height, hashFn) - if err != nil { - return fmt.Errorf("node verification failed: %w", err) - } - - elapsed := time.Since(startTime) - - if rootHash.Cmp(expectedRoot) != 0 { - return fmt.Errorf( - "root hash mismatch: expected %s, got %s (verification took %v)", - expectedRoot, rootHash, elapsed.Round(time.Second), - ) - } - - return nil -} - -func verifyNode( - ctx context.Context, - reader *trie.ReadStorage, - key *trie.BitArray, - parentKey *trie.BitArray, - height uint8, - hashFn crypto.HashFn, -) (*felt.Felt, error) { - select { - case <-ctx.Done(): - return nil, fmt.Errorf("verification cancelled: %w", ctx.Err()) - default: - } - - node, err := reader.Get(key) - if err != nil { - return nil, fmt.Errorf("failed to get node at key %s: %w", key.String(), err) - } - - if key.Len() == height { - p := path(key, parentKey) - h := node.Hash(&p, hashFn) - return &h, nil - } - - useConcurrency := key.Len() <= concurrencyMaxDepth - var leftHash, rightHash *felt.Felt - var leftErr, rightErr error - - if useConcurrency { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - if node.Left.IsEmpty() { - zero := felt.Zero - leftHash = &zero - return - } - h, err := verifyNode(ctx, reader, node.Left, key, height, hashFn) - leftHash = h - leftErr = err - }() - - if node.Right.IsEmpty() { - zero := felt.Zero - rightHash = &zero - } else { - h, err := verifyNode(ctx, reader, node.Right, key, height, hashFn) - rightHash = h - rightErr = err - } - - wg.Wait() - - if leftErr != nil { - return nil, leftErr - } - if rightErr != nil { - return nil, rightErr - } - } else { - if node.Left.IsEmpty() { - zero := felt.Zero - leftHash = &zero - } else { - h, err := verifyNode(ctx, reader, node.Left, key, height, hashFn) - if err != nil { - return nil, err - } - leftHash = h - } - - if node.Right.IsEmpty() { - zero := felt.Zero - rightHash = &zero - } else { - h, err := verifyNode(ctx, reader, node.Right, key, height, hashFn) - if err != nil { - return nil, err - } - rightHash = h - } - } - - recomputed := hashFn(leftHash, rightHash) - if recomputed.Cmp(node.Value) != 0 { - return nil, fmt.Errorf( - "node corruption detected at key %s: stored hash=%s, recomputed hash=%s", - key.String(), node.Value.String(), recomputed.String(), - ) - } - - tmp := *node - tmp.Value = &recomputed - - p := path(key, parentKey) - h := tmp.Hash(&p, hashFn) - return &h, nil -} - -func path(key, parentKey *trie.BitArray) trie.BitArray { - if parentKey == nil { - return key.Copy() - } - - var pathKey trie.BitArray - pathKey.LSBs(key, parentKey.Len()+1) - return pathKey -} diff --git a/cmd/juno/verify/trie_test.go b/cmd/juno/verify/trie_test.go deleted file mode 100644 index 74b1769a55..0000000000 --- a/cmd/juno/verify/trie_test.go +++ /dev/null @@ -1,252 +0,0 @@ -package verify - -import ( - "context" - "testing" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/db/memory" - "github.com/NethermindEth/juno/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestTrieVerifier_Run_ValidStateTrie(t *testing.T) { - logger := utils.NewNopZapLogger() - testDB := memory.New() - defer testDB.Close() - - prefix := db.StateTrie.Key() - txn := testDB.NewIndexedBatch() - trieStorage := trie.NewStorage(txn, prefix) - - testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) - require.NoError(t, err) - - key1 := felt.NewFromUint64[felt.Felt](1) - value1 := felt.NewFromUint64[felt.Felt](100) - _, err = testTrie.Put(key1, value1) - require.NoError(t, err) - - key2 := felt.NewFromUint64[felt.Felt](2) - value2 := felt.NewFromUint64[felt.Felt](200) - _, err = testTrie.Put(key2, value2) - require.NoError(t, err) - - err = testTrie.Commit() - require.NoError(t, err) - - rootHash, err := testTrie.Hash() - require.NoError(t, err) - - if testTrie.RootKey() != nil { - err = trieStorage.PutRootKey(testTrie.RootKey()) - require.NoError(t, err) - } - - err = txn.Write() - require.NoError(t, err) - - verifier := NewTrieVerifier(testDB, logger) - cfg := &TrieConfig{ - Tries: []TrieType{ContractTrieType}, - } - - ctx := context.Background() - err = verifier.Run(ctx, cfg) - assert.NoError(t, err) - - reader, err := trie.NewTrieReaderPedersen(testDB, prefix, starknetTrieHeight) - require.NoError(t, err) - storedHash, err := reader.Hash() - require.NoError(t, err) - assert.True(t, rootHash.Equal(&storedHash)) -} - -func TestTrieVerifier_Run_ValidClassTrie(t *testing.T) { - logger := utils.NewNopZapLogger() - testDB := memory.New() - defer testDB.Close() - - prefix := db.ClassesTrie.Key() - txn := testDB.NewIndexedBatch() - trieStorage := trie.NewStorage(txn, prefix) - - testTrie, err := trie.NewTriePoseidon(txn, prefix, starknetTrieHeight) - require.NoError(t, err) - - key1 := felt.NewFromUint64[felt.Felt](10) - value1 := felt.NewFromUint64[felt.Felt](1000) - _, err = testTrie.Put(key1, value1) - require.NoError(t, err) - - err = testTrie.Commit() - require.NoError(t, err) - - if testTrie.RootKey() != nil { - err = trieStorage.PutRootKey(testTrie.RootKey()) - require.NoError(t, err) - } - - err = txn.Write() - require.NoError(t, err) - - verifier := NewTrieVerifier(testDB, logger) - cfg := &TrieConfig{ - Tries: []TrieType{ClassTrieType}, - } - - ctx := context.Background() - err = verifier.Run(ctx, cfg) - assert.NoError(t, err) -} - -func TestTrieVerifier_Run_CorruptedTrie(t *testing.T) { - logger := utils.NewNopZapLogger() - testDB := memory.New() - defer testDB.Close() - - prefix := db.StateTrie.Key() - txn := testDB.NewIndexedBatch() - trieStorage := trie.NewStorage(txn, prefix) - - testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) - require.NoError(t, err) - - key1 := felt.NewFromUint64[felt.Felt](1) - value1 := felt.NewFromUint64[felt.Felt](100) - _, err = testTrie.Put(key1, value1) - require.NoError(t, err) - - key2 := felt.NewFromUint64[felt.Felt](2) - value2 := felt.NewFromUint64[felt.Felt](200) - _, err = testTrie.Put(key2, value2) - require.NoError(t, err) - - err = testTrie.Commit() - require.NoError(t, err) - - if testTrie.RootKey() != nil { - err = trieStorage.PutRootKey(testTrie.RootKey()) - require.NoError(t, err) - } - - var nodeKey trie.BitArray - nodeKey.SetFelt(starknetTrieHeight, key1) - - node, err := trieStorage.Get(&nodeKey) - require.NoError(t, err) - require.NotNil(t, node) - require.NotNil(t, node.Value) - - assert.True(t, node.Value.Equal(value1), "Expected value1 but got different value") - - corruptedValue := felt.NewFromUint64[felt.Felt](999999) - node.Value = corruptedValue - - err = trieStorage.Put(&nodeKey, node) - require.NoError(t, err) - - err = txn.Write() - require.NoError(t, err) - - verifier := NewTrieVerifier(testDB, logger) - cfg := &TrieConfig{ - Tries: []TrieType{ContractTrieType}, - } - - ctx := context.Background() - err = verifier.Run(ctx, cfg) - require.Error(t, err) - assert.Contains(t, err.Error(), "node corruption detected") -} - -func TestTrieVerifier_Run_MultipleTrieTypes(t *testing.T) { - logger := utils.NewNopZapLogger() - testDB := memory.New() - defer testDB.Close() - - statePrefix := db.StateTrie.Key() - txn := testDB.NewIndexedBatch() - stateTrie, err := trie.NewTriePedersen(txn, statePrefix, starknetTrieHeight) - require.NoError(t, err) - - key1 := felt.NewFromUint64[felt.Felt](1) - value1 := felt.NewFromUint64[felt.Felt](100) - _, err = stateTrie.Put(key1, value1) - require.NoError(t, err) - - err = stateTrie.Commit() - require.NoError(t, err) - - stateStorage := trie.NewStorage(txn, statePrefix) - if stateTrie.RootKey() != nil { - err = stateStorage.PutRootKey(stateTrie.RootKey()) - require.NoError(t, err) - } - - classPrefix := db.ClassesTrie.Key() - classTrie, err := trie.NewTriePoseidon(txn, classPrefix, starknetTrieHeight) - require.NoError(t, err) - - key2 := felt.NewFromUint64[felt.Felt](2) - value2 := felt.NewFromUint64[felt.Felt](200) - _, err = classTrie.Put(key2, value2) - require.NoError(t, err) - - err = classTrie.Commit() - require.NoError(t, err) - - classStorage := trie.NewStorage(txn, classPrefix) - if classTrie.RootKey() != nil { - err = classStorage.PutRootKey(classTrie.RootKey()) - require.NoError(t, err) - } - - err = txn.Write() - require.NoError(t, err) - - verifier := NewTrieVerifier(testDB, logger) - cfg := &TrieConfig{ - Tries: []TrieType{ContractTrieType, ClassTrieType}, - } - - ctx := context.Background() - err = verifier.Run(ctx, cfg) - assert.NoError(t, err) -} - -func TestTrieVerifier_Run_EmptyTrie(t *testing.T) { - logger := utils.NewNopZapLogger() - testDB := memory.New() - defer testDB.Close() - - prefix := db.StateTrie.Key() - txn := testDB.NewIndexedBatch() - trieStorage := trie.NewStorage(txn, prefix) - - testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) - require.NoError(t, err) - - err = testTrie.Commit() - require.NoError(t, err) - - if testTrie.RootKey() != nil { - err = trieStorage.PutRootKey(testTrie.RootKey()) - require.NoError(t, err) - } - - err = txn.Write() - require.NoError(t, err) - - verifier := NewTrieVerifier(testDB, logger) - cfg := &TrieConfig{ - Tries: []TrieType{ContractTrieType}, - } - - ctx := context.Background() - err = verifier.Run(ctx, cfg) - assert.NoError(t, err) -} diff --git a/cmd/juno/verify/verify.go b/cmd/juno/verify/verify.go index c21004664b..3b937b1988 100644 --- a/cmd/juno/verify/verify.go +++ b/cmd/juno/verify/verify.go @@ -7,13 +7,11 @@ import ( "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebblev2" - "github.com/NethermindEth/juno/utils" "github.com/spf13/cobra" ) const ( - verifyDBPathF = "db-path" - verifyTrieType = "type" + verifyDBPathF = "db-path" ) func VerifyCmd(defaultDBPath string) *cobra.Command { @@ -24,7 +22,6 @@ func VerifyCmd(defaultDBPath string) *cobra.Command { } verifyCmd.PersistentFlags().String(verifyDBPathF, defaultDBPath, "Path to the database") - verifyCmd.AddCommand(verifyTrieCmd()) verifyCmd.RunE = verifyAll return verifyCmd @@ -43,18 +40,8 @@ func verifyAll(cmd *cobra.Command, args []string) error { } defer database.Close() - logLevel := utils.NewLogLevel(utils.INFO) - logger, err := utils.NewZapLogger(logLevel, true) - if err != nil { - return fmt.Errorf("failed to create logger: %w", err) - } - - trieVerifier := NewTrieVerifier(database, logger) - verifier := &VerifyRunner{ - Verifiers: []Verifier{ - trieVerifier, - }, + Verifiers: []Verifier{}, } ctx := cmd.Context() From 3377dfbe6f2e0d90753e8a20a17f9f406606aaf9 Mon Sep 17 00:00:00 2001 From: MaksymMalicki <81577596+MaksymMalicki@users.noreply.github.com> Date: Tue, 3 Feb 2026 13:23:55 +0100 Subject: [PATCH 08/15] feat(cmd): trie verifier in the cmd (#3338) * trie verifier * fix unit tests --- cmd/juno/verify/trie.go | 534 +++++++++++++++++++++++++++++++++++ cmd/juno/verify/trie_test.go | 322 +++++++++++++++++++++ cmd/juno/verify/verify.go | 16 +- 3 files changed, 870 insertions(+), 2 deletions(-) create mode 100644 cmd/juno/verify/trie.go create mode 100644 cmd/juno/verify/trie_test.go diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go new file mode 100644 index 0000000000..4aa980666c --- /dev/null +++ b/cmd/juno/verify/trie.go @@ -0,0 +1,534 @@ +package verify + +import ( + "bytes" + "context" + "errors" + "fmt" + "slices" + "sync" + "time" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/utils" + "github.com/spf13/cobra" +) + +const ( + starknetTrieHeight = 251 + concurrencyMaxDepth = 8 + verifyContractAddr = "address" +) + +type TrieType string + +const ( + ContractTrieType TrieType = "contract" + ClassTrieType TrieType = "class" + ContractStorageTrieType TrieType = "contract-storage" +) + +func verifyTrieCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "trie", + Short: "Verify trie integrity", + Long: `Verify trie integrity by rebuilding tries from leaf nodes and comparing root hashes.`, + RunE: runTrieVerify, + } + + cmd.Flags().StringSlice( + verifyTrieType, + nil, + "Trie types to verify (contract, class, contract-storage). Can be specified multiple times. Empty = all.", + ) + + cmd.Flags().String( + verifyContractAddr, + "", + "Contract address to verify (only used with --type contract-storage). If not specified, all contract storage tries are verified.", + ) + + return cmd +} + +func runTrieVerify(cmd *cobra.Command, args []string) error { + dbPath, err := cmd.Flags().GetString(verifyDBPathF) + if err != nil { + return err + } + + database, err := openDB(dbPath) + if err != nil { + return err + } + defer database.Close() + + trieTypes, err := cmd.Flags().GetStringSlice(verifyTrieType) + if err != nil { + return err + } + + contractAddrStr, err := cmd.Flags().GetString(verifyContractAddr) + if err != nil { + return err + } + + cfg := &TrieConfig{} + + if len(trieTypes) > 0 { + cfg.Tries = make([]TrieType, len(trieTypes)) + for i, t := range trieTypes { + cfg.Tries[i] = TrieType(t) + } + } + + if contractAddrStr != "" { + hasContractStorage := slices.Contains(cfg.Tries, ContractStorageTrieType) + if len(cfg.Tries) == 0 { + hasContractStorage = true + } + + if !hasContractStorage { + return fmt.Errorf("--address flag can only be used with --type contract-storage") + } + + var contractAddr felt.Felt + _, err := (&contractAddr).SetString(contractAddrStr) + if err != nil { + return fmt.Errorf("invalid contract address %s: %w", contractAddrStr, err) + } + cfg.ContractAddress = &contractAddr + } + + logLevel := utils.NewLogLevel(utils.INFO) + logger, err := utils.NewZapLogger(logLevel, true) + if err != nil { + return fmt.Errorf("failed to create logger: %w", err) + } + + verifier := NewTrieVerifier(database, logger) + ctx := cmd.Context() + return verifier.Run(ctx, cfg) +} + +type TrieConfig struct { + Tries []TrieType + ContractAddress *felt.Felt +} + +type TrieVerifier struct { + database db.KeyValueStore + logger utils.SimpleLogger +} + +func NewTrieVerifier(database db.KeyValueStore, logger utils.SimpleLogger) *TrieVerifier { + return &TrieVerifier{ + database: database, + logger: logger, + } +} + +func (v *TrieVerifier) Name() string { + return "trie" +} + +func (v *TrieVerifier) DefaultConfig() Config { + return &TrieConfig{ + Tries: nil, + } +} + +type TrieInfo struct { + Name string + prefix []byte + HashFunc trie.NewTrieFunc + HashFn crypto.HashFn + ReaderFunc func(db.KeyValueReader, []byte, uint8) (trie.TrieReader, error) + Height uint8 +} + +func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { + startTime := time.Now() + defer func() { + elapsed := time.Since(startTime) + v.logger.Infow("=== Trie verification finished ===", "total_elapsed", elapsed.Round(time.Second)) + }() + + trieCfg, ok := cfg.(*TrieConfig) + if !ok { + return fmt.Errorf("invalid config type for trie verifier: expected *TrieConfig") + } + + typesToVerify := trieCfg.Tries + if len(typesToVerify) == 0 { + typesToVerify = []TrieType{ContractTrieType, ClassTrieType, ContractStorageTrieType} + } + + typeSet := make(map[TrieType]bool) + for _, t := range typesToVerify { + typeSet[t] = true + } + + if typeSet[ContractTrieType] { + stateTrieInfo := TrieInfo{ + Name: "ContractsTrie", + prefix: db.StateTrie.Key(), + HashFunc: trie.NewTriePedersen, + HashFn: crypto.Pedersen, + ReaderFunc: trie.NewTrieReaderPedersen, + Height: starknetTrieHeight, + } + if err := v.verifyTrieWithLogging(ctx, stateTrieInfo); err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err + } + } + + if typeSet[ClassTrieType] { + classTrieInfo := TrieInfo{ + Name: "ClassesTrie", + prefix: db.ClassesTrie.Key(), + HashFunc: trie.NewTriePoseidon, + HashFn: crypto.Poseidon, + ReaderFunc: trie.NewTrieReaderPoseidon, + Height: starknetTrieHeight, + } + if err := v.verifyTrieWithLogging(ctx, classTrieInfo); err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err + } + } + + if typeSet[ContractStorageTrieType] { + if err := v.verifyContractStorageTries(ctx, trieCfg.ContractAddress); err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err + } + } + + v.logger.Infow("=== Trie verification completed successfully ===") + return nil +} + +func (v *TrieVerifier) collectContractAddresses() []felt.Felt { + contractAddresses := make([]felt.Felt, 0) + stateTriePrefix := db.StateTrie.Key() + + err := v.database.View(func(snap db.Snapshot) error { + it, err := snap.NewIterator(stateTriePrefix, true) + if err != nil { + return err + } + defer it.Close() + + for it.First(); it.Valid(); it.Next() { + keyBytes := it.Key() + if bytes.Equal(keyBytes, stateTriePrefix) { + continue + } + + if !bytes.HasPrefix(keyBytes, stateTriePrefix) { + continue + } + nodeKeyBytes := keyBytes[len(stateTriePrefix):] + + var nodeKey trie.BitArray + if err := nodeKey.UnmarshalBinary(nodeKeyBytes); err != nil { + continue + } + + if nodeKey.Len() == starknetTrieHeight { + contractAddr := nodeKey.Felt() + contractAddresses = append(contractAddresses, contractAddr) + } + } + return nil + }) + if err != nil { + return nil + } + + return contractAddresses +} + +func (v *TrieVerifier) verifyTrieWithLogging(ctx context.Context, trieInfo TrieInfo) error { + err := v.verifyTrie(ctx, trieInfo) + if err != nil { + if errors.Is(err, context.Canceled) { + v.logger.Infow("Verification stopped", "trie", trieInfo.Name) + return err + } + v.logger.Errorw(fmt.Sprintf("%s verification failed", trieInfo.Name), "error", err) + return fmt.Errorf("%s verification failed: %w", trieInfo.Name, err) + } + return nil +} + +func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error { + v.logger.Infow(fmt.Sprintf("=== Starting %s verification ===", trieInfo.Name)) + expectedRoot := felt.Zero + err := v.database.View(func(snap db.Snapshot) error { + reader, err := trieInfo.ReaderFunc(snap, trieInfo.prefix, trieInfo.Height) + if err != nil { + return err + } + if reader.RootKey() == nil { + expectedRoot = felt.Zero + return nil + } + expectedRoot, err = reader.Hash() + return err + }) + if err != nil { + v.logger.Errorw("Failed to get stored root hash", "trie", trieInfo.Name, "error", err) + return fmt.Errorf("failed to get stored root hash for %s: %w", trieInfo.Name, err) + } + + if expectedRoot.IsZero() { + v.logger.Infow("Trie is empty (zero root)", "trie", trieInfo.Name) + return nil + } + + v.logger.Infow("Starting verification", + "trie", + trieInfo.Name, + "expectedRoot", + expectedRoot.String(), + ) + storageReader := trie.NewReadStorage(v.database, trieInfo.prefix) + + err = verifyTrie(ctx, storageReader, starknetTrieHeight, trieInfo.HashFn, &expectedRoot) + if err != nil { + if errors.Is(err, context.Canceled) { + return err + } + v.logger.Errorw("Trie verification failed", "trie", trieInfo.Name, "error", err) + return err + } + + v.logger.Infow("Trie verification successful", + "trie", trieInfo.Name, "root", + expectedRoot.String(), + ) + return nil +} + +func (v *TrieVerifier) verifyContractStorageTries(ctx context.Context, filterAddress *felt.Felt) error { + v.logger.Infow("=== Starting Contract Storage Tries verification ===") + + var contractAddresses []felt.Felt + if filterAddress != nil { + contractAddresses = []felt.Felt{*filterAddress} + v.logger.Infow("Verifying specific contract", "address", filterAddress.String()) + } else { + contractAddresses = v.collectContractAddresses() + if len(contractAddresses) == 0 { + v.logger.Infow("No contract addresses found, skipping contract storage verification") + return nil + } + v.logger.Infow("Found contracts to verify", "count", len(contractAddresses)) + } + + for i, contractAddress := range contractAddresses { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + addrBytes := contractAddress.Marshal() + prefix := db.ContractStorage.Key(addrBytes) + trieInfo := TrieInfo{ + Name: fmt.Sprintf("ContractStorage[%s]", contractAddress.String()), + prefix: db.ContractStorage.Key(addrBytes), + HashFunc: trie.NewTriePedersen, + HashFn: crypto.Pedersen, + ReaderFunc: func(r db.KeyValueReader, _ []byte, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPedersen(r, prefix, height) + }, + Height: starknetTrieHeight, + } + + v.logger.Infow( + "Verifying contract storage", + "contract", + contractAddress.String(), + "progress", + fmt.Sprintf("%d/%d", i+1, len(contractAddresses)), + ) + + err := v.verifyTrie(ctx, trieInfo) + if err != nil { + if errors.Is(err, context.Canceled) { + return err + } + v.logger.Errorw( + "Contract storage verification failed", + "contract", + contractAddress.String(), + "error", + err, + ) + return fmt.Errorf( + "contract storage verification failed for %s: %w", + contractAddress.String(), err, + ) + } + } + + v.logger.Infow("All contract storage tries verified successfully", "count", len(contractAddresses)) + return nil +} + +func verifyTrie( + ctx context.Context, + reader *trie.ReadStorage, + height uint8, + hashFn crypto.HashFn, + expectedRoot *felt.Felt, +) error { + rootKey, err := reader.RootKey() + if err != nil { + return fmt.Errorf("failed to get root key: %w", err) + } + + if rootKey == nil { + return nil + } + + startTime := time.Now() + rootHash, err := verifyNode(ctx, reader, rootKey, nil, height, hashFn) + if err != nil { + return fmt.Errorf("node verification failed: %w", err) + } + + elapsed := time.Since(startTime) + + if rootHash.Cmp(expectedRoot) != 0 { + return fmt.Errorf( + "root hash mismatch: expected %s, got %s (verification took %v)", + expectedRoot, rootHash, elapsed.Round(time.Second), + ) + } + + return nil +} + +func verifyNode( + ctx context.Context, + reader *trie.ReadStorage, + key *trie.BitArray, + parentKey *trie.BitArray, + height uint8, + hashFn crypto.HashFn, +) (*felt.Felt, error) { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("verification cancelled: %w", ctx.Err()) + default: + } + + node, err := reader.Get(key) + if err != nil { + return nil, fmt.Errorf("failed to get node at key %s: %w", key.String(), err) + } + + if key.Len() == height { + p := path(key, parentKey) + h := node.Hash(&p, hashFn) + return &h, nil + } + + useConcurrency := key.Len() <= concurrencyMaxDepth + var leftHash, rightHash *felt.Felt + var leftErr, rightErr error + + if useConcurrency { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if node.Left.IsEmpty() { + zero := felt.Zero + leftHash = &zero + return + } + h, err := verifyNode(ctx, reader, node.Left, key, height, hashFn) + leftHash = h + leftErr = err + }() + + if node.Right.IsEmpty() { + zero := felt.Zero + rightHash = &zero + } else { + h, err := verifyNode(ctx, reader, node.Right, key, height, hashFn) + rightHash = h + rightErr = err + } + + wg.Wait() + + if leftErr != nil { + return nil, leftErr + } + if rightErr != nil { + return nil, rightErr + } + } else { + if node.Left.IsEmpty() { + zero := felt.Zero + leftHash = &zero + } else { + h, err := verifyNode(ctx, reader, node.Left, key, height, hashFn) + if err != nil { + return nil, err + } + leftHash = h + } + + if node.Right.IsEmpty() { + zero := felt.Zero + rightHash = &zero + } else { + h, err := verifyNode(ctx, reader, node.Right, key, height, hashFn) + if err != nil { + return nil, err + } + rightHash = h + } + } + + recomputed := hashFn(leftHash, rightHash) + if recomputed.Cmp(node.Value) != 0 { + return nil, fmt.Errorf( + "node corruption detected at key %s: stored hash=%s, recomputed hash=%s", + key.String(), node.Value.String(), recomputed.String(), + ) + } + + tmp := *node + tmp.Value = &recomputed + + p := path(key, parentKey) + h := tmp.Hash(&p, hashFn) + return &h, nil +} + +func path(key, parentKey *trie.BitArray) trie.BitArray { + if parentKey == nil { + return key.Copy() + } + + var pathKey trie.BitArray + pathKey.LSBs(key, parentKey.Len()+1) + return pathKey +} diff --git a/cmd/juno/verify/trie_test.go b/cmd/juno/verify/trie_test.go new file mode 100644 index 0000000000..0b8c1834d2 --- /dev/null +++ b/cmd/juno/verify/trie_test.go @@ -0,0 +1,322 @@ +package verify + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/NethermindEth/juno/db/pebblev2" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTrieVerifier_Run_ValidStateTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = testTrie.Put(key2, value2) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + rootHash, err := testTrie.Hash() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger) + cfg := &TrieConfig{ + Tries: []TrieType{ContractTrieType}, + } + + ctx := context.Background() + err = verifier.Run(ctx, cfg) + assert.NoError(t, err) + + reader, err := trie.NewTrieReaderPedersen(testDB, prefix, starknetTrieHeight) + require.NoError(t, err) + storedHash, err := reader.Hash() + require.NoError(t, err) + assert.True(t, rootHash.Equal(&storedHash)) +} + +func TestTrieVerifier_Run_ValidClassTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.ClassesTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePoseidon(txn, prefix, starknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](10) + value1 := felt.NewFromUint64[felt.Felt](1000) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger) + cfg := &TrieConfig{ + Tries: []TrieType{ClassTrieType}, + } + + ctx := context.Background() + err = verifier.Run(ctx, cfg) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_CorruptedTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = testTrie.Put(key2, value2) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + var nodeKey trie.BitArray + nodeKey.SetFelt(starknetTrieHeight, key1) + + node, err := trieStorage.Get(&nodeKey) + require.NoError(t, err) + require.NotNil(t, node) + require.NotNil(t, node.Value) + + assert.True(t, node.Value.Equal(value1), "Expected value1 but got different value") + + corruptedValue := felt.NewFromUint64[felt.Felt](999999) + node.Value = corruptedValue + + err = trieStorage.Put(&nodeKey, node) + require.NoError(t, err) + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger) + cfg := &TrieConfig{ + Tries: []TrieType{ContractTrieType}, + } + + ctx := context.Background() + err = verifier.Run(ctx, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "node corruption detected") +} + +func TestTrieVerifier_Run_MultipleTrieTypes(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + statePrefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + stateTrie, err := trie.NewTriePedersen(txn, statePrefix, starknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = stateTrie.Put(key1, value1) + require.NoError(t, err) + + err = stateTrie.Commit() + require.NoError(t, err) + + stateStorage := trie.NewStorage(txn, statePrefix) + if stateTrie.RootKey() != nil { + err = stateStorage.PutRootKey(stateTrie.RootKey()) + require.NoError(t, err) + } + + classPrefix := db.ClassesTrie.Key() + classTrie, err := trie.NewTriePoseidon(txn, classPrefix, starknetTrieHeight) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = classTrie.Put(key2, value2) + require.NoError(t, err) + + err = classTrie.Commit() + require.NoError(t, err) + + classStorage := trie.NewStorage(txn, classPrefix) + if classTrie.RootKey() != nil { + err = classStorage.PutRootKey(classTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger) + cfg := &TrieConfig{ + Tries: []TrieType{ContractTrieType, ClassTrieType}, + } + + ctx := context.Background() + err = verifier.Run(ctx, cfg) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_EmptyTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger) + cfg := &TrieConfig{ + Tries: []TrieType{ContractTrieType}, + } + + ctx := context.Background() + err = verifier.Run(ctx, cfg) + assert.NoError(t, err) +} + +func TestRunTrieVerify_AddressFlagValidation(t *testing.T) { + tests := []struct { + name string + trieTypes []string + address string + expectError bool + expectedErrMsg string + }{ + { + name: "address with contract-storage type should succeed", + trieTypes: []string{"contract-storage"}, + address: "0x123", + expectError: false, + }, + { + name: "address with contract and class types should fail", + trieTypes: []string{"contract", "class"}, + address: "0x123", + expectError: true, + expectedErrMsg: "--address flag can only be used with --type contract-storage", + }, + { + name: "address with no type specified should succeed (default includes contract-storage)", + trieTypes: []string{}, + address: "0x123", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "test.db") + + testDB, err := pebblev2.New(dbPath) + require.NoError(t, err) + testDB.Close() + + parentCmd := VerifyCmd("") + args := []string{"--db-path", dbPath, "trie"} + + for _, trieType := range tt.trieTypes { + args = append(args, "--type", trieType) + } + + if tt.address != "" { + args = append(args, "--address", tt.address) + } + + parentCmd.SetArgs(args) + parentCmd.SetOut(os.Stderr) + parentCmd.SetErr(os.Stderr) + + err = parentCmd.ExecuteContext(context.Background()) + + if tt.expectError { + require.Error(t, err) + if tt.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrMsg) + } + } else if err != nil { + assert.NotContains(t, err.Error(), "--address flag can only be used with --type contract-storage") + } + }) + } +} diff --git a/cmd/juno/verify/verify.go b/cmd/juno/verify/verify.go index 3b937b1988..c6470ecaf7 100644 --- a/cmd/juno/verify/verify.go +++ b/cmd/juno/verify/verify.go @@ -7,11 +7,13 @@ import ( "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebblev2" + "github.com/NethermindEth/juno/utils" "github.com/spf13/cobra" ) const ( - verifyDBPathF = "db-path" + verifyDBPathF = "db-path" + verifyTrieType = "type" ) func VerifyCmd(defaultDBPath string) *cobra.Command { @@ -22,6 +24,7 @@ func VerifyCmd(defaultDBPath string) *cobra.Command { } verifyCmd.PersistentFlags().String(verifyDBPathF, defaultDBPath, "Path to the database") + verifyCmd.AddCommand(verifyTrieCmd()) verifyCmd.RunE = verifyAll return verifyCmd @@ -40,8 +43,17 @@ func verifyAll(cmd *cobra.Command, args []string) error { } defer database.Close() + logLevel := utils.NewLogLevel(utils.INFO) + logger, err := utils.NewZapLogger(logLevel, true) + if err != nil { + return fmt.Errorf("failed to create logger: %w", err) + } + + trieVerifier := NewTrieVerifier(database, logger) verifier := &VerifyRunner{ - Verifiers: []Verifier{}, + Verifiers: []Verifier{ + trieVerifier, + }, } ctx := cmd.Context() From 1855fd93bedd040eeda9e3784dfbb3b1dbdfe0f2 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Tue, 3 Feb 2026 17:37:50 +0100 Subject: [PATCH 09/15] linter --- cmd/juno/verify/trie.go | 82 ++++++++++++++++++------------------ cmd/juno/verify/trie_test.go | 3 +- 2 files changed, 42 insertions(+), 43 deletions(-) diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go index 4aa980666c..0f59377aad 100644 --- a/cmd/juno/verify/trie.go +++ b/cmd/juno/verify/trie.go @@ -15,6 +15,7 @@ import ( "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/utils" "github.com/spf13/cobra" + "go.uber.org/zap" ) const ( @@ -42,13 +43,15 @@ func verifyTrieCmd() *cobra.Command { cmd.Flags().StringSlice( verifyTrieType, nil, - "Trie types to verify (contract, class, contract-storage). Can be specified multiple times. Empty = all.", + "Trie types to verify (contract, class, contract-storage)."+ + "If not specified, all trie types are verified.", ) cmd.Flags().String( verifyContractAddr, "", - "Contract address to verify (only used with --type contract-storage). If not specified, all contract storage tries are verified.", + "Contract address to verify (only used with --type contract-storage). "+ + "If not specified, all contract storage tries are verified.", ) return cmd @@ -121,10 +124,10 @@ type TrieConfig struct { type TrieVerifier struct { database db.KeyValueStore - logger utils.SimpleLogger + logger utils.StructuredLogger } -func NewTrieVerifier(database db.KeyValueStore, logger utils.SimpleLogger) *TrieVerifier { +func NewTrieVerifier(database db.KeyValueStore, logger utils.StructuredLogger) *TrieVerifier { return &TrieVerifier{ database: database, logger: logger, @@ -154,7 +157,8 @@ func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { startTime := time.Now() defer func() { elapsed := time.Since(startTime) - v.logger.Infow("=== Trie verification finished ===", "total_elapsed", elapsed.Round(time.Second)) + v.logger.Info("=== Trie verification finished ===", + zap.Duration("total_elapsed", elapsed.Round(time.Second))) }() trieCfg, ok := cfg.(*TrieConfig) @@ -215,7 +219,7 @@ func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { } } - v.logger.Infow("=== Trie verification completed successfully ===") + v.logger.Info("=== Trie verification completed successfully ===") return nil } @@ -264,17 +268,17 @@ func (v *TrieVerifier) verifyTrieWithLogging(ctx context.Context, trieInfo TrieI err := v.verifyTrie(ctx, trieInfo) if err != nil { if errors.Is(err, context.Canceled) { - v.logger.Infow("Verification stopped", "trie", trieInfo.Name) + v.logger.Info("Verification stopped", zap.String("trie", trieInfo.Name)) return err } - v.logger.Errorw(fmt.Sprintf("%s verification failed", trieInfo.Name), "error", err) + v.logger.Error(fmt.Sprintf("%s verification failed", trieInfo.Name), zap.Error(err)) return fmt.Errorf("%s verification failed: %w", trieInfo.Name, err) } return nil } func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error { - v.logger.Infow(fmt.Sprintf("=== Starting %s verification ===", trieInfo.Name)) + v.logger.Info(fmt.Sprintf("=== Starting %s verification ===", trieInfo.Name)) expectedRoot := felt.Zero err := v.database.View(func(snap db.Snapshot) error { reader, err := trieInfo.ReaderFunc(snap, trieInfo.prefix, trieInfo.Height) @@ -289,21 +293,19 @@ func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error return err }) if err != nil { - v.logger.Errorw("Failed to get stored root hash", "trie", trieInfo.Name, "error", err) + v.logger.Error("Failed to get stored root hash", + zap.String("trie", trieInfo.Name), zap.Error(err)) return fmt.Errorf("failed to get stored root hash for %s: %w", trieInfo.Name, err) } if expectedRoot.IsZero() { - v.logger.Infow("Trie is empty (zero root)", "trie", trieInfo.Name) + v.logger.Info("Trie is empty (zero root)", zap.String("trie", trieInfo.Name)) return nil } - v.logger.Infow("Starting verification", - "trie", - trieInfo.Name, - "expectedRoot", - expectedRoot.String(), - ) + v.logger.Info("Starting verification", + zap.String("trie", trieInfo.Name), + zap.String("expectedRoot", expectedRoot.String())) storageReader := trie.NewReadStorage(v.database, trieInfo.prefix) err = verifyTrie(ctx, storageReader, starknetTrieHeight, trieInfo.HashFn, &expectedRoot) @@ -311,31 +313,34 @@ func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error if errors.Is(err, context.Canceled) { return err } - v.logger.Errorw("Trie verification failed", "trie", trieInfo.Name, "error", err) + v.logger.Error("Trie verification failed", + zap.String("trie", trieInfo.Name), zap.Error(err)) return err } - v.logger.Infow("Trie verification successful", - "trie", trieInfo.Name, "root", - expectedRoot.String(), - ) + v.logger.Info("Trie verification successful", + zap.String("trie", trieInfo.Name), zap.String("root", expectedRoot.String())) return nil } -func (v *TrieVerifier) verifyContractStorageTries(ctx context.Context, filterAddress *felt.Felt) error { - v.logger.Infow("=== Starting Contract Storage Tries verification ===") +func (v *TrieVerifier) verifyContractStorageTries( + ctx context.Context, filterAddress *felt.Felt, +) error { + v.logger.Info("=== Starting Contract Storage Tries verification ===") var contractAddresses []felt.Felt if filterAddress != nil { contractAddresses = []felt.Felt{*filterAddress} - v.logger.Infow("Verifying specific contract", "address", filterAddress.String()) + v.logger.Info("Verifying specific contract", + zap.String("address", filterAddress.String())) } else { contractAddresses = v.collectContractAddresses() if len(contractAddresses) == 0 { - v.logger.Infow("No contract addresses found, skipping contract storage verification") + v.logger.Info("No contract addresses found, skipping contract storage verification") return nil } - v.logger.Infow("Found contracts to verify", "count", len(contractAddresses)) + v.logger.Info("Found contracts to verify", + zap.Int("count", len(contractAddresses))) } for i, contractAddress := range contractAddresses { @@ -357,26 +362,18 @@ func (v *TrieVerifier) verifyContractStorageTries(ctx context.Context, filterAdd Height: starknetTrieHeight, } - v.logger.Infow( - "Verifying contract storage", - "contract", - contractAddress.String(), - "progress", - fmt.Sprintf("%d/%d", i+1, len(contractAddresses)), - ) + v.logger.Info("Verifying contract storage", + zap.String("contract", contractAddress.String()), + zap.String("progress", fmt.Sprintf("%d/%d", i+1, len(contractAddresses)))) err := v.verifyTrie(ctx, trieInfo) if err != nil { if errors.Is(err, context.Canceled) { return err } - v.logger.Errorw( - "Contract storage verification failed", - "contract", - contractAddress.String(), - "error", - err, - ) + v.logger.Error("Contract storage verification failed", + zap.String("contract", contractAddress.String()), + zap.Error(err)) return fmt.Errorf( "contract storage verification failed for %s: %w", contractAddress.String(), err, @@ -384,7 +381,8 @@ func (v *TrieVerifier) verifyContractStorageTries(ctx context.Context, filterAdd } } - v.logger.Infow("All contract storage tries verified successfully", "count", len(contractAddresses)) + v.logger.Info("All contract storage tries verified successfully", + zap.Int("count", len(contractAddresses))) return nil } diff --git a/cmd/juno/verify/trie_test.go b/cmd/juno/verify/trie_test.go index 0b8c1834d2..7b9133d1bf 100644 --- a/cmd/juno/verify/trie_test.go +++ b/cmd/juno/verify/trie_test.go @@ -315,7 +315,8 @@ func TestRunTrieVerify_AddressFlagValidation(t *testing.T) { assert.Contains(t, err.Error(), tt.expectedErrMsg) } } else if err != nil { - assert.NotContains(t, err.Error(), "--address flag can only be used with --type contract-storage") + addrFlagErr := "--address flag can only be used with --type contract-storage" + assert.NotContains(t, err.Error(), addrFlagErr) } }) } From 024e2cfea7930b0ee9240fe166c98de9a78fd96b Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Thu, 5 Feb 2026 16:05:53 +0100 Subject: [PATCH 10/15] refactor the trie verify cmd --- cmd/juno/verify/trie.go | 469 ++--------------------------------- cmd/juno/verify/trie_test.go | 243 ------------------ cmd/juno/verify/verifier.go | 32 --- cmd/juno/verify/verify.go | 31 ++- verify/trie/traversal.go | 56 +++++ verify/trie/trie_core.go | 117 +++++++++ verify/trie/trie_test.go | 237 ++++++++++++++++++ verify/trie/trie_verifier.go | 257 +++++++++++++++++++ verify/trie/types.go | 29 +++ 9 files changed, 735 insertions(+), 736 deletions(-) delete mode 100644 cmd/juno/verify/verifier.go create mode 100644 verify/trie/traversal.go create mode 100644 verify/trie/trie_core.go create mode 100644 verify/trie/trie_test.go create mode 100644 verify/trie/trie_verifier.go create mode 100644 verify/trie/types.go diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go index 0f59377aad..a954cac822 100644 --- a/cmd/juno/verify/trie.go +++ b/cmd/juno/verify/trie.go @@ -1,43 +1,28 @@ package verify import ( - "bytes" - "context" - "errors" "fmt" "slices" - "sync" - "time" - "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" - "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/utils" + verifytrie "github.com/NethermindEth/juno/verify/trie" "github.com/spf13/cobra" - "go.uber.org/zap" ) const ( - starknetTrieHeight = 251 - concurrencyMaxDepth = 8 - verifyContractAddr = "address" -) - -type TrieType string - -const ( - ContractTrieType TrieType = "contract" - ClassTrieType TrieType = "class" - ContractStorageTrieType TrieType = "contract-storage" + verifyTrieType = "type" + verifyContractAddr = "address" ) func verifyTrieCmd() *cobra.Command { cmd := &cobra.Command{ - Use: "trie", - Short: "Verify trie integrity", - Long: `Verify trie integrity by rebuilding tries from leaf nodes and comparing root hashes.`, - RunE: runTrieVerify, + Use: "trie", + Short: "Verify trie integrity", + Long: `Verify trie integrity by rebuilding tries from leaf nodes and comparing root hashes.`, + RunE: runTrieVerify, + SilenceUsage: true, + SilenceErrors: true, } cmd.Flags().StringSlice( @@ -79,18 +64,18 @@ func runTrieVerify(cmd *cobra.Command, args []string) error { return err } - cfg := &TrieConfig{} - + var tries []verifytrie.TrieType if len(trieTypes) > 0 { - cfg.Tries = make([]TrieType, len(trieTypes)) + tries = make([]verifytrie.TrieType, len(trieTypes)) for i, t := range trieTypes { - cfg.Tries[i] = TrieType(t) + tries[i] = verifytrie.TrieType(t) } } + var contractAddr *felt.Felt if contractAddrStr != "" { - hasContractStorage := slices.Contains(cfg.Tries, ContractStorageTrieType) - if len(cfg.Tries) == 0 { + hasContractStorage := slices.Contains(tries, verifytrie.ContractStorageTrie) + if len(tries) == 0 { hasContractStorage = true } @@ -98,12 +83,12 @@ func runTrieVerify(cmd *cobra.Command, args []string) error { return fmt.Errorf("--address flag can only be used with --type contract-storage") } - var contractAddr felt.Felt - _, err := (&contractAddr).SetString(contractAddrStr) + var addr felt.Felt + _, err := (&addr).SetString(contractAddrStr) if err != nil { return fmt.Errorf("invalid contract address %s: %w", contractAddrStr, err) } - cfg.ContractAddress = &contractAddr + contractAddr = &addr } logLevel := utils.NewLogLevel(utils.INFO) @@ -112,421 +97,7 @@ func runTrieVerify(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to create logger: %w", err) } - verifier := NewTrieVerifier(database, logger) + verifier := verifytrie.NewTrieVerifier(database, logger, tries, contractAddr) ctx := cmd.Context() - return verifier.Run(ctx, cfg) -} - -type TrieConfig struct { - Tries []TrieType - ContractAddress *felt.Felt -} - -type TrieVerifier struct { - database db.KeyValueStore - logger utils.StructuredLogger -} - -func NewTrieVerifier(database db.KeyValueStore, logger utils.StructuredLogger) *TrieVerifier { - return &TrieVerifier{ - database: database, - logger: logger, - } -} - -func (v *TrieVerifier) Name() string { - return "trie" -} - -func (v *TrieVerifier) DefaultConfig() Config { - return &TrieConfig{ - Tries: nil, - } -} - -type TrieInfo struct { - Name string - prefix []byte - HashFunc trie.NewTrieFunc - HashFn crypto.HashFn - ReaderFunc func(db.KeyValueReader, []byte, uint8) (trie.TrieReader, error) - Height uint8 -} - -func (v *TrieVerifier) Run(ctx context.Context, cfg Config) error { - startTime := time.Now() - defer func() { - elapsed := time.Since(startTime) - v.logger.Info("=== Trie verification finished ===", - zap.Duration("total_elapsed", elapsed.Round(time.Second))) - }() - - trieCfg, ok := cfg.(*TrieConfig) - if !ok { - return fmt.Errorf("invalid config type for trie verifier: expected *TrieConfig") - } - - typesToVerify := trieCfg.Tries - if len(typesToVerify) == 0 { - typesToVerify = []TrieType{ContractTrieType, ClassTrieType, ContractStorageTrieType} - } - - typeSet := make(map[TrieType]bool) - for _, t := range typesToVerify { - typeSet[t] = true - } - - if typeSet[ContractTrieType] { - stateTrieInfo := TrieInfo{ - Name: "ContractsTrie", - prefix: db.StateTrie.Key(), - HashFunc: trie.NewTriePedersen, - HashFn: crypto.Pedersen, - ReaderFunc: trie.NewTrieReaderPedersen, - Height: starknetTrieHeight, - } - if err := v.verifyTrieWithLogging(ctx, stateTrieInfo); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - return err - } - } - - if typeSet[ClassTrieType] { - classTrieInfo := TrieInfo{ - Name: "ClassesTrie", - prefix: db.ClassesTrie.Key(), - HashFunc: trie.NewTriePoseidon, - HashFn: crypto.Poseidon, - ReaderFunc: trie.NewTrieReaderPoseidon, - Height: starknetTrieHeight, - } - if err := v.verifyTrieWithLogging(ctx, classTrieInfo); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - return err - } - } - - if typeSet[ContractStorageTrieType] { - if err := v.verifyContractStorageTries(ctx, trieCfg.ContractAddress); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - return err - } - } - - v.logger.Info("=== Trie verification completed successfully ===") - return nil -} - -func (v *TrieVerifier) collectContractAddresses() []felt.Felt { - contractAddresses := make([]felt.Felt, 0) - stateTriePrefix := db.StateTrie.Key() - - err := v.database.View(func(snap db.Snapshot) error { - it, err := snap.NewIterator(stateTriePrefix, true) - if err != nil { - return err - } - defer it.Close() - - for it.First(); it.Valid(); it.Next() { - keyBytes := it.Key() - if bytes.Equal(keyBytes, stateTriePrefix) { - continue - } - - if !bytes.HasPrefix(keyBytes, stateTriePrefix) { - continue - } - nodeKeyBytes := keyBytes[len(stateTriePrefix):] - - var nodeKey trie.BitArray - if err := nodeKey.UnmarshalBinary(nodeKeyBytes); err != nil { - continue - } - - if nodeKey.Len() == starknetTrieHeight { - contractAddr := nodeKey.Felt() - contractAddresses = append(contractAddresses, contractAddr) - } - } - return nil - }) - if err != nil { - return nil - } - - return contractAddresses -} - -func (v *TrieVerifier) verifyTrieWithLogging(ctx context.Context, trieInfo TrieInfo) error { - err := v.verifyTrie(ctx, trieInfo) - if err != nil { - if errors.Is(err, context.Canceled) { - v.logger.Info("Verification stopped", zap.String("trie", trieInfo.Name)) - return err - } - v.logger.Error(fmt.Sprintf("%s verification failed", trieInfo.Name), zap.Error(err)) - return fmt.Errorf("%s verification failed: %w", trieInfo.Name, err) - } - return nil -} - -func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error { - v.logger.Info(fmt.Sprintf("=== Starting %s verification ===", trieInfo.Name)) - expectedRoot := felt.Zero - err := v.database.View(func(snap db.Snapshot) error { - reader, err := trieInfo.ReaderFunc(snap, trieInfo.prefix, trieInfo.Height) - if err != nil { - return err - } - if reader.RootKey() == nil { - expectedRoot = felt.Zero - return nil - } - expectedRoot, err = reader.Hash() - return err - }) - if err != nil { - v.logger.Error("Failed to get stored root hash", - zap.String("trie", trieInfo.Name), zap.Error(err)) - return fmt.Errorf("failed to get stored root hash for %s: %w", trieInfo.Name, err) - } - - if expectedRoot.IsZero() { - v.logger.Info("Trie is empty (zero root)", zap.String("trie", trieInfo.Name)) - return nil - } - - v.logger.Info("Starting verification", - zap.String("trie", trieInfo.Name), - zap.String("expectedRoot", expectedRoot.String())) - storageReader := trie.NewReadStorage(v.database, trieInfo.prefix) - - err = verifyTrie(ctx, storageReader, starknetTrieHeight, trieInfo.HashFn, &expectedRoot) - if err != nil { - if errors.Is(err, context.Canceled) { - return err - } - v.logger.Error("Trie verification failed", - zap.String("trie", trieInfo.Name), zap.Error(err)) - return err - } - - v.logger.Info("Trie verification successful", - zap.String("trie", trieInfo.Name), zap.String("root", expectedRoot.String())) - return nil -} - -func (v *TrieVerifier) verifyContractStorageTries( - ctx context.Context, filterAddress *felt.Felt, -) error { - v.logger.Info("=== Starting Contract Storage Tries verification ===") - - var contractAddresses []felt.Felt - if filterAddress != nil { - contractAddresses = []felt.Felt{*filterAddress} - v.logger.Info("Verifying specific contract", - zap.String("address", filterAddress.String())) - } else { - contractAddresses = v.collectContractAddresses() - if len(contractAddresses) == 0 { - v.logger.Info("No contract addresses found, skipping contract storage verification") - return nil - } - v.logger.Info("Found contracts to verify", - zap.Int("count", len(contractAddresses))) - } - - for i, contractAddress := range contractAddresses { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - addrBytes := contractAddress.Marshal() - prefix := db.ContractStorage.Key(addrBytes) - trieInfo := TrieInfo{ - Name: fmt.Sprintf("ContractStorage[%s]", contractAddress.String()), - prefix: db.ContractStorage.Key(addrBytes), - HashFunc: trie.NewTriePedersen, - HashFn: crypto.Pedersen, - ReaderFunc: func(r db.KeyValueReader, _ []byte, height uint8) (trie.TrieReader, error) { - return trie.NewTrieReaderPedersen(r, prefix, height) - }, - Height: starknetTrieHeight, - } - - v.logger.Info("Verifying contract storage", - zap.String("contract", contractAddress.String()), - zap.String("progress", fmt.Sprintf("%d/%d", i+1, len(contractAddresses)))) - - err := v.verifyTrie(ctx, trieInfo) - if err != nil { - if errors.Is(err, context.Canceled) { - return err - } - v.logger.Error("Contract storage verification failed", - zap.String("contract", contractAddress.String()), - zap.Error(err)) - return fmt.Errorf( - "contract storage verification failed for %s: %w", - contractAddress.String(), err, - ) - } - } - - v.logger.Info("All contract storage tries verified successfully", - zap.Int("count", len(contractAddresses))) - return nil -} - -func verifyTrie( - ctx context.Context, - reader *trie.ReadStorage, - height uint8, - hashFn crypto.HashFn, - expectedRoot *felt.Felt, -) error { - rootKey, err := reader.RootKey() - if err != nil { - return fmt.Errorf("failed to get root key: %w", err) - } - - if rootKey == nil { - return nil - } - - startTime := time.Now() - rootHash, err := verifyNode(ctx, reader, rootKey, nil, height, hashFn) - if err != nil { - return fmt.Errorf("node verification failed: %w", err) - } - - elapsed := time.Since(startTime) - - if rootHash.Cmp(expectedRoot) != 0 { - return fmt.Errorf( - "root hash mismatch: expected %s, got %s (verification took %v)", - expectedRoot, rootHash, elapsed.Round(time.Second), - ) - } - - return nil -} - -func verifyNode( - ctx context.Context, - reader *trie.ReadStorage, - key *trie.BitArray, - parentKey *trie.BitArray, - height uint8, - hashFn crypto.HashFn, -) (*felt.Felt, error) { - select { - case <-ctx.Done(): - return nil, fmt.Errorf("verification cancelled: %w", ctx.Err()) - default: - } - - node, err := reader.Get(key) - if err != nil { - return nil, fmt.Errorf("failed to get node at key %s: %w", key.String(), err) - } - - if key.Len() == height { - p := path(key, parentKey) - h := node.Hash(&p, hashFn) - return &h, nil - } - - useConcurrency := key.Len() <= concurrencyMaxDepth - var leftHash, rightHash *felt.Felt - var leftErr, rightErr error - - if useConcurrency { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - if node.Left.IsEmpty() { - zero := felt.Zero - leftHash = &zero - return - } - h, err := verifyNode(ctx, reader, node.Left, key, height, hashFn) - leftHash = h - leftErr = err - }() - - if node.Right.IsEmpty() { - zero := felt.Zero - rightHash = &zero - } else { - h, err := verifyNode(ctx, reader, node.Right, key, height, hashFn) - rightHash = h - rightErr = err - } - - wg.Wait() - - if leftErr != nil { - return nil, leftErr - } - if rightErr != nil { - return nil, rightErr - } - } else { - if node.Left.IsEmpty() { - zero := felt.Zero - leftHash = &zero - } else { - h, err := verifyNode(ctx, reader, node.Left, key, height, hashFn) - if err != nil { - return nil, err - } - leftHash = h - } - - if node.Right.IsEmpty() { - zero := felt.Zero - rightHash = &zero - } else { - h, err := verifyNode(ctx, reader, node.Right, key, height, hashFn) - if err != nil { - return nil, err - } - rightHash = h - } - } - - recomputed := hashFn(leftHash, rightHash) - if recomputed.Cmp(node.Value) != 0 { - return nil, fmt.Errorf( - "node corruption detected at key %s: stored hash=%s, recomputed hash=%s", - key.String(), node.Value.String(), recomputed.String(), - ) - } - - tmp := *node - tmp.Value = &recomputed - - p := path(key, parentKey) - h := tmp.Hash(&p, hashFn) - return &h, nil -} - -func path(key, parentKey *trie.BitArray) trie.BitArray { - if parentKey == nil { - return key.Copy() - } - - var pathKey trie.BitArray - pathKey.LSBs(key, parentKey.Len()+1) - return pathKey + return verifier.Run(ctx) } diff --git a/cmd/juno/verify/trie_test.go b/cmd/juno/verify/trie_test.go index 7b9133d1bf..2186060458 100644 --- a/cmd/juno/verify/trie_test.go +++ b/cmd/juno/verify/trie_test.go @@ -6,254 +6,11 @@ import ( "path/filepath" "testing" - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/db/memory" "github.com/NethermindEth/juno/db/pebblev2" - "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestTrieVerifier_Run_ValidStateTrie(t *testing.T) { - logger := utils.NewNopZapLogger() - testDB := memory.New() - defer testDB.Close() - - prefix := db.StateTrie.Key() - txn := testDB.NewIndexedBatch() - trieStorage := trie.NewStorage(txn, prefix) - - testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) - require.NoError(t, err) - - key1 := felt.NewFromUint64[felt.Felt](1) - value1 := felt.NewFromUint64[felt.Felt](100) - _, err = testTrie.Put(key1, value1) - require.NoError(t, err) - - key2 := felt.NewFromUint64[felt.Felt](2) - value2 := felt.NewFromUint64[felt.Felt](200) - _, err = testTrie.Put(key2, value2) - require.NoError(t, err) - - err = testTrie.Commit() - require.NoError(t, err) - - rootHash, err := testTrie.Hash() - require.NoError(t, err) - - if testTrie.RootKey() != nil { - err = trieStorage.PutRootKey(testTrie.RootKey()) - require.NoError(t, err) - } - - err = txn.Write() - require.NoError(t, err) - - verifier := NewTrieVerifier(testDB, logger) - cfg := &TrieConfig{ - Tries: []TrieType{ContractTrieType}, - } - - ctx := context.Background() - err = verifier.Run(ctx, cfg) - assert.NoError(t, err) - - reader, err := trie.NewTrieReaderPedersen(testDB, prefix, starknetTrieHeight) - require.NoError(t, err) - storedHash, err := reader.Hash() - require.NoError(t, err) - assert.True(t, rootHash.Equal(&storedHash)) -} - -func TestTrieVerifier_Run_ValidClassTrie(t *testing.T) { - logger := utils.NewNopZapLogger() - testDB := memory.New() - defer testDB.Close() - - prefix := db.ClassesTrie.Key() - txn := testDB.NewIndexedBatch() - trieStorage := trie.NewStorage(txn, prefix) - - testTrie, err := trie.NewTriePoseidon(txn, prefix, starknetTrieHeight) - require.NoError(t, err) - - key1 := felt.NewFromUint64[felt.Felt](10) - value1 := felt.NewFromUint64[felt.Felt](1000) - _, err = testTrie.Put(key1, value1) - require.NoError(t, err) - - err = testTrie.Commit() - require.NoError(t, err) - - if testTrie.RootKey() != nil { - err = trieStorage.PutRootKey(testTrie.RootKey()) - require.NoError(t, err) - } - - err = txn.Write() - require.NoError(t, err) - - verifier := NewTrieVerifier(testDB, logger) - cfg := &TrieConfig{ - Tries: []TrieType{ClassTrieType}, - } - - ctx := context.Background() - err = verifier.Run(ctx, cfg) - assert.NoError(t, err) -} - -func TestTrieVerifier_Run_CorruptedTrie(t *testing.T) { - logger := utils.NewNopZapLogger() - testDB := memory.New() - defer testDB.Close() - - prefix := db.StateTrie.Key() - txn := testDB.NewIndexedBatch() - trieStorage := trie.NewStorage(txn, prefix) - - testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) - require.NoError(t, err) - - key1 := felt.NewFromUint64[felt.Felt](1) - value1 := felt.NewFromUint64[felt.Felt](100) - _, err = testTrie.Put(key1, value1) - require.NoError(t, err) - - key2 := felt.NewFromUint64[felt.Felt](2) - value2 := felt.NewFromUint64[felt.Felt](200) - _, err = testTrie.Put(key2, value2) - require.NoError(t, err) - - err = testTrie.Commit() - require.NoError(t, err) - - if testTrie.RootKey() != nil { - err = trieStorage.PutRootKey(testTrie.RootKey()) - require.NoError(t, err) - } - - var nodeKey trie.BitArray - nodeKey.SetFelt(starknetTrieHeight, key1) - - node, err := trieStorage.Get(&nodeKey) - require.NoError(t, err) - require.NotNil(t, node) - require.NotNil(t, node.Value) - - assert.True(t, node.Value.Equal(value1), "Expected value1 but got different value") - - corruptedValue := felt.NewFromUint64[felt.Felt](999999) - node.Value = corruptedValue - - err = trieStorage.Put(&nodeKey, node) - require.NoError(t, err) - - err = txn.Write() - require.NoError(t, err) - - verifier := NewTrieVerifier(testDB, logger) - cfg := &TrieConfig{ - Tries: []TrieType{ContractTrieType}, - } - - ctx := context.Background() - err = verifier.Run(ctx, cfg) - require.Error(t, err) - assert.Contains(t, err.Error(), "node corruption detected") -} - -func TestTrieVerifier_Run_MultipleTrieTypes(t *testing.T) { - logger := utils.NewNopZapLogger() - testDB := memory.New() - defer testDB.Close() - - statePrefix := db.StateTrie.Key() - txn := testDB.NewIndexedBatch() - stateTrie, err := trie.NewTriePedersen(txn, statePrefix, starknetTrieHeight) - require.NoError(t, err) - - key1 := felt.NewFromUint64[felt.Felt](1) - value1 := felt.NewFromUint64[felt.Felt](100) - _, err = stateTrie.Put(key1, value1) - require.NoError(t, err) - - err = stateTrie.Commit() - require.NoError(t, err) - - stateStorage := trie.NewStorage(txn, statePrefix) - if stateTrie.RootKey() != nil { - err = stateStorage.PutRootKey(stateTrie.RootKey()) - require.NoError(t, err) - } - - classPrefix := db.ClassesTrie.Key() - classTrie, err := trie.NewTriePoseidon(txn, classPrefix, starknetTrieHeight) - require.NoError(t, err) - - key2 := felt.NewFromUint64[felt.Felt](2) - value2 := felt.NewFromUint64[felt.Felt](200) - _, err = classTrie.Put(key2, value2) - require.NoError(t, err) - - err = classTrie.Commit() - require.NoError(t, err) - - classStorage := trie.NewStorage(txn, classPrefix) - if classTrie.RootKey() != nil { - err = classStorage.PutRootKey(classTrie.RootKey()) - require.NoError(t, err) - } - - err = txn.Write() - require.NoError(t, err) - - verifier := NewTrieVerifier(testDB, logger) - cfg := &TrieConfig{ - Tries: []TrieType{ContractTrieType, ClassTrieType}, - } - - ctx := context.Background() - err = verifier.Run(ctx, cfg) - assert.NoError(t, err) -} - -func TestTrieVerifier_Run_EmptyTrie(t *testing.T) { - logger := utils.NewNopZapLogger() - testDB := memory.New() - defer testDB.Close() - - prefix := db.StateTrie.Key() - txn := testDB.NewIndexedBatch() - trieStorage := trie.NewStorage(txn, prefix) - - testTrie, err := trie.NewTriePedersen(txn, prefix, starknetTrieHeight) - require.NoError(t, err) - - err = testTrie.Commit() - require.NoError(t, err) - - if testTrie.RootKey() != nil { - err = trieStorage.PutRootKey(testTrie.RootKey()) - require.NoError(t, err) - } - - err = txn.Write() - require.NoError(t, err) - - verifier := NewTrieVerifier(testDB, logger) - cfg := &TrieConfig{ - Tries: []TrieType{ContractTrieType}, - } - - ctx := context.Background() - err = verifier.Run(ctx, cfg) - assert.NoError(t, err) -} - func TestRunTrieVerify_AddressFlagValidation(t *testing.T) { tests := []struct { name string diff --git a/cmd/juno/verify/verifier.go b/cmd/juno/verify/verifier.go deleted file mode 100644 index 6abe284de4..0000000000 --- a/cmd/juno/verify/verifier.go +++ /dev/null @@ -1,32 +0,0 @@ -package verify - -import ( - "context" - "fmt" -) - -type Config interface{} - -type Verifier interface { - Name() string - DefaultConfig() Config - Run(ctx context.Context, cfg Config) error -} - -type VerifyRunner struct { - Verifiers []Verifier -} - -func (r *VerifyRunner) Run(ctx context.Context, configs map[string]Config) error { - for _, verifier := range r.Verifiers { - cfg, ok := configs[verifier.Name()] - if !ok { - cfg = verifier.DefaultConfig() - } - - if err := verifier.Run(ctx, cfg); err != nil { - return fmt.Errorf("%s: %w", verifier.Name(), err) - } - } - return nil -} diff --git a/cmd/juno/verify/verify.go b/cmd/juno/verify/verify.go index c6470ecaf7..8c91591dc6 100644 --- a/cmd/juno/verify/verify.go +++ b/cmd/juno/verify/verify.go @@ -1,6 +1,7 @@ package verify import ( + "context" "errors" "fmt" "os" @@ -8,13 +9,16 @@ import ( "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebblev2" "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/verify/trie" "github.com/spf13/cobra" ) -const ( - verifyDBPathF = "db-path" - verifyTrieType = "type" -) +type Verifier interface { + Name() string + Run(ctx context.Context) error +} + +const verifyDBPathF = "db-path" func VerifyCmd(defaultDBPath string) *cobra.Command { verifyCmd := &cobra.Command{ @@ -30,7 +34,6 @@ func VerifyCmd(defaultDBPath string) *cobra.Command { return verifyCmd } -// verifyAll runs all verifiers with default scope when no subcommand is specified. func verifyAll(cmd *cobra.Command, args []string) error { dbPath, err := cmd.Flags().GetString(verifyDBPathF) if err != nil { @@ -49,15 +52,19 @@ func verifyAll(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to create logger: %w", err) } - trieVerifier := NewTrieVerifier(database, logger) - verifier := &VerifyRunner{ - Verifiers: []Verifier{ - trieVerifier, - }, + ctx := cmd.Context() + + verifiers := []Verifier{ + trie.NewTrieVerifier(database, logger, nil, nil), + } + + for _, v := range verifiers { + if err := v.Run(ctx); err != nil { + return fmt.Errorf("%s verification failed: %w", v.Name(), err) + } } - ctx := cmd.Context() - return verifier.Run(ctx, nil) + return nil } func openDB(path string) (db.KeyValueStore, error) { diff --git a/verify/trie/traversal.go b/verify/trie/traversal.go new file mode 100644 index 0000000000..29e94c3bb4 --- /dev/null +++ b/verify/trie/traversal.go @@ -0,0 +1,56 @@ +package trie + +import ( + "context" + "sync" +) + +func TraverseBinary[T any]( + ctx context.Context, + depth uint8, + maxConcurrentDepth uint8, + leftFn func(ctx context.Context) (T, error), + rightFn func(ctx context.Context) (T, error), +) (left, right T, err error) { + if depth <= maxConcurrentDepth { + return traverseConcurrently(ctx, leftFn, rightFn) + } + return traverseSequentially(ctx, leftFn, rightFn) +} + +func traverseConcurrently[T any]( + ctx context.Context, + leftFn func(ctx context.Context) (T, error), + rightFn func(ctx context.Context) (T, error), +) (left, right T, err error) { + var leftErr, rightErr error + var wg sync.WaitGroup + + wg.Go(func() { + left, leftErr = leftFn(ctx) + }) + + right, rightErr = rightFn(ctx) + wg.Wait() + + if leftErr != nil { + return left, right, leftErr + } + if rightErr != nil { + return left, right, rightErr + } + return left, right, nil +} + +func traverseSequentially[T any]( + ctx context.Context, + leftFn func(ctx context.Context) (T, error), + rightFn func(ctx context.Context) (T, error), +) (left, right T, err error) { + left, err = leftFn(ctx) + if err != nil { + return left, right, err + } + right, err = rightFn(ctx) + return left, right, err +} diff --git a/verify/trie/trie_core.go b/verify/trie/trie_core.go new file mode 100644 index 0000000000..eb26711772 --- /dev/null +++ b/verify/trie/trie_core.go @@ -0,0 +1,117 @@ +package trie + +import ( + "context" + "fmt" + "time" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" +) + +func VerifyTrie( + ctx context.Context, + reader *trie.ReadStorage, + height uint8, + hashFn crypto.HashFn, + expectedRoot *felt.Felt, +) error { + rootKey, err := reader.RootKey() + if err != nil { + return fmt.Errorf("failed to get root key: %w", err) + } + + if rootKey == nil { + return nil + } + + startTime := time.Now() + rootHash, err := verifyNode(ctx, reader, rootKey, nil, height, hashFn) + if err != nil { + return fmt.Errorf("node verification failed: %w", err) + } + + elapsed := time.Since(startTime) + + if rootHash.Cmp(expectedRoot) != 0 { + return fmt.Errorf( + "root hash mismatch: expected %s, got %s (verification took %v)", + expectedRoot, rootHash, elapsed.Round(time.Second), + ) + } + + return nil +} + +func verifyNode( + ctx context.Context, + reader *trie.ReadStorage, + key *trie.BitArray, + parentKey *trie.BitArray, + height uint8, + hashFn crypto.HashFn, +) (*felt.Felt, error) { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("verification cancelled: %w", ctx.Err()) + default: + } + + node, err := reader.Get(key) + if err != nil { + return nil, fmt.Errorf("failed to get node at key %s: %w", key.String(), err) + } + + if key.Len() == height { + p := path(key, parentKey) + h := node.Hash(&p, hashFn) + return &h, nil + } + + leftFn := func(ctx context.Context) (*felt.Felt, error) { + if node.Left.IsEmpty() { + zero := felt.Zero + return &zero, nil + } + return verifyNode(ctx, reader, node.Left, key, height, hashFn) + } + + rightFn := func(ctx context.Context) (*felt.Felt, error) { + if node.Right.IsEmpty() { + zero := felt.Zero + return &zero, nil + } + return verifyNode(ctx, reader, node.Right, key, height, hashFn) + } + + leftHash, rightHash, err := TraverseBinary(ctx, key.Len(), ConcurrencyMaxDepth, leftFn, rightFn) + if err != nil { + return nil, err + } + + recomputed := hashFn(leftHash, rightHash) + if recomputed.Cmp(node.Value) != 0 { + return nil, fmt.Errorf( + "node corruption detected at key %s: stored hash=%s, recomputed hash=%s", + key.String(), node.Value.String(), recomputed.String(), + ) + } + + tmp := *node + tmp.Value = &recomputed + + p := path(key, parentKey) + h := tmp.Hash(&p, hashFn) + return &h, nil +} + +func path(key, parentKey *trie.BitArray) trie.BitArray { + if parentKey == nil { + return key.Copy() + } + + var pathKey trie.BitArray + pathKey.LSBs(key, parentKey.Len()+1) + return pathKey +} diff --git a/verify/trie/trie_test.go b/verify/trie/trie_test.go new file mode 100644 index 0000000000..5fdc307207 --- /dev/null +++ b/verify/trie/trie_test.go @@ -0,0 +1,237 @@ +package trie + +import ( + "context" + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/memory" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTrieVerifier_Run_ValidStateTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = testTrie.Put(key2, value2) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + rootHash, err := testTrie.Hash() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) + + reader, err := trie.NewTrieReaderPedersen(testDB, prefix, StarknetTrieHeight) + require.NoError(t, err) + storedHash, err := reader.Hash() + require.NoError(t, err) + assert.True(t, rootHash.Equal(&storedHash)) +} + +func TestTrieVerifier_Run_ValidClassTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.ClassesTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePoseidon(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](10) + value1 := felt.NewFromUint64[felt.Felt](1000) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ClassTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_CorruptedTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = testTrie.Put(key2, value2) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + var nodeKey trie.BitArray + nodeKey.SetFelt(StarknetTrieHeight, key1) + + node, err := trieStorage.Get(&nodeKey) + require.NoError(t, err) + require.NotNil(t, node) + require.NotNil(t, node.Value) + + assert.True(t, node.Value.Equal(value1), "Expected value1 but got different value") + + corruptedValue := felt.NewFromUint64[felt.Felt](999999) + node.Value = corruptedValue + + err = trieStorage.Put(&nodeKey, node) + require.NoError(t, err) + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "node corruption detected") +} + +func TestTrieVerifier_Run_MultipleTrieTypes(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + statePrefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + stateTrie, err := trie.NewTriePedersen(txn, statePrefix, StarknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = stateTrie.Put(key1, value1) + require.NoError(t, err) + + err = stateTrie.Commit() + require.NoError(t, err) + + stateStorage := trie.NewStorage(txn, statePrefix) + if stateTrie.RootKey() != nil { + err = stateStorage.PutRootKey(stateTrie.RootKey()) + require.NoError(t, err) + } + + classPrefix := db.ClassesTrie.Key() + classTrie, err := trie.NewTriePoseidon(txn, classPrefix, StarknetTrieHeight) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = classTrie.Put(key2, value2) + require.NoError(t, err) + + err = classTrie.Commit() + require.NoError(t, err) + + classStorage := trie.NewStorage(txn, classPrefix) + if classTrie.RootKey() != nil { + err = classStorage.PutRootKey(classTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie, ClassTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_EmptyTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} diff --git a/verify/trie/trie_verifier.go b/verify/trie/trie_verifier.go new file mode 100644 index 0000000000..5035b11b9d --- /dev/null +++ b/verify/trie/trie_verifier.go @@ -0,0 +1,257 @@ +package trie + +import ( + "bytes" + "context" + "errors" + "fmt" + "time" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/utils" + "go.uber.org/zap" +) + +type TrieVerifier struct { + database db.KeyValueStore + logger utils.StructuredLogger + tries []TrieType + contractAddress *felt.Felt +} + +func NewTrieVerifier( + database db.KeyValueStore, + logger utils.StructuredLogger, + tries []TrieType, + contractAddress *felt.Felt, +) *TrieVerifier { + if len(tries) == 0 { + tries = []TrieType{ContractTrie, ClassTrie, ContractStorageTrie} + } + return &TrieVerifier{ + database: database, + logger: logger, + tries: tries, + contractAddress: contractAddress, + } +} + +func (v *TrieVerifier) Name() string { + return "trie" +} + +func (v *TrieVerifier) Run(ctx context.Context) error { + startTime := time.Now() + defer func() { + elapsed := time.Since(startTime) + v.logger.Info("=== Trie verification finished ===", + zap.Duration("total_elapsed", elapsed.Round(time.Second))) + }() + + typeSet := make(map[TrieType]bool) + for _, t := range v.tries { + typeSet[t] = true + } + + if typeSet[ContractTrie] { + stateTrieInfo := TrieInfo{ + Name: "ContractsTrie", + Prefix: db.StateTrie.Key(), + HashFunc: trie.NewTriePedersen, + HashFn: crypto.Pedersen, + ReaderFunc: trie.NewTrieReaderPedersen, + Height: StarknetTrieHeight, + } + if err := v.verifyTrieWithLogging(ctx, stateTrieInfo); err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err + } + } + + if typeSet[ClassTrie] { + classTrieInfo := TrieInfo{ + Name: "ClassesTrie", + Prefix: db.ClassesTrie.Key(), + HashFunc: trie.NewTriePoseidon, + HashFn: crypto.Poseidon, + ReaderFunc: trie.NewTrieReaderPoseidon, + Height: StarknetTrieHeight, + } + if err := v.verifyTrieWithLogging(ctx, classTrieInfo); err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err + } + } + + if typeSet[ContractStorageTrie] { + if err := v.verifyContractStorageTries(ctx, v.contractAddress); err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err + } + } + + v.logger.Info("=== Trie verification completed successfully ===") + return nil +} + +func (v *TrieVerifier) collectContractAddresses() []felt.Felt { + contractAddresses := make([]felt.Felt, 0) + stateTriePrefix := db.StateTrie.Key() + + err := v.database.View(func(snap db.Snapshot) error { + it, err := snap.NewIterator(stateTriePrefix, true) + if err != nil { + return err + } + defer it.Close() + + for it.First(); it.Valid(); it.Next() { + keyBytes := it.Key() + if bytes.Equal(keyBytes, stateTriePrefix) { + continue + } + + if !bytes.HasPrefix(keyBytes, stateTriePrefix) { + continue + } + nodeKeyBytes := keyBytes[len(stateTriePrefix):] + + var nodeKey trie.BitArray + if err := nodeKey.UnmarshalBinary(nodeKeyBytes); err != nil { + continue + } + + if nodeKey.Len() == StarknetTrieHeight { + contractAddr := nodeKey.Felt() + contractAddresses = append(contractAddresses, contractAddr) + } + } + return nil + }) + if err != nil { + return nil + } + + return contractAddresses +} + +func (v *TrieVerifier) verifyTrieWithLogging(ctx context.Context, trieInfo TrieInfo) error { + err := v.verifyTrie(ctx, trieInfo) + if err != nil { + if errors.Is(err, context.Canceled) { + v.logger.Info("Verification stopped", zap.String("trie", trieInfo.Name)) + return err + } + v.logger.Error(fmt.Sprintf("%s verification failed: %v", trieInfo.Name, err)) + return fmt.Errorf("%s verification failed: %w", trieInfo.Name, err) + } + return nil +} + +func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error { + v.logger.Info(fmt.Sprintf("=== Starting %s verification ===", trieInfo.Name)) + expectedRoot := felt.Zero + err := v.database.View(func(snap db.Snapshot) error { + reader, err := trieInfo.ReaderFunc(snap, trieInfo.Prefix, trieInfo.Height) + if err != nil { + return err + } + if reader.RootKey() == nil { + expectedRoot = felt.Zero + return nil + } + expectedRoot, err = reader.Hash() + return err + }) + if err != nil { + v.logger.Error(fmt.Sprintf("Failed to get stored root hash for %s: %v", trieInfo.Name, err)) + return fmt.Errorf("failed to get stored root hash for %s: %w", trieInfo.Name, err) + } + + if expectedRoot.IsZero() { + v.logger.Info("Trie is empty (zero root)", zap.String("trie", trieInfo.Name)) + return nil + } + + v.logger.Info("Starting verification", + zap.String("trie", trieInfo.Name), + zap.String("expectedRoot", expectedRoot.String())) + storageReader := trie.NewReadStorage(v.database, trieInfo.Prefix) + + err = VerifyTrie(ctx, storageReader, StarknetTrieHeight, trieInfo.HashFn, &expectedRoot) + if err != nil { + return err + } + + v.logger.Info("Trie verification successful", + zap.String("trie", trieInfo.Name), zap.String("root", expectedRoot.String())) + return nil +} + +func (v *TrieVerifier) verifyContractStorageTries( + ctx context.Context, filterAddress *felt.Felt, +) error { + v.logger.Info("=== Starting Contract Storage Tries verification ===") + + var contractAddresses []felt.Felt + if filterAddress != nil { + contractAddresses = []felt.Felt{*filterAddress} + v.logger.Info("Verifying specific contract", + zap.String("address", filterAddress.String())) + } else { + contractAddresses = v.collectContractAddresses() + if len(contractAddresses) == 0 { + v.logger.Info("No contract addresses found, skipping contract storage verification") + return nil + } + v.logger.Info("Found contracts to verify", + zap.Int("count", len(contractAddresses))) + } + + for i, contractAddress := range contractAddresses { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + addrBytes := contractAddress.Marshal() + prefix := db.ContractStorage.Key(addrBytes) + trieInfo := TrieInfo{ + Name: fmt.Sprintf("ContractStorage[%s]", contractAddress.String()), + Prefix: db.ContractStorage.Key(addrBytes), + HashFunc: trie.NewTriePedersen, + HashFn: crypto.Pedersen, + ReaderFunc: func(r db.KeyValueReader, _ []byte, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPedersen(r, prefix, height) + }, + Height: StarknetTrieHeight, + } + + v.logger.Info("Verifying contract storage", + zap.String("contract", contractAddress.String()), + zap.String("progress", fmt.Sprintf("%d/%d", i+1, len(contractAddresses)))) + + err := v.verifyTrie(ctx, trieInfo) + if err != nil { + if errors.Is(err, context.Canceled) { + return err + } + v.logger.Error(fmt.Sprintf("Contract storage verification failed for %s: %v", + contractAddress.String(), err)) + return fmt.Errorf("contract storage verification failed for %s: %w", contractAddress.String(), err) + } + } + + v.logger.Info("All contract storage tries verified successfully", + zap.Int("count", len(contractAddresses))) + return nil +} diff --git a/verify/trie/types.go b/verify/trie/types.go new file mode 100644 index 0000000000..1f2227f2b8 --- /dev/null +++ b/verify/trie/types.go @@ -0,0 +1,29 @@ +package trie + +import ( + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" +) + +const ( + StarknetTrieHeight uint8 = 251 + ConcurrencyMaxDepth uint8 = 8 +) + +type TrieType string + +const ( + ContractTrie TrieType = "contract" + ClassTrie TrieType = "class" + ContractStorageTrie TrieType = "contract-storage" +) + +type TrieInfo struct { + Name string + Prefix []byte + HashFunc trie.NewTrieFunc + HashFn crypto.HashFn + ReaderFunc func(db.KeyValueReader, []byte, uint8) (trie.TrieReader, error) + Height uint8 +} From 755fa6edf37d501ec3a12e6be770285a6e585d36 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Thu, 5 Feb 2026 16:11:46 +0100 Subject: [PATCH 11/15] linter --- cmd/juno/verify/trie.go | 2 +- verify/trie/trie_verifier.go | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go index a954cac822..156beed154 100644 --- a/cmd/juno/verify/trie.go +++ b/cmd/juno/verify/trie.go @@ -19,7 +19,7 @@ func verifyTrieCmd() *cobra.Command { cmd := &cobra.Command{ Use: "trie", Short: "Verify trie integrity", - Long: `Verify trie integrity by rebuilding tries from leaf nodes and comparing root hashes.`, + Long: `Verify trie integrity by rebuilding tries and comparing root hashes.`, RunE: runTrieVerify, SilenceUsage: true, SilenceErrors: true, diff --git a/verify/trie/trie_verifier.go b/verify/trie/trie_verifier.go index 5035b11b9d..9e2373c3bd 100644 --- a/verify/trie/trie_verifier.go +++ b/verify/trie/trie_verifier.go @@ -247,7 +247,11 @@ func (v *TrieVerifier) verifyContractStorageTries( } v.logger.Error(fmt.Sprintf("Contract storage verification failed for %s: %v", contractAddress.String(), err)) - return fmt.Errorf("contract storage verification failed for %s: %w", contractAddress.String(), err) + return fmt.Errorf( + "contract storage verification failed for %s: %w", + contractAddress.String(), + err, + ) } } From b3c67102727788411bbb85204ceba846f7837b3a Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Sun, 8 Feb 2026 14:47:51 +0100 Subject: [PATCH 12/15] Cleanup the logs --- cmd/juno/verify/trie.go | 12 ++-- verify/trie/trie_core.go | 15 +++-- verify/trie/trie_test.go | 2 +- verify/trie/trie_verifier.go | 115 ++++++++++++++--------------------- verify/trie/types.go | 10 ++- 5 files changed, 74 insertions(+), 80 deletions(-) diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go index 156beed154..0358562789 100644 --- a/cmd/juno/verify/trie.go +++ b/cmd/juno/verify/trie.go @@ -66,9 +66,13 @@ func runTrieVerify(cmd *cobra.Command, args []string) error { var tries []verifytrie.TrieType if len(trieTypes) > 0 { - tries = make([]verifytrie.TrieType, len(trieTypes)) - for i, t := range trieTypes { - tries[i] = verifytrie.TrieType(t) + tries = make([]verifytrie.TrieType, 0, len(trieTypes)) + for _, t := range trieTypes { + tt := verifytrie.TrieType(t) + if !tt.IsValid() { + return fmt.Errorf("invalid trie type %q (allowed: contract, class, contract-storage)", t) + } + tries = append(tries, tt) } } @@ -84,7 +88,7 @@ func runTrieVerify(cmd *cobra.Command, args []string) error { } var addr felt.Felt - _, err := (&addr).SetString(contractAddrStr) + _, err = addr.SetString(contractAddrStr) if err != nil { return fmt.Errorf("invalid contract address %s: %w", contractAddrStr, err) } diff --git a/verify/trie/trie_core.go b/verify/trie/trie_core.go index eb26711772..d281e7c95b 100644 --- a/verify/trie/trie_core.go +++ b/verify/trie/trie_core.go @@ -2,6 +2,7 @@ package trie import ( "context" + "errors" "fmt" "time" @@ -10,6 +11,8 @@ import ( "github.com/NethermindEth/juno/core/trie" ) +var ErrCorruptionDetected = errors.New("corruption detected") + func VerifyTrie( ctx context.Context, reader *trie.ReadStorage, @@ -29,15 +32,15 @@ func VerifyTrie( startTime := time.Now() rootHash, err := verifyNode(ctx, reader, rootKey, nil, height, hashFn) if err != nil { - return fmt.Errorf("node verification failed: %w", err) + return err } elapsed := time.Since(startTime) if rootHash.Cmp(expectedRoot) != 0 { return fmt.Errorf( - "root hash mismatch: expected %s, got %s (verification took %v)", - expectedRoot, rootHash, elapsed.Round(time.Second), + "%w: root hash mismatch, expected %s, got %s (verification took %v)", + ErrCorruptionDetected, expectedRoot, rootHash, elapsed.Round(time.Second), ) } @@ -54,7 +57,7 @@ func verifyNode( ) (*felt.Felt, error) { select { case <-ctx.Done(): - return nil, fmt.Errorf("verification cancelled: %w", ctx.Err()) + return nil, ctx.Err() default: } @@ -93,8 +96,8 @@ func verifyNode( recomputed := hashFn(leftHash, rightHash) if recomputed.Cmp(node.Value) != 0 { return nil, fmt.Errorf( - "node corruption detected at key %s: stored hash=%s, recomputed hash=%s", - key.String(), node.Value.String(), recomputed.String(), + "%w: node at key %s, stored hash=%s, recomputed hash=%s", + ErrCorruptionDetected, key.String(), node.Value.String(), recomputed.String(), ) } diff --git a/verify/trie/trie_test.go b/verify/trie/trie_test.go index 5fdc307207..34f4a293ac 100644 --- a/verify/trie/trie_test.go +++ b/verify/trie/trie_test.go @@ -151,7 +151,7 @@ func TestTrieVerifier_Run_CorruptedTrie(t *testing.T) { ctx := context.Background() err = verifier.Run(ctx) require.Error(t, err) - assert.Contains(t, err.Error(), "node corruption detected") + assert.ErrorIs(t, err, ErrCorruptionDetected) } func TestTrieVerifier_Run_MultipleTrieTypes(t *testing.T) { diff --git a/verify/trie/trie_verifier.go b/verify/trie/trie_verifier.go index 9e2373c3bd..946d517a2a 100644 --- a/verify/trie/trie_verifier.go +++ b/verify/trie/trie_verifier.go @@ -47,7 +47,7 @@ func (v *TrieVerifier) Run(ctx context.Context) error { startTime := time.Now() defer func() { elapsed := time.Since(startTime) - v.logger.Info("=== Trie verification finished ===", + v.logger.Info("Trie verification finished", zap.Duration("total_elapsed", elapsed.Round(time.Second))) }() @@ -60,16 +60,12 @@ func (v *TrieVerifier) Run(ctx context.Context) error { stateTrieInfo := TrieInfo{ Name: "ContractsTrie", Prefix: db.StateTrie.Key(), - HashFunc: trie.NewTriePedersen, HashFn: crypto.Pedersen, ReaderFunc: trie.NewTrieReaderPedersen, Height: StarknetTrieHeight, } - if err := v.verifyTrieWithLogging(ctx, stateTrieInfo); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - return err + if err := v.verifyTrie(ctx, stateTrieInfo); err != nil { + return v.handleResult(err, stateTrieInfo.Name) } } @@ -77,33 +73,43 @@ func (v *TrieVerifier) Run(ctx context.Context) error { classTrieInfo := TrieInfo{ Name: "ClassesTrie", Prefix: db.ClassesTrie.Key(), - HashFunc: trie.NewTriePoseidon, HashFn: crypto.Poseidon, ReaderFunc: trie.NewTrieReaderPoseidon, Height: StarknetTrieHeight, } - if err := v.verifyTrieWithLogging(ctx, classTrieInfo); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - return err + if err := v.verifyTrie(ctx, classTrieInfo); err != nil { + return v.handleResult(err, classTrieInfo.Name) } } if typeSet[ContractStorageTrie] { if err := v.verifyContractStorageTries(ctx, v.contractAddress); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - return err + return v.handleResult(err, "ContractStorageTries") } } - v.logger.Info("=== Trie verification completed successfully ===") + v.logger.Info("Trie verification completed successfully") return nil } -func (v *TrieVerifier) collectContractAddresses() []felt.Felt { +func (v *TrieVerifier) handleResult(err error, trieName string) error { + if errors.Is(err, context.Canceled) { + v.logger.Info("Verification stopped", zap.String("trie", trieName)) + return nil + } + if errors.Is(err, ErrCorruptionDetected) { + v.logger.Info("Corruption detected", + zap.String("trie", trieName), + zap.String("details", err.Error())) + return err + } + v.logger.Error("Verification error", + zap.String("trie", trieName), + zap.Error(err)) + return err +} + +func (v *TrieVerifier) collectContractAddresses() ([]felt.Felt, error) { contractAddresses := make([]felt.Felt, 0) stateTriePrefix := db.StateTrie.Key() @@ -138,27 +144,15 @@ func (v *TrieVerifier) collectContractAddresses() []felt.Felt { return nil }) if err != nil { - return nil + return nil, fmt.Errorf("failed to collect contract addresses: %w", err) } - return contractAddresses -} - -func (v *TrieVerifier) verifyTrieWithLogging(ctx context.Context, trieInfo TrieInfo) error { - err := v.verifyTrie(ctx, trieInfo) - if err != nil { - if errors.Is(err, context.Canceled) { - v.logger.Info("Verification stopped", zap.String("trie", trieInfo.Name)) - return err - } - v.logger.Error(fmt.Sprintf("%s verification failed: %v", trieInfo.Name, err)) - return fmt.Errorf("%s verification failed: %w", trieInfo.Name, err) - } - return nil + return contractAddresses, nil } func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error { - v.logger.Info(fmt.Sprintf("=== Starting %s verification ===", trieInfo.Name)) + v.logger.Info("Starting trie verification", zap.String("trie", trieInfo.Name)) + expectedRoot := felt.Zero err := v.database.View(func(snap db.Snapshot) error { reader, err := trieInfo.ReaderFunc(snap, trieInfo.Prefix, trieInfo.Height) @@ -166,14 +160,12 @@ func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error return err } if reader.RootKey() == nil { - expectedRoot = felt.Zero return nil } expectedRoot, err = reader.Hash() return err }) if err != nil { - v.logger.Error(fmt.Sprintf("Failed to get stored root hash for %s: %v", trieInfo.Name, err)) return fmt.Errorf("failed to get stored root hash for %s: %w", trieInfo.Name, err) } @@ -182,54 +174,51 @@ func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error return nil } - v.logger.Info("Starting verification", + v.logger.Info("Verifying trie", zap.String("trie", trieInfo.Name), zap.String("expectedRoot", expectedRoot.String())) - storageReader := trie.NewReadStorage(v.database, trieInfo.Prefix) - err = VerifyTrie(ctx, storageReader, StarknetTrieHeight, trieInfo.HashFn, &expectedRoot) - if err != nil { + storageReader := trie.NewReadStorage(v.database, trieInfo.Prefix) + if err := VerifyTrie(ctx, storageReader, trieInfo.Height, trieInfo.HashFn, &expectedRoot); err != nil { return err } v.logger.Info("Trie verification successful", - zap.String("trie", trieInfo.Name), zap.String("root", expectedRoot.String())) + zap.String("trie", trieInfo.Name), + zap.String("root", expectedRoot.String())) return nil } func (v *TrieVerifier) verifyContractStorageTries( ctx context.Context, filterAddress *felt.Felt, ) error { - v.logger.Info("=== Starting Contract Storage Tries verification ===") - var contractAddresses []felt.Felt if filterAddress != nil { contractAddresses = []felt.Felt{*filterAddress} - v.logger.Info("Verifying specific contract", - zap.String("address", filterAddress.String())) } else { - contractAddresses = v.collectContractAddresses() - if len(contractAddresses) == 0 { - v.logger.Info("No contract addresses found, skipping contract storage verification") - return nil + var err error + contractAddresses, err = v.collectContractAddresses() + if err != nil { + return err } - v.logger.Info("Found contracts to verify", - zap.Int("count", len(contractAddresses))) } + v.logger.Info("Starting contract storage tries verification", + zap.Int("count", len(contractAddresses))) + for i, contractAddress := range contractAddresses { select { case <-ctx.Done(): return ctx.Err() default: } + addrBytes := contractAddress.Marshal() prefix := db.ContractStorage.Key(addrBytes) trieInfo := TrieInfo{ - Name: fmt.Sprintf("ContractStorage[%s]", contractAddress.String()), - Prefix: db.ContractStorage.Key(addrBytes), - HashFunc: trie.NewTriePedersen, - HashFn: crypto.Pedersen, + Name: fmt.Sprintf("ContractStorage[%s]", contractAddress.String()), + Prefix: prefix, + HashFn: crypto.Pedersen, ReaderFunc: func(r db.KeyValueReader, _ []byte, height uint8) (trie.TrieReader, error) { return trie.NewTrieReaderPedersen(r, prefix, height) }, @@ -240,18 +229,8 @@ func (v *TrieVerifier) verifyContractStorageTries( zap.String("contract", contractAddress.String()), zap.String("progress", fmt.Sprintf("%d/%d", i+1, len(contractAddresses)))) - err := v.verifyTrie(ctx, trieInfo) - if err != nil { - if errors.Is(err, context.Canceled) { - return err - } - v.logger.Error(fmt.Sprintf("Contract storage verification failed for %s: %v", - contractAddress.String(), err)) - return fmt.Errorf( - "contract storage verification failed for %s: %w", - contractAddress.String(), - err, - ) + if err := v.verifyTrie(ctx, trieInfo); err != nil { + return err } } diff --git a/verify/trie/types.go b/verify/trie/types.go index 1f2227f2b8..47d172cb22 100644 --- a/verify/trie/types.go +++ b/verify/trie/types.go @@ -19,10 +19,18 @@ const ( ContractStorageTrie TrieType = "contract-storage" ) +func (t TrieType) IsValid() bool { + switch t { + case ContractTrie, ClassTrie, ContractStorageTrie: + return true + default: + return false + } +} + type TrieInfo struct { Name string Prefix []byte - HashFunc trie.NewTrieFunc HashFn crypto.HashFn ReaderFunc func(db.KeyValueReader, []byte, uint8) (trie.TrieReader, error) Height uint8 From e6e0119fe76091b5f08e70220d93ce0982cd4224 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 9 Feb 2026 10:51:31 +0100 Subject: [PATCH 13/15] comments, refactor --- cmd/juno/verify/trie_test.go | 20 +- cmd/juno/verify/verify.go | 2 +- verify/trie/trie_core.go | 28 ++- verify/trie/trie_verifier.go | 1 + .../{trie_test.go => trie_verifier_test.go} | 217 +++++++++++++++++- 5 files changed, 241 insertions(+), 27 deletions(-) rename verify/trie/{trie_test.go => trie_verifier_test.go} (50%) diff --git a/cmd/juno/verify/trie_test.go b/cmd/juno/verify/trie_test.go index 2186060458..e371162b06 100644 --- a/cmd/juno/verify/trie_test.go +++ b/cmd/juno/verify/trie_test.go @@ -38,6 +38,20 @@ func TestRunTrieVerify_AddressFlagValidation(t *testing.T) { address: "0x123", expectError: false, }, + { + name: "invalid type should fail", + trieTypes: []string{"invalid-type"}, + address: "", + expectError: true, + expectedErrMsg: "invalid trie type", + }, + { + name: "invalid address format should fail", + trieTypes: []string{"contract-storage"}, + address: "not-a-hex", + expectError: true, + expectedErrMsg: "invalid contract address", + }, } for _, tt := range tests { @@ -72,8 +86,12 @@ func TestRunTrieVerify_AddressFlagValidation(t *testing.T) { assert.Contains(t, err.Error(), tt.expectedErrMsg) } } else if err != nil { + // For "success" cases, we're testing flag validation, not full execution. + // The command may fail downstream (empty DB, no data) - that's expected. + // We only verify that the specific flag validation error we're testing didn't occur. addrFlagErr := "--address flag can only be used with --type contract-storage" - assert.NotContains(t, err.Error(), addrFlagErr) + assert.NotContains(t, err.Error(), addrFlagErr, + "flag validation should pass; downstream errors are acceptable") } }) } diff --git a/cmd/juno/verify/verify.go b/cmd/juno/verify/verify.go index 8c91591dc6..8c8fec7dc5 100644 --- a/cmd/juno/verify/verify.go +++ b/cmd/juno/verify/verify.go @@ -60,7 +60,7 @@ func verifyAll(cmd *cobra.Command, args []string) error { for _, v := range verifiers { if err := v.Run(ctx); err != nil { - return fmt.Errorf("%s verification failed: %w", v.Name(), err) + return fmt.Errorf("%s verification stopped: %w", v.Name(), err) } } diff --git a/verify/trie/trie_core.go b/verify/trie/trie_core.go index d281e7c95b..ce95b94156 100644 --- a/verify/trie/trie_core.go +++ b/verify/trie/trie_core.go @@ -40,7 +40,7 @@ func VerifyTrie( if rootHash.Cmp(expectedRoot) != 0 { return fmt.Errorf( "%w: root hash mismatch, expected %s, got %s (verification took %v)", - ErrCorruptionDetected, expectedRoot, rootHash, elapsed.Round(time.Second), + ErrCorruptionDetected, expectedRoot.String(), rootHash.String(), elapsed.Round(time.Second), ) } @@ -54,48 +54,46 @@ func verifyNode( parentKey *trie.BitArray, height uint8, hashFn crypto.HashFn, -) (*felt.Felt, error) { +) (felt.Felt, error) { select { case <-ctx.Done(): - return nil, ctx.Err() + return felt.Zero, ctx.Err() default: } node, err := reader.Get(key) if err != nil { - return nil, fmt.Errorf("failed to get node at key %s: %w", key.String(), err) + return felt.Zero, fmt.Errorf("failed to get node at key %s: %w", key.String(), err) } if key.Len() == height { p := path(key, parentKey) h := node.Hash(&p, hashFn) - return &h, nil + return h, nil } - leftFn := func(ctx context.Context) (*felt.Felt, error) { + leftFn := func(ctx context.Context) (felt.Felt, error) { if node.Left.IsEmpty() { - zero := felt.Zero - return &zero, nil + return felt.Zero, nil } return verifyNode(ctx, reader, node.Left, key, height, hashFn) } - rightFn := func(ctx context.Context) (*felt.Felt, error) { + rightFn := func(ctx context.Context) (felt.Felt, error) { if node.Right.IsEmpty() { - zero := felt.Zero - return &zero, nil + return felt.Zero, nil } return verifyNode(ctx, reader, node.Right, key, height, hashFn) } leftHash, rightHash, err := TraverseBinary(ctx, key.Len(), ConcurrencyMaxDepth, leftFn, rightFn) if err != nil { - return nil, err + return felt.Zero, err } - recomputed := hashFn(leftHash, rightHash) + recomputed := hashFn(&leftHash, &rightHash) if recomputed.Cmp(node.Value) != 0 { - return nil, fmt.Errorf( + return felt.Zero, fmt.Errorf( "%w: node at key %s, stored hash=%s, recomputed hash=%s", ErrCorruptionDetected, key.String(), node.Value.String(), recomputed.String(), ) @@ -106,7 +104,7 @@ func verifyNode( p := path(key, parentKey) h := tmp.Hash(&p, hashFn) - return &h, nil + return h, nil } func path(key, parentKey *trie.BitArray) trie.BitArray { diff --git a/verify/trie/trie_verifier.go b/verify/trie/trie_verifier.go index 946d517a2a..26617a53fa 100644 --- a/verify/trie/trie_verifier.go +++ b/verify/trie/trie_verifier.go @@ -144,6 +144,7 @@ func (v *TrieVerifier) collectContractAddresses() ([]felt.Felt, error) { return nil }) if err != nil { + v.logger.Error("Failed to collect contract addresses", zap.Error(err)) return nil, fmt.Errorf("failed to collect contract addresses: %w", err) } diff --git a/verify/trie/trie_test.go b/verify/trie/trie_verifier_test.go similarity index 50% rename from verify/trie/trie_test.go rename to verify/trie/trie_verifier_test.go index 34f4a293ac..3ec4d70272 100644 --- a/verify/trie/trie_test.go +++ b/verify/trie/trie_verifier_test.go @@ -38,9 +38,6 @@ func TestTrieVerifier_Run_ValidStateTrie(t *testing.T) { err = testTrie.Commit() require.NoError(t, err) - rootHash, err := testTrie.Hash() - require.NoError(t, err) - if testTrie.RootKey() != nil { err = trieStorage.PutRootKey(testTrie.RootKey()) require.NoError(t, err) @@ -54,12 +51,6 @@ func TestTrieVerifier_Run_ValidStateTrie(t *testing.T) { ctx := context.Background() err = verifier.Run(ctx) assert.NoError(t, err) - - reader, err := trie.NewTrieReaderPedersen(testDB, prefix, StarknetTrieHeight) - require.NoError(t, err) - storedHash, err := reader.Hash() - require.NoError(t, err) - assert.True(t, rootHash.Equal(&storedHash)) } func TestTrieVerifier_Run_ValidClassTrie(t *testing.T) { @@ -211,6 +202,148 @@ func TestTrieVerifier_Run_EmptyTrie(t *testing.T) { testDB := memory.New() defer testDB.Close() + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + + testTrie, err := trie.NewTriePedersen(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + // Empty trie has nil RootKey, so no need to store it + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_ValidContractStorageTrie(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + txn := testDB.NewIndexedBatch() + + // Create a state trie with a contract address + // The key at height 251 represents a contract address + statePrefix := db.StateTrie.Key() + stateTrie, err := trie.NewTriePedersen(txn, statePrefix, StarknetTrieHeight) + require.NoError(t, err) + + contractAddr := felt.NewFromUint64[felt.Felt](0x1234) + contractValue := felt.NewFromUint64[felt.Felt](1) + _, err = stateTrie.Put(contractAddr, contractValue) + require.NoError(t, err) + + err = stateTrie.Commit() + require.NoError(t, err) + + stateStorage := trie.NewStorage(txn, statePrefix) + if stateTrie.RootKey() != nil { + err = stateStorage.PutRootKey(stateTrie.RootKey()) + require.NoError(t, err) + } + + // Create a contract storage trie for this contract + storagePrefix := db.ContractStorage.Key(contractAddr.Marshal()) + storageTrie, err := trie.NewTriePedersen(txn, storagePrefix, StarknetTrieHeight) + require.NoError(t, err) + + storageKey := felt.NewFromUint64[felt.Felt](1) + storageValue := felt.NewFromUint64[felt.Felt](100) + _, err = storageTrie.Put(storageKey, storageValue) + require.NoError(t, err) + + err = storageTrie.Commit() + require.NoError(t, err) + + storageTrieStorage := trie.NewStorage(txn, storagePrefix) + if storageTrie.RootKey() != nil { + err = storageTrieStorage.PutRootKey(storageTrie.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractStorageTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_ContractStorageWithFilter(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + txn := testDB.NewIndexedBatch() + + // Create state trie with two contracts + statePrefix := db.StateTrie.Key() + stateTrie, err := trie.NewTriePedersen(txn, statePrefix, StarknetTrieHeight) + require.NoError(t, err) + + contractAddr1 := felt.NewFromUint64[felt.Felt](0x1111) + contractAddr2 := felt.NewFromUint64[felt.Felt](0x2222) + + _, err = stateTrie.Put(contractAddr1, felt.NewFromUint64[felt.Felt](1)) + require.NoError(t, err) + _, err = stateTrie.Put(contractAddr2, felt.NewFromUint64[felt.Felt](2)) + require.NoError(t, err) + + err = stateTrie.Commit() + require.NoError(t, err) + + stateStorage := trie.NewStorage(txn, statePrefix) + if stateTrie.RootKey() != nil { + err = stateStorage.PutRootKey(stateTrie.RootKey()) + require.NoError(t, err) + } + + // Create contract storage trie only for contract 1 + storagePrefix1 := db.ContractStorage.Key(contractAddr1.Marshal()) + storageTrie1, err := trie.NewTriePedersen(txn, storagePrefix1, StarknetTrieHeight) + require.NoError(t, err) + + _, err = storageTrie1.Put( + felt.NewFromUint64[felt.Felt](1), + felt.NewFromUint64[felt.Felt](100), + ) + require.NoError(t, err) + + err = storageTrie1.Commit() + require.NoError(t, err) + + storage1 := trie.NewStorage(txn, storagePrefix1) + if storageTrie1.RootKey() != nil { + err = storage1.PutRootKey(storageTrie1.RootKey()) + require.NoError(t, err) + } + + err = txn.Write() + require.NoError(t, err) + + // Verify only contract 1 using the filter - should succeed + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractStorageTrie}, contractAddr1) + + ctx := context.Background() + err = verifier.Run(ctx) + assert.NoError(t, err) +} + +func TestTrieVerifier_Run_ContextCancellation(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + prefix := db.StateTrie.Key() txn := testDB.NewIndexedBatch() trieStorage := trie.NewStorage(txn, prefix) @@ -218,6 +351,11 @@ func TestTrieVerifier_Run_EmptyTrie(t *testing.T) { testTrie, err := trie.NewTriePedersen(txn, prefix, StarknetTrieHeight) require.NoError(t, err) + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + err = testTrie.Commit() require.NoError(t, err) @@ -231,7 +369,66 @@ func TestTrieVerifier_Run_EmptyTrie(t *testing.T) { verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie}, nil) - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = verifier.Run(ctx) assert.NoError(t, err) } + +func TestTrieVerifier_Run_RootHashMismatch(t *testing.T) { + logger := utils.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + prefix := db.StateTrie.Key() + txn := testDB.NewIndexedBatch() + trieStorage := trie.NewStorage(txn, prefix) + + testTrie, err := trie.NewTriePedersen(txn, prefix, StarknetTrieHeight) + require.NoError(t, err) + + key1 := felt.NewFromUint64[felt.Felt](1) + value1 := felt.NewFromUint64[felt.Felt](100) + _, err = testTrie.Put(key1, value1) + require.NoError(t, err) + + key2 := felt.NewFromUint64[felt.Felt](2) + value2 := felt.NewFromUint64[felt.Felt](200) + _, err = testTrie.Put(key2, value2) + require.NoError(t, err) + + err = testTrie.Commit() + require.NoError(t, err) + + if testTrie.RootKey() != nil { + err = trieStorage.PutRootKey(testTrie.RootKey()) + require.NoError(t, err) + } + + // Get the root key and corrupt the stored root hash (not node values) + rootKey := testTrie.RootKey() + require.NotNil(t, rootKey) + + rootNode, err := trieStorage.Get(rootKey) + require.NoError(t, err) + require.NotNil(t, rootNode) + + // Corrupt the root node's value (stored hash) while keeping node structure intact + // This will cause root hash mismatch, not node corruption during traversal + corruptedHash := felt.NewFromUint64[felt.Felt](999999) + rootNode.Value = corruptedHash + + err = trieStorage.Put(rootKey, rootNode) + require.NoError(t, err) + + err = txn.Write() + require.NoError(t, err) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractTrie}, nil) + + ctx := context.Background() + err = verifier.Run(ctx) + require.Error(t, err) + assert.ErrorIs(t, err, ErrCorruptionDetected) +} From b7a8e79fcfcda62dd41a4210c7195fd7aba72c70 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Mon, 9 Feb 2026 11:08:15 +0100 Subject: [PATCH 14/15] linter --- verify/trie/trie_verifier.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/verify/trie/trie_verifier.go b/verify/trie/trie_verifier.go index 26617a53fa..f745c62719 100644 --- a/verify/trie/trie_verifier.go +++ b/verify/trie/trie_verifier.go @@ -180,7 +180,13 @@ func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error zap.String("expectedRoot", expectedRoot.String())) storageReader := trie.NewReadStorage(v.database, trieInfo.Prefix) - if err := VerifyTrie(ctx, storageReader, trieInfo.Height, trieInfo.HashFn, &expectedRoot); err != nil { + if err := VerifyTrie( + ctx, + storageReader, + trieInfo.Height, + trieInfo.HashFn, + &expectedRoot, + ); err != nil { return err } From b9cdb918a6293530a5502ace7b264fb9bf04a171 Mon Sep 17 00:00:00 2001 From: MaksymMalicki Date: Thu, 23 Apr 2026 11:58:19 +0200 Subject: [PATCH 15/15] chore: self review --- cmd/juno/juno.go | 8 +- cmd/juno/verify/trie.go | 23 +-- cmd/juno/verify/trie_test.go | 20 +-- cmd/juno/verify/verify.go | 36 ---- verify/trie/traversal.go | 28 +-- verify/trie/trie_core.go | 27 +-- verify/trie/trie_verifier.go | 280 +++++++++++++++--------------- verify/trie/trie_verifier_test.go | 76 ++++++-- verify/trie/types.go | 13 +- 9 files changed, 257 insertions(+), 254 deletions(-) diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index cd1a85176c..5af4434870 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -486,8 +486,12 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr junoCmd.Flags().Bool( disableReceivedTxnStreamF, defaultDisableReceivedTxnStream, disableReceivedTxnStreamUsage, ) - junoCmd.AddCommand(GenP2PKeyPair(), DBCmd(defaultDBPath), CompileSierraCmd()) - junoCmd.AddCommand(GenP2PKeyPair(), DBCmd(defaultDBPath), verify.VerifyCmd(defaultDBPath)) + junoCmd.AddCommand( + GenP2PKeyPair(), + DBCmd(defaultDBPath), + CompileSierraCmd(), + verify.VerifyCmd(defaultDBPath), + ) return junoCmd } diff --git a/cmd/juno/verify/trie.go b/cmd/juno/verify/trie.go index 0358562789..8dc125cc2b 100644 --- a/cmd/juno/verify/trie.go +++ b/cmd/juno/verify/trie.go @@ -5,7 +5,7 @@ import ( "slices" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/log" verifytrie "github.com/NethermindEth/juno/verify/trie" "github.com/spf13/cobra" ) @@ -48,12 +48,6 @@ func runTrieVerify(cmd *cobra.Command, args []string) error { return err } - database, err := openDB(dbPath) - if err != nil { - return err - } - defer database.Close() - trieTypes, err := cmd.Flags().GetStringSlice(verifyTrieType) if err != nil { return err @@ -78,11 +72,7 @@ func runTrieVerify(cmd *cobra.Command, args []string) error { var contractAddr *felt.Felt if contractAddrStr != "" { - hasContractStorage := slices.Contains(tries, verifytrie.ContractStorageTrie) - if len(tries) == 0 { - hasContractStorage = true - } - + hasContractStorage := len(tries) == 0 || slices.Contains(tries, verifytrie.ContractStorageTrie) if !hasContractStorage { return fmt.Errorf("--address flag can only be used with --type contract-storage") } @@ -95,8 +85,13 @@ func runTrieVerify(cmd *cobra.Command, args []string) error { contractAddr = &addr } - logLevel := utils.NewLogLevel(utils.INFO) - logger, err := utils.NewZapLogger(logLevel, true) + database, err := openDB(dbPath) + if err != nil { + return err + } + defer database.Close() + + logger, err := log.NewZapLogger(log.NewLevel(log.INFO)) if err != nil { return fmt.Errorf("failed to create logger: %w", err) } diff --git a/cmd/juno/verify/trie_test.go b/cmd/juno/verify/trie_test.go index e371162b06..b5c238635e 100644 --- a/cmd/juno/verify/trie_test.go +++ b/cmd/juno/verify/trie_test.go @@ -3,10 +3,8 @@ package verify import ( "context" "os" - "path/filepath" "testing" - "github.com/NethermindEth/juno/db/pebblev2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -32,6 +30,13 @@ func TestRunTrieVerify_AddressFlagValidation(t *testing.T) { expectError: true, expectedErrMsg: "--address flag can only be used with --type contract-storage", }, + { + name: "address with contract type only should fail", + trieTypes: []string{"contract"}, + address: "0x1", + expectError: true, + expectedErrMsg: "--address flag can only be used with --type contract-storage", + }, { name: "address with no type specified should succeed (default includes contract-storage)", trieTypes: []string{}, @@ -56,15 +61,8 @@ func TestRunTrieVerify_AddressFlagValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tempDir := t.TempDir() - dbPath := filepath.Join(tempDir, "test.db") - - testDB, err := pebblev2.New(dbPath) - require.NoError(t, err) - testDB.Close() - parentCmd := VerifyCmd("") - args := []string{"--db-path", dbPath, "trie"} + args := []string{"--db-path", "ignored", "trie"} for _, trieType := range tt.trieTypes { args = append(args, "--type", trieType) @@ -78,7 +76,7 @@ func TestRunTrieVerify_AddressFlagValidation(t *testing.T) { parentCmd.SetOut(os.Stderr) parentCmd.SetErr(os.Stderr) - err = parentCmd.ExecuteContext(context.Background()) + err := parentCmd.ExecuteContext(context.Background()) if tt.expectError { require.Error(t, err) diff --git a/cmd/juno/verify/verify.go b/cmd/juno/verify/verify.go index 8c8fec7dc5..3611ed2789 100644 --- a/cmd/juno/verify/verify.go +++ b/cmd/juno/verify/verify.go @@ -8,8 +8,6 @@ import ( "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebblev2" - "github.com/NethermindEth/juno/utils" - "github.com/NethermindEth/juno/verify/trie" "github.com/spf13/cobra" ) @@ -29,44 +27,10 @@ func VerifyCmd(defaultDBPath string) *cobra.Command { verifyCmd.PersistentFlags().String(verifyDBPathF, defaultDBPath, "Path to the database") verifyCmd.AddCommand(verifyTrieCmd()) - verifyCmd.RunE = verifyAll return verifyCmd } -func verifyAll(cmd *cobra.Command, args []string) error { - dbPath, err := cmd.Flags().GetString(verifyDBPathF) - if err != nil { - return err - } - - database, err := openDB(dbPath) - if err != nil { - return err - } - defer database.Close() - - logLevel := utils.NewLogLevel(utils.INFO) - logger, err := utils.NewZapLogger(logLevel, true) - if err != nil { - return fmt.Errorf("failed to create logger: %w", err) - } - - ctx := cmd.Context() - - verifiers := []Verifier{ - trie.NewTrieVerifier(database, logger, nil, nil), - } - - for _, v := range verifiers { - if err := v.Run(ctx); err != nil { - return fmt.Errorf("%s verification stopped: %w", v.Name(), err) - } - } - - return nil -} - func openDB(path string) (db.KeyValueStore, error) { _, err := os.Stat(path) if os.IsNotExist(err) { diff --git a/verify/trie/traversal.go b/verify/trie/traversal.go index 29e94c3bb4..f4323780ca 100644 --- a/verify/trie/traversal.go +++ b/verify/trie/traversal.go @@ -2,7 +2,8 @@ package trie import ( "context" - "sync" + + "golang.org/x/sync/errgroup" ) func TraverseBinary[T any]( @@ -23,23 +24,22 @@ func traverseConcurrently[T any]( leftFn func(ctx context.Context) (T, error), rightFn func(ctx context.Context) (T, error), ) (left, right T, err error) { - var leftErr, rightErr error - var wg sync.WaitGroup + eg, gCtx := errgroup.WithContext(ctx) - wg.Go(func() { - left, leftErr = leftFn(ctx) + eg.Go(func() error { + var err error + left, err = leftFn(gCtx) + return err }) - right, rightErr = rightFn(ctx) - wg.Wait() + eg.Go(func() error { + var err error + right, err = rightFn(gCtx) + return err + }) - if leftErr != nil { - return left, right, leftErr - } - if rightErr != nil { - return left, right, rightErr - } - return left, right, nil + err = eg.Wait() + return left, right, err } func traverseSequentially[T any]( diff --git a/verify/trie/trie_core.go b/verify/trie/trie_core.go index ce95b94156..1a89b11e8b 100644 --- a/verify/trie/trie_core.go +++ b/verify/trie/trie_core.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "time" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" @@ -29,18 +28,15 @@ func VerifyTrie( return nil } - startTime := time.Now() rootHash, err := verifyNode(ctx, reader, rootKey, nil, height, hashFn) if err != nil { return err } - elapsed := time.Since(startTime) - if rootHash.Cmp(expectedRoot) != 0 { return fmt.Errorf( - "%w: root hash mismatch, expected %s, got %s (verification took %v)", - ErrCorruptionDetected, expectedRoot.String(), rootHash.String(), elapsed.Round(time.Second), + "%w: root hash mismatch, expected %s, got %s", + ErrCorruptionDetected, expectedRoot.String(), rootHash.String(), ) } @@ -55,10 +51,8 @@ func verifyNode( height uint8, hashFn crypto.HashFn, ) (felt.Felt, error) { - select { - case <-ctx.Done(): - return felt.Zero, ctx.Err() - default: + if err := ctx.Err(); err != nil { + return felt.Zero, err } node, err := reader.Get(key) @@ -66,10 +60,10 @@ func verifyNode( return felt.Zero, fmt.Errorf("failed to get node at key %s: %w", key.String(), err) } + p := path(key, parentKey) + if key.Len() == height { - p := path(key, parentKey) - h := node.Hash(&p, hashFn) - return h, nil + return node.Hash(&p, hashFn), nil } leftFn := func(ctx context.Context) (felt.Felt, error) { @@ -99,12 +93,7 @@ func verifyNode( ) } - tmp := *node - tmp.Value = &recomputed - - p := path(key, parentKey) - h := tmp.Hash(&p, hashFn) - return h, nil + return node.Hash(&p, hashFn), nil } func path(key, parentKey *trie.BitArray) trie.BitArray { diff --git a/verify/trie/trie_verifier.go b/verify/trie/trie_verifier.go index f745c62719..dff5e66012 100644 --- a/verify/trie/trie_verifier.go +++ b/verify/trie/trie_verifier.go @@ -1,7 +1,6 @@ package trie import ( - "bytes" "context" "errors" "fmt" @@ -11,30 +10,52 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/log" "go.uber.org/zap" ) +var ( + stateTrieInfo = TrieInfo{ + Name: "ContractsTrie", + Prefix: db.StateTrie.Key(), + HashFn: crypto.Pedersen, + ReaderFunc: func(r db.KeyValueReader, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPedersen(r, db.StateTrie.Key(), height) + }, + Height: StarknetTrieHeight, + } + + classTrieInfo = TrieInfo{ + Name: "ClassesTrie", + Prefix: db.ClassesTrie.Key(), + HashFn: crypto.Poseidon, + ReaderFunc: func(r db.KeyValueReader, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPoseidon(r, db.ClassesTrie.Key(), height) + }, + Height: StarknetTrieHeight, + } +) + type TrieVerifier struct { database db.KeyValueStore - logger utils.StructuredLogger - tries []TrieType + logger log.StructuredLogger + trieTypes []TrieType contractAddress *felt.Felt } func NewTrieVerifier( database db.KeyValueStore, - logger utils.StructuredLogger, - tries []TrieType, + logger log.StructuredLogger, + trieTypes []TrieType, contractAddress *felt.Felt, ) *TrieVerifier { - if len(tries) == 0 { - tries = []TrieType{ContractTrie, ClassTrie, ContractStorageTrie} + if len(trieTypes) == 0 { + trieTypes = allTrieTypes } return &TrieVerifier{ database: database, logger: logger, - tries: tries, + trieTypes: trieTypes, contractAddress: contractAddress, } } @@ -46,140 +67,96 @@ func (v *TrieVerifier) Name() string { func (v *TrieVerifier) Run(ctx context.Context) error { startTime := time.Now() defer func() { - elapsed := time.Since(startTime) v.logger.Info("Trie verification finished", - zap.Duration("total_elapsed", elapsed.Round(time.Second))) + zap.Duration("total_elapsed", time.Since(startTime).Round(time.Second))) }() - typeSet := make(map[TrieType]bool) - for _, t := range v.tries { - typeSet[t] = true + err := v.database.View(func(snap db.Snapshot) error { + return v.verifyAll(ctx, snap) + }) + if errors.Is(err, context.Canceled) { + return nil } - - if typeSet[ContractTrie] { - stateTrieInfo := TrieInfo{ - Name: "ContractsTrie", - Prefix: db.StateTrie.Key(), - HashFn: crypto.Pedersen, - ReaderFunc: trie.NewTrieReaderPedersen, - Height: StarknetTrieHeight, - } - if err := v.verifyTrie(ctx, stateTrieInfo); err != nil { - return v.handleResult(err, stateTrieInfo.Name) - } + if err != nil { + return err } - if typeSet[ClassTrie] { - classTrieInfo := TrieInfo{ - Name: "ClassesTrie", - Prefix: db.ClassesTrie.Key(), - HashFn: crypto.Poseidon, - ReaderFunc: trie.NewTrieReaderPoseidon, - Height: StarknetTrieHeight, - } - if err := v.verifyTrie(ctx, classTrieInfo); err != nil { - return v.handleResult(err, classTrieInfo.Name) - } - } + v.logger.Info("Trie verification completed successfully") + return nil +} - if typeSet[ContractStorageTrie] { - if err := v.verifyContractStorageTries(ctx, v.contractAddress); err != nil { - return v.handleResult(err, "ContractStorageTries") +func (v *TrieVerifier) verifyAll(ctx context.Context, snap db.Snapshot) error { + for _, t := range v.trieTypes { + var ( + name string + err error + ) + switch t { + case ContractTrie: + name, err = stateTrieInfo.Name, v.verifyTrie(ctx, snap, stateTrieInfo) + case ClassTrie: + name, err = classTrieInfo.Name, v.verifyTrie(ctx, snap, classTrieInfo) + case ContractStorageTrie: + name, err = "ContractStorageTries", v.verifyContractStorageTries(ctx, snap, v.contractAddress) + } + if err != nil { + v.logResult(err, name) + return err } } - - v.logger.Info("Trie verification completed successfully") return nil } -func (v *TrieVerifier) handleResult(err error, trieName string) error { - if errors.Is(err, context.Canceled) { +func (v *TrieVerifier) logResult(err error, trieName string) { + switch { + case errors.Is(err, context.Canceled): v.logger.Info("Verification stopped", zap.String("trie", trieName)) - return nil - } - if errors.Is(err, ErrCorruptionDetected) { - v.logger.Info("Corruption detected", + case errors.Is(err, ErrCorruptionDetected): + v.logger.Error("Corruption detected", zap.String("trie", trieName), zap.String("details", err.Error())) - return err + default: + v.logger.Error("Verification error", + zap.String("trie", trieName), + zap.Error(err)) } - v.logger.Error("Verification error", - zap.String("trie", trieName), - zap.Error(err)) - return err } -func (v *TrieVerifier) collectContractAddresses() ([]felt.Felt, error) { - contractAddresses := make([]felt.Felt, 0) - stateTriePrefix := db.StateTrie.Key() - - err := v.database.View(func(snap db.Snapshot) error { - it, err := snap.NewIterator(stateTriePrefix, true) - if err != nil { - return err - } - defer it.Close() - - for it.First(); it.Valid(); it.Next() { - keyBytes := it.Key() - if bytes.Equal(keyBytes, stateTriePrefix) { - continue - } - - if !bytes.HasPrefix(keyBytes, stateTriePrefix) { - continue - } - nodeKeyBytes := keyBytes[len(stateTriePrefix):] - - var nodeKey trie.BitArray - if err := nodeKey.UnmarshalBinary(nodeKeyBytes); err != nil { - continue - } - - if nodeKey.Len() == StarknetTrieHeight { - contractAddr := nodeKey.Felt() - contractAddresses = append(contractAddresses, contractAddr) - } - } - return nil - }) - if err != nil { - v.logger.Error("Failed to collect contract addresses", zap.Error(err)) - return nil, fmt.Errorf("failed to collect contract addresses: %w", err) +func contractStorageTrieInfo(addr *felt.Felt) TrieInfo { + prefix := db.ContractStorage.Key(addr.Marshal()) + return TrieInfo{ + Name: fmt.Sprintf("ContractStorage[%s]", addr.String()), + Prefix: prefix, + HashFn: crypto.Pedersen, + ReaderFunc: func(r db.KeyValueReader, height uint8) (trie.TrieReader, error) { + return trie.NewTrieReaderPedersen(r, prefix, height) + }, + Height: StarknetTrieHeight, } - - return contractAddresses, nil } -func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error { +func (v *TrieVerifier) verifyTrie(ctx context.Context, snap db.Snapshot, trieInfo TrieInfo) error { v.logger.Info("Starting trie verification", zap.String("trie", trieInfo.Name)) - expectedRoot := felt.Zero - err := v.database.View(func(snap db.Snapshot) error { - reader, err := trieInfo.ReaderFunc(snap, trieInfo.Prefix, trieInfo.Height) - if err != nil { - return err - } - if reader.RootKey() == nil { - return nil - } - expectedRoot, err = reader.Hash() - return err - }) + reader, err := trieInfo.ReaderFunc(snap, trieInfo.Height) if err != nil { - return fmt.Errorf("failed to get stored root hash for %s: %w", trieInfo.Name, err) + return fmt.Errorf("failed to open reader for %s: %w", trieInfo.Name, err) } - - if expectedRoot.IsZero() { - v.logger.Info("Trie is empty (zero root)", zap.String("trie", trieInfo.Name)) + if reader.RootKey() == nil { + v.logger.Info("Trie is empty", zap.String("trie", trieInfo.Name)) return nil } + expectedRoot, err := reader.Hash() + if err != nil { + return fmt.Errorf("failed to get stored root hash for %s: %w", trieInfo.Name, err) + } + v.logger.Info("Verifying trie", zap.String("trie", trieInfo.Name), zap.String("expectedRoot", expectedRoot.String())) - storageReader := trie.NewReadStorage(v.database, trieInfo.Prefix) + storageReader := trie.NewReadStorage(snap, trieInfo.Prefix) if err := VerifyTrie( ctx, storageReader, @@ -197,51 +174,72 @@ func (v *TrieVerifier) verifyTrie(ctx context.Context, trieInfo TrieInfo) error } func (v *TrieVerifier) verifyContractStorageTries( - ctx context.Context, filterAddress *felt.Felt, + ctx context.Context, snap db.Snapshot, filterAddress *felt.Felt, ) error { - var contractAddresses []felt.Felt if filterAddress != nil { - contractAddresses = []felt.Felt{*filterAddress} - } else { - var err error - contractAddresses, err = v.collectContractAddresses() - if err != nil { - return err - } + return v.verifyTrie(ctx, snap, contractStorageTrieInfo(filterAddress)) } - v.logger.Info("Starting contract storage tries verification", - zap.Int("count", len(contractAddresses))) + bucketPrefix := db.ContractStorage.Key() + it, err := snap.NewIterator(bucketPrefix, true) + if err != nil { + return fmt.Errorf("failed to open contract storage iterator: %w", err) + } + defer it.Close() + + v.logger.Info("Starting contract storage tries verification") - for i, contractAddress := range contractAddresses { - select { - case <-ctx.Done(): - return ctx.Err() - default: + count := 0 + addrStart := len(bucketPrefix) + addrEnd := addrStart + felt.Bytes + + for ok := it.First(); ok; { + if err := ctx.Err(); err != nil { + return err } - addrBytes := contractAddress.Marshal() - prefix := db.ContractStorage.Key(addrBytes) - trieInfo := TrieInfo{ - Name: fmt.Sprintf("ContractStorage[%s]", contractAddress.String()), - Prefix: prefix, - HashFn: crypto.Pedersen, - ReaderFunc: func(r db.KeyValueReader, _ []byte, height uint8) (trie.TrieReader, error) { - return trie.NewTrieReaderPedersen(r, prefix, height) - }, - Height: StarknetTrieHeight, + key := it.Key() + if len(key) < addrEnd { + // Unexpected short key in bucket — skip it. + ok = it.Next() + continue } + addrBytes := key[addrStart:addrEnd] + var addr felt.Felt + addr.SetBytes(addrBytes) + + count++ v.logger.Info("Verifying contract storage", - zap.String("contract", contractAddress.String()), - zap.String("progress", fmt.Sprintf("%d/%d", i+1, len(contractAddresses)))) + zap.String("contract", addr.String()), + zap.Int("index", count)) - if err := v.verifyTrie(ctx, trieInfo); err != nil { + if err := v.verifyTrie(ctx, snap, contractStorageTrieInfo(&addr)); err != nil { return err } + + nextAddr := nextLexAddr(addrBytes) + if nextAddr == nil { + break + } + ok = it.Seek(db.ContractStorage.Key(nextAddr)) } v.logger.Info("All contract storage tries verified successfully", - zap.Int("count", len(contractAddresses))) + zap.Int("count", count)) + return nil +} + +// nextLexAddr returns the lexicographically next 32-byte address, or nil if +// addr is the maximum value (all 0xff). +func nextLexAddr(addr []byte) []byte { + out := make([]byte, len(addr)) + copy(out, addr) + for i := len(out) - 1; i >= 0; i-- { + out[i]++ + if out[i] != 0 { + return out + } + } return nil } diff --git a/verify/trie/trie_verifier_test.go b/verify/trie/trie_verifier_test.go index 3ec4d70272..d083ffc314 100644 --- a/verify/trie/trie_verifier_test.go +++ b/verify/trie/trie_verifier_test.go @@ -8,13 +8,13 @@ import ( "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/memory" - "github.com/NethermindEth/juno/utils" + "github.com/NethermindEth/juno/utils/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestTrieVerifier_Run_ValidStateTrie(t *testing.T) { - logger := utils.NewNopZapLogger() + logger := log.NewNopZapLogger() testDB := memory.New() defer testDB.Close() @@ -54,7 +54,7 @@ func TestTrieVerifier_Run_ValidStateTrie(t *testing.T) { } func TestTrieVerifier_Run_ValidClassTrie(t *testing.T) { - logger := utils.NewNopZapLogger() + logger := log.NewNopZapLogger() testDB := memory.New() defer testDB.Close() @@ -89,7 +89,7 @@ func TestTrieVerifier_Run_ValidClassTrie(t *testing.T) { } func TestTrieVerifier_Run_CorruptedTrie(t *testing.T) { - logger := utils.NewNopZapLogger() + logger := log.NewNopZapLogger() testDB := memory.New() defer testDB.Close() @@ -146,7 +146,7 @@ func TestTrieVerifier_Run_CorruptedTrie(t *testing.T) { } func TestTrieVerifier_Run_MultipleTrieTypes(t *testing.T) { - logger := utils.NewNopZapLogger() + logger := log.NewNopZapLogger() testDB := memory.New() defer testDB.Close() @@ -198,7 +198,7 @@ func TestTrieVerifier_Run_MultipleTrieTypes(t *testing.T) { } func TestTrieVerifier_Run_EmptyTrie(t *testing.T) { - logger := utils.NewNopZapLogger() + logger := log.NewNopZapLogger() testDB := memory.New() defer testDB.Close() @@ -224,7 +224,7 @@ func TestTrieVerifier_Run_EmptyTrie(t *testing.T) { } func TestTrieVerifier_Run_ValidContractStorageTrie(t *testing.T) { - logger := utils.NewNopZapLogger() + logger := log.NewNopZapLogger() testDB := memory.New() defer testDB.Close() @@ -279,8 +279,64 @@ func TestTrieVerifier_Run_ValidContractStorageTrie(t *testing.T) { assert.NoError(t, err) } +func TestTrieVerifier_Run_ContractStorageCorruption(t *testing.T) { + logger := log.NewNopZapLogger() + testDB := memory.New() + defer testDB.Close() + + txn := testDB.NewIndexedBatch() + + statePrefix := db.StateTrie.Key() + stateTrie, err := trie.NewTriePedersen(txn, statePrefix, StarknetTrieHeight) + require.NoError(t, err) + + contractAddr := felt.NewFromUint64[felt.Felt](0x1234) + _, err = stateTrie.Put(contractAddr, felt.NewFromUint64[felt.Felt](1)) + require.NoError(t, err) + require.NoError(t, stateTrie.Commit()) + + stateStorage := trie.NewStorage(txn, statePrefix) + if stateTrie.RootKey() != nil { + require.NoError(t, stateStorage.PutRootKey(stateTrie.RootKey())) + } + + storagePrefix := db.ContractStorage.Key(contractAddr.Marshal()) + storageTrie, err := trie.NewTriePedersen(txn, storagePrefix, StarknetTrieHeight) + require.NoError(t, err) + + storageKey1 := felt.NewFromUint64[felt.Felt](1) + storageKey2 := felt.NewFromUint64[felt.Felt](2) + _, err = storageTrie.Put(storageKey1, felt.NewFromUint64[felt.Felt](100)) + require.NoError(t, err) + _, err = storageTrie.Put(storageKey2, felt.NewFromUint64[felt.Felt](200)) + require.NoError(t, err) + require.NoError(t, storageTrie.Commit()) + + storageTrieStorage := trie.NewStorage(txn, storagePrefix) + if storageTrie.RootKey() != nil { + require.NoError(t, storageTrieStorage.PutRootKey(storageTrie.RootKey())) + } + + var leafKey trie.BitArray + leafKey.SetFelt(StarknetTrieHeight, storageKey1) + leafNode, err := storageTrieStorage.Get(&leafKey) + require.NoError(t, err) + require.NotNil(t, leafNode) + + leafNode.Value = felt.NewFromUint64[felt.Felt](999999) + require.NoError(t, storageTrieStorage.Put(&leafKey, leafNode)) + + require.NoError(t, txn.Write()) + + verifier := NewTrieVerifier(testDB, logger, []TrieType{ContractStorageTrie}, contractAddr) + + err = verifier.Run(context.Background()) + require.Error(t, err) + assert.ErrorIs(t, err, ErrCorruptionDetected) +} + func TestTrieVerifier_Run_ContractStorageWithFilter(t *testing.T) { - logger := utils.NewNopZapLogger() + logger := log.NewNopZapLogger() testDB := memory.New() defer testDB.Close() @@ -340,7 +396,7 @@ func TestTrieVerifier_Run_ContractStorageWithFilter(t *testing.T) { } func TestTrieVerifier_Run_ContextCancellation(t *testing.T) { - logger := utils.NewNopZapLogger() + logger := log.NewNopZapLogger() testDB := memory.New() defer testDB.Close() @@ -377,7 +433,7 @@ func TestTrieVerifier_Run_ContextCancellation(t *testing.T) { } func TestTrieVerifier_Run_RootHashMismatch(t *testing.T) { - logger := utils.NewNopZapLogger() + logger := log.NewNopZapLogger() testDB := memory.New() defer testDB.Close() diff --git a/verify/trie/types.go b/verify/trie/types.go index 47d172cb22..859091aaa8 100644 --- a/verify/trie/types.go +++ b/verify/trie/types.go @@ -1,6 +1,8 @@ package trie import ( + "slices" + "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" @@ -19,19 +21,16 @@ const ( ContractStorageTrie TrieType = "contract-storage" ) +var allTrieTypes = []TrieType{ContractTrie, ClassTrie, ContractStorageTrie} + func (t TrieType) IsValid() bool { - switch t { - case ContractTrie, ClassTrie, ContractStorageTrie: - return true - default: - return false - } + return slices.Contains(allTrieTypes, t) } type TrieInfo struct { Name string Prefix []byte HashFn crypto.HashFn - ReaderFunc func(db.KeyValueReader, []byte, uint8) (trie.TrieReader, error) + ReaderFunc func(db.KeyValueReader, uint8) (trie.TrieReader, error) Height uint8 }