diff --git a/Cargo.lock b/Cargo.lock index ce3797a86..525a4e991 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -526,6 +526,7 @@ version = "1.0.0" dependencies = [ "anyhow", "dyn-clone", + "egglog-ast", "egglog-core-relations", "egglog-numeric-id", "egglog-reports", @@ -1237,6 +1238,13 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proof_api" +version = "0.1.0" +dependencies = [ + "egglog", +] + [[package]] name = "quote" version = "1.0.41" diff --git a/Cargo.toml b/Cargo.toml index 4bd499d61..fdd768b52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "egglog-bridge", "numeric-id", "union-find", + "tests/proof_api", "src/sort/add_primitive", "wasm-example" ] diff --git a/Makefile b/Makefile index bdc15d364..bca6c4c0e 100644 --- a/Makefile +++ b/Makefile @@ -8,9 +8,9 @@ WWW=${PWD}/target/www all: test nits docs test: - cargo nextest run --release + cargo nextest run --workspace --all-features --release # nextest doesn't run doctests, so do it here - cargo test --doc --release + cargo test --workspace --all-features --doc --release nits: @rustup component add clippy diff --git a/core-relations/src/action/mod.rs b/core-relations/src/action/mod.rs index 0fdf95e91..f8f80c4ee 100644 --- a/core-relations/src/action/mod.rs +++ b/core-relations/src/action/mod.rs @@ -112,7 +112,8 @@ pub(crate) struct Bindings { impl std::ops::Index for Bindings { type Output = [Value]; fn index(&self, var: Variable) -> &[Value] { - self.get(var).unwrap() + self.get(var) + .unwrap_or_else(|| panic!("Bindings missing value for variable {:?}", var)) } } @@ -632,7 +633,8 @@ impl ExecutionState<'_> { bindings[*v][offset] } WriteVal::IncCounter(ctr) => { - Value::from_usize(ctrs.inc(*ctr)) + let res = ctrs.inc(*ctr); + Value::from_usize(res) } WriteVal::CurrentVal(ix) => row[*ix], }; @@ -743,6 +745,37 @@ impl ExecutionState<'_> { } } }, + Instr::InsertIfNe { table, l, r, vals } => match (l, r) { + (QueryEntry::Var(v1), QueryEntry::Var(v2)) => { + for_each_binding_with_mask!(mask, vals.as_slice(), bindings, |iter| { + iter.zip(&bindings[*v1]) + .zip(&bindings[*v2]) + .for_each(|((vals, v1), v2)| { + if v1 != v2 { + self.stage_insert(*table, &vals); + } + }) + }) + } + (QueryEntry::Var(v), QueryEntry::Const(c)) + | (QueryEntry::Const(c), QueryEntry::Var(v)) => { + for_each_binding_with_mask!(mask, vals.as_slice(), bindings, |iter| { + iter.zip(&bindings[*v]).for_each(|(vals, cond)| { + if cond != c { + self.stage_insert(*table, &vals); + } + }) + }) + } + (QueryEntry::Const(c1), QueryEntry::Const(c2)) => { + if c1 != c2 { + for_each_binding_with_mask!(mask, vals.as_slice(), bindings, |iter| iter + .for_each(|vals| { + self.stage_insert(*table, &vals); + })) + } + } + }, Instr::Remove { table, args } => { for_each_binding_with_mask!(mask, args.as_slice(), bindings, |iter| { iter.for_each(|args| { @@ -866,6 +899,14 @@ pub(crate) enum Instr { vals: Vec, }, + /// Insert `vals` into `table` if `l` and `r` are not equal. + InsertIfNe { + table: TableId, + l: QueryEntry, + r: QueryEntry, + vals: Vec, + }, + /// Remove the entry corresponding to `args` in `func`. Remove { table: TableId, diff --git a/core-relations/src/base_values/mod.rs b/core-relations/src/base_values/mod.rs index f20bcdfa0..21cf9f584 100644 --- a/core-relations/src/base_values/mod.rs +++ b/core-relations/src/base_values/mod.rs @@ -79,6 +79,17 @@ impl BaseValues { id } + /// Get the underlying `P` represented by the [`Value`] `v`, assuming it's type `ty` matches + /// the [`BaseValueId`] for `P`. + /// + /// This method does not panic if `P` was never registered as a type in this [`BaseValues`] + /// instance. + pub fn try_get_as(&self, v: Value, ty: BaseValueId) -> Option

{ + self.type_ids + .get(&TypeId::of::

()) + .and_then(|&id| (id == ty).then_some(self.unwrap::

(v))) + } + /// Get the [`BaseValueId`] for the given base value type `P`. pub fn get_ty(&self) -> BaseValueId { self.type_ids[&TypeId::of::

()] diff --git a/core-relations/src/base_values/tests.rs b/core-relations/src/base_values/tests.rs index d286c35c9..18896b0e8 100644 --- a/core-relations/src/base_values/tests.rs +++ b/core-relations/src/base_values/tests.rs @@ -268,7 +268,7 @@ fn roundtrip_medium_integers_interned() { // Test large isize values (need interning on 64-bit systems) bases.register_type::(); if std::mem::size_of::() == 8 { - for val in [-2147483649isize, isize::MIN, isize::MAX] { + for val in [isize::MIN, isize::MAX] { let boxed = bases.get(val); let unboxed = bases.unwrap::(boxed); assert_eq!(val, unboxed); diff --git a/core-relations/src/common.rs b/core-relations/src/common.rs index 6fd582ca6..761b9d929 100644 --- a/core-relations/src/common.rs +++ b/core-relations/src/common.rs @@ -103,8 +103,8 @@ impl Clear for DenseIdMap { define_id!(pub Value, u32, "A generic identifier representing an egglog value"); impl Value { - pub(crate) fn stale() -> Self { - Value::new(u32::MAX) + pub(crate) const fn stale() -> Self { + Value::new_const(u32::MAX) } /// Values have a special "Stale" value that is used to indicate that the /// value isn't intended to be read. diff --git a/core-relations/src/free_join/mod.rs b/core-relations/src/free_join/mod.rs index 1186c016c..838438ef6 100644 --- a/core-relations/src/free_join/mod.rs +++ b/core-relations/src/free_join/mod.rs @@ -435,7 +435,14 @@ impl Database { /// /// These counters can be used to generate unique ids as part of an action. pub fn add_counter(&mut self) -> CounterId { - self.counters.0.push(AtomicUsize::new(0)) + self.add_counter_with_initial_val(0) + } + + /// Create a new counter for this database starting at the given `val`. + /// + /// These counters can be used to generate unique ids as part of an action. + pub fn add_counter_with_initial_val(&mut self, val: usize) -> CounterId { + self.counters.0.push(AtomicUsize::new(val)) } /// Increment the given counter and return its previous value. diff --git a/core-relations/src/query.rs b/core-relations/src/query.rs index e66c2a29e..a85fff1c3 100644 --- a/core-relations/src/query.rs +++ b/core-relations/src/query.rs @@ -455,6 +455,7 @@ impl RuleBuilder<'_, '_> { None } })); + let desc: String = desc.into(); let action_id = self.qb.rsb.rule_set.actions.push(ActionInfo { instrs: Arc::new(self.qb.instrs), used_vars, @@ -462,7 +463,6 @@ impl RuleBuilder<'_, '_> { self.qb.query.action = action_id; // Plan the query let plan = self.qb.rsb.db.plan_query(self.qb.query); - let desc: String = desc.into(); // Add it to the ruleset. self.qb .rsb @@ -601,6 +601,28 @@ impl RuleBuilder<'_, '_> { Ok(()) } + /// Insert the specified values into the given table if `l` and `r` are not + /// equal. + pub fn insert_if_ne( + &mut self, + table: TableId, + l: QueryEntry, + r: QueryEntry, + vals: &[QueryEntry], + ) -> Result<(), QueryError> { + let table_info = self.table_info(table); + self.validate_row(table, table_info, vals)?; + self.qb.instrs.push(Instr::InsertIfNe { + table, + l, + r, + vals: vals.to_vec(), + }); + self.qb + .mark_used(vals.iter().chain(once(&l)).chain(once(&r))); + Ok(()) + } + /// Remove the specified entry from the given table, if it is there. pub fn remove(&mut self, table: TableId, args: &[QueryEntry]) -> Result<(), QueryError> { let table_info = self.table_info(table); diff --git a/core-relations/src/table_spec.rs b/core-relations/src/table_spec.rs index c8bd10223..46fd3fb76 100644 --- a/core-relations/src/table_spec.rs +++ b/core-relations/src/table_spec.rs @@ -135,6 +135,7 @@ pub trait Rebuilder: Send + Sync { } /// A row in a table. +#[derive(Debug)] pub struct Row { /// The id associated with the row. pub id: RowId, diff --git a/core-relations/src/tests.rs b/core-relations/src/tests.rs index 3febb2c5e..777366dae 100644 --- a/core-relations/src/tests.rs +++ b/core-relations/src/tests.rs @@ -1082,3 +1082,135 @@ fn call_external_with_fallback() { h_contents.sort(); assert_eq!(h_contents, vec![vec![v(2), v(0)], vec![v(4), v(0)],]); } + +#[test] +fn insert_if_ne_with_vars() { + fn make_two_col_table() -> SortedWritesTable { + SortedWritesTable::new( + 2, + 2, + None, + vec![], + Box::new(move |_, a, b, _| { + if a != b { + panic!("merge not supported") + } else { + false + } + }), + ) + } + let mut db = Database::default(); + let pairs = db.add_table(make_two_col_table(), iter::empty(), iter::empty()); + let output = db.add_table(make_two_col_table(), iter::empty(), iter::empty()); + + { + let mut buf = db.get_table(pairs).new_buffer(); + buf.stage_insert(&[v(1), v(1)]); + buf.stage_insert(&[v(1), v(2)]); + buf.stage_insert(&[v(2), v(3)]); + buf.stage_insert(&[v(4), v(4)]); + } + db.merge_all(); + + let mut rsb = RuleSetBuilder::new(&mut db); + let mut query = rsb.new_rule(); + let x = query.new_var_named("x"); + let y = query.new_var_named("y"); + query.add_atom(pairs, &[x.into(), y.into()], &[]).unwrap(); + let mut rb = query.build(); + rb.insert_if_ne(output, x.into(), y.into(), &[x.into(), y.into()]) + .unwrap(); + rb.build(); + let rs = rsb.build(); + assert!(db.run_rule_set(&rs, ReportLevel::TimeOnly).changed); + + let out = db.get_table(output); + let all = out.all(); + let mut contents = out + .scan(all.as_ref()) + .iter() + .map(|(_, row)| row.to_vec()) + .collect::>(); + contents.sort(); + assert_eq!(contents, vec![vec![v(1), v(2)], vec![v(2), v(3)]]); +} + +#[test] +fn insert_if_ne_consts_and_vars() { + fn make_one_col_table() -> SortedWritesTable { + SortedWritesTable::new( + 1, + 1, + None, + vec![], + Box::new(move |_, a, b, _| { + if a != b { + panic!("merge not supported") + } else { + false + } + }), + ) + } + fn make_two_col_table() -> SortedWritesTable { + SortedWritesTable::new( + 2, + 2, + None, + vec![], + Box::new(move |_, a, b, _| { + if a != b { + panic!("merge not supported") + } else { + false + } + }), + ) + } + + let mut db = Database::default(); + let source = db.add_table(make_one_col_table(), iter::empty(), iter::empty()); + let output = db.add_table(make_two_col_table(), iter::empty(), iter::empty()); + { + let mut buf = db.get_table(source).new_buffer(); + buf.stage_insert(&[v(0)]); + buf.stage_insert(&[v(1)]); + } + db.merge_all(); + + let mut rsb = RuleSetBuilder::new(&mut db); + let mut query = rsb.new_rule(); + let x = query.new_var_named("x"); + query.add_atom(source, &[x.into()], &[]).unwrap(); + let mut rb = query.build(); + rb.insert_if_ne(output, x.into(), v(1).into(), &[x.into(), v(10).into()]) + .unwrap(); + rb.insert_if_ne( + output, + v(2).into(), + v(2).into(), + &[v(2).into(), v(2).into()], + ) + .unwrap(); + rb.insert_if_ne( + output, + v(3).into(), + v(4).into(), + &[v(3).into(), v(4).into()], + ) + .unwrap(); + rb.build(); + let rs = rsb.build(); + assert!(db.run_rule_set(&rs, ReportLevel::TimeOnly).changed); + + let output = db.get_table(output); + let all = output.all(); + let mut contents = output + .scan(all.as_ref()) + .iter() + .map(|(_, row)| row.to_vec()) + .collect::>(); + contents.sort(); + assert_eq!(contents, vec![vec![v(0), v(10)], vec![v(3), v(4)]]); +} diff --git a/egglog-ast/src/generic_ast.rs b/egglog-ast/src/generic_ast.rs index c90cf1f67..011167ba0 100644 --- a/egglog-ast/src/generic_ast.rs +++ b/egglog-ast/src/generic_ast.rs @@ -21,6 +21,8 @@ pub enum GenericExpr { Lit(Span, Literal), } +pub type Expr = GenericExpr; + /// Facts are the left-hand side of a [`Command::Rule`]. /// They represent a part of a database query. /// Facts can be expressions or equality constraints between expressions. diff --git a/egglog-ast/src/generic_ast_helpers.rs b/egglog-ast/src/generic_ast_helpers.rs index 73db584d6..1a76b7755 100644 --- a/egglog-ast/src/generic_ast_helpers.rs +++ b/egglog-ast/src/generic_ast_helpers.rs @@ -157,7 +157,11 @@ where GenericExpr::Lit(_ann, lit) => write!(f, "{lit}"), GenericExpr::Var(_ann, var) => write!(f, "{var}"), GenericExpr::Call(_ann, op, children) => { - write!(f, "({} {})", op, ListDisplay(children, " ")) + if children.is_empty() { + write!(f, "({})", op) + } else { + write!(f, "({} {})", op, ListDisplay(children, " ")) + } } } } diff --git a/egglog-bridge/Cargo.toml b/egglog-bridge/Cargo.toml index a73715006..3ce787114 100644 --- a/egglog-bridge/Cargo.toml +++ b/egglog-bridge/Cargo.toml @@ -13,6 +13,7 @@ egglog-core-relations = { workspace = true } egglog-numeric-id = { workspace = true } egglog-union-find = { workspace = true } egglog-reports = { workspace = true } +egglog-ast = { workspace = true } hashbrown = { workspace = true } smallvec = { workspace = true } thiserror = { workspace = true } diff --git a/egglog-bridge/examples/ac.rs b/egglog-bridge/examples/ac.rs index e6a83b16e..78bbc2c8a 100644 --- a/egglog-bridge/examples/ac.rs +++ b/egglog-bridge/examples/ac.rs @@ -20,6 +20,7 @@ fn main() { merge: MergeFn::UnionId, name: "num".into(), can_subsume: false, + fiat_reason_only: None, }); let add_table = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id; 3], @@ -27,6 +28,7 @@ fn main() { merge: MergeFn::UnionId, name: "add".into(), can_subsume: false, + fiat_reason_only: None, }); let add_comm = define_rule! { diff --git a/egglog-bridge/examples/ac_tracing.rs b/egglog-bridge/examples/ac_tracing.rs index c15ba0ab0..a08af4c3b 100644 --- a/egglog-bridge/examples/ac_tracing.rs +++ b/egglog-bridge/examples/ac_tracing.rs @@ -13,6 +13,7 @@ fn main() { merge: MergeFn::UnionId, name: "num".into(), can_subsume: false, + fiat_reason_only: None, }); let add_table = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id; 3], @@ -20,6 +21,7 @@ fn main() { merge: MergeFn::UnionId, name: "add".into(), can_subsume: false, + fiat_reason_only: None, }); let add_comm = define_rule! { diff --git a/egglog-bridge/examples/math.rs b/egglog-bridge/examples/math.rs index 8c0eaa13b..fe6223476 100644 --- a/egglog-bridge/examples/math.rs +++ b/egglog-bridge/examples/math.rs @@ -25,6 +25,7 @@ fn main() { merge: MergeFn::UnionId, name: "diff".into(), can_subsume: false, + fiat_reason_only: None, }); let integral = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id, ColumnTy::Id, ColumnTy::Id], @@ -32,6 +33,7 @@ fn main() { merge: MergeFn::UnionId, name: "integral".into(), can_subsume: false, + fiat_reason_only: None, }); let add = egraph.add_table(FunctionConfig { @@ -40,6 +42,7 @@ fn main() { merge: MergeFn::UnionId, name: "add".into(), can_subsume: false, + fiat_reason_only: None, }); let sub = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id, ColumnTy::Id, ColumnTy::Id], @@ -47,6 +50,7 @@ fn main() { merge: MergeFn::UnionId, name: "sub".into(), can_subsume: false, + fiat_reason_only: None, }); let mul = egraph.add_table(FunctionConfig { @@ -55,6 +59,7 @@ fn main() { merge: MergeFn::UnionId, name: "mul".into(), can_subsume: false, + fiat_reason_only: None, }); let div = egraph.add_table(FunctionConfig { @@ -63,6 +68,7 @@ fn main() { merge: MergeFn::UnionId, name: "div".into(), can_subsume: false, + fiat_reason_only: None, }); let pow = egraph.add_table(FunctionConfig { @@ -71,6 +77,7 @@ fn main() { merge: MergeFn::UnionId, name: "pow".into(), can_subsume: false, + fiat_reason_only: None, }); let ln = egraph.add_table(FunctionConfig { @@ -79,6 +86,7 @@ fn main() { merge: MergeFn::UnionId, name: "ln".into(), can_subsume: false, + fiat_reason_only: None, }); let sqrt = egraph.add_table(FunctionConfig { @@ -87,6 +95,7 @@ fn main() { merge: MergeFn::UnionId, name: "sqrt".into(), can_subsume: false, + fiat_reason_only: None, }); let sin = egraph.add_table(FunctionConfig { @@ -95,6 +104,7 @@ fn main() { merge: MergeFn::UnionId, name: "sin".into(), can_subsume: false, + fiat_reason_only: None, }); let cos = egraph.add_table(FunctionConfig { @@ -103,6 +113,7 @@ fn main() { merge: MergeFn::UnionId, name: "cos".into(), can_subsume: false, + fiat_reason_only: None, }); let rat = egraph.add_table(FunctionConfig { @@ -111,6 +122,7 @@ fn main() { merge: MergeFn::UnionId, name: "rat".into(), can_subsume: false, + fiat_reason_only: None, }); let var = egraph.add_table(FunctionConfig { @@ -119,6 +131,7 @@ fn main() { merge: MergeFn::UnionId, name: "var".into(), can_subsume: false, + fiat_reason_only: None, }); let zero = egraph.base_value_constant(Rational64::new(0, 1)); diff --git a/egglog-bridge/src/lib.rs b/egglog-bridge/src/lib.rs index 1bddea882..fa23515f2 100644 --- a/egglog-bridge/src/lib.rs +++ b/egglog-bridge/src/lib.rs @@ -24,6 +24,7 @@ use crate::core_relations::{ TableId, TaggedRowBuffer, Value, WrappedTable, }; use crate::numeric_id::{DenseIdMap, DenseIdMapWithReuse, IdVec, NumericId, define_id}; +use egglog_ast::generic_ast::Literal; use egglog_core_relations as core_relations; use egglog_numeric_id as numeric_id; use egglog_reports::{IterationReport, ReportLevel, RuleSetReport}; @@ -31,6 +32,7 @@ use hashbrown::HashMap; use indexmap::{IndexMap, IndexSet, map::Entry}; use log::info; use once_cell::sync::Lazy; +use ordered_float::OrderedFloat; pub use proof_format::{EqProofId, ProofStore, TermProofId}; use proof_spec::{ProofReason, ProofReconstructionState, ReasonSpecId}; use smallvec::SmallVec; @@ -41,10 +43,11 @@ pub mod proof_format; pub(crate) mod proof_spec; pub(crate) mod rule; pub mod syntax; +pub mod termdag; #[cfg(test)] mod tests; -pub use rule::{Function, QueryEntry, RuleBuilder}; +pub use rule::{AtomId, Function, QueryEntry, RuleBuilder, VariableId}; pub use syntax::{SourceExpr, SourceSyntax, TopLevelLhsExpr}; use thiserror::Error; @@ -81,6 +84,7 @@ pub struct EGraph { /// bound. panic_funcs: HashMap, proof_specs: IdVec>, + refl_reason: Value, cong_spec: ReasonSpecId, /// Side tables used to store proof information. We initialize these lazily /// as a proof object with a given number of parameters is added. @@ -130,6 +134,10 @@ pub struct EGraph { /// union-find. In this way, it's a subset of the full union-find in the `uf_table` row, only /// used to resolved temporary inconsistencies in cached term values. term_consistency_table: TableId, + /// The reason consistency table is the reason-level analog to the term consistency table. When + /// identical reasons are inserted to a table concurrently, this table tracks which one is + /// canonical. + reason_consistency_table: TableId, tracing: bool, report_level: ReportLevel, } @@ -161,8 +169,14 @@ pub struct FunctionConfig { pub name: String, /// Whether or not subsumption is enabled for this function. pub can_subsume: bool, + /// If present, every write to this table gets an associated "fiat" reason with this label + /// rather than a standard, tree-shaped one. This is only relevant when proofs are enabled. + pub fiat_reason_only: Option, } +pub type BackendFloat = core_relations::Boxed>; +pub type BackendString = core_relations::Boxed; + impl EGraph { /// Create a new EGraph with tracing (aka 'proofs') enabled. /// @@ -180,22 +194,82 @@ impl EGraph { EGraph::create_internal(db, uf_table, true) } + /// Returns whether this e-graph is collecting provenance data for proofs. + pub fn proofs_enabled(&self) -> bool { + self.tracing + } + + pub fn literal_to_value(&self, l: &Literal) -> Value { + match l { + Literal::Int(x) => self.base_values().get::(*x), + Literal::Float(x) => self.base_values().get::(x.into()), + Literal::String(x) => self + .base_values() + .get::(BackendString::new(x.clone())), + Literal::Bool(x) => self.base_values().get::(*x), + Literal::Unit => self.base_values().get::<()>(()), + } + } + + pub fn literal_to_typed_constant(&self, l: &Literal) -> (Value, ColumnTy) { + match l { + Literal::Int(x) => self.base_value_typed_constant::(*x), + Literal::Float(x) => self.base_value_typed_constant::(x.into()), + Literal::String(x) => { + self.base_value_typed_constant::(BackendString::new(x.clone())) + } + Literal::Bool(x) => self.base_value_typed_constant::(*x), + Literal::Unit => self.base_value_typed_constant::<()>(()), + } + } + + pub fn literal_to_entry(&self, l: &Literal) -> QueryEntry { + let (val, ty) = self.literal_to_typed_constant(l); + QueryEntry::Const { val, ty } + } + + /// Render `v` as a [`Literal`], where possible. + /// + /// This method returns None when `v` is backed by a type not supported by the [`Literal`] + /// enum. + pub fn value_to_literal(&self, v: &Value, ty: BaseValueId) -> Option { + let base_values = self.base_values(); + + Some(if let Some(b) = base_values.try_get_as::(*v, ty) { + Literal::Bool(b) + } else if let Some(i) = base_values.try_get_as::(*v, ty) { + Literal::Int(i) + } else if let Some(f) = base_values.try_get_as::(*v, ty) { + Literal::Float(f.into_inner()) + } else if let Some(s) = base_values.try_get_as::(*v, ty) { + Literal::String(s.into_inner()) + } else if base_values.try_get_as::<()>(*v, ty).is_some() { + Literal::Unit + } else { + return None; + }) + } + fn create_internal(mut db: Database, uf_table: TableId, tracing: bool) -> EGraph { let id_counter = db.add_counter(); - let trace_counter = db.add_counter(); + let reason_counter = db.add_counter_with_initial_val(1_000_000); let ts_counter = db.add_counter(); // Start the timestamp counter at 1. db.inc_counter(ts_counter); let mut proof_specs = IdVec::default(); let cong_spec = proof_specs.push(Arc::new(ProofReason::CongRow)); + let refl_spec = proof_specs.push(Arc::new(ProofReason::Refl)); + let refl_reason = Value::new(!0); let term_consistency_table = db.add_table(DisplacedTable::default(), iter::empty(), iter::empty()); + let reason_consistency_table = + db.add_table(DisplacedTable::default(), iter::empty(), iter::empty()); - Self { + let mut res = Self { db, uf_table, id_counter, - reason_counter: trace_counter, + reason_counter, timestamp_counter: ts_counter, rules: Default::default(), funcs: Default::default(), @@ -203,12 +277,28 @@ impl EGraph { panic_funcs: Default::default(), proof_specs, cong_spec, + refl_reason, reason_tables: Default::default(), term_tables: Default::default(), term_consistency_table, + reason_consistency_table, report_level: Default::default(), tracing, + }; + if tracing { + let refl_table = res.reason_table(&ProofReason::Refl); + let refl_reason = res.with_execution_state(|es| { + es.predict_col( + refl_table, + &[Value::new(refl_spec.rep())], + iter::once(MergeVal::Counter(reason_counter)), + ColumnId::new(1), + ) + }); + res.flush_updates(); + res.refl_reason = refl_reason; } + res } fn next_ts(&self) -> Timestamp { @@ -280,6 +370,17 @@ impl EGraph { } } + /// Create a [`QueryEntry`] for a base value. + pub fn base_value_typed_constant(&self, x: T) -> (Value, ColumnTy) + where + T: BaseValue, + { + ( + self.base_values().get(x), + ColumnTy::Base(self.base_values().get_ty::()), + ) + } + pub fn register_external_func( &mut self, func: impl ExternalFunction + 'static, @@ -316,20 +417,6 @@ impl EGraph { } } - fn record_term_consistency( - state: &mut ExecutionState, - table: TableId, - ts_counter: CounterId, - from: Value, - to: Value, - ) { - if from == to { - return; - } - let ts = Value::from_usize(state.read_counter(ts_counter)); - state.stage_insert(table, &[from, to, ts]); - } - fn canonicalize_term_id(&mut self, term_id: Value) -> Value { let table = self.db.get_table(self.term_consistency_table); table @@ -338,40 +425,60 @@ impl EGraph { .unwrap_or(term_id) } + fn canonicalize_reason_id(&mut self, term_id: Value) -> Value { + let table = self.db.get_table(self.reason_consistency_table); + table + .get_row(&[term_id]) + .map(|row| row.vals[1]) + .unwrap_or(term_id) + } + + fn consistency_merge_fn( + uf_table: TableId, + ts_counter: CounterId, + old_id: Value, + new_id: Value, + new: &[Value], + out: &mut Vec, + state: &mut ExecutionState, + ) -> bool { + if new_id == old_id { + return false; + } + let ts = Value::from_usize(state.read_counter(ts_counter)); + state.stage_insert(uf_table, &[old_id, new_id, ts]); + if new_id < old_id { + out.extend(new); + true + } else { + false + } + } + fn term_table(&mut self, table: TableId) -> TableId { let info = self.db.get_table_info(table); let spec = info.spec(); match self.term_tables.entry(spec.n_keys) { Entry::Occupied(o) => *o.get(), Entry::Vacant(v) => { - let term_index = spec.n_keys + 1; let term_consistency_table = self.term_consistency_table; let ts_counter = self.timestamp_counter; + let term_id_col = spec.n_keys + 1; let table = SortedWritesTable::new( spec.n_keys + 1, // added entry for the tableid spec.n_keys + 1 + 2, // one value for the term id, one for the reason, None, vec![], // no rebuilding needed for term table Box::new(move |state, old, new, out| { - // We want to pick the minimum term value. - let l_term_id = old[term_index]; - let r_term_id = new[term_index]; - // NB: we should only need this merge function when we are executing - // rules in parallel. We could consider a simpler merge function if - // parallelism is disabled. - if r_term_id < l_term_id { - EGraph::record_term_consistency( - state, - term_consistency_table, - ts_counter, - l_term_id, - r_term_id, - ); - out.extend(new); - true - } else { - false - } + Self::consistency_merge_fn( + term_consistency_table, + ts_counter, + old[term_id_col], + new[term_id_col], + new, + out, + state, + ) }), ); let table_id = @@ -387,14 +494,28 @@ impl EGraph { match self.reason_tables.entry(arity) { Entry::Occupied(o) => *o.get(), Entry::Vacant(v) => { + let consistency = self.reason_consistency_table; + let ts_counter = self.timestamp_counter; let table = SortedWritesTable::new( arity, arity + 1, // one value for the reason id None, vec![], // no rebuilding needed for reason tables - Box::new(|_, _, _, _| false), + Box::new(move |state, old, new, out| { + Self::consistency_merge_fn( + consistency, + ts_counter, + *old.last().unwrap(), + *new.last().unwrap(), + new, + out, + state, + ) + }), ); - let table_id = self.db.add_table(table, iter::empty(), iter::empty()); + let table_id = self + .db + .add_table(table, iter::empty(), iter::once(consistency)); *v.insert(table_id) } } @@ -607,14 +728,10 @@ impl EGraph { if self.get_canon_in_uf(id1) != self.get_canon_in_uf(id2) { // These terms aren't equal. Reconstruct the relevant terms so as to // get a nicer error message on the way out. - let mut buf = Vec::::new(); let term_id_1 = self.reconstruct_term(id1, ColumnTy::Id, &mut state); let term_id_2 = self.reconstruct_term(id2, ColumnTy::Id, &mut state); - store.termdag.print_term(term_id_1, &mut buf).unwrap(); - let term1 = String::from_utf8(buf).unwrap(); - let mut buf = Vec::::new(); - store.termdag.print_term(term_id_2, &mut buf).unwrap(); - let term2 = String::from_utf8(buf).unwrap(); + let term1 = store.termdag.to_string_pretty_id(term_id_1); + let term2 = store.termdag.to_string_pretty_id(term_id_2); return Err( ProofReconstructionError::EqualityExplanationOfUnequalTerms { term1, term2 }.into(), ); @@ -672,6 +789,29 @@ impl EGraph { drain_buf!(buf); } + /// A basic method for dumping the names and mappings for table information: useful for + /// debugging egglog internals. + pub fn dump_table_info(&self) { + info!("=== View Tables ==="); + for (id, info) in self.funcs.iter() { + info!( + "View Table {name} / {id:?} / {table:?}", + name = info.name, + table = info.table + ); + } + + info!("=== Term Tables ==="); + for (_, table_id) in &self.term_tables { + info!("Term Table {table_id:?}"); + } + + info!("=== Reason Tables ==="); + for (_, table_id) in &self.reason_tables { + info!("Reason Table {table_id:?}"); + } + } + /// A basic method for dumping the state of the database to `log::info!`. /// /// For large tables, this is unlikely to give particularly useful output. @@ -679,6 +819,11 @@ impl EGraph { info!("=== View Tables ==="); for (id, info) in self.funcs.iter() { let table = self.db.get_table(info.table); + info!( + "View Table {name} / {id:?} / {table:?}", + name = info.name, + table = info.table + ); self.scan_table(table, |row| { info!( "View Table {name} / {id:?} / {table:?}: {row:?}", @@ -691,6 +836,7 @@ impl EGraph { info!("=== Term Tables ==="); for (_, table_id) in &self.term_tables { let table = self.db.get_table(*table_id); + info!("Term Table {table_id:?}"); self.scan_table(table, |row| { let name = &self.funcs[FunctionId::new(row[0].rep())].name; let row = &row[1..]; @@ -700,6 +846,7 @@ impl EGraph { info!("=== Reason Tables ==="); for (_, table_id) in &self.reason_tables { + info!("Reason Table {table_id:?}"); let table = self.db.get_table(*table_id); self.scan_table(table, |row| { let spec = self.proof_specs[ReasonSpecId::new(row[0].rep())].as_ref(); @@ -707,6 +854,13 @@ impl EGraph { info!("Reason Table {table_id:?}: {spec:?}, {row:?}") }); } + info!("=== UF ==="); + let uf_table = self.db.get_table(self.uf_table); + self.scan_table(uf_table, |row| match row { + [x, y, t] => info!("Displaced {x:?} => {y:?} @ {t:?}"), + [x, y, t, r] => info!("Displaced {x:?} => {y:?} @ {t:?}, reason {r:?}"), + _ => panic!("unexpected format for union-find"), + }); } /// A helper for scanning the entries in a table. @@ -731,6 +885,7 @@ impl EGraph { merge, name, can_subsume, + fiat_reason_only, } = config; assert!( !schema.is_empty(), @@ -761,6 +916,26 @@ impl EGraph { to_rebuild, merge_fn, ); + let fiat_reason: Option = fiat_reason_only.as_ref().and_then(|desc| { + self.tracing.then(|| { + let reason = Arc::new(ProofReason::Fiat { + desc: desc.clone().into(), + }); + let reason_table = self.reason_table(&reason); + let reason_spec_id = self.proof_specs.push(reason); + let reason_id = self.with_execution_state(|es| { + es.predict_col( + reason_table, + &[Value::new(reason_spec_id.rep())], + iter::once(MergeVal::Counter(self.reason_counter)), + ColumnId::new(1), + ) + }); + self.flush_updates(); + reason_id + }) + }); + let name: Arc = name.into(); let table_id = self.db.add_table_named( table, @@ -777,6 +952,7 @@ impl EGraph { default_val: default, can_subsume, name, + fiat_reason, }); debug_assert_eq!(res, next_func_id); let incremental_rebuild_rules = self.incremental_rebuild_rules(res, &schema); @@ -1172,6 +1348,8 @@ struct FunctionInfo { default_val: DefaultVal, can_subsume: bool, name: Arc, + #[allow(dead_code)] + fiat_reason: Option, } impl FunctionInfo { @@ -1232,10 +1410,13 @@ impl MergeFn { args.iter() .for_each(|arg| arg.fill_deps(egraph, read_deps, write_deps)); } - UnionId if !egraph.tracing => { + UnionId => { write_deps.insert(egraph.uf_table); + if egraph.tracing { + write_deps.insert(egraph.term_consistency_table); + } } - UnionId | AssertEq | Old | New | Const(..) => {} + AssertEq | Old | New | Const(..) => {} } } @@ -1246,22 +1427,56 @@ impl MergeFn { egraph: &mut EGraph, ) -> Box { assert!( - !egraph.tracing || matches!(self, MergeFn::UnionId), - "proofs aren't supported for non-union merge functions" + !egraph.tracing || matches!(self, MergeFn::UnionId | MergeFn::AssertEq), + "proofs aren't supported for merge functions other than UnionId or AssertEq" ); let resolved = self.resolve(function_name, egraph); - + let refl_reason = egraph.refl_reason; Box::new(move |state, cur, new, out| { let timestamp = new[schema_math.ts_col()]; let mut changed = false; + // This `terms_equal` handling is here to handle a particular edge case: + // + // When proofs are enabled we explicitly plumb through a reason for each `union`. This + // means that explicit unions within a table are load-bearing in a way that they aren't + // without proofs being turned on (they're always the same as lookup+set). + // + // As a result, unions generally don't happen "implicitly" as part of a merge function. + // Instead, all unions or sets are done explicitly with a reason pointing to the rule + // that did the union. There's one edge-case though: + // + // If two separate threads attempt to create the a term (say) `(f x)` concurrently, + // they will both create their own term ids (id1 and id2) and insert them to the term + // table and `f`. `f`'s merge function will blindly discard the higher id (id2), + // assuming that id2 was explicitly `union`ed with id1, but id1 hadn't actually be + // inserted yet! This gets even worse if we are inserting something like `(h (f x))` + // because now we will have a row in `h` that could reference `id2`, effectively + // incorrectly points to an empty e-class. + // + // This can only happen to duplicate / concurrent insertions of `(f x)`: two equal + // e-nodes that correspond to two different terms will still be written to two + // different rows and hence will not rely on implicit unions in this way. Because the + // only kinds of insertions that fall prey to this are identical terms, we add an + // explicit case for when `old` and `new` represent identical terms with different term + // ids. In that case we can list our reason for the terms being equal as `Refl`, and + // union these two ids in both the main union-find and the term-consistency table. + let terms_equal = schema_math.tracing + && cur[0..schema_math.ret_val_col()] == new[0..schema_math.ret_val_col()]; + let ret_val = { - let cur = cur[schema_math.ret_val_col()]; - let new = new[schema_math.ret_val_col()]; - let out = resolved.run(state, cur, new, timestamp); - changed |= cur != out; + let cur_id = cur[schema_math.ret_val_col()]; + let new_id = new[schema_math.ret_val_col()]; + let out = resolved.run( + state, + cur_id, + new_id, + timestamp, + terms_equal.then_some(refl_reason), + ); + changed |= cur_id != out; out }; @@ -1309,6 +1524,7 @@ impl MergeFn { }, MergeFn::UnionId => ResolvedMergeFn::UnionId { uf_table: egraph.uf_table, + term_consistency: egraph.term_consistency_table, tracing: egraph.tracing, }, // NB: The primitive and function-based merge functions heap allocate a single callback @@ -1362,6 +1578,7 @@ enum ResolvedMergeFn { }, UnionId { uf_table: TableId, + term_consistency: TableId, tracing: bool, }, Primitive { @@ -1377,7 +1594,17 @@ enum ResolvedMergeFn { } impl ResolvedMergeFn { - fn run(&self, state: &mut ExecutionState, cur: Value, new: Value, ts: Value) -> Value { + fn run( + &self, + state: &mut ExecutionState, + cur: Value, + new: Value, + ts: Value, + // If `Some`, the corresponding terms for the two ids are equal, in which case this value + // is the reason id for a `Refl` reason that should be inserted to the union-find given + // table. + terms_equal: Option, + ) -> Value { match self { ResolvedMergeFn::Const(v) => *v, ResolvedMergeFn::Old => cur, @@ -1389,7 +1616,11 @@ impl ResolvedMergeFn { } cur } - ResolvedMergeFn::UnionId { uf_table, tracing } => { + ResolvedMergeFn::UnionId { + uf_table, + term_consistency, + tracing, + } => { if cur != new && !tracing { // When proofs are enabled, these are the same term. They are already // equal and we can just do nothing. @@ -1397,7 +1628,17 @@ impl ResolvedMergeFn { // We pick the minimum when unioning. This matches the original egglog // behavior. THIS MUST MATCH THE UNION-FIND IMPLEMENTATION! std::cmp::min(cur, new) + } else if *tracing && cur != new && terms_equal.is_some() { + let Some(refl_reason) = terms_equal else { + unreachable!() + }; + state.stage_insert(*uf_table, &[cur, new, ts, refl_reason]); + state.stage_insert(*term_consistency, &[cur, new, ts]); + std::cmp::min(cur, new) } else { + // If proofs are enabled, but we see two non-equal ids targetting the same row, + // we do nothing. We ensure there is a separate reason being populated in the + // union-find table as part of a separate action. cur } } @@ -1408,7 +1649,7 @@ impl ResolvedMergeFn { ResolvedMergeFn::Primitive { prim, args, panic } => { let args = args .iter() - .map(|arg| arg.run(state, cur, new, ts)) + .map(|arg| arg.run(state, cur, new, ts, terms_equal)) .collect::>(); match state.call_external_func(*prim, &args) { @@ -1428,7 +1669,7 @@ impl ResolvedMergeFn { let args = args .iter() - .map(|arg| arg.run(state, cur, new, ts)) + .map(|arg| arg.run(state, cur, new, ts, terms_equal)) .collect::>(); func.lookup(state, &args).unwrap_or_else(|| { diff --git a/egglog-bridge/src/proof_format.rs b/egglog-bridge/src/proof_format.rs index 8008acae7..e3b4a02ed 100644 --- a/egglog-bridge/src/proof_format.rs +++ b/egglog-bridge/src/proof_format.rs @@ -1,16 +1,15 @@ //! A proof format for egglog programs, based on the Rocq format and checker from Tia Vu, Ryan //! Doegens, and Oliver Flatt. -use std::{hash::Hash, io, rc::Rc}; +use std::{hash::Hash, io, rc::Rc, sync::Arc}; -use crate::core_relations::Value; -use crate::numeric_id::{DenseIdMap, NumericId, define_id}; use indexmap::IndexSet; -use crate::{FunctionId, rule::VariableId}; +use crate::ColumnTy; +use crate::numeric_id::{NumericId, define_id}; +use crate::termdag::{PrettyPrintConfig, PrettyPrinter, TermDag, TermId}; define_id!(pub TermProofId, u32, "an id identifying proofs of terms within a [`ProofStore`]"); define_id!(pub EqProofId, u32, "an id identifying proofs of equality between two terms within a [`ProofStore`]"); -define_id!(pub TermId, u32, "an id identifying terms within a [`TermDag`]"); #[derive(Clone, Debug)] struct HashCons { @@ -43,95 +42,6 @@ impl HashCons { } } -#[derive(Default, Clone)] -pub struct TermDag { - store: HashCons, -} - -impl TermDag { - /// Print the term in a human-readable format to the given writer. - pub fn print_term(&self, term: TermId, writer: &mut impl io::Write) -> io::Result<()> { - self.print_term_pretty(term, &PrettyPrintConfig::default(), writer) - } - - /// Print the term with pretty-printing configuration. - pub fn print_term_pretty( - &self, - term: TermId, - config: &PrettyPrintConfig, - writer: &mut impl io::Write, - ) -> io::Result<()> { - let mut printer = PrettyPrinter::new(writer, config); - self.print_term_with_printer(term, &mut printer) - } - - fn print_term_with_printer( - &self, - term: TermId, - printer: &mut PrettyPrinter, - ) -> io::Result<()> { - let term = self.store.lookup(term).unwrap(); - match term { - Term::Constant { id, rendered } => { - if let Some(rendered) = rendered { - printer.write_str(rendered)?; - } else { - printer.write_str(&format!("c{}", id.index()))?; - } - } - Term::Func { id, args } => { - printer.write_str(&format!("({id:?}"))?; - if !args.is_empty() { - printer.increase_indent(); - for (i, arg) in args.iter().enumerate() { - if i > 0 { - printer.write_str(",")?; - } - printer.write_with_break(" ")?; - self.print_term_with_printer(*arg, printer)?; - } - printer.decrease_indent(); - } - printer.write_str(")")?; - } - } - Ok(()) - } - - /// Add the given [`Term`] to the store, returning its [`TermId`]. - /// - /// The [`TermId`]s in this term should point into this same [`TermDag`]. - pub fn get_or_insert(&mut self, term: &Term) -> TermId { - self.store.get_or_insert(term) - } - - pub(crate) fn proj(&self, term: TermId, arg_idx: usize) -> TermId { - let term = self.store.lookup(term).unwrap(); - match term { - Term::Func { args, .. } => { - if arg_idx < args.len() { - args[arg_idx] - } else { - panic!("Index out of bounds for function arguments") - } - } - _ => panic!("Cannot project a non-function term"), - } - } -} - -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub enum Term { - Constant { - id: Value, - rendered: Option>, - }, - Func { - id: FunctionId, - args: Vec, - }, -} - /// A hash-cons store for proofs and terms related to an egglog program. #[derive(Clone, Default)] pub struct ProofStore { @@ -175,10 +85,10 @@ impl ProofStore { new_term, func, } = cong_pf; - printer.write_str(&format!("Cong({func:?}, "))?; - self.termdag.print_term_with_printer(*old_term, printer)?; + printer.write_str(&format!("Cong({func}, "))?; + self.print_term(*old_term, printer)?; printer.write_str(" => ")?; - self.termdag.print_term_with_printer(*new_term, printer)?; + self.print_term(*new_term, printer)?; printer.write_str(" by (")?; printer.increase_indent(); for (i, pf) in pf_args_eq.iter().enumerate() { @@ -220,13 +130,13 @@ impl ProofStore { printer.write_str(&format!("PRule[Equality]({rule_name:?}, Subst {{"))?; printer.increase_indent(); printer.newline()?; - for (i, (var, term)) in subst.iter().enumerate() { + for (i, binding) in subst.iter().enumerate() { if i > 0 { printer.write_str(",")?; } printer.write_with_break(" ")?; - printer.write_str(&format!("{var:?} => "))?; - self.termdag.print_term_with_printer(*term, printer)?; + printer.write_str(&format!("{}: {:?} => ", binding.name, binding.ty))?; + self.print_term(binding.term, printer)?; printer.newline()?; } printer.newline()?; @@ -256,9 +166,9 @@ impl ProofStore { printer.write_with_break("], ")?; printer.newline()?; printer.write_with_break(" Result: ")?; - self.termdag.print_term_with_printer(*result_lhs, printer)?; + self.print_term(*result_lhs, printer)?; printer.write_str(" = ")?; - self.termdag.print_term_with_printer(*result_rhs, printer)?; + self.print_term(*result_rhs, printer)?; printer.write_str(")")?; printer.decrease_indent(); } @@ -266,7 +176,8 @@ impl ProofStore { printer.write_str("PRefl(")?; self.print_term_proof_with_printer(*t_ok_pf, printer)?; printer.write_str(", (term= ")?; - self.termdag.print_term_with_printer(*t, printer)?; + self.termdag + .print_term_with_printer(self.termdag.get(*t), printer)?; printer.write_str("))")? } EqProof::PSym { eq_pf } => { @@ -325,13 +236,13 @@ impl ProofStore { printer.write_str(&format!("PRule[Existence]({rule_name:?}, Subst {{"))?; printer.increase_indent(); printer.newline()?; - for (i, (var, term)) in subst.iter().enumerate() { + for (i, binding) in subst.iter().enumerate() { if i > 0 { printer.write_str(",")?; } printer.write_with_break(" ")?; - printer.write_str(&format!("{var:?} => "))?; - self.termdag.print_term_with_printer(*term, printer)?; + printer.write_str(&format!("{}: {:?} => ", binding.name, binding.ty))?; + self.print_term(binding.term, printer)?; printer.newline()?; } printer.newline()?; @@ -359,7 +270,7 @@ impl ProofStore { } printer.decrease_indent(); printer.write_with_break("], Result: ")?; - self.termdag.print_term_with_printer(*result, printer)?; + self.print_term(*result, printer)?; printer.write_str(")") } TermProof::PProj { @@ -378,11 +289,20 @@ impl ProofStore { TermProof::PFiat { desc, term } => { printer.write_str(&format!("PFiat({desc:?}"))?; printer.write_str(", ")?; - self.termdag.print_term_with_printer(*term, printer)?; + self.print_term(*term, printer)?; printer.write_str(")") } } } + fn print_term( + &self, + term: TermId, + printer: &mut PrettyPrinter, + ) -> io::Result<()> { + self.termdag + .print_term_with_printer(self.termdag.get(term), printer) + } + pub(crate) fn intern_term(&mut self, prf: &TermProof) -> TermProofId { self.term_memo.get_or_insert(prf) } @@ -432,7 +352,14 @@ pub struct CongProof { pub pf_f_args_ok: TermProofId, pub old_term: TermId, pub new_term: TermId, - pub func: FunctionId, + pub func: Arc, +} + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct RuleVarBinding { + pub name: Arc, + pub ty: ColumnTy, + pub term: TermId, } #[allow(clippy::enum_variant_names)] @@ -444,7 +371,7 @@ pub enum TermProof { /// the act_pf gives a location in the action of the proposition PRule { rule_name: Rc, - subst: DenseIdMap, + subst: Vec, body_pfs: Vec, result: TermId, }, @@ -469,7 +396,7 @@ pub enum EqProof { /// the act_pf gives a location in the action of the proposition PRule { rule_name: Rc, - subst: DenseIdMap, + subst: Vec, body_pfs: Vec, result_lhs: TermId, result_rhs: TermId, @@ -492,77 +419,3 @@ pub enum EqProof { /// pf_f_args_ok is a proof that the term with the lhs children is valid PCong(CongProof), } - -#[derive(Clone, Debug)] -pub struct PrettyPrintConfig { - pub line_width: usize, - pub indent_size: usize, -} - -impl Default for PrettyPrintConfig { - fn default() -> Self { - Self { - line_width: 512, - indent_size: 4, - } - } -} - -struct PrettyPrinter<'w, W: io::Write> { - writer: &'w mut W, - config: &'w PrettyPrintConfig, - current_indent: usize, - current_line_pos: usize, -} - -impl<'w, W: io::Write> PrettyPrinter<'w, W> { - fn new(writer: &'w mut W, config: &'w PrettyPrintConfig) -> Self { - Self { - writer, - config, - current_indent: 0, - current_line_pos: 0, - } - } - - fn write_str(&mut self, s: &str) -> io::Result<()> { - write!(self.writer, "{s}")?; - self.current_line_pos += s.len(); - Ok(()) - } - - fn newline(&mut self) -> io::Result<()> { - writeln!(self.writer)?; - self.current_line_pos = 0; - self.write_indent()?; - Ok(()) - } - - fn write_indent(&mut self) -> io::Result<()> { - for _ in 0..self.current_indent { - write!(self.writer, " ")?; - } - self.current_line_pos = self.current_indent; - Ok(()) - } - - fn increase_indent(&mut self) { - self.current_indent += self.config.indent_size; - } - - fn decrease_indent(&mut self) { - self.current_indent = self.current_indent.saturating_sub(self.config.indent_size); - } - - fn should_break(&self, additional_chars: usize) -> bool { - self.current_line_pos + additional_chars > self.config.line_width - } - - fn write_with_break(&mut self, s: &str) -> io::Result<()> { - if self.should_break(s.len()) && self.current_line_pos > self.current_indent { - self.newline()?; - self.write_indent()?; - } - self.write_str(s) - } -} diff --git a/egglog-bridge/src/proof_spec.rs b/egglog-bridge/src/proof_spec.rs index f4d4cb872..2d1584bb3 100644 --- a/egglog-bridge/src/proof_spec.rs +++ b/egglog-bridge/src/proof_spec.rs @@ -1,11 +1,12 @@ use std::{iter, rc::Rc, sync::Arc}; use crate::core_relations::{ - BaseValuePrinter, ColumnId, DisplacedTableWithProvenance, ProofReason as UfProofReason, - ProofStep, RuleBuilder, Value, + ColumnId, DisplacedTableWithProvenance, ProofReason as UfProofReason, ProofStep, RuleBuilder, + Value, }; use crate::numeric_id::{DenseIdMap, NumericId, define_id}; use crate::rule::Variable; +use crate::termdag::TermId; use egglog_reports::ReportLevel; use hashbrown::{HashMap, HashSet}; @@ -13,10 +14,10 @@ use crate::{ ColumnTy, EGraph, FunctionId, GetFirstMatch, QueryEntry, Result, RuleId, SideChannel, SourceExpr, TopLevelLhsExpr, proof_format::{ - CongProof, EqProof, EqProofId, Premise, ProofStore, Term, TermId, TermProof, TermProofId, + CongProof, EqProof, EqProofId, Premise, ProofStore, RuleVarBinding, TermProof, TermProofId, }, rule::{AtomId, Bindings, DstVar, VariableId}, - syntax::{RuleData, SourceSyntax, SyntaxId}, + syntax::{RuleData, SourceSyntax, SourceVar, SyntaxId}, }; define_id!(pub(crate) ReasonSpecId, u32, "A unique identifier for the step in a proof."); @@ -34,6 +35,11 @@ pub(crate) enum ProofReason { Fiat { desc: Arc, }, + /// A proof that a term equals itself. + /// + /// This is generally only used when two identical terms are created, but with a different term + /// id (due to concurrency / "term consistency" reasons). + Refl, } impl ProofReason { @@ -42,7 +48,7 @@ impl ProofReason { 1 + match self { ProofReason::CongRow => 1, ProofReason::Rule(data) => data.n_vars(), - ProofReason::Fiat { .. } => 0, + ProofReason::Refl | ProofReason::Fiat { .. } => 0, } } } @@ -116,25 +122,44 @@ impl ProofBuilder { &mut self, func: FunctionId, entries: Vec, + res_id: Option, term_var: VariableId, db: &mut EGraph, - ) -> impl Fn(&mut Bindings, &mut RuleBuilder) -> Result<()> + Clone + use<> { - let func_table = db.funcs[func].table; + ) -> impl Fn(&mut Bindings, &mut RuleBuilder) -> Result + Clone + use<> { + let func_info = &db.funcs[func]; + let func_table = func_info.table; + let fiat_reason = func_info.fiat_reason; let term_table = db.term_table(func_table); let func_val = Value::new(func.rep()); move |inner, rb| { - let reason_var = inner - .lhs_reason - .expect("must have a reason variable for new rows"); - let mut translated = Vec::new(); + let reason_var: DstVar = if let Some(fiat_reason) = fiat_reason { + // This table has been marked as "fiat only", meaning that all proofs for this + // table should have a single hard-coded fiat reason, rather than the ambient used + // for the rule. + fiat_reason.into() + } else { + inner + .lhs_reason + .expect("must have a reason variable for new rows") + }; + let mut translated = Vec::with_capacity(entries.len() + 2); translated.push(func_val.into()); for entry in &entries[0..entries.len() - 1] { translated.push(inner.convert(entry)); } translated.push(inner.mapping[term_var]); translated.push(reason_var); - rb.insert(term_table, &translated)?; - Ok(()) + if let Some(res_id) = res_id { + rb.insert_if_eq( + term_table, + inner.mapping[res_id], + inner.mapping[term_var], + &translated, + )?; + } else { + rb.insert(term_table, &translated)?; + } + Ok(reason_var) } } } @@ -214,14 +239,18 @@ impl EGraph { RuleData { syntax, .. }: &RuleData, vars: &[Value], state: &mut ProofReconstructionState, - ) -> (DenseIdMap, Vec) { + ) -> (Vec, Vec) { // First, reconstruct terms for all the relevant variables. - let mut subst_term = DenseIdMap::::new(); + let mut subst_term = Vec::with_capacity(syntax.vars.len()); let mut subst_val = DenseIdMap::::new(); - for ((var, ty), term_id) in syntax.vars.iter().zip(vars) { - subst_val.insert(*var, *term_id); + for (SourceVar { id, ty, name }, term_id) in syntax.vars.iter().zip(vars) { + subst_val.insert(*id, *term_id); let term = self.reconstruct_term(*term_id, *ty, state); - subst_term.insert(*var, term); + subst_term.push(RuleVarBinding { + name: Arc::clone(name), + ty: *ty, + term, + }); } let mut terms = DenseIdMap::::new(); let mut premises = Vec::new(); @@ -262,10 +291,14 @@ impl EGraph { let spec = self.proof_specs[ReasonSpecId::new(reason_row[0].rep())].clone(); let res = match &*spec { ProofReason::Rule(data) => { + debug_assert_eq!( + self.rules[data.rule_id].desc.as_ref(), + data.rule_name.as_ref() + ); let (subst, body_pfs) = self.rule_proof(data, &reason_row[1..], state); let result = self.reconstruct_term(term_id, ColumnTy::Id, state); state.store.intern_term(&TermProof::PRule { - rule_name: String::from(&*self.rules[data.rule_id].desc).into(), + rule_name: Rc::::from(data.rule_name.as_ref()), subst, body_pfs, result, @@ -282,6 +315,11 @@ impl EGraph { term, }) } + ProofReason::Refl => { + panic!( + "Refl cannot be a reason for a term's existence. This is an internal proofs error" + ); + } }; state.in_progress.remove(&term_id); @@ -306,6 +344,7 @@ impl EGraph { ColumnTy::Id => { let term_row = self.get_term_row(key_id); let func = FunctionId::new(term_row[0].rep()); + let func_name = self.funcs[func].name.to_string(); let info = &self.funcs[func]; // NB: this clone is needed because `get_term_row` borrows the whole egraph, though it // really only needs mutable access to `db`. This is of course fixable if we wanted to get @@ -313,27 +352,19 @@ impl EGraph { let schema = info.schema.clone(); let mut args = Vec::with_capacity(term_row.len() - 1); for (ty, entry) in schema[0..schema.len() - 1].iter().zip(term_row[1..].iter()) { - args.push(self.reconstruct_term(*entry, *ty, state)); + let term = self.reconstruct_term(*entry, *ty, state); + args.push(state.store.termdag.get(term).clone()); } - state - .store - .termdag - .get_or_insert(&Term::Func { id: func, args }) + let app = state.store.termdag.app(func_name, args); + state.store.termdag.lookup(&app) } ColumnTy::Base(ty) => { - let rendered: Rc = format!( - "{:?}", - BaseValuePrinter { - base: self.db.base_values(), - ty, - val: term_id, - } - ) - .into(); - state.store.termdag.get_or_insert(&Term::Constant { - id: term_id, - rendered: Some(rendered), - }) + let term = if let Some(literal) = self.value_to_literal(&term_id, ty) { + state.store.termdag.lit(literal) + } else { + state.store.termdag.unknown_lit() + }; + state.store.termdag.lookup(&term) } }; @@ -405,8 +436,6 @@ impl EGraph { let old_term_row = self.get_term_row(old_term_id); let new_term_row = self.get_term_row(new_term_id); let old_term_proof = self.explain_term_inner(old_term_id, state); - let old_term = self.reconstruct_term(old_term_id, ColumnTy::Id, state); - let new_term = self.reconstruct_term(new_term_id, ColumnTy::Id, state); let func_id = FunctionId::new(old_term_row[0].rep()); debug_assert_eq!( old_term_row[0], new_term_row[0], @@ -415,6 +444,8 @@ impl EGraph { let info = &self.funcs[func_id]; let schema = info.schema.clone(); let mut args_eq_proofs = Vec::with_capacity(schema.len() - 1); + let old_term = self.reconstruct_term(old_term_id, ColumnTy::Id, state); + let new_term = self.reconstruct_term(new_term_id, ColumnTy::Id, state); for (i, (ty, (lhs, rhs))) in schema[0..schema.len() - 1] .iter() .zip(old_term_row[1..].iter().zip(new_term_row[1..].iter())) @@ -424,24 +455,23 @@ impl EGraph { ColumnTy::Id => self.explain_terms_equal_inner(*lhs, *rhs, state), ColumnTy::Base(_) => { assert_eq!(lhs, rhs, "congruence proof must have equal base values"); - // For an equality proof, we first need an existence proof, which we can get by - // doing a projection from the existence proof of `old_term`. let arg_exists = state.store.intern_term(&TermProof::PProj { pf_f_args_ok: old_term_proof, arg_idx: i, }); - let arg_term = state.store.termdag.proj(old_term, i); + let arg_term = state.store.termdag.proj_id(old_term, i).unwrap(); state.store.refl(arg_exists, arg_term) } }; args_eq_proofs.push(eq_proof); } + let func_name = Arc::from(self.funcs[func_id].name.as_ref()); CongProof { pf_args_eq: args_eq_proofs, pf_f_args_ok: old_term_proof, old_term, new_term, - func: func_id, + func: func_name, } } @@ -456,10 +486,14 @@ impl EGraph { let spec = self.proof_specs[ReasonSpecId::new(reason_row[0].rep())].clone(); match &*spec { ProofReason::Rule(data) => { + debug_assert_eq!( + self.rules[data.rule_id].desc.as_ref(), + data.rule_name.as_ref() + ); let (subst, body_pfs) = self.rule_proof(data, &reason_row[1..], state); let l_term = self.reconstruct_term(l, ColumnTy::Id, state); let r_term = self.reconstruct_term(r, ColumnTy::Id, state); - let rule_name = String::from(&*self.rules[data.rule_id].desc).into(); + let rule_name = Rc::::from(data.rule_name.as_ref()); state.store.intern_eq(&EqProof::PRule { rule_name, subst, @@ -476,6 +510,14 @@ impl EGraph { // NB: we could add this if we wanted to. panic!("fiat reason being used to explain equality, rather than a row's existence") } + ProofReason::Refl => { + let l = self.canonicalize_term_id(l); + let r = self.canonicalize_term_id(r); + assert_eq!(l, r, "refl justification for two non-equal terms"); + let t_ok_pf = self.explain_term_inner(l, state); + let t = self.reconstruct_term(l, ColumnTy::Id, state); + state.store.intern_eq(&EqProof::PRefl { t_ok_pf, t }) + } } } @@ -518,6 +560,7 @@ impl EGraph { } fn get_reason(&mut self, reason_id: Value) -> Vec { + let reason_id = self.canonicalize_reason_id(reason_id); let mut atom = Vec::::new(); let mut cur = 0; loop { diff --git a/egglog-bridge/src/rule.rs b/egglog-bridge/src/rule.rs index acfd08072..755e1f530 100644 --- a/egglog-bridge/src/rule.rs +++ b/egglog-bridge/src/rule.rs @@ -115,6 +115,46 @@ impl Result<()> + Clone + Send + S dyn_clone::clone_trait_object!(Brc); type BuildRuleCallback = Box; +/// The builders for queries in this module essentially wrap the lower-level +/// builders from the `core_relations` crate. A single egglog rule can turn into +/// N core-relations rules. The code is structured by constructing a series of +/// callbacks that will iteratively build up a low-level rule that looks like +/// the high-level rule, passing along an environment that keeps track of the +/// mappings between low and high-level variables. +#[derive(Clone, Default)] +struct RuleCallbacks { + /// A set of callbacks to run prior to running `build_reason`, which wires up proof metadata. + /// Most instructions rely on this proof metadata -- this is currently only used for + /// `query_prim`. + header: Vec, + /// An optional callback to wire up proof-related metadata before running + /// the RHS of a rule. + build_reason: Option, + add_rule: Vec, +} + +impl RuleCallbacks { + fn add_callback(&mut self, cb: BuildRuleCallback) { + self.add_rule.push(cb); + } + + fn add_header_callback(&mut self, cb: BuildRuleCallback) { + self.header.push(cb); + } + + fn add_build_reason(&mut self, cb: BuildRuleCallback) { + self.build_reason = Some(cb); + } + + fn run(&self, inner: &mut Bindings, rb: &mut CoreRuleBuilder) -> Result<()> { + self.header.iter().try_for_each(|f| f(inner, rb))?; + if let Some(build_reason) = &self.build_reason { + build_reason(inner, rb)?; + } + self.add_rule.iter().try_for_each(|f| f(inner, rb)) + } +} + #[derive(Clone)] pub(crate) struct Query { uf_table: TableId, @@ -126,15 +166,7 @@ pub(crate) struct Query { /// The current proofs that are in scope. atom_proofs: Vec, atoms: Vec<(TableId, Vec, SchemaMath)>, - /// An optional callback to wire up proof-related metadata before running the RHS of a rule. - build_reason: Option, - /// The builders for queries in this module essentially wrap the lower-level - /// builders from the `core_relations` crate. A single egglog rule can turn - /// into N core-relations rules. The code is structured by constructing a - /// series of callbacks that will iteratively build up a low-level rule that - /// looks like the high-level rule, passing along an environment that keeps - /// track of the mappings between low and high-level variables. - add_rule: Vec, + callbacks: RuleCallbacks, /// If set, execute a single rule (rather than O(atoms.len()) rules) during /// seminaive, with the given atom as the focus. sole_focus: Option, @@ -167,12 +199,11 @@ impl EGraph { tracing, rule_id, seminaive, - build_reason: None, sole_focus: None, atom_proofs: Default::default(), vars: Default::default(), atoms: Default::default(), - add_rule: Default::default(), + callbacks: Default::default(), plan_strategy: Default::default(), }, } @@ -186,7 +217,11 @@ impl EGraph { impl RuleBuilder<'_> { fn add_callback(&mut self, cb: impl Brc + 'static) { - self.query.add_rule.push(Box::new(cb)); + self.query.callbacks.add_callback(Box::new(cb)); + } + + fn add_header_callback(&mut self, cb: impl Brc + 'static) { + self.query.callbacks.add_header_callback(Box::new(cb)); } /// Access the underlying egraph within the builder. @@ -316,11 +351,13 @@ impl RuleBuilder<'_> { let cb = self .proof_builder .create_reason(syntax.clone(), self.egraph); - self.query.build_reason = Some(Box::new(move |bndgs, rb| { - let reason = cb(bndgs, rb)?; - bndgs.lhs_reason = Some(reason.into()); - Ok(()) - })); + self.query + .callbacks + .add_build_reason(Box::new(move |bndgs, rb| { + let reason = cb(bndgs, rb)?; + bndgs.lhs_reason = Some(reason.into()); + Ok(()) + })); } } let res = self.query.rule_id; @@ -448,12 +485,12 @@ impl RuleBuilder<'_> { let res = self.new_var(ret_ty); // External functions that fail on the RHS of a rule should cause a panic. let panic_fn = self.egraph.new_panic_lazy(panic_msg); - self.query.add_rule.push(Box::new(move |inner, rb| { + self.add_callback(move |inner, rb| { let args = inner.convert_all(&args); let var = rb.call_external_with_fallback(func, &args, panic_fn, &[])?; inner.mapping.insert(res.id, var.into()); Ok(()) - })); + }); res } @@ -506,7 +543,7 @@ impl RuleBuilder<'_> { _ret_ty: ColumnTy, ) -> Result<()> { let entries = entries.to_vec(); - self.query.add_rule.push(Box::new(move |inner, rb| { + self.add_header_callback(move |inner, rb| { let mut dst_vars = inner.convert_all(&entries); let expected = dst_vars.pop().expect("must specify a return value"); let var = rb.call_external(func, &dst_vars)?; @@ -518,7 +555,7 @@ impl RuleBuilder<'_> { _ => rb.assert_eq(var.into(), expected), } Ok(()) - })); + }); Ok(()) } @@ -647,9 +684,13 @@ impl RuleBuilder<'_> { let ts_var = self.new_var(ColumnTy::Id); let mut insert_entries = entries.to_vec(); insert_entries.push(res.clone().into()); - let add_proof = - self.proof_builder - .new_row(func, insert_entries, term_var.id, self.egraph); + let add_proof = self.proof_builder.new_row( + func, + insert_entries, + Some(res.id), + term_var.id, + self.egraph, + ); Box::new(move |inner, rb| { let write_vals = get_write_vals(inner); let dst_vars = inner.convert_all(&entries); @@ -682,10 +723,17 @@ impl RuleBuilder<'_> { inner.mapping.insert(term_var.id, term.into()); inner.mapping.insert(res.id, var.into()); inner.mapping.insert(ts_var.id, ts.into()); - rb.assert_eq(var.into(), term.into()); // The following bookeeping is only needed // if the value is new. That only happens if // the main id equals the term id. + // + // We could add this line, but it could cause the rest of the rule to also + // fail to run. + // + // > rb.assert_eq(var.into(), term.into()); + // + // Instead we leverage the `insert_if_eq` instruction inside of + // `add_proof`. add_proof(inner, rb)?; Ok(()) }) @@ -742,7 +790,7 @@ impl RuleBuilder<'_> { } } }; - self.query.add_rule.push(cb); + self.query.callbacks.add_callback(cb); res } @@ -793,7 +841,7 @@ impl RuleBuilder<'_> { .context("union") }) }; - self.query.add_rule.push(cb); + self.query.callbacks.add_callback(cb); } /// This method is equivalent to `remove(table, before); set(table, after)` @@ -853,7 +901,7 @@ impl RuleBuilder<'_> { func_cols: info.schema.len(), }; - self.query.add_rule.push(Box::new(move |inner, rb| { + self.add_callback(move |inner, rb| { add_proof(inner, rb)?; let mut dst_vars = inner.convert_all(&after); schema_math.write_table_row( @@ -878,7 +926,7 @@ impl RuleBuilder<'_> { ) .context("rebuild_row_uf")?; rb.insert(table, &dst_vars).context("rebuild_row_table") - })); + }); } /// Set the value of a function in the database. @@ -908,34 +956,119 @@ impl RuleBuilder<'_> { func_cols: info.schema.len(), }; if self.egraph.tracing { - let res = self.lookup(func, &entries[0..entries.len() - 1], || { - "lookup failed during proof-enabled set; this is an internal proofs bug".to_string() - }); - let res_entry = res.clone().into(); - self.union(res.into(), entries.last().unwrap().clone()); - if schema_math.subsume { - // Set the original row but with the passed-in subsumption value. - self.add_callback(move |inner, rb| { - let mut dst_vars = inner.convert_all(&entries); - let proof_var = rb.lookup( - table, - &dst_vars[0..schema_math.num_keys()], - ColumnId::from_usize(schema_math.proof_id_col()), - )?; - schema_math.write_table_row( - &mut dst_vars, - RowVals { - timestamp: inner.next_ts(), - proof: Some(proof_var.into()), - subsume: Some(inner.convert(&subsume_entry)), - ret_val: Some(inner.convert(&res_entry)), - }, + match info.default_val { + DefaultVal::FreshId => { + let res = self.lookup(func, &entries[0..entries.len() - 1], || { + "lookup failed during proof-enabled set; this is an internal proofs bug" + .to_string() + }); + let res_entry = res.clone().into(); + self.union(res.into(), entries.last().unwrap().clone()); + if schema_math.subsume + && !matches!( + subsume_entry, + QueryEntry::Const { + val: NOT_SUBSUMED, + .. + }, + ) + { + // Set the original row but with the passed-in subsumption value. + self.add_callback(move |inner, rb| { + let mut dst_vars = inner.convert_all(&entries); + let proof_var = rb.lookup( + table, + &dst_vars[0..schema_math.num_keys()], + ColumnId::from_usize(schema_math.proof_id_col()), + )?; + schema_math.write_table_row( + &mut dst_vars, + RowVals { + timestamp: inner.next_ts(), + proof: Some(proof_var.into()), + subsume: Some(inner.convert(&subsume_entry)), + ret_val: Some(inner.convert(&res_entry)), + }, + ); + rb.insert(table, &dst_vars).context("set") + }); + } + } + DefaultVal::Fail => { + let table = info.table; + let entries = entries.clone(); + let subsume_entry = subsume_entry.clone(); + let subsume_entry_for_write = subsume_entry.clone(); + let id_counter = self.query.id_counter; + let term_var = self.new_var(ColumnTy::Id); + let term_var_id = term_var.id; + let add_proof = self.proof_builder.new_row( + func, + entries.clone(), + None, + term_var_id, + self.egraph, ); - rb.insert(table, &dst_vars).context("set") - }); + let get_write_vals = move |inner: &mut Bindings, ret_val: DstVar| { + let mut write_vals = SmallVec::<[WriteVal; 4]>::new(); + for i in schema_math.num_keys()..schema_math.table_columns() { + write_vals.push(if i == schema_math.ts_col() { + inner.next_ts().into() + } else if i == schema_math.ret_val_col() { + ret_val.into() + } else if i == schema_math.proof_id_col() { + WriteVal::IncCounter(id_counter) + } else if schema_math.subsume && i == schema_math.subsume_col() { + inner.convert(&subsume_entry_for_write).into() + } else { + unreachable!() + }); + } + write_vals + }; + let uf_table = self.egraph.uf_table; + self.add_callback(move |inner, rb| { + let mut dst_vars = inner.convert_all(&entries); + let ret_val = dst_vars[schema_math.ret_val_col()]; + let write_vals = get_write_vals(inner, ret_val); + let (key_slice, _) = dst_vars.split_at(schema_math.num_keys()); + let term = rb + .lookup_or_insert( + table, + key_slice, + &write_vals, + ColumnId::from_usize(schema_math.proof_id_col()), + ) + .context("set proof lookup")?; + inner.mapping.insert(term_var_id, term.into()); + + let reason_var = add_proof(inner, rb)?; + // If the term we are creating is different from the id we are setting it + // to, use this rule as the reason why the term and the id are equal. + rb.insert_if_ne( + uf_table, + ret_val, + term.into(), + &[ret_val, term.into(), inner.next_ts(), reason_var], + )?; + schema_math.write_table_row( + &mut dst_vars, + RowVals { + timestamp: inner.next_ts(), + proof: Some(term.into()), + subsume: schema_math.subsume.then(|| inner.convert(&subsume_entry)), + ret_val: None, + }, + ); + rb.insert(table, &dst_vars).context("set") + }); + } + DefaultVal::Const(_) => { + panic!("unsupported constant default value when proofs are enabled") + } } } else { - self.query.add_rule.push(Box::new(move |inner, rb| { + self.add_callback(move |inner, rb| { let mut dst_vars = inner.convert_all(&entries); schema_math.write_table_row( &mut dst_vars, @@ -947,7 +1080,7 @@ impl RuleBuilder<'_> { }, ); rb.insert(table, &dst_vars).context("set") - })); + }); }; } @@ -959,7 +1092,7 @@ impl RuleBuilder<'_> { let dst_vars = inner.convert_all(&entries); rb.remove(table, &dst_vars).context("remove") }); - self.query.add_rule.push(cb); + self.query.callbacks.add_callback(cb); } /// Panic with a given message. @@ -967,11 +1100,11 @@ impl RuleBuilder<'_> { let panic = self.egraph.new_panic(message.clone()); let ret_ty = ColumnTy::Id; let res = self.new_var(ret_ty); - self.query.add_rule.push(Box::new(move |inner, rb| { + self.add_callback(move |inner, rb| { let var = rb.call_external(panic, &[])?; inner.mapping.insert(res.id, var.into()); Ok(()) - })); + }); } } @@ -1007,13 +1140,7 @@ impl Query { ) -> Result { let mut rb = qb.build(); inner.next_ts = Some(rb.read_counter(self.ts_counter).into()); - // Set up proof state if it's configured. - if let Some(build_reason) = &self.build_reason { - build_reason(&mut inner, &mut rb)?; - } - self.add_rule - .iter() - .try_for_each(|f| f(&mut inner, &mut rb))?; + self.callbacks.run(&mut inner, &mut rb)?; Ok(rb.build_with_description(desc)) } diff --git a/egglog-bridge/src/syntax.rs b/egglog-bridge/src/syntax.rs index 48d313a9a..df86d589d 100644 --- a/egglog-bridge/src/syntax.rs +++ b/egglog-bridge/src/syntax.rs @@ -12,11 +12,11 @@ use crate::core_relations::{ WriteVal, make_external_func, }; use crate::numeric_id::{DenseIdMap, IdVec, NumericId, define_id}; -use crate::{EGraph, NOT_SUBSUMED, ProofReason, QueryEntry, ReasonSpecId, Result, SchemaMath}; +use crate::{ColumnTy, EGraph, ProofReason, QueryEntry, ReasonSpecId, Result, SchemaMath}; use smallvec::SmallVec; use crate::{ - ColumnTy, FunctionId, RuleId, + FunctionId, RuleId, proof_spec::ProofBuilder, rule::{AtomId, Bindings, VariableId}, }; @@ -41,7 +41,7 @@ pub enum SourceExpr { Var { id: VariableId, ty: ColumnTy, - name: String, + name: Arc, }, /// A call to an external (aka primitive) function. ExternalCall { @@ -50,6 +50,7 @@ pub enum SourceExpr { var: VariableId, ty: ColumnTy, func: ExternalFunctionId, + name: Arc, args: Vec, }, /// A query of an egglog-level function (i.e. a table). @@ -64,12 +65,19 @@ pub enum SourceExpr { }, } +#[derive(Debug, Clone)] +pub(crate) struct SourceVar { + pub id: VariableId, + pub ty: ColumnTy, + pub name: Arc, +} + /// A data-structure representing an egglog query. Essentially, multiple [`SourceExpr`]s, one per /// line, along with a backing store accounting for subterms indexed by [`SyntaxId`]. #[derive(Debug, Clone, Default)] pub struct SourceSyntax { pub(crate) backing: IdVec, - pub(crate) vars: Vec<(VariableId, ColumnTy)>, + pub(crate) vars: Vec, pub(crate) roots: Vec, } @@ -81,8 +89,16 @@ impl SourceSyntax { pub fn add_expr(&mut self, expr: SourceExpr) -> SyntaxId { match &expr { SourceExpr::Const { .. } | SourceExpr::FunctionCall { .. } => {} - SourceExpr::Var { id, ty, .. } => self.vars.push((*id, *ty)), - SourceExpr::ExternalCall { var, ty, .. } => self.vars.push((*var, *ty)), + SourceExpr::Var { id, ty, name } => self.vars.push(SourceVar { + id: *id, + ty: *ty, + name: Arc::clone(name), + }), + SourceExpr::ExternalCall { var, ty, name, .. } => self.vars.push(SourceVar { + id: *var, + ty: *ty, + name: Arc::clone(name), + }), }; self.backing.push(expr) } @@ -108,6 +124,7 @@ impl SourceSyntax { #[derive(Debug)] pub(crate) struct RuleData { pub(crate) rule_id: RuleId, + pub(crate) rule_name: Box, pub(crate) syntax: SourceSyntax, } @@ -144,6 +161,7 @@ impl ProofBuilder { let reason_spec = Arc::new(ProofReason::Rule(RuleData { rule_id: self.rule_id, + rule_name: Box::::from(&*self.rule_description), syntax: syntax.clone(), })); let reason_table = egraph.reason_table(&reason_spec); @@ -173,8 +191,8 @@ impl ProofBuilder { // the base substitution of variables into a reason table. let mut row = SmallVec::<[core_relations::QueryEntry; 8]>::new(); row.push(Value::new(reason_spec_id.rep()).into()); - for (var, _) in &syntax.vars { - row.push(bndgs.mapping[*var]); + for SourceVar { id, .. } in &syntax.vars { + row.push(bndgs.mapping[*id]); } Ok(rb.lookup_or_insert( reason_table, @@ -195,14 +213,13 @@ impl ProofBuilder { }; let cong_args = CongArgs { func_table: func, - func_underlying, - schema_math, reason_table: egraph.reason_table(&ProofReason::CongRow), term_table: egraph.term_table(func_underlying), reason_counter: egraph.reason_counter, term_counter: egraph.id_counter, - ts_counter: egraph.timestamp_counter, reason_spec_id: egraph.cong_spec, + ts_counter: egraph.timestamp_counter, + uf_table: egraph.uf_table, }; let build_term = egraph.register_external_func(make_external_func(move |es, vals| { cong_term(&cong_args, es, vals) @@ -285,10 +302,6 @@ impl TermReconstructionState<'_> { struct CongArgs { /// The function that we are applying congruence to. func_table: FunctionId, - /// The undcerlying `core_relations` table that this function corresponds to. - func_underlying: TableId, - /// Schema-related offset information needed for writing to the table. - schema_math: SchemaMath, /// The table that will hold the reason justifying the new term, if we need to insert one. reason_table: TableId, /// The table that will hold the new term, if we need to insert one. @@ -297,10 +310,12 @@ struct CongArgs { reason_counter: CounterId, /// The counter that will be incremented when we insert a new term. term_counter: CounterId, - /// The counter that will be used to read the current timestamp for the new row. - ts_counter: CounterId, /// The specification (or schema) for the reason we are writing (congruence, in this case). reason_spec_id: ReasonSpecId, + /// The counter that will be used to read the current timestamp for the new row. + ts_counter: CounterId, + /// The union-find, used to record equality between existing e-class ids and new terms. + uf_table: TableId, } fn cong_term(args: &CongArgs, es: &mut ExecutionState, vals: &[Value]) -> Option { @@ -326,14 +341,9 @@ fn cong_term(args: &CongArgs, es: &mut ExecutionState, vals: &[Value]) -> Option ColumnId::from_usize(term_row.len()), ); - // We should be able to do a raw insert at this point. All conflicting inserts will have the - // same term value, and this function only gets called when a lookup fails. - + // We just created a new term that wasn't previously inserted into the e-graph. We want to + // ensure that this is equal to the existing e-class that this term is in (by congruence). let ts = Value::from_usize(es.read_counter(args.ts_counter)); - term_row.resize(args.schema_math.table_columns(), NOT_SUBSUMED); - term_row[args.schema_math.ret_val_col()] = term_val; - term_row[args.schema_math.proof_id_col()] = term_val; - term_row[args.schema_math.ts_col()] = ts; - es.stage_insert(args.func_underlying, &term_row); + es.stage_insert(args.uf_table, &[old_term, term_val, ts, reason]); Some(term_val) } diff --git a/src/termdag.rs b/egglog-bridge/src/termdag.rs similarity index 50% rename from src/termdag.rs rename to egglog-bridge/src/termdag.rs index ee22c652e..11049ad20 100644 --- a/src/termdag.rs +++ b/egglog-bridge/src/termdag.rs @@ -1,5 +1,12 @@ +use egglog_ast::{ + generic_ast::{Expr, GenericExpr, Literal}, + span::Span, +}; + use crate::*; -use std::fmt::Write; +use hashbrown::HashMap; +use indexmap::IndexSet; +use std::{fmt::Write, io}; pub type TermId = usize; @@ -11,6 +18,9 @@ pub type TermId = usize; #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum Term { Lit(Literal), + /// This is a placeholder, used to represent terms that are backed by a base type that we + /// cannot model in the source / AST language. + UnknownLit, Var(String), App(String, Vec), } @@ -35,6 +45,80 @@ macro_rules! match_term_app { } } +#[derive(Clone, Debug)] +pub struct PrettyPrintConfig { + pub line_width: usize, + pub indent_size: usize, +} + +impl Default for PrettyPrintConfig { + fn default() -> Self { + Self { + line_width: 512, + indent_size: 4, + } + } +} + +pub(crate) struct PrettyPrinter<'w, W: io::Write> { + writer: &'w mut W, + config: &'w PrettyPrintConfig, + current_indent: usize, + current_line_pos: usize, +} + +impl<'w, W: io::Write> PrettyPrinter<'w, W> { + pub(crate) fn new(writer: &'w mut W, config: &'w PrettyPrintConfig) -> Self { + Self { + writer, + config, + current_indent: 0, + current_line_pos: 0, + } + } + + pub(crate) fn write_str(&mut self, s: &str) -> io::Result<()> { + write!(self.writer, "{s}")?; + self.current_line_pos += s.len(); + Ok(()) + } + + pub(crate) fn newline(&mut self) -> io::Result<()> { + writeln!(self.writer)?; + self.current_line_pos = 0; + self.write_indent()?; + Ok(()) + } + + pub(crate) fn write_indent(&mut self) -> io::Result<()> { + for _ in 0..self.current_indent { + write!(self.writer, " ")?; + } + self.current_line_pos = self.current_indent; + Ok(()) + } + + pub(crate) fn increase_indent(&mut self) { + self.current_indent += self.config.indent_size; + } + + pub(crate) fn decrease_indent(&mut self) { + self.current_indent = self.current_indent.saturating_sub(self.config.indent_size); + } + + pub(crate) fn should_break(&self, additional_chars: usize) -> bool { + self.current_line_pos + additional_chars > self.config.line_width + } + + pub(crate) fn write_with_break(&mut self, s: &str) -> io::Result<()> { + if self.should_break(s.len()) && self.current_line_pos > self.current_indent { + self.newline()?; + self.write_indent()?; + } + self.write_str(s) + } +} + impl TermDag { /// Returns the number of nodes in this DAG. pub fn size(&self) -> usize { @@ -67,6 +151,16 @@ impl TermDag { node } + /// Like [`TermDag::app`], but expects and returns [`TermId`]s instead of + /// [`Term`]s. + pub fn app_id(&mut self, sym: String, children: Vec) -> TermId { + let node = Term::App(sym, children); + + self.add_node(&node); + + self.lookup(&node) + } + /// Make and return a [`Term::Lit`] with the given literal, and insert into /// the DAG if it is not already present. pub fn lit(&mut self, lit: Literal) -> Term { @@ -77,6 +171,19 @@ impl TermDag { node } + pub fn unknown_lit(&mut self) -> Term { + let node = Term::UnknownLit; + self.add_node(&node); + node + } + + /// Like [`TermDag::lit`], but returns a [`TermId`] instead of a [`Term`]. + pub fn lit_id(&mut self, lit: Literal) -> TermId { + let node = Term::Lit(lit); + self.add_node(&node); + self.lookup(&node) + } + /// Make and return a [`Term::Var`] with the given symbol, and insert into /// the DAG if it is not already present. pub fn var(&mut self, sym: String) -> Term { @@ -87,6 +194,12 @@ impl TermDag { node } + pub fn var_id(&mut self, sym: String) -> TermId { + let node = Term::Var(sym); + self.add_node(&node); + self.lookup(&node) + } + fn add_node(&mut self, node: &Term) { if self.nodes.get(node).is_none() { self.nodes.insert(node.clone()); @@ -119,9 +232,11 @@ impl TermDag { /// Recursively converts the given term to an expression. /// - /// Panics if the term contains subterms that are not in the DAG. + /// Panics if the term contains subterms that are not in the DAG or cannot be represented by + /// the current syntax. pub fn term_to_expr(&self, term: &Term, span: Span) -> Expr { match term { + Term::UnknownLit => panic!("unknown base value"), Term::Lit(lit) => Expr::Lit(span, lit.clone()), Term::Var(v) => Expr::Var(span, v.clone()), Term::App(op, args) => { @@ -175,6 +290,10 @@ impl TermDag { start_index = Some(result.len()); write!(&mut result, "{v}").unwrap(); } + Term::UnknownLit => { + start_index = Some(result.len()); + write!(&mut result, "").unwrap(); + } } if let Some(start_index) = start_index { @@ -184,87 +303,73 @@ impl TermDag { result } -} -#[cfg(test)] -mod tests { - use super::*; - use crate::{ast::*, span}; - - fn parse_term(s: &str) -> (TermDag, Term) { - let e = Parser::default().get_expr_from_string(None, s).unwrap(); - let mut td = TermDag::default(); - let t = td.expr_to_term(&e); - (td, t) + /// Pretty-print the given term to a string. + pub fn to_string_pretty(&self, term: &Term) -> String { + let mut buf = Vec::new(); + self.print_term_pretty(term, &PrettyPrintConfig::default(), &mut buf) + .expect("pretty printing term failed"); + String::from_utf8(buf).expect("pretty printer emitted invalid UTF-8") } - #[test] - fn test_to_from_expr() { - let s = r#"(f (g x y) x y (g x y))"#; - let e = Parser::default().get_expr_from_string(None, s).unwrap(); - let mut td = TermDag::default(); - assert_eq!(td.size(), 0); - let t = td.expr_to_term(&e); - assert_eq!(td.size(), 4); - // the expression above has 4 distinct subterms. - // in left-to-right, depth-first order, they are: - // x, y, (g x y), and the root call to f - // so we can compute expected answer by hand: - assert_eq!( - td.nodes.as_slice().iter().cloned().collect::>(), - vec![ - Term::Var("x".into()), - Term::Var("y".into()), - Term::App("g".into(), vec![0, 1]), - Term::App("f".into(), vec![2, 0, 1, 2]), - ] - ); - // This is tested using string equality because e1 and e2 have different - let e2 = td.term_to_expr(&t, span!()); - // annotations. A better way to test this would be to implement a map_ann - // function for GenericExpr. - assert_eq!(format!("{e}"), format!("{e2}")); // roundtrip + /// Pretty-print the given term to a string by term id. + pub fn to_string_pretty_id(&self, term: TermId) -> String { + self.to_string_pretty(self.get(term)) } - #[test] - fn test_match_term_app() { - let s = r#"(f (g x y) x y (g x y))"#; - let (td, t) = parse_term(s); - match_term_app!(t; { - ("f", [_, x, _, _]) => { - let span = span!(); - assert_eq!( - td.term_to_expr(td.get(*x), span.clone()), - crate::ast::GenericExpr::Var(span, "x".to_owned()) - ) - } - (head, _) => panic!("unexpected head {}, in {}:{}:{}", head, file!(), line!(), column!()) - }) + /// Print the term with pretty-printing configuration. + pub fn print_term_pretty( + &self, + term: &Term, + config: &PrettyPrintConfig, + writer: &mut impl io::Write, + ) -> io::Result<()> { + let mut printer = PrettyPrinter::new(writer, config); + self.print_term_with_printer(term, &mut printer) } - #[test] - fn test_to_string() { - let s = r#"(f (g x y) x y (g x y))"#; - let (td, t) = parse_term(s); - assert_eq!(td.to_string(&t), s); + pub(crate) fn print_term_with_printer( + &self, + term: &Term, + printer: &mut PrettyPrinter, + ) -> io::Result<()> { + match term { + Term::Lit(lit) => { + printer.write_str(&format!("{lit}"))?; + } + Term::UnknownLit => { + printer.write_str("(unsupported-base-val)")?; + } + Term::Var(v) => { + printer.write_str(v)?; + } + Term::App(head, args) => { + printer.write_str(&format!("({head}"))?; + if !args.is_empty() { + printer.increase_indent(); + for arg in args.iter() { + printer.write_with_break(" ")?; + self.print_term_with_printer(self.get(*arg), printer)?; + } + printer.decrease_indent(); + } + printer.write_str(")")?; + } + } + Ok(()) } - #[test] - fn test_lookup() { - let s = r#"(f (g x y) x y (g x y))"#; - let (td, t) = parse_term(s); - assert_eq!(td.lookup(&t), td.size() - 1); + /// Project a particular argument of a term by index. + /// Returns None if the term is not an application or the index is out of bounds. + pub fn proj(&self, term: &Term, arg_idx: usize) -> Option { + match term { + Term::App(_hd, args) => args.get(arg_idx).copied(), + _ => None, + } } - #[test] - fn test_app_var_lit() { - let s = r#"(f (g x y) x 7 (g x y))"#; - let (mut td, t) = parse_term(s); - let x = td.var("x".into()); - let y = td.var("y".into()); - let seven = td.lit(7.into()); - let g = td.app("g".into(), vec![x.clone(), y.clone()]); - let t2 = td.app("f".into(), vec![g.clone(), x, seven, g]); - assert_eq!(t, t2); + /// Project a particular argument of a term by index, given the term's id. + pub fn proj_id(&self, term: TermId, arg_idx: usize) -> Option { + self.proj(self.get(term), arg_idx) } } diff --git a/egglog-bridge/src/tests.rs b/egglog-bridge/src/tests.rs index 5c6a33d4b..382731079 100644 --- a/egglog-bridge/src/tests.rs +++ b/egglog-bridge/src/tests.rs @@ -19,7 +19,7 @@ use num_rational::Rational64; use crate::{ ColumnTy, DefaultVal, EGraph, FunctionConfig, FunctionId, MergeFn, ProofStore, QueryEntry, - add_expressions, define_rule, + add_expressions, define_rule, termdag::PrettyPrintConfig, }; /// Run a simple associativity/commutativity test. In addition to testing that the rules properly @@ -43,6 +43,7 @@ fn ac_test(tracing: bool, can_subsume: bool) { merge: MergeFn::UnionId, name: "num".into(), can_subsume, + fiat_reason_only: None, }); let add_table = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id; 3], @@ -50,6 +51,7 @@ fn ac_test(tracing: bool, can_subsume: bool) { merge: MergeFn::UnionId, name: "add".into(), can_subsume, + fiat_reason_only: None, }); let add_comm = define_rule! { @@ -150,6 +152,7 @@ fn ac_fail() { merge: MergeFn::UnionId, name: "num".into(), can_subsume: false, + fiat_reason_only: None, }); let add_table = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id; 3], @@ -157,6 +160,7 @@ fn ac_fail() { merge: MergeFn::UnionId, name: "add".into(), can_subsume: false, + fiat_reason_only: None, }); let add_comm = define_rule! { @@ -255,6 +259,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "diff".into(), can_subsume, + fiat_reason_only: None, }); let integral = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id, ColumnTy::Id, ColumnTy::Id], @@ -262,6 +267,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "integral".into(), can_subsume, + fiat_reason_only: None, }); let add = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id, ColumnTy::Id, ColumnTy::Id], @@ -269,6 +275,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "add".into(), can_subsume, + fiat_reason_only: None, }); let sub = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id, ColumnTy::Id, ColumnTy::Id], @@ -276,6 +283,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "sub".into(), can_subsume, + fiat_reason_only: None, }); let mul = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id, ColumnTy::Id, ColumnTy::Id], @@ -283,6 +291,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "mul".into(), can_subsume, + fiat_reason_only: None, }); let div = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id, ColumnTy::Id, ColumnTy::Id], @@ -290,6 +299,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "div".into(), can_subsume, + fiat_reason_only: None, }); let pow = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id, ColumnTy::Id, ColumnTy::Id], @@ -297,6 +307,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "pow".into(), can_subsume, + fiat_reason_only: None, }); let ln = egraph.add_table(FunctionConfig { @@ -305,6 +316,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "ln".into(), can_subsume, + fiat_reason_only: None, }); let sqrt = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id, ColumnTy::Id], @@ -312,6 +324,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "sqrt".into(), can_subsume, + fiat_reason_only: None, }); let sin = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id, ColumnTy::Id], @@ -319,6 +332,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "sin".into(), can_subsume, + fiat_reason_only: None, }); let cos = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id, ColumnTy::Id], @@ -326,6 +340,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "cos".into(), can_subsume, + fiat_reason_only: None, }); let rat = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Base(rational_ty), ColumnTy::Id], @@ -333,6 +348,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "rat".into(), can_subsume, + fiat_reason_only: None, }); let var = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Base(string_ty), ColumnTy::Id], @@ -340,6 +356,7 @@ fn math_test(mut egraph: EGraph, can_subsume: bool) { merge: MergeFn::UnionId, name: "var".into(), can_subsume, + fiat_reason_only: None, }); let zero = egraph.base_value_constant(Rational64::new(0, 1)); @@ -599,6 +616,7 @@ fn container_test() { merge: MergeFn::UnionId, name: "num".into(), can_subsume: false, + fiat_reason_only: None, }); let add_table = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id; 3], @@ -606,6 +624,7 @@ fn container_test() { merge: MergeFn::UnionId, name: "add".into(), can_subsume: false, + fiat_reason_only: None, }); let vec_table = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Id; 2], @@ -613,6 +632,7 @@ fn container_test() { merge: MergeFn::UnionId, name: "vec".into(), can_subsume: false, + fiat_reason_only: None, }); let int_add = egraph.register_external_func(make_external_func(|exec_state, args| { let [x, y] = args else { panic!() }; @@ -785,6 +805,7 @@ fn rhs_only_rule() { merge: MergeFn::UnionId, name: "num".into(), can_subsume: false, + fiat_reason_only: None, }); let add_data = { let zero = egraph.base_value_constant(0i64); @@ -878,6 +899,7 @@ fn mergefn_arithmetic() { ), name: "f".into(), can_subsume: false, + fiat_reason_only: None, }); let value_0 = egraph.base_value_constant(0i64); @@ -968,6 +990,7 @@ fn mergefn_nested_function() { merge: MergeFn::UnionId, name: "g".into(), can_subsume: true, + fiat_reason_only: None, }); // Create a function f whose merge function is (g (g new new) (g old old)) @@ -984,6 +1007,7 @@ fn mergefn_nested_function() { ), name: "f".into(), can_subsume: true, + fiat_reason_only: None, }); let value_1 = egraph.base_value_constant(1i64); @@ -1090,6 +1114,7 @@ fn constrain_prims_simple() { merge: MergeFn::UnionId, name: "f".into(), can_subsume: false, + fiat_reason_only: None, }); let g_table = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Base(int_base), ColumnTy::Id], @@ -1097,6 +1122,7 @@ fn constrain_prims_simple() { merge: MergeFn::UnionId, name: "g".into(), can_subsume: false, + fiat_reason_only: None, }); let is_even = egraph.register_external_func(core_relations::make_external_func( @@ -1173,6 +1199,7 @@ fn constrain_prims_abstract() { merge: MergeFn::UnionId, name: "f".into(), can_subsume: false, + fiat_reason_only: None, }); let g_table = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Base(int_base), ColumnTy::Id], @@ -1180,6 +1207,7 @@ fn constrain_prims_abstract() { merge: MergeFn::UnionId, name: "g".into(), can_subsume: false, + fiat_reason_only: None, }); let neg = egraph.register_external_func(core_relations::make_external_func( @@ -1270,6 +1298,7 @@ fn basic_subsumption() { merge: MergeFn::UnionId, name: "f".into(), can_subsume: true, + fiat_reason_only: None, }); let g_table = egraph.add_table(FunctionConfig { schema: vec![ColumnTy::Base(int_base), ColumnTy::Id], @@ -1277,6 +1306,7 @@ fn basic_subsumption() { merge: MergeFn::UnionId, name: "g".into(), can_subsume: false, + fiat_reason_only: None, }); let value_1 = egraph.base_value_constant(1i64); @@ -1349,6 +1379,7 @@ fn lookup_failure_panics() { merge: MergeFn::UnionId, name: "test".into(), can_subsume: false, + fiat_reason_only: None, }); let to_entry = |val: u32| QueryEntry::Const { @@ -1441,6 +1472,7 @@ fn test_simple_rule_proof_format() { merge: MergeFn::UnionId, name: "bool".into(), can_subsume: false, + fiat_reason_only: None, }); // Add table for not function let not_table = egraph.add_table(FunctionConfig { @@ -1449,6 +1481,7 @@ fn test_simple_rule_proof_format() { merge: MergeFn::UnionId, name: "not".into(), can_subsume: false, + fiat_reason_only: None, }); // Add true/false wrapped terms let true_id = egraph.add_term(bool_table, &[true_val], "true"); @@ -1468,11 +1501,75 @@ fn test_simple_rule_proof_format() { egraph.run_rules(&[not_true_rule, not_false_rule]).unwrap(); // Get proof for not_true = false let mut proof_store = ProofStore::default(); - egraph + let _eq_pf_id = egraph .explain_terms_equal(not_true_id, false_id, &mut proof_store) .unwrap(); } +#[test] +fn fiat_reason_proof_is_shallow() { + let mut egraph = EGraph::with_tracing(); + let int_ty = egraph.base_values_mut().register_type::(); + let num_table = egraph.add_table(FunctionConfig { + schema: vec![ColumnTy::Base(int_ty), ColumnTy::Id], + default: DefaultVal::FreshId, + merge: MergeFn::UnionId, + name: "num".into(), + can_subsume: false, + fiat_reason_only: None, + }); + let add_table = egraph.add_table(FunctionConfig { + schema: vec![ColumnTy::Id; 3], + default: DefaultVal::FreshId, + merge: MergeFn::UnionId, + name: "add".into(), + can_subsume: false, + fiat_reason_only: Some("fiat add".to_string()), + }); + + let one_val = egraph.base_values_mut().get(1i64); + let two_val = egraph.base_values_mut().get(2i64); + let _one = egraph.add_term(num_table, &[one_val], "one"); + let _two = egraph.add_term(num_table, &[two_val], "two"); + + let add_rule = define_rule! { + [egraph] + ((-> (num_table x) id_x) (-> (num_table y) id_y)) + => ((set (add_table id_x id_y) id_x)) + }; + egraph.run_rules(&[add_rule]).unwrap(); + + // Grab any row from the fiat-only table and explain it. + let mut row = Vec::new(); + let mut add_id = None; + egraph.for_each(add_table, |func_row| { + row.clear(); + row.extend_from_slice(func_row.vals); + add_id = egraph.lookup_id(add_table, &row[0..row.len() - 1]); + }); + let add_id = add_id.expect("expected at least one add row"); + + let mut proof_store = ProofStore::default(); + let term_pf = egraph.explain_term(add_id, &mut proof_store).unwrap(); + let mut buf = Vec::new(); + proof_store + .print_term_proof_pretty(term_pf, &PrettyPrintConfig::default(), &mut buf) + .unwrap(); + let proof_str = String::from_utf8(buf).unwrap(); + assert!( + proof_str.contains("PFiat"), + "fiat-only table should yield fiat proof: {proof_str}" + ); + assert!( + !proof_str.contains("PRule"), + "fiat-only table proof should be shallow: {proof_str}" + ); + assert!( + !proof_str.contains("PCong"), + "fiat-only table proof should be shallow: {proof_str}" + ); +} + const _: () = { const fn assert_send() {} assert_send::() diff --git a/numeric-id/src/lib.rs b/numeric-id/src/lib.rs index f64aeb5b9..42a895ef4 100644 --- a/numeric-id/src/lib.rs +++ b/numeric-id/src/lib.rs @@ -433,38 +433,12 @@ macro_rules! atomic_of { macro_rules! define_id { ($v:vis $name:ident, $repr:tt) => { define_id!($v, $name, $repr, ""); }; ($v:vis $name:ident, $repr:tt, $doc:tt) => { - #[derive(Copy, Clone)] + #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[doc = $doc] $v struct $name { rep: $repr, } - impl PartialEq for $name { - fn eq(&self, other: &Self) -> bool { - self.rep == other.rep - } - } - - impl Eq for $name {} - - impl PartialOrd for $name { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } - } - - impl Ord for $name { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.rep.cmp(&other.rep) - } - } - - impl std::hash::Hash for $name { - fn hash(&self, state: &mut H) { - self.rep.hash(state); - } - } - impl $name { #[allow(unused)] $v const fn new_const(id: $repr) -> Self { diff --git a/numeric-id/src/tests.rs b/numeric-id/src/tests.rs index 6e086d887..a5d0bb231 100644 --- a/numeric-id/src/tests.rs +++ b/numeric-id/src/tests.rs @@ -4,6 +4,7 @@ define_id!(pub(crate) Id, u32, "a unique id"); #[test] #[should_panic] +#[allow(arithmetic_overflow)] fn id_out_of_bounds() { Id::from_usize(u32::MAX as usize + 5); } diff --git a/src/ast/check_shadowing.rs b/src/ast/check_shadowing.rs index 225f9fac9..37b3823ce 100644 --- a/src/ast/check_shadowing.rs +++ b/src/ast/check_shadowing.rs @@ -1,3 +1,4 @@ +use crate::ast::expr::ResolvedExpr; use crate::{util::HashMap, *}; #[derive(Clone, Debug, Default)] diff --git a/src/ast/expr.rs b/src/ast/expr.rs index 6fc1cb582..ca596f0eb 100644 --- a/src/ast/expr.rs +++ b/src/ast/expr.rs @@ -6,6 +6,7 @@ use std::hash::Hasher; use crate::ast::CorrespondingVar; use crate::core::ResolvedCall; use crate::{ArcSort, sort}; +pub use egglog_ast::generic_ast::Expr; #[derive(Debug, Clone)] pub struct ResolvedVar { @@ -43,7 +44,6 @@ impl Display for ResolvedVar { } } -pub type Expr = GenericExpr; /// A generated expression is an expression that is generated by the system /// and does not have annotations. pub type ResolvedExpr = GenericExpr; @@ -98,3 +98,9 @@ macro_rules! var { // Rust macro annoyance; see stackoverflow.com/questions/26731243/how-do-i-use-a-macro-across-module-files pub use {call, lit, var}; + +impl ResolvedVar { + pub fn name(&self) -> &str { + &self.name + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index da32af15b..874416d97 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,12 +1,13 @@ pub mod check_shadowing; pub mod desugar; -mod expr; +pub mod expr; mod parse; pub mod remove_globals; use crate::core::{ GenericAtom, GenericAtomTerm, GenericExprExt, HeadOrEq, Query, ResolvedCall, ResolvedCoreRule, }; +use crate::util::IndexMap; use crate::util::sanitize_internal_name; use crate::*; pub use egglog_ast::generic_ast::{ @@ -1053,6 +1054,35 @@ pub(crate) type MappedFact = GenericFact(pub Vec>); +pub(crate) fn collect_query_vars(facts: &[ResolvedFact]) -> Vec<(ResolvedVar, ArcSort)> { + let mut vars: IndexMap = IndexMap::default(); + for fact in facts { + match fact { + ResolvedFact::Fact(expr) => collect_expr_vars(expr, &mut vars), + ResolvedFact::Eq(_, lhs, rhs) => { + collect_expr_vars(lhs, &mut vars); + collect_expr_vars(rhs, &mut vars); + } + } + } + vars.into_iter().map(|(_, entry)| entry).collect() +} + +fn collect_expr_vars(expr: &ResolvedExpr, out: &mut IndexMap) { + match expr { + ResolvedExpr::Var(_, var) => { + out.entry(var.name.clone()) + .or_insert_with(|| (var.clone(), var.sort.clone())); + } + ResolvedExpr::Call(_, _, args) => { + for arg in args { + collect_expr_vars(arg, out); + } + } + ResolvedExpr::Lit(_, _) => {} + } +} + impl Facts where Head: Clone + Display, @@ -1103,6 +1133,51 @@ where } } +/// This is a variant of [`CorrespondingVar`] that tracks `Leaf`s specifically. It is meant to +/// preserve the original variable names for atoms after canonicalization. +#[derive(Clone, Debug)] +pub struct CanonicalizedVar +where + Leaf: Clone + PartialEq + Eq + Hash, +{ + /// The actual variable used in the query. + pub var: Leaf, + /// The original variable used in this position, prior to canonicalization. + pub orig: Leaf, +} + +impl CanonicalizedVar +where + Leaf: Clone + PartialEq + Eq + Hash, +{ + pub fn new_current(var: Leaf) -> Self { + Self { + orig: var.clone(), + var, + } + } +} + +impl PartialEq for CanonicalizedVar +where + Leaf: Clone + PartialEq + Eq + Hash, +{ + fn eq(&self, other: &Self) -> bool { + self.var == other.var + } +} + +impl Eq for CanonicalizedVar where Leaf: Clone + PartialEq + Eq + Hash {} + +impl Hash for CanonicalizedVar +where + Leaf: Clone + PartialEq + Eq + Hash, +{ + fn hash(&self, state: &mut H) { + self.var.hash(state); + } +} + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct CorrespondingVar where diff --git a/src/ast/parse.rs b/src/ast/parse.rs index bf739c386..43bcadd63 100644 --- a/src/ast/parse.rs +++ b/src/ast/parse.rs @@ -1064,6 +1064,8 @@ pub(crate) fn all_sexps(mut ctx: SexpParser) -> Result, ParseError> { #[cfg(test)] mod tests { + use egglog_bridge::match_term_app; + use super::*; #[test] @@ -1073,13 +1075,17 @@ mod tests { assert_eq!(format!("{}", e), s); } + use std::path::{Path, PathBuf}; + #[test] - #[rustfmt::skip] fn rust_span_display() { - let actual = format!("{}", span!()).replace('\\', "/"); - assert!(actual.starts_with("At ")); - assert!(actual.contains(":")); - assert!(actual.ends_with("src/ast/parse.rs")); + // non-platform specific path construction + let expected_path: PathBuf = Path::new("src").join("ast").join("parse.rs"); + + assert_eq!( + format!("{}", span!()), + format!("At {}:27 of {}", line!() - 1, expected_path.display()) + ); } #[test] @@ -1098,4 +1104,86 @@ mod tests { let e = parser.get_expr_from_string(None, s).unwrap(); assert_eq!(format!("{}", e), t); } + + fn parse_term(s: &str) -> (TermDag, Term) { + let e = Parser::default().get_expr_from_string(None, s).unwrap(); + let mut td = TermDag::default(); + let t = td.expr_to_term(&e); + (td, t) + } + + #[test] + fn test_to_from_expr() { + let s = r#"(f (g x y) x y (g x y))"#; + let e = Parser::default().get_expr_from_string(None, s).unwrap(); + let mut td = TermDag::default(); + assert_eq!(td.size(), 0); + let t = td.expr_to_term(&e); + assert_eq!(td.size(), 4); + // the expression above has 4 distinct subterms. + // in left-to-right, depth-first order, they are: + // x, y, (g x y), and the root call to f + // so we can compute expected answer by hand: + let mut td2 = TermDag::default(); + td2.var("x".into()); + td2.var("y".into()); + td2.app("g".into(), vec![td2.get(0).clone(), td2.get(1).clone()]); + td2.app( + "f".into(), + vec![ + td2.get(2).clone(), + td2.get(0).clone(), + td2.get(1).clone(), + td2.get(2).clone(), + ], + ); + assert_eq!(td, td2); + // This is tested using string equality because e1 and e2 have different + let e2 = td.term_to_expr(&t, span!()); + // annotations. A better way to test this would be to implement a map_ann + // function for GenericExpr. + assert_eq!(format!("{e}"), format!("{e2}")); // roundtrip + } + + #[test] + fn test_match_term_app() { + let s = r#"(f (g x y) x y (g x y))"#; + let (td, t) = parse_term(s); + match_term_app!(t; { + ("f", [_, x, _, _]) => { + let span = span!(); + assert_eq!( + td.term_to_expr(td.get(*x), span.clone()), + crate::ast::GenericExpr::Var(span, "x".to_owned()) + ) + } + (head, _) => panic!("unexpected head {}, in {}:{}:{}", head, file!(), line!(), column!()) + }) + } + + #[test] + fn test_to_string() { + let s = r#"(f (g x y) x y (g x y))"#; + let (td, t) = parse_term(s); + assert_eq!(td.to_string(&t), s); + } + + #[test] + fn test_lookup() { + let s = r#"(f (g x y) x y (g x y))"#; + let (td, t) = parse_term(s); + assert_eq!(td.lookup(&t), td.size() - 1); + } + + #[test] + fn test_app_var_lit() { + let s = r#"(f (g x y) x 7 (g x y))"#; + let (mut td, t) = parse_term(s); + let x = td.var("x".into()); + let y = td.var("y".into()); + let seven = td.lit(7.into()); + let g = td.app("g".into(), vec![x.clone(), y.clone()]); + let t2 = td.app("f".into(), vec![g.clone(), x, seven, g]); + assert_eq!(t, t2); + } } diff --git a/src/ast/remove_globals.rs b/src/ast/remove_globals.rs index 6446d296c..fd885706b 100644 --- a/src/ast/remove_globals.rs +++ b/src/ast/remove_globals.rs @@ -5,6 +5,7 @@ //! When a globally-bound primitive value is used in the actions of a rule, //! we add a new variable to the query bound to the primitive value. +use crate::ast::expr::ResolvedExpr; use crate::*; use crate::{core::ResolvedCall, typechecking::FuncType}; use egglog_ast::generic_ast::{GenericAction, GenericExpr, GenericFact, GenericRule}; diff --git a/src/cli.rs b/src/cli.rs index 9015e83b5..f4ce098a2 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -28,6 +28,9 @@ struct Args { /// Serializes the egraph for each egglog file as an SVG #[clap(long)] to_svg: bool, + /// Enables proof generation and provenance tracking + #[clap(long)] + enable_proofs: bool, /// Splits the serialized egraph into primitives and non-primitives #[clap(long)] serialize_split_primitive_outputs: bool, @@ -67,10 +70,17 @@ pub fn cli(mut egraph: EGraph) { .init(); let args = Args::parse(); + let threads = args.threads; rayon::ThreadPoolBuilder::new() - .num_threads(args.threads) + .num_threads(threads) .build_global() - .unwrap(); + .ok(); + if args.enable_proofs && !egraph.proofs_enabled() { + // NB: this clears any previous settings and state from the e-graph. It is not generally + // safe to enable proofs mid-stream and given the way `cli` is invoked by main, this is not + // a problem. + egraph = EGraph::with_proofs(); + } log::debug!( "Initialized thread pool with {} threads", rayon::current_num_threads() diff --git a/src/constraint.rs b/src/constraint.rs index c91ef83fc..2cd588f8a 100644 --- a/src/constraint.rs +++ b/src/constraint.rs @@ -1,3 +1,4 @@ +use crate::ast::expr::ResolvedExpr; use crate::{ core::{ Atom, CoreAction, CoreRule, GenericCoreActions, GenericCoreRule, HeadOrEq, Query, @@ -1120,7 +1121,7 @@ impl TypeConstraint for AllEqualTypeConstraint { /// A variable is grounded if it appears in a function call or is equal to a grounded variable. /// This pass happens after type resolution and lowering to core rules, but before canonicalization. pub(crate) fn grounded_check( - rule: &GenericCoreRule, ResolvedCall, ResolvedVar>, + rule: &GenericCoreRule, ResolvedCall, ResolvedVar, ResolvedVar>, ) -> Result<(), TypeError> { use crate::core::ResolvedAtomTerm; let body = &rule.body; diff --git a/src/core.rs b/src/core.rs index 8c510e4a8..f4e6cc17f 100644 --- a/src/core.rs +++ b/src/core.rs @@ -13,6 +13,7 @@ use std::hash::Hasher; use std::ops::AddAssign; +use crate::ast::MappedFact; use crate::{constraint::grounded_check, *}; use egglog_ast::generic_ast::{Change, GenericAction, GenericActions, GenericExpr}; use egglog_ast::span::Span; @@ -204,6 +205,7 @@ where pub type AtomTerm = GenericAtomTerm; pub type ResolvedAtomTerm = GenericAtomTerm; +pub type CanonicalizedResolvedAtomTerm = GenericAtomTerm>; impl GenericAtomTerm { pub fn span(&self) -> &Span { @@ -235,6 +237,16 @@ impl ResolvedAtomTerm { } } +impl CanonicalizedResolvedAtomTerm { + pub fn output(&self) -> ArcSort { + match self { + CanonicalizedResolvedAtomTerm::Var(_, v) => v.var.sort.clone(), + CanonicalizedResolvedAtomTerm::Literal(_, l) => literal_sort(l), + CanonicalizedResolvedAtomTerm::Global(_, v) => v.var.sort.clone(), + } + } +} + impl std::fmt::Display for AtomTerm { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -259,7 +271,6 @@ impl std::fmt::Display for Atom { write!(f, "({} {}) ", self.head, ListDisplay(&self.args, " ")) } } - impl GenericAtom where Leaf: Clone + Eq + Hash, @@ -272,15 +283,36 @@ where GenericAtomTerm::Global(..) => None, }) } +} + +impl GenericAtom> +where + Leaf: Clone + Eq + Hash, + Head: Clone, +{ + pub fn current_vars(&self) -> impl Iterator + '_ { + self.args.iter().filter_map(|t| match t { + GenericAtomTerm::Var(_, v) => Some(v.var.clone()), + GenericAtomTerm::Literal(..) => None, + GenericAtomTerm::Global(..) => None, + }) + } - fn subst(&mut self, subst: &HashMap>) { + fn subst( + &mut self, + subst: &HashMap, GenericAtomTerm>>, + ) { for arg in self.args.iter_mut() { match arg { - GenericAtomTerm::Var(_, v) => { - if let Some(at) = subst.get(v) { - *arg = at.clone(); + GenericAtomTerm::Var(_, v) => match subst.get(v) { + Some(GenericAtomTerm::Var(_, other)) => { + v.var = other.var.clone(); } - } + Some(x) => { + *arg = x.clone(); + } + None => {} + }, GenericAtomTerm::Literal(..) => (), GenericAtomTerm::Global(..) => (), } @@ -383,6 +415,37 @@ impl std::fmt::Display for Query { } } +impl Query { + pub fn map_leaves(&self, mut f: impl FnMut(&Leaf) -> L2) -> Query + where + Leaf: Clone, + L2: Clone, + { + let atoms = self + .atoms + .iter() + .map(|atom| GenericAtom { + span: atom.span.clone(), + head: atom.head.clone(), + args: atom + .args + .iter() + .map(|arg| match arg { + GenericAtomTerm::Var(span, v) => GenericAtomTerm::Var(span.clone(), f(v)), + GenericAtomTerm::Literal(span, lit) => { + GenericAtomTerm::Literal(span.clone(), lit.clone()) + } + GenericAtomTerm::Global(span, g) => { + GenericAtomTerm::Global(span.clone(), f(g)) + } + }) + .collect(), + }) + .collect(); + Query { atoms } + } +} + impl Query { pub fn filters(&self) -> impl Iterator> + '_ { self.atoms.iter().filter_map(|atom| match &atom.head { @@ -817,40 +880,98 @@ where /// A [`GenericCoreRule`] represents a generalization of lowered form of a rule. /// Unlike other `Generic`-prefixed types, [`GenericCoreRule`] takes two `Head` -/// parameters instead of one. This is because the `Head` parameter of `body` and -/// `head` can be different. In particular, early in the compilation pipeline, -/// `body` can contain `Eq` atoms, which denotes equality constraints, so the `Head` -/// for `body` needs to be a `HeadOrEq`, while `head` does not have equality -/// constraints. +/// parameters and two `Leaf` parameters. This is because the `Head` parameter of +/// `body` and `head` can be different. Similarly, query and action leaves may +/// diverge, even though we currently pass the same type for both. +/// In particular, early in the compilation pipeline, `body` can contain `Eq` +/// atoms, which denotes equality constraints, so the `Head` for `body` needs to +/// be a `HeadOrEq`, while `head` does not have equality constraints. #[derive(Debug, Clone)] -pub struct GenericCoreRule { +pub struct GenericCoreRule { pub span: Span, - pub body: Query, - pub head: GenericCoreActions, + pub body: Query, + pub head: GenericCoreActions, +} + +type CanonicalizedCoreRule = + GenericCoreRule, Leaf>; + +pub(crate) type CoreRule = GenericCoreRule; +pub(crate) type ResolvedCoreRule = + GenericCoreRule, ResolvedVar>; + +/// A core rule paired with metadata that records how each original fact maps +/// onto the flattened query. The `rule` field is the canonicalized core rule +/// that downstream lowering consumes, while `mapped_facts` preserves the +/// correspondence to the surface syntax so proof reconstruction can recover the +/// user-facing facts. +pub(crate) struct CoreRuleWithFacts +where + Head: Clone + Display, + Leaf: Clone + PartialEq + Eq + Display + Hash, +{ + /// The canonicalized core rule ready for backend lowering. + pub rule: GenericCoreRule, Head, Leaf, Leaf>, + /// Fact annotations that capture how each original AST fact was flattened. + /// Used when proofs are enabled to reflect source syntax in the backend. + pub mapped_facts: Vec>, } -pub(crate) type CoreRule = GenericCoreRule; -pub(crate) type ResolvedCoreRule = GenericCoreRule; +#[derive(Clone)] +pub(crate) struct CanonicalizedRule { + pub rule: CanonicalizedCoreRule, + pub mapped_facts: Vec>, +} -impl GenericCoreRule +impl GenericCoreRule, Leaf> where Head1: Clone, Head2: Clone, Leaf: Clone + Eq + Hash, { - pub fn subst(&mut self, subst: &HashMap>) { + pub fn subst( + &mut self, + subst: &HashMap, GenericAtomTerm>>, + ) { for atom in &mut self.body.atoms { atom.subst(subst); } - self.head.subst(subst); + let head_subst: HashMap> = subst + .iter() + .map(|(leaf, term)| { + let term = match term { + GenericAtomTerm::Var(span, v) => { + GenericAtomTerm::Var(span.clone(), v.var.clone()) + } + GenericAtomTerm::Literal(span, lit) => { + GenericAtomTerm::Literal(span.clone(), lit.clone()) + } + GenericAtomTerm::Global(span, v) => { + GenericAtomTerm::Global(span.clone(), v.var.clone()) + } + }; + (leaf.var.clone(), term) + }) + .collect(); + self.head.subst(&head_subst); } } -impl GenericCoreRule, Head, Leaf> +impl GenericCoreRule, Head, Leaf, Leaf> where Leaf: Eq + Clone + Hash + Debug, Head: Clone, { + fn init_canon(self) -> CanonicalizedCoreRule, Head, Leaf> { + let body = self + .body + .map_leaves(|leaf| CanonicalizedVar::new_current(leaf.clone())); + GenericCoreRule { + span: self.span, + body, + head: self.head, + } + } /// Transformed a UnresolvedCoreRule into a CanonicalizedCoreRule. /// In particular, it removes equality checks between variables and /// other arguments, and turns equality checks between non-variable arguments @@ -858,16 +979,21 @@ where pub(crate) fn canonicalize( self, // Users need to pass in a substitute for equality constraints. - value_eq: impl Fn(&GenericAtomTerm, &GenericAtomTerm) -> Head, - ) -> GenericCoreRule { - let mut result_rule = self; + value_eq: impl Fn( + &GenericAtomTerm>, + &GenericAtomTerm>, + ) -> Head, + ) -> CanonicalizedCoreRule { + // TODO: have canonicalization preserve the original leaf code. + // to check: does correspondingVar work, or do we need a Canonicalized type. + let mut result_rule = self.init_canon(); loop { let mut to_subst = None; for atom in result_rule.body.atoms.iter() { if atom.head.is_eq() && atom.args[0] != atom.args[1] { match &atom.args[..] { [GenericAtomTerm::Var(_, x), y] | [y, GenericAtomTerm::Var(_, x)] => { - to_subst = Some((x, y)); + to_subst = Some((x.clone(), y.clone())); break; } _ => (), @@ -903,7 +1029,10 @@ where } else { Some(GenericAtom { span: atom.span.clone(), - head: value_eq(&atom.args[0], &atom.args[1]), + head: value_eq( + &atom.args[0], + &atom.args[1], + ), args: vec![ atom.args[0].clone(), atom.args[1].clone(), @@ -934,7 +1063,7 @@ pub(crate) trait GenericRuleExt { &self, typeinfo: &TypeInfo, fresh_gen: &mut impl FreshGen, - ) -> Result, Head, Leaf>, TypeError> + ) -> Result, TypeError> where Head: Clone + Display + IsFunc, Leaf: Clone + PartialEq + Eq + Display + Hash + Debug; @@ -949,20 +1078,23 @@ where &self, typeinfo: &TypeInfo, fresh_gen: &mut impl FreshGen, - ) -> Result, Head, Leaf>, TypeError> + ) -> Result, TypeError> where Head: Clone + Display + IsFunc, Leaf: Clone + PartialEq + Eq + Display + Hash + Debug, { - let (body, _correspondence) = Facts(self.body.clone()).to_query(typeinfo, fresh_gen); + let (body, correspondence) = Facts(self.body.clone()).to_query(typeinfo, fresh_gen); let mut binding = body.get_vars(); let (head, _correspondence) = self.head .to_core_actions(typeinfo, &mut binding, fresh_gen)?; - Ok(GenericCoreRule { - span: self.span.clone(), - body, - head, + Ok(CoreRuleWithFacts { + rule: GenericCoreRule { + span: self.span.clone(), + body, + head, + }, + mapped_facts: correspondence, }) } } @@ -972,7 +1104,7 @@ pub(crate) trait ResolvedRuleExt { &self, typeinfo: &TypeInfo, fresh_gen: &mut SymbolGen, - ) -> Result; + ) -> Result; } impl ResolvedRuleExt for ResolvedRule { @@ -980,9 +1112,10 @@ impl ResolvedRuleExt for ResolvedRule { &self, typeinfo: &TypeInfo, fresh_gen: &mut SymbolGen, - ) -> Result { + ) -> Result { let value_eq = &typeinfo.get_prims("value-eq").unwrap()[0]; - let value_eq = |at1: &ResolvedAtomTerm, at2: &ResolvedAtomTerm| { + let value_eq = |at1: &CanonicalizedResolvedAtomTerm, + at2: &CanonicalizedResolvedAtomTerm| { ResolvedCall::Primitive(SpecializedPrimitive { primitive: value_eq.clone(), input: vec![at1.output(), at2.output()], @@ -990,7 +1123,7 @@ impl ResolvedRuleExt for ResolvedRule { }) }; - let rule = self.to_core_rule(typeinfo, fresh_gen)?; + let CoreRuleWithFacts { rule, mapped_facts } = self.to_core_rule(typeinfo, fresh_gen)?; // The groundedness check happens before canonicalization, because canonicalization // may turn ungrounded variables in a query to unbounded variables in actions (e.g., @@ -999,6 +1132,6 @@ impl ResolvedRuleExt for ResolvedRule { let rule = rule.canonicalize(value_eq); - Ok(rule) + Ok(CanonicalizedRule { rule, mapped_facts }) } } diff --git a/src/egraph_operations.rs b/src/egraph_operations.rs new file mode 100644 index 000000000..d61ad257a --- /dev/null +++ b/src/egraph_operations.rs @@ -0,0 +1,268 @@ +use crate::span; +use egglog_ast::span::{RustSpan, Span}; +use egglog_core_relations::Value; + +use crate::{ + EGraph, Error, ProofStore, TermProofId, + ast::{Facts, RunConfig, Schedule, collect_query_vars}, + util::{FreshGen, IndexMap}, +}; + +/// Represents a single match of a query, containing values for all query variables. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QueryMatch { + /// Maps variable names to their values in this match + bindings: IndexMap, +} + +impl QueryMatch { + /// Get the value bound to a variable name in this match. + pub fn get(&self, var_name: &str) -> Option { + self.bindings.get(var_name).copied() + } + + /// Get all variable names in this match. + pub fn vars(&self) -> impl Iterator { + self.bindings.keys().map(|s| s.as_str()) + } + + /// Get the number of variables in this match. + pub fn len(&self) -> usize { + self.bindings.len() + } + + /// Check if this match has no variables. + pub fn is_empty(&self) -> bool { + self.bindings.is_empty() + } +} + +impl EGraph { + /// Returns all matches for the given query as a vector of QueryMatch structs. + /// + /// Each QueryMatch contains bindings only for user-defined variables in the query. + /// Internal variables generated during canonicalization (starting with $) are excluded. + /// + /// **Note**: This method requires proofs to be enabled. Create the EGraph with + /// `EGraph::with_proofs()` to use this feature. + /// + /// TODO this implementation is in-progress. + /// + /// # Example + /// ``` + /// # use egglog::prelude::*; + /// # let mut egraph = EGraph::with_proofs(); + /// egraph.parse_and_run_program(None, " + /// (datatype Math + /// (Num i64) + /// (Add Math Math)) + /// (Add (Num 1) (Num 2)) + /// ").unwrap(); + /// + /// // Query for all Add expressions + /// let matches = egraph.get_matches(facts![(= lhs (Add x y))]).unwrap(); + /// + /// // We found 1 match with lhs, x, and y bound + /// assert_eq!(matches.len(), 1); + /// assert!(matches[0].get("x").is_some()); + /// assert!(matches[0].get("y").is_some()); + /// assert!(matches[0].get("lhs").is_some()); + /// assert_eq!(matches[0].len(), 3); + /// ``` + pub fn get_matches(&mut self, facts: Facts) -> Result, Error> { + let Facts(query_facts) = facts; + if !self.backend.proofs_enabled() { + return Err(Error::BackendError( + "get_matches requires proofs to be enabled. Create the EGraph with EGraph::with_proofs().".to_string(), + )); + } + + let span = span!(); + + let resolved_facts = self + .type_info + .typecheck_facts(&mut self.parser.symbol_gen, &query_facts)?; + let query_vars = collect_query_vars(&resolved_facts); + + let constructor_name = self.parser.symbol_gen.fresh("get_matches_ctor"); + let relation_name = self.parser.symbol_gen.fresh("get_matches_rel"); + let ruleset_name = self.parser.symbol_gen.fresh("get_matches_ruleset"); + let rule_name = self.parser.symbol_gen.fresh("get_matches_rule"); + let match_sort_name = self.parser.symbol_gen.fresh("get_matches_sort"); + + let constructor_schema = { + let inputs = query_vars + .iter() + .map(|(_, sort)| sort.name().to_string()) + .collect::>(); + crate::ast::Schema::new(inputs, match_sort_name.clone()) + }; + + let mut program = vec![ + crate::ast::Command::Push(1), + crate::ast::Command::Sort(span.clone(), match_sort_name.clone(), None), + crate::ast::Command::Constructor { + span: span.clone(), + name: constructor_name.clone(), + schema: constructor_schema, + cost: None, + unextractable: true, + }, + crate::ast::Command::Relation { + span: span.clone(), + name: relation_name.clone(), + inputs: query_vars + .iter() + .map(|(_, sort)| sort.name().to_string()) + .collect(), + }, + crate::ast::Command::AddRuleset(span.clone(), ruleset_name.clone()), + ]; + + let body_facts = query_facts.clone(); + let action_expr = { + let args = query_vars + .iter() + .map(|(var, _)| crate::ast::Expr::Var(span.clone(), var.name.clone())) + .collect(); + crate::ast::Expr::Call(span.clone(), constructor_name.clone(), args) + }; + + let rule_actions = + crate::ast::GenericActions(vec![crate::ast::Action::Expr(span.clone(), action_expr)]); + + program.push(crate::ast::Command::Rule { + rule: crate::ast::GenericRule { + span: span.clone(), + head: rule_actions, + body: body_facts, + name: rule_name.clone(), + ruleset: ruleset_name.clone(), + }, + }); + + program.push(crate::ast::Command::RunSchedule(Schedule::Run( + span.clone(), + RunConfig { + ruleset: ruleset_name.clone(), + until: None, + }, + ))); + self.run_program(program)?; + + let constructor_function = self + .functions + .get(&constructor_name) + .expect("constructor should exist"); + + let mut results = Vec::new(); + self.backend + .for_each(constructor_function.backend_id, |row| { + let mut bindings = IndexMap::default(); + for ((var, _), value) in query_vars.iter().zip(row.vals.iter()) { + bindings.insert(var.name.clone(), *value); + } + results.push(QueryMatch { bindings }); + }); + + self.run_program(vec![crate::ast::Command::Pop(span.clone(), 1)])?; + + Ok(results) + } + /// Runs the query and produces a proof for any bound variable; returns the first proof found. + pub fn get_proof( + &mut self, + facts: Facts, + store: &mut ProofStore, + ) -> Result, Error> { + let Facts(query_facts) = facts; + if !self.backend.proofs_enabled() { + return Err(Error::BackendError( + "get_proof requires proofs to be enabled. Create the EGraph with EGraph::with_proofs()." + .to_string(), + )); + } + + let span = span!(); + let resolved_facts = self + .type_info + .typecheck_facts(&mut self.parser.symbol_gen, &query_facts)?; + let query_vars = collect_query_vars(&resolved_facts); + + for (target_var, target_sort) in query_vars { + let constructor_name = self.parser.symbol_gen.fresh("get_proof_ctor"); + let ruleset_name = self.parser.symbol_gen.fresh("get_proof_ruleset"); + let rule_name = self.parser.symbol_gen.fresh("get_proof_rule"); + let proof_sort_name = self.parser.symbol_gen.fresh("get_proof_sort"); + + let constructor_schema = crate::ast::Schema::new( + vec![target_sort.name().to_string()], + proof_sort_name.clone(), + ); + + let mut program = vec![ + crate::ast::Command::Push(1), + crate::ast::Command::Sort(span.clone(), proof_sort_name.clone(), None), + crate::ast::Command::Constructor { + span: span.clone(), + name: constructor_name.clone(), + schema: constructor_schema, + cost: None, + unextractable: true, + }, + crate::ast::Command::AddRuleset(span.clone(), ruleset_name.clone()), + ]; + + let body_facts = query_facts.clone(); + let action_expr = crate::ast::Expr::Call( + span.clone(), + constructor_name.clone(), + vec![crate::ast::Expr::Var(span.clone(), target_var.name.clone())], + ); + let rule_actions = crate::ast::GenericActions(vec![crate::ast::Action::Expr( + span.clone(), + action_expr, + )]); + + program.push(crate::ast::Command::Rule { + rule: crate::ast::GenericRule { + span: span.clone(), + head: rule_actions, + body: body_facts, + name: rule_name.clone(), + ruleset: ruleset_name.clone(), + }, + }); + + program.push(crate::ast::Command::RunSchedule(Schedule::Run( + span.clone(), + RunConfig { + ruleset: ruleset_name.clone(), + until: None, + }, + ))); + self.run_program(program)?; + + let mut captured = None; + if let Some(constructor_function) = self.functions.get(&constructor_name) { + self.backend + .for_each(constructor_function.backend_id, |row| { + if captured.is_none() { + captured = row.vals.first().copied(); + } + }); + } + + self.run_program(vec![crate::ast::Command::Pop(span.clone(), 1)])?; + + if let Some(value) = captured { + let proof = self + .explain_term(value, store) + .map_err(|e| Error::BackendError(e.to_string()))?; + return Ok(Some(proof)); + } + } + + Ok(None) + } +} diff --git a/src/extract.rs b/src/extract.rs index 7bcd3125a..b2ce197e1 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -1,6 +1,6 @@ -use crate::termdag::{Term, TermDag}; use crate::util::{HashMap, HashSet}; use crate::*; +use crate::{Term, TermDag}; use std::collections::VecDeque; /// An interface for custom cost model. diff --git a/src/lib.rs b/src/lib.rs index 2d450abb7..e3e7f54e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,12 +18,12 @@ mod cli; mod command_macro; pub mod constraint; mod core; +pub mod egraph_operations; pub mod extract; pub mod prelude; pub mod scheduler; mod serialize; pub mod sort; -mod termdag; mod typechecking; pub mod util; pub use command_macro::{CommandMacro, CommandMacroRegistry}; @@ -31,14 +31,13 @@ pub use command_macro::{CommandMacro, CommandMacroRegistry}; // This is used to allow the `add_primitive` macro to work in // both this crate and other crates by referring to `::egglog`. extern crate self as egglog; +use ast::CanonicalizedVar; use ast::*; pub use ast::{ResolvedExpr, ResolvedFact, ResolvedVar}; #[cfg(feature = "bin")] pub use cli::*; use constraint::{Constraint, Problem, SimpleTypeConstraint, TypeConstraint}; -use core::ResolvedAtomTerm; -pub use core::{Atom, AtomTerm}; -pub use core::{ResolvedCall, SpecializedPrimitive}; +pub use core::{Atom, AtomTerm, ResolvedCall, SpecializedPrimitive}; pub use core_relations::{BaseValue, ContainerValue, ExecutionState, Value}; use core_relations::{ExternalFunctionId, make_external_func}; use csv::Writer; @@ -47,11 +46,16 @@ use egglog_ast::generic_ast::{Change, GenericExpr, Literal}; use egglog_ast::span::Span; use egglog_ast::util::ListDisplay; pub use egglog_bridge::FunctionRow; -use egglog_bridge::{ColumnTy, QueryEntry}; +pub use egglog_bridge::match_term_app; +pub use egglog_bridge::proof_format::{EqProofId, ProofStore, TermProofId}; +use egglog_bridge::syntax::SyntaxId; +pub use egglog_bridge::termdag::{Term, TermDag, TermId}; +use egglog_bridge::{ColumnTy, QueryEntry, SourceExpr, SourceSyntax, TopLevelLhsExpr}; use egglog_core_relations as core_relations; use egglog_numeric_id as numeric_id; use egglog_reports::{ReportLevel, RunReport}; use extract::{CostModel, DefaultCost, Extractor, TreeAdditiveCostModel}; +use indexmap::IndexSet; use indexmap::map::Entry; use log::{Level, log_enabled}; use numeric_id::DenseIdMap; @@ -68,13 +72,13 @@ use std::ops::Deref; use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; -pub use termdag::{Term, TermDag, TermId}; use thiserror::Error; pub use typechecking::TypeError; pub use typechecking::TypeInfo; use util::*; use crate::ast::desugar::desugar_command; +use crate::ast::{CorrespondingVar, MappedExpr, MappedFact}; use crate::core::{GenericActionsExt, ResolvedRuleExt}; pub const GLOBAL_NAME_PREFIX: &str = "$"; @@ -290,10 +294,10 @@ impl Debug for Function { } } -impl Default for EGraph { - fn default() -> Self { +impl EGraph { + fn new_from_backend(backend: egglog_bridge::EGraph) -> Self { let mut eg = Self { - backend: Default::default(), + backend, parser: Default::default(), names: Default::default(), pushed_egraph: Default::default(), @@ -341,6 +345,25 @@ impl Default for EGraph { eg } + + /// Create a fresh e-graph with proof tracing enabled. + /// + /// Proofs are disabled by default. Use this constructor to enable proofs and provenance + /// tracking. + pub fn with_proofs() -> Self { + Self::new_from_backend(egglog_bridge::EGraph::with_tracing()) + } + + /// Returns `true` if this e-graph was constructed with proofs enabled. + pub fn proofs_enabled(&self) -> bool { + self.backend.proofs_enabled() + } +} + +impl Default for EGraph { + fn default() -> Self { + Self::new_from_backend(Default::default()) + } } #[derive(Debug, Error)] @@ -462,7 +485,7 @@ impl EGraph { ) -> Result { match expr { GenericExpr::Lit(_, literal) => { - let val = literal_to_value(&self.backend, literal); + let val = self.backend.literal_to_value(literal); Ok(egglog_bridge::MergeFn::Const(val)) } GenericExpr::Var(span, resolved_var) => match resolved_var.name.as_str() { @@ -539,6 +562,7 @@ impl EGraph { }, name: decl.name.to_string(), can_subsume, + fiat_reason_only: decl.let_binding.then(|| format!("global({})", decl.name)), }); let function = Function { @@ -828,15 +852,16 @@ impl EGraph { } fn add_rule(&mut self, rule: ast::ResolvedRule) -> Result { - let core_rule = + let canonical = rule.to_canonicalized_core_rule(&self.type_info, &mut self.parser.symbol_gen)?; - let (query, actions) = (&core_rule.body, &core_rule.head); + let (query, actions) = (&canonical.rule.body, &canonical.rule.head); let rule_id = { let mut translator = BackendRule::new( self.backend.new_rule(&rule.name, self.seminaive), &self.functions, &self.type_info, + canonical.mapped_facts.clone(), ); translator.query(query, false); translator.actions(actions)?; @@ -851,7 +876,7 @@ impl EGraph { let name = rule.name; panic!("Rule '{name}' was already present") } - indexmap::map::Entry::Vacant(e) => e.insert((core_rule, rule_id)), + indexmap::map::Entry::Vacant(e) => e.insert((canonical.rule, rule_id)), }; Ok(rule.name) } @@ -873,6 +898,7 @@ impl EGraph { self.backend.new_rule("eval_actions", false), &self.functions, &self.type_info, + Vec::new(), ); translator.actions(&actions)?; let id = translator.build(); @@ -919,6 +945,7 @@ impl EGraph { self.backend.new_rule("eval_resolved_expr", false), &self.functions, &self.type_info, + Vec::new(), ); let result_var = ResolvedVar { @@ -940,7 +967,10 @@ impl EGraph { .0; translator.actions(&actions)?; - let arg = translator.entry(&ResolvedAtomTerm::Var(span.clone(), result_var)); + let arg = translator.entry(&core::CanonicalizedResolvedAtomTerm::Var( + span.clone(), + CanonicalizedVar::new_current(result_var), + )); translator.rb.call_external_func( ext_id, &[arg], @@ -959,7 +989,9 @@ impl EGraph { let result = result.lock().unwrap().unwrap(); Ok(result) } +} +impl EGraph { fn add_combined_ruleset(&mut self, name: String, rulesets: Vec) { match self.rulesets.entry(name.clone()) { Entry::Occupied(_) => panic!("Ruleset '{name}' was already present"), @@ -984,9 +1016,9 @@ impl EGraph { name: fresh_name.clone(), ruleset: fresh_ruleset.clone(), }; - let core_rule = + let canonical_rule = rule.to_canonicalized_core_rule(&self.type_info, &mut self.parser.symbol_gen)?; - let query = core_rule.body; + let query = canonical_rule.rule.body.clone(); let ext_sc = egglog_bridge::SideChannel::default(); let ext_sc_ref = ext_sc.clone(); @@ -1001,6 +1033,7 @@ impl EGraph { self.backend.new_rule("check_facts", false), &self.functions, &self.type_info, + canonical_rule.mapped_facts, ); translator.query(&query, true); translator @@ -1567,13 +1600,266 @@ impl EGraph { self.backend .get_canon_repr(val, sort.column_ty(&self.backend)) } + + /// Generate a proof explaining how a term was constructed in the e-graph. + /// + /// This method requires that the e-graph was created with [`EGraph::with_proofs()`]. + /// The proof is stored in the provided `ProofStore` and can be printed or inspected. + /// + /// # Arguments + /// * `id` - The value representing the term to explain + /// * `store` - A mutable reference to a `ProofStore` where the proof will be stored + /// + /// # Returns + /// A `TermProofId` that can be used to print or traverse the proof. + /// + /// # Errors + /// Returns an error if: + /// - Proofs are not enabled (e-graph not created with `with_proofs()`) + /// - The proof cannot be reconstructed + /// + /// # Example + /// ``` + /// # use egglog::prelude::*; + /// # use egglog::ProofStore; + /// let mut egraph = EGraph::with_proofs(); + /// egraph.parse_and_run_program(None, " + /// (datatype Math (Num i64) (Add Math Math)) + /// (let x (Add (Num 1) (Num 2))) + /// ").unwrap(); + /// + /// // Get the value for x + /// let (_, x_value) = egraph.eval_expr(&expr!(x)).unwrap(); + /// + /// // Generate a proof explaining how x was constructed + /// let mut store = ProofStore::default(); + /// let proof_id = egraph.explain_term(x_value, &mut store).unwrap(); + /// + /// // Print the proof + /// store.print_term_proof(proof_id, &mut std::io::stdout()).unwrap(); + /// ``` + pub fn explain_term( + &mut self, + id: Value, + store: &mut ProofStore, + ) -> egglog_bridge::Result { + self.backend.explain_term(id, store) + } + + /// Generate a proof explaining why two terms are equal in the e-graph. + /// + /// This method requires that the e-graph was created with [`EGraph::with_proofs()`]. + /// The proof shows the sequence of rewrites that establish the equality between the two terms. + /// + /// # Arguments + /// * `id1` - The value representing the first term + /// * `id2` - The value representing the second term + /// * `store` - A mutable reference to a `ProofStore` where the proof will be stored + /// + /// # Returns + /// An `EqProofId` that can be used to print or traverse the equality proof. + /// + /// # Errors + /// Returns an error if: + /// - Proofs are not enabled (e-graph not created with `with_proofs()`) + /// - The two terms are not actually equal in the e-graph + /// - The proof cannot be reconstructed + /// + /// # Example + /// ``` + /// # use egglog::prelude::*; + /// # use egglog::ProofStore; + /// let mut egraph = EGraph::with_proofs(); + /// egraph.parse_and_run_program(None, " + /// (datatype Math (Num i64) (Add Math Math)) + /// (rule ((Add x y)) ((union (Add x y) (Add y x)))) + /// (let a (Add (Num 1) (Num 2))) + /// (let b (Add (Num 2) (Num 1))) + /// (run 1) + /// ").unwrap(); + /// + /// // Get the values for a and b + /// let (_, a_value) = egraph.eval_expr(&expr!(a)).unwrap(); + /// let (_, b_value) = egraph.eval_expr(&expr!(b)).unwrap(); + /// + /// // Generate a proof that a and b are equal + /// let mut store = ProofStore::default(); + /// let proof_id = egraph.explain_terms_equal(a_value, b_value, &mut store).unwrap(); + /// + /// // Print the proof + /// store.print_eq_proof(proof_id, &mut std::io::stdout()).unwrap(); + /// ``` + pub fn explain_terms_equal( + &mut self, + id1: Value, + id2: Value, + store: &mut ProofStore, + ) -> egglog_bridge::Result { + self.backend.explain_terms_equal(id1, id2, store) + } } struct BackendRule<'a> { rb: egglog_bridge::RuleBuilder<'a>, - entries: HashMap, + entries: HashMap, functions: &'a IndexMap, type_info: &'a TypeInfo, + var_types: DenseIdMap, + atoms: DenseIdMap, + resolved_var_entries: HashMap, + proof_state: Option, +} + +#[derive(Debug)] +struct FunctionInfo { + atom: egglog_bridge::AtomId, + backend_id: egglog_bridge::FunctionId, +} + +#[derive(Debug)] +struct PrimInfo { + var: egglog_bridge::VariableId, + ty: ColumnTy, + func: ExternalFunctionId, + name: Arc, +} + +struct ProofState { + facts: Vec>, + function_info: HashMap, + prim_info: HashMap, +} + +struct SyntaxBuilder<'a> { + env: SourceSyntax, + proof_state: &'a ProofState, + egraph: &'a egglog_bridge::EGraph, + var_map: &'a HashMap, + var_types: &'a DenseIdMap, +} + +impl SyntaxBuilder<'_> { + fn reconstruct_syntax(mut self) -> SourceSyntax { + for fact in &self.proof_state.facts { + // To use the handy "?". See the comment on `reconstruct_expr` for why this can fail. + || -> Option<()> { + match fact { + GenericFact::Eq(_, l, r) => { + let l_id = self.reconstruct_expr(l)?; + let r_id = self.reconstruct_expr(r)?; + self.env.add_toplevel_expr(TopLevelLhsExpr::Eq(l_id, r_id)); + } + GenericFact::Fact(expr) => { + let id = self.reconstruct_expr(expr)?; + self.env.add_toplevel_expr(TopLevelLhsExpr::Exists(id)); + } + } + Some(()) + }(); + } + self.env + } + + /// Generate the corresponding [`SyntaxId`] for an expression. + /// + /// This returns an Option because the source syntax can have expressions involving + /// primitives where no actual Syntax needs to be passed down to proofs. For example, a + /// primitive call asserting that `(= 2 (+ 1 1))` will have no corresponding substitution + /// available, and even something like `(Num x) (= x (+ x x))` can easily be run at proof-check + /// time, with no added auxiliary variables needed to be stored in the DB for the (= x (+ x x)) + /// atom. + /// + /// In cases like this `reconstruct_expr` returns None and the rest of the process + /// short-circuits. + fn reconstruct_expr( + &mut self, + expr: &GenericExpr, ResolvedVar>, + ) -> Option { + Some(match expr { + GenericExpr::Var(span, var) => { + let Some(qe) = self.var_map.get(&core::CanonicalizedResolvedAtomTerm::Var( + span.clone(), + CanonicalizedVar::new_current(var.clone()), + )) else { + panic!("no mapping found for variable {var} [span={span}]") + }; + let QueryEntry::Var(v) = qe else { + panic!( + "found a non-variable entry mapped from a variable {var} [span={span}], instead found {qe:?}" + ); + }; + let ty = self.var_types[v.id]; + self.env.add_expr(SourceExpr::Var { + id: v.id, + ty, + name: var.name().into(), + }) + } + GenericExpr::Lit(_, lit) => { + let (val, ty) = self.egraph.literal_to_typed_constant(lit); + self.env.add_expr(SourceExpr::Const { ty, val }) + } + + GenericExpr::Call(_, CorrespondingVar { head, to }, children) => { + let mut any_failed = false; + let args: Vec<_> = children + .iter() + .filter_map(|child| { + let res = self.reconstruct_expr(child); + any_failed |= res.is_none(); + res + }) + .collect(); + if any_failed { + return None; + } + match head { + ResolvedCall::Func(_) => { + let FunctionInfo { atom, backend_id } = + self.proof_state.function_info[to.name()]; + self.env.add_expr(SourceExpr::FunctionCall { + func: backend_id, + atom, + args, + }) + } + ResolvedCall::Primitive(_) => { + let PrimInfo { + var, + ty, + func, + name, + } = self.proof_state.prim_info.get(to.name())?; + self.env.add_expr(SourceExpr::ExternalCall { + var: *var, + ty: *ty, + func: *func, + name: name.clone(), + args, + }) + } + } + } + }) + } +} + +impl ProofState { + fn new(facts: Vec>) -> ProofState { + ProofState { + facts, + function_info: Default::default(), + prim_info: Default::default(), + } + } + + fn record_call(&mut self, res_var: String, func: FunctionInfo) { + self.function_info.insert(res_var, func); + } + + fn record_prim(&mut self, res_var: String, prim: PrimInfo) { + self.prim_info.insert(res_var, prim); + } } impl<'a> BackendRule<'a> { @@ -1581,24 +1867,38 @@ impl<'a> BackendRule<'a> { rb: egglog_bridge::RuleBuilder<'a>, functions: &'a IndexMap, type_info: &'a TypeInfo, + mapped_facts: Vec>, ) -> BackendRule<'a> { + let proofs_enabled = rb.egraph().proofs_enabled(); + BackendRule { rb, functions, type_info, entries: Default::default(), + var_types: Default::default(), + atoms: Default::default(), + resolved_var_entries: Default::default(), + proof_state: proofs_enabled.then(move || ProofState::new(mapped_facts)), } } - fn entry(&mut self, x: &core::ResolvedAtomTerm) -> QueryEntry { + fn entry(&mut self, x: &core::CanonicalizedResolvedAtomTerm) -> QueryEntry { self.entries .entry(x.clone()) .or_insert_with(|| match x { - core::GenericAtomTerm::Var(_, v) => self - .rb - .new_var_named(v.sort.column_ty(self.rb.egraph()), &v.name), - core::GenericAtomTerm::Literal(_, l) => literal_to_entry(self.rb.egraph(), l), - core::GenericAtomTerm::Global(..) => { + core::GenericAtomTerm::Var(_, v) => { + let ty = v.var.sort.column_ty(self.rb.egraph()); + let entry = self.rb.new_var_named(ty, &v.var.name); + if let QueryEntry::Var(var) = &entry { + self.var_types.insert(var.id, ty); + self.resolved_var_entries + .insert(v.var.clone(), entry.clone()); + } + entry + } + core::GenericAtomTerm::Literal(_, l) => self.rb.egraph().literal_to_entry(l), + core::GenericAtomTerm::Global(_, _) => { panic!("Globals should have been desugared") } }) @@ -1612,12 +1912,14 @@ impl<'a> BackendRule<'a> { fn prim( &mut self, prim: &core::SpecializedPrimitive, - args: &[core::ResolvedAtomTerm], + args: &[core::CanonicalizedResolvedAtomTerm], ) -> (ExternalFunctionId, Vec, ColumnTy) { let mut qe_args = self.args(args); if prim.name() == "unstable-fn" { - let core::ResolvedAtomTerm::Literal(_, Literal::String(ref name)) = args[0] else { + let core::CanonicalizedResolvedAtomTerm::Literal(_, Literal::String(ref name)) = + args[0] + else { panic!("expected string literal after `unstable-fn`") }; let id = if let Some(f) = self.type_info.get_func_type(name) { @@ -1666,26 +1968,90 @@ impl<'a> BackendRule<'a> { fn args<'b>( &mut self, - args: impl IntoIterator, + args: impl IntoIterator, ) -> Vec { args.into_iter().map(|x| self.entry(x)).collect() } - fn query(&mut self, query: &core::Query, include_subsumed: bool) { + fn canon_term(&self, term: &core::ResolvedAtomTerm) -> core::CanonicalizedResolvedAtomTerm { + match term { + core::GenericAtomTerm::Var(span, v) => { + core::GenericAtomTerm::Var(span.clone(), CanonicalizedVar::new_current(v.clone())) + } + core::GenericAtomTerm::Literal(span, lit) => { + core::GenericAtomTerm::Literal(span.clone(), lit.clone()) + } + core::GenericAtomTerm::Global(span, v) => core::GenericAtomTerm::Global( + span.clone(), + CanonicalizedVar::new_current(v.clone()), + ), + } + } + + fn canon_args<'b>( + &self, + args: impl IntoIterator, + ) -> Vec { + args.into_iter().map(|a| self.canon_term(a)).collect() + } + + fn query( + &mut self, + query: &core::Query>, + include_subsumed: bool, + ) { for atom in &query.atoms { match &atom.head { ResolvedCall::Func(f) => { - let f = self.func(f); + let f_id = self.func(f); let args = self.args(&atom.args); let is_subsumed = match include_subsumed { true => None, false => Some(false), }; - self.rb.query_table(f, &args, is_subsumed).unwrap(); + let atom_id = self.rb.query_table(f_id, &args, is_subsumed).unwrap(); + self.proof_state.as_mut().map(|ps| -> Option<()> { + let last = atom.args.last()?; + let core::CanonicalizedResolvedAtomTerm::Var(_span, var) = last else { + panic!("expected unique variable as last argument to a query, instead got {last:?}, part of {atom:?}"); + }; + ps.record_call( + var.orig.name().into(), + FunctionInfo { atom: atom_id, backend_id:f_id }, + ); + Some(()) + }); + if let Some(QueryEntry::Var(var)) = args.last() { + self.atoms.insert(var.id, atom_id); + } } ResolvedCall::Primitive(p) => { - let (p, args, ty) = self.prim(p, &atom.args); - self.rb.query_prim(p, &args, ty).unwrap() + let (ext_id, args, ty) = self.prim(p, &atom.args); + self.proof_state.as_mut().map(|ps| -> Option<()> { + let Some(QueryEntry::Var(var)) = args.last() else { + // This is an assertion about the output of some primitive. The + // primitive itself is not returning a useful variable for other atoms + // in the query. The proof checker as a result does not need to know + // anything more about this call. + return None; + }; + let core::CanonicalizedResolvedAtomTerm::Var(_span, res_var) = + atom.args.last()? + else { + panic!("expected unique variable as last argument to a prim query, instead got {atom:?}"); + }; + ps.record_prim( + res_var.orig.name().into(), + PrimInfo { + var: var.id, + ty, + func: ext_id, + name: p.name().into(), + }, + ); + Some(()) + }); + self.rb.query_prim(ext_id, &args, ty).unwrap(); } } } @@ -1695,12 +2061,14 @@ impl<'a> BackendRule<'a> { for action in &actions.0 { match action { core::GenericCoreAction::Let(span, v, f, args) => { - let v = core::GenericAtomTerm::Var(span.clone(), v.clone()); + let canon_v = CanonicalizedVar::new_current(v.clone()); + let v = core::GenericAtomTerm::Var(span.clone(), canon_v.clone()); + let canon_args = self.canon_args(args); let y = match f { ResolvedCall::Func(f) => { let name = f.name.clone(); let f = self.func(f); - let args = self.args(args); + let args = self.args(canon_args.iter()); let span = span.clone(); self.rb.lookup(f, &args, move || { format!("{span}: lookup of function {name} failed") @@ -1708,7 +2076,7 @@ impl<'a> BackendRule<'a> { } ResolvedCall::Primitive(p) => { let name = p.name().to_owned(); - let (p, args, ty) = self.prim(p, args); + let (p, args, ty) = self.prim(p, &canon_args); let span = span.clone(); self.rb.call_external_func(p, &args, ty, move || { format!("{span}: call of primitive {name} failed") @@ -1718,15 +2086,19 @@ impl<'a> BackendRule<'a> { self.entries.insert(v, y.into()); } core::GenericCoreAction::LetAtomTerm(span, v, x) => { - let v = core::GenericAtomTerm::Var(span.clone(), v.clone()); - let x = self.entry(x); + let v = core::GenericAtomTerm::Var( + span.clone(), + CanonicalizedVar::new_current(v.clone()), + ); + let x = self.entry(&self.canon_term(x)); self.entries.insert(v, x); } core::GenericCoreAction::Set(_, f, xs, y) => match f { ResolvedCall::Primitive(..) => panic!("runtime primitive set!"), ResolvedCall::Func(f) => { let f = self.func(f); - let args = self.args(xs.iter().chain([y])); + let canon_args = self.canon_args(xs.iter().chain([y])); + let args = self.args(canon_args.iter()); self.rb.set(f, &args) } }, @@ -1736,7 +2108,8 @@ impl<'a> BackendRule<'a> { let name = f.name.clone(); let can_subsume = self.functions[&f.name].can_subsume; let f = self.func(f); - let args = self.args(args); + let canon_args = self.canon_args(args); + let args = self.args(canon_args.iter()); match change { Change::Delete => self.rb.remove(f, &args), Change::Subsume if can_subsume => self.rb.subsume(f, &args), @@ -1747,8 +2120,8 @@ impl<'a> BackendRule<'a> { } }, core::GenericCoreAction::Union(_, x, y) => { - let x = self.entry(x); - let y = self.entry(y); + let x = self.entry(&self.canon_term(x)); + let y = self.entry(&self.canon_term(y)); self.rb.union(x, y) } core::GenericCoreAction::Panic(_, message) => self.rb.panic(message.clone()), @@ -1758,27 +2131,19 @@ impl<'a> BackendRule<'a> { } fn build(self) -> egglog_bridge::RuleId { - self.rb.build() - } -} - -fn literal_to_entry(egraph: &egglog_bridge::EGraph, l: &Literal) -> QueryEntry { - match l { - Literal::Int(x) => egraph.base_value_constant::(*x), - Literal::Float(x) => egraph.base_value_constant::(x.into()), - Literal::String(x) => egraph.base_value_constant::(sort::S::new(x.clone())), - Literal::Bool(x) => egraph.base_value_constant::(*x), - Literal::Unit => egraph.base_value_constant::<()>(()), - } -} - -fn literal_to_value(egraph: &egglog_bridge::EGraph, l: &Literal) -> Value { - match l { - Literal::Int(x) => egraph.base_values().get::(*x), - Literal::Float(x) => egraph.base_values().get::(x.into()), - Literal::String(x) => egraph.base_values().get::(sort::S::new(x.clone())), - Literal::Bool(x) => egraph.base_values().get::(*x), - Literal::Unit => egraph.base_values().get::<()>(()), + if let Some(proof_state) = self.proof_state.as_ref() { + let syntax = SyntaxBuilder { + proof_state, + egraph: self.rb.egraph(), + var_map: &self.entries, + var_types: &self.var_types, + env: Default::default(), + } + .reconstruct_syntax(); + self.rb.build_with_syntax(syntax) + } else { + self.rb.build() + } } } diff --git a/src/prelude.rs b/src/prelude.rs index 48af1367b..3ee280882 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -11,6 +11,7 @@ use std::any::{Any, TypeId}; // Re-exports in `prelude` for convenience. pub use egglog::ast::{Action, Fact, Facts, GenericActions, RustSpan, Span}; +pub use egglog::egraph_operations::QueryMatch; pub use egglog::sort::{BigIntSort, BigRatSort, BoolSort, F64Sort, I64Sort, StringSort, UnitSort}; pub use egglog::{CommandMacro, CommandMacroRegistry}; pub use egglog::{EGraph, span}; @@ -637,6 +638,7 @@ pub fn add_relation( #[macro_export] macro_rules! datatype { ($egraph:expr, (datatype $sort:ident $(($name:ident $($args:ident)* $(:cost $cost:expr)?))*)) => { + use $crate::ast::Schema; add_sort($egraph, stringify!($sort))?; $(add_constructor( $egraph, diff --git a/src/scheduler.rs b/src/scheduler.rs index a8ef800fa..237f5381e 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -8,7 +8,12 @@ use egglog_bridge::{ use egglog_reports::RunReport; use numeric_id::define_id; -use crate::{ast::ResolvedVar, core::GenericAtomTerm, core::ResolvedCoreRule, util::IndexMap, *}; +use crate::{ + ast::{CanonicalizedVar, ResolvedVar}, + core::{GenericAtomTerm, ResolvedCoreRule}, + util::IndexMap, + *, +}; /// A scheduler decides which matches to be applied for a rule. /// @@ -341,6 +346,7 @@ impl SchedulerRuleInfo { merge: MergeFn::AssertEq, name: "backend".to_string(), can_subsume: false, + fiat_reason_only: None, }); // Step 1: build the query rule @@ -348,11 +354,17 @@ impl SchedulerRuleInfo { egraph.backend.new_rule(name, true), &egraph.functions, &egraph.type_info, + Vec::new(), ); qrule_builder.query(&rule.body, true); let entries = free_vars .iter() - .map(|fv| qrule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone()))) + .map(|fv| { + qrule_builder.entry(&GenericAtomTerm::Var( + span!(), + CanonicalizedVar::new_current(fv.clone()), + )) + }) .collect::>(); let _var = qrule_builder.rb.call_external_func( collect_matches, @@ -367,10 +379,16 @@ impl SchedulerRuleInfo { egraph.backend.new_rule(name, false), &egraph.functions, &egraph.type_info, + Vec::new(), ); let mut entries = free_vars .iter() - .map(|fv| arule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone()))) + .map(|fv| { + arule_builder.entry(&GenericAtomTerm::Var( + span!(), + CanonicalizedVar::new_current(fv.clone()), + )) + }) .collect::>(); entries.push(unit_entry); arule_builder diff --git a/src/typechecking.rs b/src/typechecking.rs index 4188569ba..3ec6a1882 100644 --- a/src/typechecking.rs +++ b/src/typechecking.rs @@ -131,6 +131,7 @@ impl EGraph { let prim = Arc::new(x.clone()); let ext = self.backend.register_external_func(Wrapper(x)); + self.type_info .primitives .entry(prim.name().to_owned()) diff --git a/tests/basic-let.egg b/tests/basic-let.egg new file mode 100644 index 000000000..71e0aaaa9 --- /dev/null +++ b/tests/basic-let.egg @@ -0,0 +1,10 @@ +(datatype Expr + (Num i64 :cost 1) + (Add Expr Expr :cost 5)) + + +(let e1 (Num 1)) +(let e2 (Add e1 e1)) + +(check (= e1 e1)) + diff --git a/tests/fibonacci-demand.egg b/tests/fibonacci-demand.egg index 6a6d6afeb..f92843efc 100644 --- a/tests/fibonacci-demand.egg +++ b/tests/fibonacci-demand.egg @@ -5,10 +5,8 @@ (constructor Fib (i64) Expr :cost 10) (rewrite (Add (Num a) (Num b)) (Num (+ a b))) -(rewrite (Fib x) (Add (Fib (- x 1)) (Fib (- x 2))) - :when ((> x 1))) -(rewrite (Fib x) (Num x) - :when ((<= x 1))) +(rewrite (Fib x) (Add (Fib (- x 1)) (Fib (- x 2))) :when ((> x 1))) +(rewrite (Fib x) (Num x) :when ((<= x 1))) (let f7 (Fib 7)) (run 1000) @@ -16,4 +14,4 @@ (extract f7) (check (= f7 (Num 13))) - \ No newline at end of file + diff --git a/tests/files.rs b/tests/files.rs index 454b5f25f..f6036b9b4 100644 --- a/tests/files.rs +++ b/tests/files.rs @@ -1,5 +1,7 @@ -use std::path::PathBuf; +use hashbrown::HashSet; +use std::path::{Path, PathBuf}; +use egglog::ast::{Command, Parser}; use egglog::*; use libtest_mimic::Trial; @@ -7,6 +9,7 @@ use libtest_mimic::Trial; struct Run { path: PathBuf, desugar: bool, + proofs: bool, } impl Run { @@ -40,7 +43,11 @@ impl Run { } fn test_program(&self, filename: Option, program: &str, message: &str) { - let mut egraph = EGraph::default(); + let mut egraph = if self.proofs { + EGraph::with_proofs() + } else { + EGraph::default() + }; match egraph.parse_and_run_program(filename, program) { Ok(msgs) => { if self.should_fail() { @@ -96,6 +103,9 @@ impl Run { let stem = self.0.path.file_stem().unwrap(); let stem_str = stem.to_string_lossy().replace(['.', '-', ' '], "_"); write!(f, "{stem_str}")?; + if self.0.proofs { + write!(f, "_with_proofs")?; + } if self.0.desugar { write!(f, "_desugar")?; } @@ -115,16 +125,29 @@ fn generate_tests(glob: &str) -> Vec { let mut push_trial = |run: Run| trials.push(run.into_trial()); for entry in glob::glob(glob).unwrap() { + let path = entry.unwrap().clone(); + let program = std::fs::read_to_string(&path) + .unwrap_or_else(|err| panic!("Couldn't read {:?}: {:?}", path, err)); + let proofs_ok = proofs_supported(&path, &program); + let run = Run { - path: entry.unwrap().clone(), + path: path.clone(), desugar: false, + proofs: false, }; let should_fail = run.should_fail(); push_trial(run.clone()); + if !should_fail && proofs_ok { + push_trial(Run { + proofs: true, + ..run.clone() + }); + } if !should_fail { push_trial(Run { desugar: true, + proofs: false, ..run.clone() }); } @@ -133,6 +156,116 @@ fn generate_tests(glob: &str) -> Vec { trials } +fn proofs_supported(path: &Path, program: &str) -> bool { + let mut parser = Parser::default(); + let mut visited = HashSet::new(); + proofs_supported_inner(&mut parser, path, path, program, &mut visited) +} + +fn proofs_supported_inner( + parser: &mut Parser, + root: &Path, + path: &Path, + program: &str, + visited: &mut HashSet, +) -> bool { + let canonical = path.canonicalize().unwrap_or_else(|_| root.join(path)); + if !visited.insert(canonical.clone()) { + return true; + } + + let filename = path.to_string_lossy().into_owned(); + let commands = match parser.get_program_from_string(Some(filename), program) { + Ok(cmds) => cmds, + Err(_) => return false, + }; + let base_dir = path.parent().unwrap_or_else(|| Path::new(".")); + + commands + .into_iter() + .all(|command| command_allows_proofs(parser, root, base_dir, command, visited)) +} + +fn command_allows_proofs( + parser: &mut Parser, + root: &Path, + base_dir: &Path, + command: Command, + visited: &mut HashSet, +) -> bool { + match command { + Command::Function { merge: Some(_), .. } => false, + Command::Fail(_, inner) => command_allows_proofs(parser, root, base_dir, *inner, visited), + Command::Include(_, file) => { + let include_path = { + let candidate = Path::new(&file); + if candidate.is_absolute() { + candidate.to_path_buf() + } else { + base_dir.join(candidate) + } + }; + let Ok(contents) = std::fs::read_to_string(&include_path) else { + return false; + }; + proofs_supported_inner(parser, root, &include_path, &contents, visited) + } + Command::Rewrite(_, rewrite, _) | Command::BiRewrite(_, rewrite) => { + expr_allows_proofs(parser, root, base_dir, &rewrite.lhs, visited) + && expr_allows_proofs(parser, root, base_dir, &rewrite.rhs, visited) + } + Command::Rule { rule } => { + rule.body + .iter() + .all(|fact| fact_allows_proofs(parser, root, base_dir, fact, visited)) + && rule.head.iter().all(|action| { + let mut ok = true; + action.clone().visit_exprs(&mut |expr| { + ok &= expr_allows_proofs(parser, root, base_dir, &expr, visited); + expr + }); + ok + }) + } + _ => true, + } +} + +fn fact_allows_proofs( + parser: &mut Parser, + root: &Path, + base_dir: &Path, + fact: &egglog::ast::Fact, + visited: &mut HashSet, +) -> bool { + match fact { + egglog::ast::Fact::Fact(expr) => expr_allows_proofs(parser, root, base_dir, expr, visited), + egglog::ast::Fact::Eq(_, lhs, rhs) => { + expr_allows_proofs(parser, root, base_dir, lhs, visited) + && expr_allows_proofs(parser, root, base_dir, rhs, visited) + } + } +} + +fn expr_allows_proofs( + _parser: &mut Parser, + _root: &Path, + _base_dir: &Path, + expr: &egglog::ast::Expr, + _visited: &mut HashSet, +) -> bool { + match expr { + egglog::ast::Expr::Call(_, head, args) => { + if head.starts_with("unstable-") || head == "+" { + return false; + } + args.iter() + .all(|arg| expr_allows_proofs(_parser, _root, _base_dir, arg, _visited)) + } + _ => true, + } +} + fn main() { let args = libtest_mimic::Arguments::from_args(); let tests = generate_tests("tests/**/*.egg"); diff --git a/tests/proof_api/Cargo.toml b/tests/proof_api/Cargo.toml new file mode 100644 index 000000000..626ed7308 --- /dev/null +++ b/tests/proof_api/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "proof_api" +version = "0.1.0" +edition = "2024" + +[lib] +path = "src/lib.rs" + +[dependencies] +egglog = { path = "../..", features = ["bin"] } diff --git a/tests/proof_api/src/lib.rs b/tests/proof_api/src/lib.rs new file mode 100644 index 000000000..33c37b031 --- /dev/null +++ b/tests/proof_api/src/lib.rs @@ -0,0 +1,44 @@ +use egglog::ast::Schema; +use egglog::prelude::*; + +pub fn build_program() -> Result { + let mut egraph = EGraph::default(); + + add_sort(&mut egraph, "Expr")?; + + add_constructor( + &mut egraph, + "Num", + Schema { + input: vec!["i64".into()], + output: "Expr".into(), + }, + None, + false, + )?; + + add_constructor( + &mut egraph, + "Add", + Schema { + input: vec!["Expr".into(), "Expr".into()], + output: "Expr".into(), + }, + None, + false, + )?; + + let ruleset = "rules"; + add_ruleset(&mut egraph, ruleset)?; + rule( + &mut egraph, + ruleset, + facts![ + (= (Add (Num x) (Num y)) (Num n)) + (= (+ x y) n) + ], + actions![], + )?; + + Ok(egraph) +} diff --git a/tests/proof_api/tests/api.rs b/tests/proof_api/tests/api.rs new file mode 100644 index 000000000..561729dbd --- /dev/null +++ b/tests/proof_api/tests/api.rs @@ -0,0 +1,12 @@ +use egglog::prelude::*; +use proof_api::build_program; + +#[test] +fn api_program_builds_without_proofs() { + let mut egraph = build_program().expect("building program succeeds"); + + let results = query(&mut egraph, vars![x: i64], facts![ (= x 0) ]).expect("query runs"); + + // Ensure we can iterate over the results returned by the API. + let _count = results.iter().count(); +} diff --git a/tests/proof_api/tests/get_matches.rs b/tests/proof_api/tests/get_matches.rs new file mode 100644 index 000000000..964bf36ec --- /dev/null +++ b/tests/proof_api/tests/get_matches.rs @@ -0,0 +1,100 @@ +use egglog::prelude::*; + +#[test] +fn test_get_matches_basic() { + let mut egraph = EGraph::with_proofs(); + egraph + .parse_and_run_program( + None, + " + (datatype Math + (Num i64) + (Add Math Math)) + + (Add (Num 1) (Num 2)) + (Add (Num 3) (Num 4)) + (Add (Num 5) (Num 6)) + ", + ) + .unwrap(); + + let matches = egraph.get_matches(facts![(Add x y)]).unwrap(); + + assert_eq!(matches.len(), 3); + for m in &matches { + assert_eq!(m.len(), 2); + assert!(m.get("x").is_some()); + assert!(m.get("y").is_some()); + } +} + +#[test] +fn test_get_matches_with_equality() { + let mut egraph = EGraph::with_proofs(); + egraph + .parse_and_run_program( + None, + " + (datatype Math + (Num i64) + (Add Math Math)) + + (let a (Num 1)) + (let b (Num 2)) + (let sum (Add a b)) + (union a (Num 10)) + ", + ) + .unwrap(); + + let matches = egraph.get_matches(facts![(Add x y) (= x (Num 1))]).unwrap(); + + assert!(!matches.is_empty()); +} + +#[test] +fn test_get_matches_lhs_equality() { + let mut egraph = EGraph::with_proofs(); + egraph + .parse_and_run_program( + None, + " + (datatype Math + (Num i64) + (Add Math Math)) + + (let lhs (Add (Num 1) (Num 2))) + ", + ) + .unwrap(); + + let matches = egraph.get_matches(facts![(= lhs (Add x y))]).unwrap(); + + assert_eq!(matches.len(), 1); + let bindings = &matches[0]; + assert!(bindings.get("lhs").is_some()); + assert!(bindings.get("x").is_some()); + assert!(bindings.get("y").is_some()); + assert_eq!(bindings.len(), 3); +} + +#[test] +fn test_get_matches_empty_result() { + let mut egraph = EGraph::with_proofs(); + egraph + .parse_and_run_program( + None, + " + (datatype Math + (Num i64) + (Add Math Math)) + + (Num 1) + (Num 2) + ", + ) + .unwrap(); + + let matches = egraph.get_matches(facts![(Add x y)]).unwrap(); + assert_eq!(matches.len(), 0); +} diff --git a/tests/proof_api/tests/query_and_explain.rs b/tests/proof_api/tests/query_and_explain.rs new file mode 100644 index 000000000..9f4373d91 --- /dev/null +++ b/tests/proof_api/tests/query_and_explain.rs @@ -0,0 +1,94 @@ +use egglog::ProofStore; +use egglog::prelude::*; + +#[test] +#[allow(clippy::disallowed_macros)] +fn test_query_and_explain_match() { + // Create an e-graph with proofs enabled + let mut egraph = EGraph::with_proofs(); + + egraph + .parse_and_run_program( + None, + " + (datatype Math + (Num i64) + (Add Math Math)) + + ; Create a commutative rule for addition + (rule ((Add x y)) + ((union (Add x y) (Add y x)))) + + ; Add some expressions + (let a (Add (Num 1) (Num 2))) + (let b (Add (Num 3) (Num 4))) + (let c (Add (Num 5) (Num 6))) + + ; Run rules to establish equalities + (run 1) + ", + ) + .unwrap(); + + // Query for all Add expressions with a name bound to them + // This will match both the original expressions and the ones created by the commutative rule + let matches = egraph + .get_matches(facts![(Fact (= lhs (Add x y)))]) + .unwrap(); + + println!("Found {} matches", matches.len()); + // We get 6 matches because the commutative rule creates both (Add a b) and (Add b a) + assert_eq!(matches.len(), 6); + + // Get the first match + let first_match = &matches[0]; + println!("\nFirst match has {} variables", first_match.len()); + // Note: We only get x and y, not lhs, because lhs is equal to the pattern (Add x y) + // and equality constraints don't create new variables in the match + assert_eq!(first_match.len(), 2); // x, y + + // Get the values from the match + let x_value = first_match.get("x").expect("x should be bound"); + let y_value = first_match.get("y").expect("y should be bound"); + + println!("x = {:?}, y = {:?}", x_value, y_value); + + // To get lhs, we need to evaluate (Add x y) in the egraph + // But for this test, let's just explain the operands + + // Explain how the operands were constructed + let mut store = ProofStore::default(); + let x_proof = egraph.explain_term(x_value, &mut store).unwrap(); + + println!("\nProof for x:"); + store + .print_term_proof(x_proof, &mut std::io::stdout()) + .unwrap(); + + // Now let's look at a few matches to see both original and rule-generated ones + println!("\n--- First 4 matches ---"); + + for (i, m) in matches.iter().take(4).enumerate() { + let x = m.get("x").unwrap(); + let y = m.get("y").unwrap(); + + println!("\nMatch {}: x={:?}, y={:?}", i, x, y); + println!("Proof for x:"); + let proof_x = egraph.explain_term(x, &mut store).unwrap(); + store + .print_term_proof(proof_x, &mut std::io::stdout()) + .unwrap(); + + println!("Proof for y:"); + let proof_y = egraph.explain_term(y, &mut store).unwrap(); + store + .print_term_proof(proof_y, &mut std::io::stdout()) + .unwrap(); + } + + println!("\n--- Note ---"); + println!("The matches include both:"); + println!(" - Original (Add x y) expressions from the program"); + println!(" - Commuted (Add y x) expressions created by the rule"); + println!("\n✓ Successfully queried for matches and explained all their proofs!"); +} diff --git a/tests/proof_api/tests/query_proof.rs b/tests/proof_api/tests/query_proof.rs new file mode 100644 index 000000000..3e2474cb2 --- /dev/null +++ b/tests/proof_api/tests/query_proof.rs @@ -0,0 +1,48 @@ +use egglog::ProofStore; +use egglog::prelude::*; + +#[test] +fn proof_none_when_no_match() { + let mut egraph = EGraph::with_proofs(); + egraph + .parse_and_run_program( + None, + " + (datatype Math + (Num i64) + (Add Math Math)) + ", + ) + .unwrap(); + + let mut store = ProofStore::default(); + let proof = egraph.get_proof(facts![(Add x y)], &mut store).unwrap(); + + assert!(proof.is_none()); +} + +#[test] +fn proof_for_single_match() { + let mut egraph = EGraph::with_proofs(); + egraph + .parse_and_run_program( + None, + " + (datatype Math + (Num i64) + (Add Math Math)) + (let lhs (Add (Num 1) (Num 2))) + ", + ) + .unwrap(); + + let mut store = ProofStore::default(); + let proof = egraph + .get_proof(facts![(= lhs (Add x y))], &mut store) + .unwrap() + .expect("expected proof"); + + store + .print_term_proof(proof, &mut std::io::stdout()) + .unwrap(); +} diff --git a/tests/proof_api/tests/simple_math.egg b/tests/proof_api/tests/simple_math.egg new file mode 100644 index 000000000..b139d9bb1 --- /dev/null +++ b/tests/proof_api/tests/simple_math.egg @@ -0,0 +1,11 @@ +(datatype Expr + (Num i64) + (Add Expr Expr)) + +(rule + ((= lhs (Add x y))) + ((union (Add x y) (Add y x)))) + +(Add (Num 2) (Num 3)) + +(run 1) diff --git a/tests/proof_api/tests/with_proofs.rs b/tests/proof_api/tests/with_proofs.rs new file mode 100644 index 000000000..b52e8124e --- /dev/null +++ b/tests/proof_api/tests/with_proofs.rs @@ -0,0 +1,27 @@ +use std::io::stdout; + +use egglog::Error; +use egglog::ProofStore; +use egglog::prelude::*; + +#[test] +fn proofs_from_egg_file() -> Result<(), Error> { + let program = include_str!("simple_math.egg"); + let mut egraph = EGraph::with_proofs(); + + egraph.parse_and_run_program(None, program)?; + + let mut store = ProofStore::default(); + + // Evaluate a small term + let (_, term_value) = egraph + .eval_expr(&expr!((Add (Num 3) (Num 2)))) + .expect("evaluate expression"); + let term_pf = egraph + .explain_term(term_value, &mut store) + .expect("term proof"); + store + .print_term_proof(term_pf, &mut stdout()) + .expect("print term proof"); + Ok(()) +}