From d838528d5a24caf84334d308eb47bb5637dbaf26 Mon Sep 17 00:00:00 2001 From: Kai Mast Date: Fri, 20 Feb 2026 14:58:40 -0800 Subject: [PATCH 1/4] build: check that errors are printed using `flatten_error` --- Cargo.lock | 2 + Cargo.toml | 7 + build.rs | 267 ++++++++++++++++++++---------------- node/src/client/mod.rs | 2 +- node/sync/src/block_sync.rs | 2 +- 5 files changed, 163 insertions(+), 117 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4dcfbbe419..5ae85988a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4591,6 +4591,7 @@ dependencies = [ "built", "clap", "locktick", + "proc-macro2", "rusty-hook", "snarkos-account", "snarkos-cli", @@ -4604,6 +4605,7 @@ dependencies = [ "snarkos-node-sync", "snarkos-node-tcp", "snarkvm", + "syn 2.0.115", "tikv-jemallocator", "toml 0.9.12+spec-1.1.0", "tracing", diff --git a/Cargo.toml b/Cargo.toml index dabbbaa0d6..e9289fe64b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -365,6 +365,13 @@ version = "0.9" [build-dependencies.walkdir] version = "2" +[build-dependencies.proc-macro2] +version = "1" +features = ["span-locations"] + +[build-dependencies.syn] +version = "2" + [profile.release] opt-level = 3 lto = "thin" diff --git a/build.rs b/build.rs index 26fe810f2b..a4e42e07ca 100644 --- a/build.rs +++ b/build.rs @@ -22,8 +22,9 @@ use std::{ process, str, }; +use syn::{ExprMacro, Macro, StmtMacro, spanned::Spanned, visit::Visit}; use toml::Value; -use walkdir::WalkDir; +use walkdir::{DirEntry, WalkDir}; // The following license text that should be present at the beginning of every source file. const EXPECTED_LICENSE_TEXT: &[u8] = include_bytes!(".resources/license_header"); @@ -38,35 +39,66 @@ enum ImportOfInterest { Tokio, } -fn check_locktick_imports>(path: P) { +fn should_skip_dir(entry: &DirEntry) -> bool { + let entry_type = entry.file_type(); + if !entry_type.is_dir() { + return false; + } + // Skip root-level dot folders (e.g. .git, .github, .cargo, .ci). + if entry.depth() == 1 && entry.file_name().to_str().is_some_and(|n| n.starts_with('.')) { + return true; + } + // Skip the specified directories at any depth. + DIRS_TO_SKIP.contains(&entry.file_name().to_str().unwrap_or("")) +} + +/// Checks license headers, locktick import balance, and forbidden error formatting in a single +/// directory walk to avoid reading every source file more than once. +fn check_source_files>(path: P) { + // Perform the license year check if on Linux. + if cfg!(target_os = "linux") { + let os_year = process::Command::new("date").arg("+%Y").output().expect("Failed to execute 'date' command"); + let current_year = str::from_utf8(&os_year.stdout).expect("Date output was not valid UTF-8").trim(); + let license_year = str::from_utf8(&EXPECTED_LICENSE_TEXT[22..][..4]).unwrap(); + assert_eq!(license_year, current_year, "The license year doesn't match the current OS year"); + } + + let mut error_formatting_violations: Vec<(String, usize, String)> = Vec::new(); + let mut iter = WalkDir::new(path).into_iter(); while let Some(entry) = iter.next() { let entry = entry.unwrap(); - let entry_type = entry.file_type(); - // Skip the specified directories. - if entry_type.is_dir() && DIRS_TO_SKIP.contains(&entry.file_name().to_str().unwrap_or("")) { + if should_skip_dir(&entry) { iter.skip_current_dir(); + continue; + } + // Only process .rs files. + if !entry.file_type().is_file() || entry.path().extension() != Some(OsStr::new("rs")) { continue; } let path = entry.path(); - // Ignore non-rs - if path.extension() != Some(OsStr::new("rs")) { - continue; + // --- License check (reads only the header bytes) --- + { + let file = File::open(path).unwrap(); + let mut contents = Vec::with_capacity(EXPECTED_LICENSE_TEXT.len()); + file.take(EXPECTED_LICENSE_TEXT.len() as u64).read_to_end(&mut contents).unwrap(); + assert!( + contents == EXPECTED_LICENSE_TEXT, + "The license in \"{}\" is either missing or it doesn't match the expected string!", + path.display() + ); } - // Read the entire file. - let file = fs::read_to_string(path).unwrap(); + // Read the full file once for the remaining checks. + let src = fs::read_to_string(path).unwrap(); - // Prepare a filtered line iterator. - let lines = file - .lines() - .filter(|l| !l.is_empty()) // Ignore empty lines. - .skip_while(|l| !l.starts_with("use")) // Skip the license etc. - .take_while(|l| { // Process the section containing import statements. + // --- Locktick import balance check --- + { + let lines = src.lines().filter(|l| !l.is_empty()).skip_while(|l| !l.starts_with("use")).take_while(|l| { l.starts_with("use") || l.starts_with("#[cfg") || l.starts_with("//") @@ -74,113 +106,74 @@ fn check_locktick_imports>(path: P) { || l.starts_with(|c: char| c.is_ascii_whitespace()) }); - // The currently processed import of interest. - let mut import_of_interest: Option = None; - // This value not being zero at the end of the imports suggests a missing locktick import. - let mut lock_balance: i8 = 0; - - // Process the filtered lines. - for line in lines { - // Check if this is a lock-related import. - if import_of_interest.is_none() { - if line.starts_with("use locktick::") { - import_of_interest = Some(ImportOfInterest::Locktick); - } else if line.starts_with("use parking_lot::") { - import_of_interest = Some(ImportOfInterest::ParkingLot); - } else if line.starts_with("use tokio::") { - import_of_interest = Some(ImportOfInterest::Tokio); + let mut import_of_interest: Option = None; + let mut lock_balance: i8 = 0; + + for line in lines { + if import_of_interest.is_none() { + if line.starts_with("use locktick::") { + import_of_interest = Some(ImportOfInterest::Locktick); + } else if line.starts_with("use parking_lot::") { + import_of_interest = Some(ImportOfInterest::ParkingLot); + } else if line.starts_with("use tokio::") { + import_of_interest = Some(ImportOfInterest::Tokio); + } } - } - // Skip irrelevant imports. - let Some(ioi) = import_of_interest else { - continue; - }; + let Some(ioi) = import_of_interest else { + continue; + }; - // Modify the lock balance based on the type of the relevant import. - if [ImportOfInterest::ParkingLot, ImportOfInterest::Tokio].contains(&ioi) { - if line.contains("Mutex") { - lock_balance += 1; - } - if line.contains("RwLock") { - lock_balance += 1; - } - } else if ioi == ImportOfInterest::Locktick { - // Use `matches` instead of just `contains` here, as more than a single - // lock type entry is possible in a locktick import. - for _hit in line.matches("Mutex") { - lock_balance -= 1; - } - for _hit in line.matches("RwLock") { - lock_balance -= 1; - } - // A correction in case of the `use tokio::Mutex as TMutex` convention. - if line.contains("TMutex") { - lock_balance += 1; + if [ImportOfInterest::ParkingLot, ImportOfInterest::Tokio].contains(&ioi) { + if line.contains("Mutex") { + lock_balance += 1; + } + if line.contains("RwLock") { + lock_balance += 1; + } + } else if ioi == ImportOfInterest::Locktick { + // Use `matches` instead of just `contains` here, as more than a single + // lock type entry is possible in a locktick import. + for _hit in line.matches("Mutex") { + lock_balance -= 1; + } + for _hit in line.matches("RwLock") { + lock_balance -= 1; + } + // A correction in case of the `use tokio::Mutex as TMutex` convention. + if line.contains("TMutex") { + lock_balance += 1; + } } - } - // Register the end of an import statement. - if line.ends_with(";") { - import_of_interest = None; + if line.ends_with(";") { + import_of_interest = None; + } } - } - - // If the file has a lock import "imbalance", print it out and increment the counter. - assert!( - lock_balance == 0, - "The locks in \"{}\" don't seem to have `locktick` counterparts!", - entry.path().display() - ); - } -} -fn check_file_licenses>(path: P) { - let path = path.as_ref(); - - // Perform the license year check if on Linux. - if cfg!(target_os = "linux") { - // Get the current year from the OS - let os_year = process::Command::new("date") - .arg("+%Y") // ask only for the year - .output() - .expect("Failed to execute 'date' command"); - let current_year = str::from_utf8(&os_year.stdout).expect("Date output was not valid UTF-8").trim(); - - // Check if the end of the year range in the license matches the OS year. - let license_year = str::from_utf8(&EXPECTED_LICENSE_TEXT[22..][..4]).unwrap(); - assert_eq!(license_year, current_year, "The license year doesn't match the current OS year"); - } - - let mut iter = WalkDir::new(path).into_iter(); - while let Some(entry) = iter.next() { - let entry = entry.unwrap(); - let entry_type = entry.file_type(); - - // Skip the root-level dot folders (e.g. .git, .github, .cargo, .ci). - if entry_type.is_dir() && entry.depth() == 1 && entry.file_name().to_str().is_some_and(|n| n.starts_with('.')) { - iter.skip_current_dir(); - continue; + assert!( + lock_balance == 0, + "The locks in \"{}\" don't seem to have `locktick` counterparts!", + path.display() + ); } - // Skip the specified directories (any depth). - if entry_type.is_dir() && DIRS_TO_SKIP.contains(&entry.file_name().to_str().unwrap_or("")) { - iter.skip_current_dir(); - continue; + // --- Error formatting check --- + { + let ast = syn::parse_file(&src).unwrap(); + let mut checker = ErrorChecker { violations: Vec::new() }; + checker.visit_file(&ast); + error_formatting_violations + .extend(checker.violations.into_iter().map(|(line, code)| (path.display().to_string(), line, code))); } + } - // Check all files with the ".rs" extension. - if entry_type.is_file() && entry.file_name().to_str().unwrap_or("").ends_with(".rs") { - let file = File::open(entry.path()).unwrap(); - let mut contents = Vec::with_capacity(EXPECTED_LICENSE_TEXT.len()); - file.take(EXPECTED_LICENSE_TEXT.len() as u64).read_to_end(&mut contents).unwrap(); - - assert!( - contents == EXPECTED_LICENSE_TEXT, - "The license in \"{}\" is either missing or it doesn't match the expected string!", - entry.path().display() - ); + if !error_formatting_violations.is_empty() { + eprintln!("Forbidden error formatting found! Use `{{:#}}` or helper like `full_chain()`:"); + for (file, line, code) in error_formatting_violations { + eprintln!("{file}:{line} -> {code}"); } + panic!("Build failed due to forbidden error formatting."); } } @@ -292,12 +285,56 @@ fn check_tokio_console_flags() { } } +/// List of allowed wrapper function names +const ALLOWED_WRAPPERS: &[&str] = &["flatten_error"]; +/// Common variable names used for errors (used to detect captured-identifier format syntax). +const ERROR_VAR_NAMES: &[&str] = &["error", "err", "e"]; + +struct ErrorChecker { + violations: Vec<(usize, String)>, +} + +impl ErrorChecker { + fn check_macro(&mut self, mac: &Macro) { + let mac_name = mac.path.segments.last().unwrap().ident.to_string(); + if !["println", "format", "error", "warn", "info", "debug", "trace"].contains(&mac_name.as_str()) { + return; + } + + let tokens = mac.tokens.to_string(); + + // Heuristic: detect raw error formatting via `"{}"` with an error variable, + // or via the captured-identifier syntax `"{error}"` / `"{err}"` / `"{e}"`. + // + // For the plain-placeholder case, check that the variable name appears as a + // standalone identifier (not as a substring of a longer word like "current_round"). + let has_plain_placeholder = tokens.contains("\"{}\"") + && ERROR_VAR_NAMES + .iter() + .any(|var| tokens.split(|c: char| !c.is_alphanumeric() && c != '_').any(|word| word == *var)); + let has_captured_error = ERROR_VAR_NAMES.iter().any(|var| tokens.contains(&format!("\"{{{var}}}\""))); + if (has_plain_placeholder || has_captured_error) && !ALLOWED_WRAPPERS.iter().any(|f| tokens.contains(f)) { + self.violations.push((mac.span().start().line, tokens)); + } + } +} + +impl<'ast> Visit<'ast> for ErrorChecker { + fn visit_expr_macro(&mut self, node: &'ast ExprMacro) { + self.check_macro(&node.mac); + syn::visit::visit_expr_macro(self, node); + } + + fn visit_stmt_macro(&mut self, node: &'ast StmtMacro) { + self.check_macro(&node.mac); + syn::visit::visit_stmt_macro(self, node); + } +} + // The build script. fn main() { - // Check licenses in the current folder. - check_file_licenses("."); - // Ensure that lock imports have locktick counterparts. - check_locktick_imports("."); + // Single walk: check licenses, locktick imports, and error formatting for all source files. + check_source_files("."); // Check if locktick feature is correctly enabled. check_locktick_profile(); // Check if the tokio_console feature is correctly enabled. diff --git a/node/src/client/mod.rs b/node/src/client/mod.rs index 790991e758..34f9d05c12 100644 --- a/node/src/client/mod.rs +++ b/node/src/client/mod.rs @@ -344,7 +344,7 @@ impl> Client { let has_new_blocks = match self.sync.try_advancing_block_synchronization().await { Ok(val) => val, Err(err) => { - error!("{err}"); + error!("{}", flatten_error(err)); return; } }; diff --git a/node/sync/src/block_sync.rs b/node/sync/src/block_sync.rs index 036072f1e4..40a92d5806 100644 --- a/node/sync/src/block_sync.rs +++ b/node/sync/src/block_sync.rs @@ -1281,7 +1281,7 @@ impl BlockSync { for height in start_height..end_height { // Ensure the current height is not in the ledger or already requested. if let Err(err) = self.check_block_request(height) { - trace!("{err}"); + trace!("{}", flatten_error(err)); // If the sequence of block requests is interrupted, then return early. // Otherwise, continue until the first start height that is new. From 0e6034747e1ad24baeca1eb3864bf9a31b99bd7a Mon Sep 17 00:00:00 2001 From: Kai Mast Date: Fri, 20 Feb 2026 15:50:08 -0800 Subject: [PATCH 2/4] build: ensure errors are chained and not concatinated --- build.rs | 50 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/build.rs b/build.rs index a4e42e07ca..eda8f15138 100644 --- a/build.rs +++ b/build.rs @@ -39,6 +39,7 @@ enum ImportOfInterest { Tokio, } +/// Determines, if a directory contains auxiliary files, not source code, and should be skipped. fn should_skip_dir(entry: &DirEntry) -> bool { let entry_type = entry.file_type(); if !entry_type.is_dir() { @@ -169,7 +170,7 @@ fn check_source_files>(path: P) { } if !error_formatting_violations.is_empty() { - eprintln!("Forbidden error formatting found! Use `{{:#}}` or helper like `full_chain()`:"); + eprintln!("Forbidden error formatting found! Use `{{:#}}` in log macros or chain errors via `.context()`:"); for (file, line, code) in error_formatting_violations { eprintln!("{file}:{line} -> {code}"); } @@ -177,6 +178,8 @@ fn check_source_files>(path: P) { } } +/// Verifies that, if the locktick feature is enabled, the build profile includes the required settings +/// (`line-tables-only` and `strip = "none"`). fn check_locktick_profile() { let locktick_enabled = env::var("CARGO_FEATURE_LOCKTICK").is_ok(); if locktick_enabled { @@ -297,24 +300,37 @@ struct ErrorChecker { impl ErrorChecker { fn check_macro(&mut self, mac: &Macro) { let mac_name = mac.path.segments.last().unwrap().ident.to_string(); - if !["println", "format", "error", "warn", "info", "debug", "trace"].contains(&mac_name.as_str()) { - return; - } - let tokens = mac.tokens.to_string(); + let line = mac.span().start().line; + + // Check logging macros for raw error display — should use `{:#}` or a helper. + if ["println", "format", "error", "warn", "info", "debug", "trace"].contains(&mac_name.as_str()) { + // Detect `"{}"` with a standalone error variable, or captured-identifier syntax + // `"{error}"` / `"{err}"` / `"{e}"`. Use word-boundary splitting for the plain + // case to avoid matching substrings like "current_round" when checking for "err". + let has_plain_placeholder = tokens.contains("\"{}\"") + && ERROR_VAR_NAMES + .iter() + .any(|var| tokens.split(|c: char| !c.is_alphanumeric() && c != '_').any(|word| word == *var)); + let has_captured_error = ERROR_VAR_NAMES.iter().any(|var| tokens.contains(&format!("\"{{{var}}}\""))); + if (has_plain_placeholder || has_captured_error) && !ALLOWED_WRAPPERS.iter().any(|f| tokens.contains(f)) { + self.violations.push((line, format!("{mac_name}!({tokens})"))); + } + } - // Heuristic: detect raw error formatting via `"{}"` with an error variable, - // or via the captured-identifier syntax `"{error}"` / `"{err}"` / `"{e}"`. - // - // For the plain-placeholder case, check that the variable name appears as a - // standalone identifier (not as a substring of a longer word like "current_round"). - let has_plain_placeholder = tokens.contains("\"{}\"") - && ERROR_VAR_NAMES - .iter() - .any(|var| tokens.split(|c: char| !c.is_alphanumeric() && c != '_').any(|word| word == *var)); - let has_captured_error = ERROR_VAR_NAMES.iter().any(|var| tokens.contains(&format!("\"{{{var}}}\""))); - if (has_plain_placeholder || has_captured_error) && !ALLOWED_WRAPPERS.iter().any(|f| tokens.contains(f)) { - self.violations.push((mac.span().start().line, tokens)); + // Check error-construction macros (anyhow!, bail!, format_err!) for embedded error + // variables — errors should be chained via `.context()`/`.with_context()` instead. + // Use `anyhow_concat!` or `bail_concat!` to explicitly opt in to concatenation. + if ["anyhow", "bail", "format_err"].contains(&mac_name.as_str()) { + let has_captured_error = ERROR_VAR_NAMES.iter().any(|var| tokens.contains(&format!("\"{{{var}}}\""))); + // Also catch `"...{var}..."` where the var appears inside a longer format string. + let has_embedded_error = !has_captured_error + && ERROR_VAR_NAMES.iter().any(|var| { + tokens.contains(&format!("{{{var}}}")) || tokens.contains(&format!("{{{var}:")) + }); + if has_captured_error || has_embedded_error { + self.violations.push((line, format!("{mac_name}!({tokens})"))); + } } } } From 7285d8a52fd972f8403381d79bb9d004616b0627 Mon Sep 17 00:00:00 2001 From: Kai Mast Date: Fri, 20 Feb 2026 15:50:37 -0800 Subject: [PATCH 3/4] chore: fix all cases where errors are concatinated and context may be lost --- Cargo.lock | 1 + build.rs | 6 +-- cli/src/commands/account.rs | 6 +-- cli/src/commands/start.rs | 21 +++++---- cli/src/helpers/args.rs | 4 +- cli/src/helpers/logger.rs | 17 ++++---- display/src/lib.rs | 14 +++--- node/bft/ledger-service/src/ledger.rs | 54 +++++++++++------------ node/bft/src/gateway.rs | 33 +++++++------- node/bft/src/helpers/channels.rs | 20 +++++---- node/bft/src/helpers/proposal_cache.rs | 6 +-- node/bft/src/helpers/storage.rs | 4 +- node/cdn/src/blocks.rs | 57 ++++++++++++------------- node/rest/src/routes.rs | 11 +++-- node/router/src/inbound.rs | 23 +++++----- node/src/node.rs | 7 ++- utilities/Cargo.toml | 3 ++ utilities/src/lib.rs | 59 ++++++++++++++++++++++++++ 18 files changed, 204 insertions(+), 142 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5ae85988a7..71229858c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5052,6 +5052,7 @@ dependencies = [ name = "snarkos-utilities" version = "4.2.1" dependencies = [ + "anyhow", "locktick", "parking_lot", "tokio", diff --git a/build.rs b/build.rs index eda8f15138..9cd32d011d 100644 --- a/build.rs +++ b/build.rs @@ -325,9 +325,9 @@ impl ErrorChecker { let has_captured_error = ERROR_VAR_NAMES.iter().any(|var| tokens.contains(&format!("\"{{{var}}}\""))); // Also catch `"...{var}..."` where the var appears inside a longer format string. let has_embedded_error = !has_captured_error - && ERROR_VAR_NAMES.iter().any(|var| { - tokens.contains(&format!("{{{var}}}")) || tokens.contains(&format!("{{{var}:")) - }); + && ERROR_VAR_NAMES + .iter() + .any(|var| tokens.contains(&format!("{{{var}}}")) || tokens.contains(&format!("{{{var}:"))); if has_captured_error || has_embedded_error { self.violations.push((line, format!("{mac_name}!({tokens})"))); } diff --git a/cli/src/commands/account.rs b/cli/src/commands/account.rs index e753fa7fd2..cf3484056f 100644 --- a/cli/src/commands/account.rs +++ b/cli/src/commands/account.rs @@ -26,7 +26,7 @@ use snarkvm::console::{ types::Field, }; -use anyhow::{Result, anyhow, bail}; +use anyhow::{Context, Result, anyhow, bail}; use clap::Parser; use colored::Colorize; use core::str::FromStr; @@ -231,9 +231,7 @@ impl Account { // Recover the seed. let seed = match seed { // Recover the field element deterministically. - Some(seed) => { - Field::new(::Field::from_str(&seed).map_err(|e| anyhow!("Invalid seed - {e}"))?) - } + Some(seed) => Field::new(::Field::from_str(&seed).with_context(|| "Invalid session")?), // Sample a random field element. None => Field::rand(&mut ChaChaRng::from_entropy()), }; diff --git a/cli/src/commands/start.rs b/cli/src/commands/start.rs index 52ae706de1..24c0772e90 100644 --- a/cli/src/commands/start.rs +++ b/cli/src/commands/start.rs @@ -1095,17 +1095,16 @@ fn load_or_compute_genesis( fn resolve_potential_hostnames(ip_or_hostname: &str) -> Result { let trimmed = ip_or_hostname.trim(); - match trimmed.to_socket_addrs() { - Ok(mut ip_iter) => { - // A hostname might resolve to multiple IP addresses. We will use only the first one, - // assuming this aligns with the user's expectations. - let Some(ip) = ip_iter.next() else { - return Err(anyhow!("The supplied trusted hostname ('{trimmed}') does not reference any ip.")); - }; - Ok(ip) - } - Err(e) => Err(anyhow!("The supplied trusted hostname or IP ('{trimmed}') is malformed: {e}")), - } + let mut ip_iter = trimmed + .to_socket_addrs() + .with_context(|| format!("The supplied trusted hostname or IP ('{trimmed}') is malformed"))?; + + // A hostname might resolve to multiple IP addresses. We will use only the first one, + // assuming this aligns with the user's expectations. + let Some(ip) = ip_iter.next() else { + return Err(anyhow!("The supplied trusted hostname ('{trimmed}') does not reference any ip.")); + }; + Ok(ip) } #[cfg(test)] diff --git a/cli/src/helpers/args.rs b/cli/src/helpers/args.rs index d98931606c..6388ed6d0f 100644 --- a/cli/src/helpers/args.rs +++ b/cli/src/helpers/args.rs @@ -23,7 +23,7 @@ use snarkvm::{ }; use aleo_std::aleo_dir; -use anyhow::{Context, Result, anyhow}; +use anyhow::{Context, Result}; use clap::builder::RangedU64ValueParser; use std::{path::PathBuf, str::FromStr}; use ureq::http::{Uri, uri}; @@ -65,7 +65,7 @@ pub(crate) fn parse_private_key( let key_str = if let Some(keystr) = cmdline { keystr } else if let Some(file_name) = file_name { - let path = file_name.parse::().map_err(|e| anyhow!("Invalid path - {e}"))?; + let path = file_name.parse::().with_context(|| "Invalid private key path")?; std::fs::read_to_string(path).with_context(|| "Failed to read private key from disk")?.trim().to_string() } else { unreachable!(); diff --git a/cli/src/helpers/logger.rs b/cli/src/helpers/logger.rs index 66101a855e..ac411ba0b7 100644 --- a/cli/src/helpers/logger.rs +++ b/cli/src/helpers/logger.rs @@ -15,8 +15,7 @@ use crate::helpers::{DynamicFormatter, LogWriter}; -use anyhow::{Result, bail}; - +use anyhow::{Context, Result, bail}; use crossterm::tty::IsTty; use std::{ fs::File, @@ -145,15 +144,15 @@ pub fn initialize_logger>( let Some(logfile_dir) = logfile.as_ref().parent() else { bail!("Root directory passed as a logfile") }; if !logfile_dir.exists() { - if let Err(err) = std::fs::create_dir_all(logfile_dir) { - bail!("Failed to create a directory: '{}' ({err})", logfile_dir.display()); - } + std::fs::create_dir_all(logfile_dir) + .with_context(|| format!("Failed to create a directory: '{}'", logfile_dir.display()))?; } // Create a file to write logs to. - let logfile = match File::options().append(true).create(true).open(logfile) { - Ok(logfile) => logfile, - Err(err) => bail!("Failed to open the file for writing logs: {err}"), - }; + let logfile = File::options() + .append(true) + .create(true) + .open(logfile) + .with_context(|| "Failed to open the file for writing logs")?; // Initialize the log channel. let (log_sender, log_receiver) = mpsc::channel(1024); diff --git a/display/src/lib.rs b/display/src/lib.rs index 9fb79994a7..9ca992d51d 100644 --- a/display/src/lib.rs +++ b/display/src/lib.rs @@ -26,7 +26,7 @@ use snarkos_utilities::Stoppable; use snarkvm::prelude::Network; -use anyhow::{Result, anyhow}; +use anyhow::Result; use crossterm::{ event::{self, DisableMouseCapture, EnableMouseCapture, Event, KeyCode}, execute, @@ -102,12 +102,16 @@ impl Display { impl Display { /// Renders the display. - fn render(&mut self, terminal: &mut Terminal, stoppable: Arc) -> Result<()> { + fn render(&mut self, terminal: &mut Terminal, stoppable: Arc) -> Result<()> + where + B::Error: Into, + { let mut last_tick = Instant::now(); loop { - if let Err(err) = terminal.draw(|f| self.draw(f)) { - return Err(anyhow!("{err}").context("Failed to draw terminal UI")); - } + terminal.draw(|f| self.draw(f)).map_err(|err| { + let err: anyhow::Error = err.into(); + err.context("Failed to draw terminal UI") + })?; // Determine how long to wait for an input event, before we redraw. let timeout = self.tick_rate.saturating_sub(last_tick.elapsed()); diff --git a/node/bft/ledger-service/src/ledger.rs b/node/bft/ledger-service/src/ledger.rs index d1d5829735..0c8ff80595 100644 --- a/node/bft/ledger-service/src/ledger.rs +++ b/node/bft/ledger-service/src/ledger.rs @@ -46,7 +46,7 @@ use snarkvm::{ }, }; -use anyhow::ensure; +use anyhow::{Context, ensure}; use indexmap::IndexMap; #[cfg(feature = "locktick")] use locktick::parking_lot::RwLock; @@ -260,32 +260,28 @@ impl> LedgerService for CoreLedgerService< TransmissionID::Solution(expected_solution_id, expected_checksum), Transmission::Solution(solution_data), ) => { - match solution_data.clone().deserialize_blocking() { - Ok(solution) => { - if solution.id() != expected_solution_id { - bail!( - "Received mismatching solution ID - expected {}, found {}", - fmt_id(expected_solution_id), - fmt_id(solution.id()), - ); - } - - // Ensure the transmission checksum matches the expected checksum. - let checksum = solution_data.to_checksum::()?; - if checksum != expected_checksum { - bail!( - "Received mismatching checksum for solution {} - expected {expected_checksum} but found {checksum}", - fmt_id(expected_solution_id) - ); - } - - // Update the transmission with the deserialized solution. - *solution_data = Data::Object(solution); - } - Err(err) => { - bail!("Failed to deserialize solution: {err}"); - } + let solution = + solution_data.clone().deserialize_blocking().with_context(|| "Failed to deserialize solution")?; + + if solution.id() != expected_solution_id { + bail!( + "Received mismatching solution ID - expected {}, found {}", + fmt_id(expected_solution_id), + fmt_id(solution.id()), + ); } + + // Ensure the transmission checksum matches the expected checksum. + let checksum = solution_data.to_checksum::()?; + if checksum != expected_checksum { + bail!( + "Received mismatching checksum for solution {} - expected {expected_checksum} but found {checksum}", + fmt_id(expected_solution_id) + ); + } + + // Update the transmission with the deserialized solution. + *solution_data = Data::Object(solution); } _ => { bail!("Mismatching `(transmission_id, transmission)` pair"); @@ -321,10 +317,8 @@ impl> LedgerService for CoreLedgerService< // Ensure that the solution is valid for the given epoch. let puzzle = self.ledger.puzzle().clone(); - match spawn_blocking!(puzzle.check_solution(&solution, epoch_hash, proof_target)) { - Ok(()) => Ok(()), - Err(e) => bail!("Invalid solution '{}' for the current epoch - {e}", fmt_id(solution_id)), - } + spawn_blocking!(puzzle.check_solution(&solution, epoch_hash, proof_target)) + .with_context(|| format!("Invalid solution '{}' for the current epoch", fmt_id(solution_id))) } /// Checks the given transaction is well-formed and unique. diff --git a/node/bft/src/gateway.rs b/node/bft/src/gateway.rs index 1d326b0607..2c1d9fc273 100644 --- a/node/bft/src/gateway.rs +++ b/node/bft/src/gateway.rs @@ -61,7 +61,7 @@ use snarkos_node_tcp::{ Tcp, protocols::{Disconnect, Handshake, OnConnect, Reading, Writing}, }; -use snarkos_utilities::NodeDataDir; +use snarkos_utilities::{NodeDataDir, prefix_error}; use snarkvm::{ console::prelude::*, ledger::{ @@ -71,6 +71,7 @@ use snarkvm::{ prelude::{Address, Field}, }; +use anyhow::Context; use colored::Colorize; use futures::{SinkExt, future::join_all}; use indexmap::IndexMap; @@ -598,16 +599,17 @@ impl Gateway { let self_ = self.clone(); let blocks = match task::spawn_blocking(move || { // Retrieve the blocks within the requested range. - match self_.ledger.get_blocks(start_height..end_height) { - Ok(blocks) => Ok(DataBlocks(blocks)), - Err(error) => bail!("Missing blocks {start_height} to {end_height} from ledger - {error}"), - } + let blocks = self_ + .ledger + .get_blocks(start_height..end_height) + .with_context(|| format!("Missing blocks {start_height} to {end_height} from ledger"))?; + Ok(DataBlocks(blocks)) }) .await { Ok(Ok(blocks)) => blocks, Ok(Err(error)) => return Err(error), - Err(error) => return Err(anyhow!("[BlockRequest] {error}")), + Err(error) => return Err(prefix_error("BlockRequest", error.into())), }; let self_ = self.clone(); @@ -632,14 +634,12 @@ impl Gateway { // this on a blocking task, but on a rayon thread pool. let (send, recv) = tokio::sync::oneshot::channel(); rayon::spawn_fifo(move || { - let blocks = blocks.deserialize_blocking().map_err(|error| anyhow!("[BlockResponse] {error}")); - let _ = send.send(blocks); + let _ = send.send(blocks.deserialize_blocking()); }); - let blocks = match recv.await { - Ok(Ok(blocks)) => blocks, - Ok(Err(error)) => bail!("Peer '{peer_ip}' sent an invalid block response - {error}"), - Err(error) => bail!("Peer '{peer_ip}' sent an invalid block response - {error}"), - }; + let blocks = recv + .await + .with_context(|| format!("Peer '{peer_ip}' sent an invalid block response"))? + .with_context(|| format!("Peer '{peer_ip}' sent an invalid block response"))?; // Ensure the block response is well-formed. blocks.ensure_response_is_well_formed(peer_ip, request.start_height, request.end_height)?; @@ -705,9 +705,10 @@ impl Gateway { // Update the peer locators. Except for some tests, there is always a sync sender. if let Some(sync_sender) = self.sync_sender.get() { // Check the block locators are valid, and update the validators in the sync module. - if let Err(error) = sync_sender.update_peer_locators(peer_ip, block_locators).await { - bail!("Validator '{peer_ip}' sent invalid block locators - {error}"); - } + sync_sender + .update_peer_locators(peer_ip, block_locators) + .await + .with_context(|| format!("Validator '{peer_ip}' sent invalid block locators"))?; } // Send the batch certificates to the primary. diff --git a/node/bft/src/helpers/channels.rs b/node/bft/src/helpers/channels.rs index e23b612682..0964a7f3fd 100644 --- a/node/bft/src/helpers/channels.rs +++ b/node/bft/src/helpers/channels.rs @@ -281,19 +281,21 @@ impl SyncSender { // This `tx_block_sync_advance_with_sync_blocks.send()` call // causes the `rx_block_sync_advance_with_sync_blocks.recv()` call // in one of the loops in [`Sync::run()`] to return. - if let Err(err) = self - .tx_block_sync_insert_block_response + self.tx_block_sync_insert_block_response .send((peer_ip, blocks, latest_consensus_version, callback_sender)) .await - { - return Err(anyhow!("Failed to send block response - {err}").into()); - } + .map_err(|err| { + let err: anyhow::Error = err.into(); + let err = err.context("Failed to send block response to '{peer_ip}'"); + InsertBlockResponseError::Other(err) + })?; // Await the callback to continue. - match callback_receiver.await { - Ok(result) => result, - Err(err) => Err(anyhow!("Failed to wait for block response insertion - {err}").into()), - } + callback_receiver.await.map_err(|err| { + let err: anyhow::Error = err.into(); + let err = err.context("Failed to wait for block response insertion"); + InsertBlockResponseError::Other(err) + })? } } diff --git a/node/bft/src/helpers/proposal_cache.rs b/node/bft/src/helpers/proposal_cache.rs index 7dc2496f39..a1db92df06 100644 --- a/node/bft/src/helpers/proposal_cache.rs +++ b/node/bft/src/helpers/proposal_cache.rs @@ -19,9 +19,10 @@ use snarkos_utilities::NodeDataDir; use snarkvm::{ console::{account::Address, network::Network, program::SUBDAG_CERTIFICATES_DEPTH}, ledger::narwhal::BatchCertificate, - prelude::{FromBytes, IoResult, Read, Result, ToBytes, Write, anyhow, bail, error}, + prelude::{FromBytes, IoResult, Read, Result, ToBytes, Write, bail, error}, }; +use anyhow::Context; use indexmap::IndexSet; use std::{fs, path::PathBuf}; @@ -102,8 +103,7 @@ impl ProposalCache { // Serialize the proposal cache. let bytes = self.to_bytes_le()?; // Store the proposal cache to the file system. - fs::write(&path, bytes) - .map_err(|err| anyhow!("Couldn't write the proposal cache to {} - {err}", path.display()))?; + fs::write(&path, bytes).with_context(|| format!("Couldn't write the proposal cache to {}", path.display()))?; Ok(()) } diff --git a/node/bft/src/helpers/storage.rs b/node/bft/src/helpers/storage.rs index 359763b2ce..a08c6894a1 100644 --- a/node/bft/src/helpers/storage.rs +++ b/node/bft/src/helpers/storage.rs @@ -21,7 +21,7 @@ use snarkvm::{ block::{Block, Transaction}, narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID}, }, - prelude::{Address, Field, Network, Result, anyhow, bail, ensure}, + prelude::{Address, Field, Network, Result, bail, ensure}, utilities::{cfg_into_iter, cfg_iter, cfg_sorted_by, flatten_error}, }; @@ -473,7 +473,7 @@ impl Storage { let missing_transmissions = self .transmissions .find_missing_transmissions(batch_header, transmissions, aborted_transmissions) - .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?; + .with_context(|| format!("Failed to find missing transmission(s) for round {round} {gc_log}"))?; // Compute the previous round. let previous_round = round.saturating_sub(1); diff --git a/node/cdn/src/blocks.rs b/node/cdn/src/blocks.rs index 938dfa5250..b2e88942c6 100644 --- a/node/cdn/src/blocks.rs +++ b/node/cdn/src/blocks.rs @@ -108,7 +108,7 @@ impl CdnBlockSync { bail!("CDN task was already awaited"); }; - let result = hdl.await.map_err(|err| anyhow!("Failed to wait for CDN task: {err}")); + let result = hdl.await.with_context(|| "Failed to wait for CDN task"); self.done.store(true, Ordering::SeqCst); result } @@ -182,7 +182,8 @@ pub async fn load_blocks( let client = match Client::builder().use_rustls_tls().build() { Ok(client) => client, Err(error) => { - return Err((start_height.saturating_sub(1), anyhow!("Failed to create a CDN request client - {error}"))); + let error: anyhow::Error = error.into(); + return Err((start_height.saturating_sub(1), error.context("Failed to create a CDN request client"))); } }; @@ -437,28 +438,26 @@ async fn cdn_height(client: &Client, base_url: &http // Prepare the URL. let latest_json_url = format!("{base_url}/latest.json"); // Send the request. - let response = match client.get(latest_json_url).send().await { - Ok(response) => response, - Err(error) => bail!("Failed to fetch the CDN height - {error}"), - }; + let response = client.get(latest_json_url).send().await.with_context(|| "Failed to fetch the CDN height")?; // Parse the response. - let bytes = match response.bytes().await { - Ok(bytes) => bytes, - Err(error) => bail!("Failed to parse the CDN height response - {error}"), - }; + let bytes = response.bytes().await.with_context(|| "Failed to parse the CDN height response")?; // Parse the bytes for the string. let latest_state_string = match bincode::deserialize::(&bytes) { Ok(string) => string, Err(error) => { + let error: anyhow::Error = error.into(); let bytes_as_string = String::from_utf8_lossy(&bytes); - bail!("Failed to deserialize the CDN height response - {error} - {bytes_as_string}") + bail!( + "Failed to deserialize the CDN height response - {full_error} - {bytes_as_string}", + full_error = flatten_error(&error) + ); } }; // Parse the string for the tip. - let tip = match serde_json::from_str::(&latest_state_string) { - Ok(latest) => latest.exclusive_height, - Err(error) => bail!("Failed to extract the CDN height response - {error}"), - }; + let tip = serde_json::from_str::(&latest_state_string) + .with_context(|| "Failed to extract the CDN height response")? + .exclusive_height; + // Decrement the tip by a few blocks to ensure the CDN is caught up. let tip = tip.saturating_sub(10); // Adjust the tip to the closest subsequent multiple of BLOCKS_PER_FILE. @@ -468,25 +467,23 @@ async fn cdn_height(client: &Client, base_url: &http /// Retrieves the objects from the CDN with the given URL. async fn cdn_get(client: Client, url: &str, ctx: &str) -> Result { // Fetch the bytes from the given URL. - let response = match client.get(url).send().await { - Ok(response) => response, - Err(error) => bail!("Failed to fetch {ctx} - {error}"), - }; + let response = client.get(url).send().await.with_context(|| format!("Failed to fetch {ctx}"))?; // Parse the response. - let bytes = match response.bytes().await { - Ok(bytes) => bytes, - Err(error) => bail!("Failed to parse {ctx} - {error}"), - }; + let bytes = response.bytes().await.with_context(|| format!("Failed to parse {ctx}"))?; // Parse the objects. - match tokio::task::spawn_blocking(move || (bincode::deserialize::(&bytes), bytes)).await { - Ok((Ok(objects), _)) => Ok(objects), - Ok((Err(error), response_bytes)) => { + match tokio::task::spawn_blocking(move || (bincode::deserialize::(&bytes), bytes)) + .await + .with_context(|| format!("Failed to join task for {ctx}"))? + { + (Ok(objects), _) => Ok(objects), + (Err(error), response_bytes) => { let bytes_as_string = String::from_utf8_lossy(&response_bytes); - bail!("Failed to deserialize {ctx} - {error} - {bytes_as_string}") - } - Err(error) => { - bail!("Failed to join task for {ctx} - {error}") + let error: anyhow::Error = error.into(); + Err(anyhow!( + "Failed to deserialize {ctx} - {full_error} - {bytes_as_string}", + full_error = flatten_error(&error) + )) } } } diff --git a/node/rest/src/routes.rs b/node/rest/src/routes.rs index 504f4fa3b4..9004953a01 100644 --- a/node/rest/src/routes.rs +++ b/node/rest/src/routes.rs @@ -480,7 +480,10 @@ impl, R: Routing> Rest { Ok(ErasedJson::pretty(mapping_values)) } Ok(Err(err)) => Err(RestError::internal_server_error(err.context("Unable to read mapping"))), - Err(err) => Err(RestError::internal_server_error(anyhow!("Tokio error: {err}"))), + Err(err) => { + let err: anyhow::Error = err.into(); + Err(RestError::internal_server_error(err.context("Tokio error"))) + } } } @@ -686,7 +689,8 @@ impl, R: Routing> Rest { Ok(json) => json, Err(JsonRejection::JsonDataError(err)) => { // For JsonDataError, return 422 to let transaction validation handle it - return Err(RestError::unprocessable_entity(anyhow!("Invalid transaction data: {err}"))); + let err: anyhow::Error = err.into(); + return Err(RestError::unprocessable_entity(err.context("Invalid transaction data"))); } Err(other_rejection) => return Err(other_rejection.into()), }; @@ -867,7 +871,8 @@ impl, R: Routing> Rest { }; } Err(err) => { - return Err(RestError::internal_server_error(anyhow!("Tokio error: {err}"))); + let err: anyhow::Error = err.into(); + return Err(RestError::internal_server_error(err.context("Tokio error"))); } }; // Release the slot. diff --git a/node/router/src/inbound.rs b/node/router/src/inbound.rs index b6e52d1ee8..a153556240 100644 --- a/node/router/src/inbound.rs +++ b/node/router/src/inbound.rs @@ -29,6 +29,7 @@ use crate::{ }, }; use snarkos_node_tcp::protocols::Reading; +use snarkos_utilities::{bail_concat, prefix_error}; use snarkvm::prelude::{ ConsensusVersion, Network, @@ -36,7 +37,7 @@ use snarkvm::prelude::{ puzzle::Solution, }; -use anyhow::{Result, anyhow, bail}; +use anyhow::{Context, Result, bail}; use std::net::SocketAddr; use tokio::task::spawn_blocking; @@ -136,14 +137,14 @@ pub trait Inbound: Reading + Outbound { // this on a blocking task, but on a rayon thread pool. let (send, recv) = tokio::sync::oneshot::channel(); rayon::spawn_fifo(move || { - let blocks = blocks.deserialize_blocking().map_err(|error| anyhow!("[BlockResponse] {error}")); - let _ = send.send(blocks); + let result = blocks.deserialize_blocking(); + let _ = send.send(result); }); - let blocks = match recv.await { - Ok(Ok(blocks)) => blocks, - Ok(Err(error)) => bail!("Peer '{peer_ip}' sent an invalid block response - {error}"), - Err(error) => bail!("Peer '{peer_ip}' sent an invalid block response - {error}"), - }; + + let blocks = recv + .await + .with_context(|| format!("Peer '{peer_ip}' sent an invalid block response"))? + .with_context(|| format!("Peer '{peer_ip}' sent an invalid block response"))?; // Ensure the block response is well-formed. blocks.ensure_response_is_well_formed(peer_ip, request.start_height, request.end_height)?; @@ -233,7 +234,7 @@ pub trait Inbound: Reading + Outbound { // Perform the deferred non-blocking deserialization of the block header. let header = match message.block_header.deserialize().await { Ok(header) => header, - Err(error) => bail!("[PuzzleResponse] {error}"), + Err(error) => return Err(prefix_error("PuzzleResponse", error)), }; // Process the puzzle response. match self.puzzle_response(peer_ip, message.epoch_hash, header) { @@ -260,7 +261,7 @@ pub trait Inbound: Reading + Outbound { // Perform the deferred non-blocking deserialization of the solution. let solution = match message.solution.deserialize().await { Ok(solution) => solution, - Err(error) => bail!("[UnconfirmedSolution] {error}"), + Err(error) => bail_concat!("[UnconfirmedSolution] {error}"), }; // Check that the solution parameters match. if message.solution_id != solution.id() { @@ -291,7 +292,7 @@ pub trait Inbound: Reading + Outbound { // Perform the deferred non-blocking deserialization of the transaction. let transaction = match message.transaction.deserialize().await { Ok(transaction) => transaction, - Err(error) => bail!("[UnconfirmedTransaction] {error}"), + Err(error) => bail_concat!("[UnconfirmedTransaction] {error}"), }; // Check that the transaction parameters match. if message.transaction_id != transaction.id() { diff --git a/node/src/node.rs b/node/src/node.rs index b6b889c02b..f59cc51f28 100644 --- a/node/src/node.rs +++ b/node/src/node.rs @@ -38,7 +38,7 @@ use snarkvm::prelude::{ }; use aleo_std::{StorageMode, aleo_ledger_dir}; -use anyhow::{Result, bail}; +use anyhow::{Context, Result, bail}; #[cfg(feature = "locktick")] use locktick::parking_lot::RwLock; @@ -331,9 +331,8 @@ impl Node { // Ensure that the target path exists as a folder or create it. if !auto_checkpoint_path.exists() { - if let Err(e) = fs::create_dir_all(&auto_checkpoint_path) { - bail!("Couldn't create the specified path for the automatic ledger checkpoints: {e}"); - } + fs::create_dir_all(&auto_checkpoint_path) + .with_context(|| "Couldn't create the specified path for the automatic ledger checkpoints")?; } else if auto_checkpoint_path.exists() && !auto_checkpoint_path.is_dir() { bail!("The specified path for automatic ledger checkpoints is not a directory"); } diff --git a/utilities/Cargo.toml b/utilities/Cargo.toml index c6e39c6596..7a519fe4c2 100644 --- a/utilities/Cargo.toml +++ b/utilities/Cargo.toml @@ -17,6 +17,9 @@ categories = [ "cryptography", "cryptography::cryptocurrencies", "os" ] license = "Apache-2.0" edition = "2024" +[dependencies.anyhow] +workspace = true + [dependencies.tokio] workspace = true features = [ "macros", "signal", "sync" ] diff --git a/utilities/src/lib.rs b/utilities/src/lib.rs index 14c6e08aff..61a5ceb516 100644 --- a/utilities/src/lib.rs +++ b/utilities/src/lib.rs @@ -13,6 +13,65 @@ // See the License for the specific language governing permissions and // limitations under the License. +/// A convenience macro that explicitly concatenates an error into an `anyhow::Error` message. +/// +/// Use this instead of `anyhow!("... {err}")` to make the intent clear and satisfy the +/// build-time check that disallows silent error concatenation. Prefer `.context()` / +/// `.with_context()` when the original error should be preserved as a cause chain. +#[macro_export] +macro_rules! anyhow_concat { + ($($arg:tt)*) => { ::anyhow::anyhow!($($arg)*) }; +} + +/// A convenience macro that explicitly concatenates an error into a `bail!` message. +/// +/// Use this instead of `bail!("... {err}")` to make the intent clear and satisfy the +/// build-time check that disallows silent error concatenation. Prefer `.context()` / +/// `.with_context()` when the original error should be preserved as a cause chain. +#[macro_export] +macro_rules! bail_concat { + ($($arg:tt)*) => { return Err(::anyhow::anyhow!($($arg)*).into()) }; +} + +/// Prepends a prefix to the message of the top-level `anyhow::Error` while keeping its source +/// chain intact. Unlike `.context()`, this does not add an extra wrapping layer — the prefix is +/// folded into the existing top-level message. +/// +/// # Example +/// ```ignore +/// let err = some_fallible_call().map_err(|e| prefix_error("[BlockResponse]", e))?; +/// ``` +pub fn prefix_error(prefix: &str, error: anyhow::Error) -> anyhow::Error { + // Collect the source chain *before* consuming `error`. + // We stop before the top-level message because we are replacing it. + let causes: Vec = { + let mut chain = Vec::new(); + let mut src: Option<&dyn std::error::Error> = error.source(); + while let Some(cause) = src { + chain.push(cause.to_string()); + src = cause.source(); + } + chain + }; + + // Build the new top-level message. + let new_msg = format!("[{prefix}] {error}"); + + // If there are no causes we are done. + if causes.is_empty() { + return anyhow::anyhow!("{new_msg}"); + } + + // Rebuild from the deepest cause upward, then wrap with the new top message. + // We use string-based reconstruction because `std::error::Error` sources are not + // `Send + Sync + 'static` and cannot be re-owned generically. + let mut rebuilt = anyhow::anyhow!("{}", causes.last().unwrap()); + for cause in causes.iter().rev().skip(1) { + rebuilt = rebuilt.context(cause.clone()); + } + rebuilt.context(new_msg) +} + /// Utilities for signal and shutdown handling. pub mod signals; From fa56a43c4ffc080a9a58168bbcb4bf2642c7a477 Mon Sep 17 00:00:00 2001 From: Kai Mast Date: Fri, 20 Feb 2026 16:29:18 -0800 Subject: [PATCH 4/4] build: rewrite locktick checks to use syntax tree --- Cargo.lock | 1 + Cargo.toml | 3 + build.rs | 380 ++++++++++++++++++++++++++++++----------------------- 3 files changed, 217 insertions(+), 167 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 71229858c9..838283a496 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4588,6 +4588,7 @@ dependencies = [ name = "snarkos" version = "4.4.0" dependencies = [ + "anyhow", "built", "clap", "locktick", diff --git a/Cargo.toml b/Cargo.toml index e9289fe64b..3b27fd0af7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -355,6 +355,9 @@ workspace = true [dev-dependencies.rusty-hook] version = "0.11.2" +[build-dependencies.anyhow] +workspace = true + [build-dependencies.built] version = "0.8" features = [ "git2" ] diff --git a/build.rs b/build.rs index 9cd32d011d..b5765438d7 100644 --- a/build.rs +++ b/build.rs @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::{Context, bail, ensure}; use std::{ env, ffi::OsStr, @@ -22,7 +23,19 @@ use std::{ process, str, }; -use syn::{ExprMacro, Macro, StmtMacro, spanned::Spanned, visit::Visit}; +use syn::{ + ExprMacro, + ItemUse, + Macro, + StmtMacro, + UseGroup, + UseName, + UsePath, + UseRename, + UseTree, + spanned::Spanned, + visit::Visit, +}; use toml::Value; use walkdir::{DirEntry, WalkDir}; @@ -32,13 +45,6 @@ const EXPECTED_LICENSE_TEXT: &[u8] = include_bytes!(".resources/license_header") // The following directories will be excluded from the license scan. const DIRS_TO_SKIP: [&str; 3] = ["examples", "js", "target"]; -#[derive(Clone, Copy, PartialEq, Eq)] -enum ImportOfInterest { - Locktick, - ParkingLot, - Tokio, -} - /// Determines, if a directory contains auxiliary files, not source code, and should be skipped. fn should_skip_dir(entry: &DirEntry) -> bool { let entry_type = entry.file_type(); @@ -55,7 +61,7 @@ fn should_skip_dir(entry: &DirEntry) -> bool { /// Checks license headers, locktick import balance, and forbidden error formatting in a single /// directory walk to avoid reading every source file more than once. -fn check_source_files>(path: P) { +fn check_source_files>(path: P) -> anyhow::Result<()> { // Perform the license year check if on Linux. if cfg!(target_os = "linux") { let os_year = process::Command::new("date").arg("+%Y").output().expect("Failed to execute 'date' command"); @@ -65,6 +71,7 @@ fn check_source_files>(path: P) { } let mut error_formatting_violations: Vec<(String, usize, String)> = Vec::new(); + let mut locktick_violations: Vec<(String, usize, String)> = Vec::new(); let mut iter = WalkDir::new(path).into_iter(); while let Some(entry) = iter.next() { @@ -87,86 +94,35 @@ fn check_source_files>(path: P) { let file = File::open(path).unwrap(); let mut contents = Vec::with_capacity(EXPECTED_LICENSE_TEXT.len()); file.take(EXPECTED_LICENSE_TEXT.len() as u64).read_to_end(&mut contents).unwrap(); - assert!( + ensure!( contents == EXPECTED_LICENSE_TEXT, "The license in \"{}\" is either missing or it doesn't match the expected string!", path.display() ); } - // Read the full file once for the remaining checks. + // Read the full file once and run all AST-based checks in a single pass. let src = fs::read_to_string(path).unwrap(); + let ast = syn::parse_file(&src).unwrap(); - // --- Locktick import balance check --- - { - let lines = src.lines().filter(|l| !l.is_empty()).skip_while(|l| !l.starts_with("use")).take_while(|l| { - l.starts_with("use") - || l.starts_with("#[cfg") - || l.starts_with("//") - || *l == "};" - || l.starts_with(|c: char| c.is_ascii_whitespace()) - }); - - let mut import_of_interest: Option = None; - let mut lock_balance: i8 = 0; - - for line in lines { - if import_of_interest.is_none() { - if line.starts_with("use locktick::") { - import_of_interest = Some(ImportOfInterest::Locktick); - } else if line.starts_with("use parking_lot::") { - import_of_interest = Some(ImportOfInterest::ParkingLot); - } else if line.starts_with("use tokio::") { - import_of_interest = Some(ImportOfInterest::Tokio); - } - } - - let Some(ioi) = import_of_interest else { - continue; - }; - - if [ImportOfInterest::ParkingLot, ImportOfInterest::Tokio].contains(&ioi) { - if line.contains("Mutex") { - lock_balance += 1; - } - if line.contains("RwLock") { - lock_balance += 1; - } - } else if ioi == ImportOfInterest::Locktick { - // Use `matches` instead of just `contains` here, as more than a single - // lock type entry is possible in a locktick import. - for _hit in line.matches("Mutex") { - lock_balance -= 1; - } - for _hit in line.matches("RwLock") { - lock_balance -= 1; - } - // A correction in case of the `use tokio::Mutex as TMutex` convention. - if line.contains("TMutex") { - lock_balance += 1; - } - } + let mut checker = FileChecker::default(); + checker.visit_file(&ast); + checker.finalize_lock_check(); - if line.ends_with(";") { - import_of_interest = None; - } - } + let file_str = path.display().to_string(); + locktick_violations + .extend(checker.lock_violations.into_iter().map(|(line, code)| (file_str.clone(), line, code))); + error_formatting_violations + .extend(checker.error_violations.into_iter().map(|(line, code)| (file_str.clone(), line, code))); + } - assert!( - lock_balance == 0, - "The locks in \"{}\" don't seem to have `locktick` counterparts!", - path.display() - ); + if !locktick_violations.is_empty() { + eprintln!("Lock imports without `locktick` counterparts found:"); + for (file, line, code) in locktick_violations { + eprintln!("{file}:{line} -> {code}"); } - // --- Error formatting check --- - { - let ast = syn::parse_file(&src).unwrap(); - let mut checker = ErrorChecker { violations: Vec::new() }; - checker.visit_file(&ast); - error_formatting_violations - .extend(checker.violations.into_iter().map(|(line, code)| (path.display().to_string(), line, code))); - } + bail!("Build failed due to missing locktick counterparts."); } if !error_formatting_violations.is_empty() { @@ -174,130 +130,186 @@ fn check_source_files>(path: P) { for (file, line, code) in error_formatting_violations { eprintln!("{file}:{line} -> {code}"); } - panic!("Build failed due to forbidden error formatting."); + + bail!("Build failed due to forbidden error formatting."); } + + Ok(()) } /// Verifies that, if the locktick feature is enabled, the build profile includes the required settings /// (`line-tables-only` and `strip = "none"`). -fn check_locktick_profile() { +fn check_locktick_profile() -> anyhow::Result<()> { let locktick_enabled = env::var("CARGO_FEATURE_LOCKTICK").is_ok(); - if locktick_enabled { - // First check the env variables that can override the TOML values. - let (mut valid_debug_override, mut valid_strip_override) = (false, false); - - if let Ok(val) = env::var("CARGO_PROFILE_RELEASE_DEBUG") { - if val != "line-tables-only" { - eprintln!( - "🔴 When enabling the locktick feature, CARGO_PROFILE_RELEASE_DEBUG may only be set to `line-tables-only`." - ); - process::exit(1); - } else { - valid_debug_override = true; - } + if !locktick_enabled { + // Nohting to check. + return Ok(()); + } + + // First check the env variables that can override the TOML values. + let (mut valid_debug_override, mut valid_strip_override) = (false, false); + + if let Ok(val) = env::var("CARGO_PROFILE_RELEASE_DEBUG") { + if val != "line-tables-only" { + bail!( + "🔴 When enabling the locktick feature, CARGO_PROFILE_RELEASE_DEBUG may only be set to `line-tables-only`." + ); + } else { + valid_debug_override = true; } - if let Ok(val) = env::var("CARGO_PROFILE_RELEASE_STRIP") { - if val != "none" { - eprintln!( - "🔴 When enabling the locktick feature, CARGO_PROFILE_RELEASE_STRIP may only be set to `none`." - ); - process::exit(1); - } else { - valid_strip_override = true; - } + } + if let Ok(val) = env::var("CARGO_PROFILE_RELEASE_STRIP") { + if val != "none" { + bail!("🔴 When enabling the locktick feature, CARGO_PROFILE_RELEASE_STRIP may only be set to `none`."); + } else { + valid_strip_override = true; } + } - if valid_debug_override && valid_strip_override { - // Both overrides are compatible with locktick, no need to check the TOML. - return; - } + if valid_debug_override && valid_strip_override { + // Both overrides are compatible with locktick, no need to check the TOML. + return Ok(()); + } - // If the relevant overrides were either invalid or not present, check the TOML. - let profile = env::var("PROFILE").unwrap_or_else(|_| "".to_string()); - let manifest = Path::new(&env::var("CARGO_MANIFEST_DIR").unwrap()).join("Cargo.toml"); - let contents = fs::read_to_string(&manifest).expect("failed to read Cargo.toml"); - let doc: Value = toml::from_str(&contents).expect("invalid TOML in Cargo.toml"); - - let profile_table = doc.get("profile").and_then(|p| p.get(profile)); - if let Some(Value::Table(profile_settings)) = profile_table { - if let Some(debug) = profile_settings.get("debug") { - match debug { - Value::String(s) if s == "line-tables-only" => { - println!("cargo:info=manifest has debuginfo=line-tables-only"); - } - _ => { - eprintln!( - "🔴 When enabling the locktick feature, the profile must have debug set to `line-tables-only`. Uncomment the relevant lines in Cargo.toml." - ); - process::exit(1); - } + // If the relevant overrides were either invalid or not present, check the TOML. + let profile = env::var("PROFILE").unwrap_or_else(|_| "".to_string()); + let manifest = Path::new(&env::var("CARGO_MANIFEST_DIR").unwrap()).join("Cargo.toml"); + let contents = fs::read_to_string(&manifest).expect("failed to read Cargo.toml"); + let doc: Value = toml::from_str(&contents).expect("invalid TOML in Cargo.toml"); + + let profile_table = doc.get("profile").and_then(|p| p.get(profile)); + if let Some(Value::Table(profile_settings)) = profile_table { + if let Some(debug) = profile_settings.get("debug") { + match debug { + Value::String(s) if s == "line-tables-only" => { + println!("cargo:info=manifest has debuginfo=line-tables-only"); + } + _ => { + bail!( + "🔴 When enabling the locktick feature, the profile must have debug set to `line-tables-only`. Uncomment the relevant lines in Cargo.toml." + ); } - } else { - eprintln!( - "🔴 When enabling the locktick feature, the profile must have `debug` set to `line-tables-only`. Uncomment the relevant lines in Cargo.toml." - ); - process::exit(1); } - if let Some(debug) = profile_settings.get("strip") { - match debug { - Value::String(s) if s == "none" => { - println!("cargo:info=manifest has strip=none"); - } - _ => { - eprintln!( - "🔴 When enabling the locktick feature, the profile must have `strip` set to `none`. Uncomment the relevant lines in Cargo.toml." - ); - process::exit(1); - } + } else { + bail!( + "🔴 When enabling the locktick feature, the profile must have `debug` set to `line-tables-only`. Uncomment the relevant lines in Cargo.toml." + ); + } + if let Some(debug) = profile_settings.get("strip") { + match debug { + Value::String(s) if s == "none" => { + println!("cargo:info=manifest has strip=none"); + } + _ => { + bail!( + "🔴 When enabling the locktick feature, the profile must have `strip` set to `none`. Uncomment the relevant lines in Cargo.toml." + ); } } } } + + Ok(()) } fn is_clippy() -> bool { env::var("RUSTC_WORKSPACE_WRAPPER").is_ok_and(|var| var.contains("clippy")) } -fn check_tokio_console_flags() { +fn check_tokio_console_flags() -> anyhow::Result<()> { // Don't run this check under clippy, otherwise it will cause issues with --all-features. if is_clippy() { - return; + return Ok(()); } // Skip if the feature is not used. let feature_enabled = env::var("CARGO_FEATURE_TOKIO_CONSOLE").is_ok(); if !feature_enabled { - return; + return Ok(()); } // Check for the presence of RUSTFLAGS. let Ok(rustflags) = env::var("CARGO_ENCODED_RUSTFLAGS") else { - eprintln!( - "🔴 When enabling the tokio_console feature, you must run with `RUSTFLAGS=\"--cfg tokio_unstable\"`." - ); - process::exit(1); + bail!("🔴 When enabling the tokio_console feature, you must run with `RUSTFLAGS=\"--cfg tokio_unstable\"`."); }; // Check for the presence of `tokio_unstable` within RUSTFLAGS. - if !rustflags.contains("tokio_unstable") { - eprintln!( - "🔴 When enabling the tokio_console feature, you must run with `RUSTFLAGS=\"--cfg tokio_unstable\"`." - ); - process::exit(1); - } + ensure!( + rustflags.contains("tokio_unstable"), + "🔴 When enabling the tokio_console feature, you must run with `RUSTFLAGS=\"--cfg tokio_unstable\"`." + ); + + Ok(()) } -/// List of allowed wrapper function names const ALLOWED_WRAPPERS: &[&str] = &["flatten_error"]; -/// Common variable names used for errors (used to detect captured-identifier format syntax). const ERROR_VAR_NAMES: &[&str] = &["error", "err", "e"]; +const LOCK_TYPES: &[&str] = &["Mutex", "RwLock"]; + +/// Visits a single source file, checking both locktick import balance and forbidden error +/// formatting in one AST pass. +#[derive(Default)] +struct FileChecker { + /// Lock types imported from `parking_lot` or `tokio`: (line, type_name). + non_locktick_locks: Vec<(usize, String)>, + /// Lock types imported from `locktick`: (line, type_name). + locktick_locks: Vec<(usize, String)>, + lock_violations: Vec<(usize, String)>, + error_violations: Vec<(usize, String)>, + /// Depth counter for `#[cfg(test)]`-gated modules; imports inside are ignored. + test_module_depth: usize, +} -struct ErrorChecker { - violations: Vec<(usize, String)>, +impl FileChecker { + /// After visiting the file, compare the two lock sets and populate `lock_violations` + /// with any type that appears in one side but not the other. + fn finalize_lock_check(&mut self) { + let non_locktick_types: std::collections::HashSet<&str> = + self.non_locktick_locks.iter().map(|(_, t)| t.as_str()).collect(); + let locktick_types: std::collections::HashSet<&str> = + self.locktick_locks.iter().map(|(_, t)| t.as_str()).collect(); + + // parking_lot/tokio imports with no locktick counterpart. + for (line, ty) in &self.non_locktick_locks { + if !locktick_types.contains(ty.as_str()) { + self.lock_violations.push((*line, format!("{ty} imported without a locktick counterpart"))); + } + } + // locktick imports with no parking_lot/tokio counterpart. + for (line, ty) in &self.locktick_locks { + if !non_locktick_types.contains(ty.as_str()) { + self.lock_violations + .push((*line, format!("{ty} imported from locktick without a non-locktick counterpart"))); + } + } + } } -impl ErrorChecker { +impl FileChecker { + /// Collects lock-type names (`Mutex`, `RwLock`) found within a `UseTree` into `out`. + fn collect_lock_types_in_tree(module: Option<&str>, tree: &UseTree, line: usize, out: &mut Vec<(usize, String)>) { + match tree { + UseTree::Name(UseName { ident, .. }) | UseTree::Rename(UseRename { ident, .. }) => { + let name = ident.to_string(); + if LOCK_TYPES.contains(&name.as_str()) { + // At this point we should know if it is `tokio` or `parking_lot`. + let module = module.expect("module name is missing"); + out.push((line, format!("{module}::{name}"))); + } + } + UseTree::Group(UseGroup { items, .. }) => { + for item in items { + Self::collect_lock_types_in_tree(module, item, line, out); + } + } + UseTree::Path(UsePath { tree, ident, .. }) => { + let module = if let Some(module) = module { module } else { &ident.to_string() }; + Self::collect_lock_types_in_tree(Some(module), tree, line, out); + } + UseTree::Glob(_) => {} + } + } + fn check_macro(&mut self, mac: &Macro) { let mac_name = mac.path.segments.last().unwrap().ident.to_string(); let tokens = mac.tokens.to_string(); @@ -314,7 +326,7 @@ impl ErrorChecker { .any(|var| tokens.split(|c: char| !c.is_alphanumeric() && c != '_').any(|word| word == *var)); let has_captured_error = ERROR_VAR_NAMES.iter().any(|var| tokens.contains(&format!("\"{{{var}}}\""))); if (has_plain_placeholder || has_captured_error) && !ALLOWED_WRAPPERS.iter().any(|f| tokens.contains(f)) { - self.violations.push((line, format!("{mac_name}!({tokens})"))); + self.error_violations.push((line, format!("{mac_name}!({tokens})"))); } } @@ -323,19 +335,51 @@ impl ErrorChecker { // Use `anyhow_concat!` or `bail_concat!` to explicitly opt in to concatenation. if ["anyhow", "bail", "format_err"].contains(&mac_name.as_str()) { let has_captured_error = ERROR_VAR_NAMES.iter().any(|var| tokens.contains(&format!("\"{{{var}}}\""))); - // Also catch `"...{var}..."` where the var appears inside a longer format string. let has_embedded_error = !has_captured_error && ERROR_VAR_NAMES .iter() .any(|var| tokens.contains(&format!("{{{var}}}")) || tokens.contains(&format!("{{{var}:"))); if has_captured_error || has_embedded_error { - self.violations.push((line, format!("{mac_name}!({tokens})"))); + self.error_violations.push((line, format!("{mac_name}!({tokens})"))); } } } } -impl<'ast> Visit<'ast> for ErrorChecker { +impl<'ast> Visit<'ast> for FileChecker { + fn visit_item_mod(&mut self, node: &'ast syn::ItemMod) { + let is_test_mod = node.attrs.iter().any(|attr| { + attr.path().is_ident("cfg") + && attr.meta.require_list().ok().is_some_and(|l| l.tokens.to_string().contains("test")) + }); + if is_test_mod { + self.test_module_depth += 1; + } + syn::visit::visit_item_mod(self, node); + if is_test_mod { + self.test_module_depth -= 1; + } + } + + fn visit_item_use(&mut self, node: &'ast ItemUse) { + if self.test_module_depth > 0 { + // Ignore test code. + return; + } + + let line = node.span().start().line; + if let UseTree::Path(UsePath { ident, tree, .. }) = &node.tree { + match ident.to_string().as_str() { + "parking_lot" => { + Self::collect_lock_types_in_tree(Some("parking_lot"), tree, line, &mut self.non_locktick_locks) + } + "tokio" => Self::collect_lock_types_in_tree(Some("tokio"), tree, line, &mut self.non_locktick_locks), + "locktick" => Self::collect_lock_types_in_tree(None, tree, line, &mut self.locktick_locks), + _ => {} + } + } + } + fn visit_expr_macro(&mut self, node: &'ast ExprMacro) { self.check_macro(&node.mac); syn::visit::visit_expr_macro(self, node); @@ -348,14 +392,16 @@ impl<'ast> Visit<'ast> for ErrorChecker { } // The build script. -fn main() { +fn main() -> anyhow::Result<()> { // Single walk: check licenses, locktick imports, and error formatting for all source files. - check_source_files("."); + check_source_files(".")?; // Check if locktick feature is correctly enabled. - check_locktick_profile(); + check_locktick_profile()?; // Check if the tokio_console feature is correctly enabled. - check_tokio_console_flags(); + check_tokio_console_flags()?; // Register build-time information. - built::write_built_file().expect("Failed to acquire build-time information"); + built::write_built_file().with_context(|| "Failed to acquire build-time information")?; + + Ok(()) }