From 89b5371d0f66e7426bdd412547267634b83ed9b3 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Tue, 9 Sep 2025 11:42:12 +0300 Subject: [PATCH 1/6] Support attributes on `Node`s (not just their output values). --- src/cf/structurize.rs | 194 +++++++++++++++++++++++++++++++++++++----- src/lib.rs | 2 + src/mem/analyze.rs | 12 ++- src/print/mod.rs | 14 +-- src/spv/lower.rs | 1 + src/transform.rs | 17 ++-- src/visit.rs | 3 +- 7 files changed, 208 insertions(+), 35 deletions(-) diff --git a/src/cf/structurize.rs b/src/cf/structurize.rs index e1d51c36..aa654b6e 100644 --- a/src/cf/structurize.rs +++ b/src/cf/structurize.rs @@ -8,9 +8,9 @@ use crate::cf::unstructured::{ }; use crate::transform::{InnerInPlaceTransform as _, Transformer}; use crate::{ - AttrSet, Const, ConstDef, ConstKind, Context, EntityOrientedDenseMap, FuncDefBody, FxIndexMap, - FxIndexSet, Node, NodeDef, NodeKind, NodeOutputDecl, Region, RegionDef, Type, TypeKind, Value, - spv, + AttrSet, Const, ConstDef, ConstKind, Context, DbgSrcLoc, EntityOrientedDenseMap, FuncDefBody, + FxIndexMap, FxIndexSet, Node, NodeDef, NodeKind, NodeOutputDecl, Region, RegionDef, Type, + TypeKind, Value, spv, }; use itertools::{Either, Itertools}; use smallvec::SmallVec; @@ -190,6 +190,14 @@ enum StructurizeRegionState { /// **Note**: `target` has a generic type `T` to reduce redundancy when it's /// already implied (e.g. by the key in [`DeferredEdgeBundleSet`]'s map). struct IncomingEdgeBundle { + /// Attributes from the original [`ControlInst`]s (likely debuginfo), kept + /// when merging only when exactly identical, which can naturally be the case + /// for debuginfo (e.g. for branches from inside `if`-`else`/`switch` to a + /// common merge point, just after the whole control-flow construct). + // + // FIXME(eddyb) semantically filter these, maybe focus on debuginfo? + attrs: AttrSet, + target: T, accumulated_count: IncomingEdgeCount, @@ -200,8 +208,8 @@ struct IncomingEdgeBundle { impl IncomingEdgeBundle { fn with_target(self, target: U) -> IncomingEdgeBundle { - let IncomingEdgeBundle { target: _, accumulated_count, target_inputs } = self; - IncomingEdgeBundle { target, accumulated_count, target_inputs } + let IncomingEdgeBundle { attrs, target: _, accumulated_count, target_inputs } = self; + IncomingEdgeBundle { attrs, target, accumulated_count, target_inputs } } } @@ -429,6 +437,7 @@ impl DeferredEdgeBundleSet { search_target: DeferredTarget, ) -> Option> { let steal_edge_bundle = |edge_bundle: &mut IncomingEdgeBundle<()>| IncomingEdgeBundle { + attrs: edge_bundle.attrs, target: (), accumulated_count: edge_bundle.accumulated_count, target_inputs: mem::take(&mut edge_bundle.target_inputs), @@ -520,6 +529,7 @@ impl DeferredEdgeBundleSet { DeferredEdgeBundle { condition: LazyCond::False, edge_bundle: IncomingEdgeBundle { + attrs: Default::default(), target: Default::default(), accumulated_count: Default::default(), target_inputs: Default::default(), @@ -674,6 +684,7 @@ impl<'a> Structurizer<'a> { let func_entry_pseudo_edge = { let target = self.func_def_body.body; move || IncomingEdgeBundle { + attrs: Default::default(), target, accumulated_count: IncomingEdgeCount::ONE, target_inputs: [].into_iter().collect(), @@ -900,6 +911,9 @@ impl<'a> Structurizer<'a> { let loop_node = self.func_def_body.nodes.define( self.cx, NodeDef { + // FIXME(eddyb) could it be possible to synthesize attrs + // from `ControlInst`s' attrs and/or `OpLoopMerge`'s? + attrs: AttrSet::default(), kind: NodeKind::Loop { initial_inputs, body, repeat_condition }, outputs: [].into_iter().collect(), } @@ -928,11 +942,13 @@ impl<'a> Structurizer<'a> { } else { target }; - Ok(ClaimedRegion { - structured_body, - structured_body_inputs: edge_bundle.target_inputs, - deferred_edges, - }) + let IncomingEdgeBundle { attrs, target: _, accumulated_count: _, target_inputs } = + edge_bundle; + + // FIXME(eddyb) this loses `attrs`. + let _ = attrs; + + Ok(ClaimedRegion { structured_body, structured_body_inputs: target_inputs, deferred_edges }) } /// Structurize `region` by absorbing into it the entire CFG subgraph which @@ -977,13 +993,11 @@ impl<'a> Structurizer<'a> { let mut deferred_edges = { let ControlInst { attrs, kind, inputs, targets, target_inputs } = control_inst_on_exit; - // FIXME(eddyb) this loses `attrs`. - let _ = attrs; - let target_regions: SmallVec<[_; 8]> = targets .iter() .map(|&target| { self.try_claim_edge_bundle(IncomingEdgeBundle { + attrs: if targets.len() == 1 { attrs } else { AttrSet::default() }, target, accumulated_count: IncomingEdgeCount::ONE, target_inputs: target_inputs.get(&target).cloned().unwrap_or_default(), @@ -1023,6 +1037,9 @@ impl<'a> Structurizer<'a> { match kind { ControlInstKind::Unreachable => { + // FIXME(eddyb) this loses `attrs`. + let _ = attrs; + assert_eq!((inputs.len(), target_regions.len()), (0, 0)); // FIXME(eddyb) this may result in lost optimizations over @@ -1051,6 +1068,7 @@ impl<'a> Structurizer<'a> { let node = self.func_def_body.nodes.define( self.cx, NodeDef { + attrs, kind: NodeKind::ExitInvocation { kind, inputs }, outputs: [].into_iter().collect(), } @@ -1069,6 +1087,7 @@ impl<'a> Structurizer<'a> { DeferredEdgeBundleSet::Always { target: DeferredTarget::Return, edge_bundle: IncomingEdgeBundle { + attrs, accumulated_count: IncomingEdgeCount::default(), target: (), target_inputs: inputs, @@ -1077,11 +1096,11 @@ impl<'a> Structurizer<'a> { } ControlInstKind::Branch => { - assert_eq!((inputs.len(), target_regions.len()), (0, 1)); + assert_eq!(inputs.len(), 0); self.append_maybe_claimed_region( region, - target_regions.into_iter().next().unwrap(), + target_regions.into_iter().exactly_one().ok().unwrap(), ) } @@ -1090,7 +1109,7 @@ impl<'a> Structurizer<'a> { let scrutinee = inputs[0]; - self.structurize_select_into(region, kind, Ok(scrutinee), target_regions) + self.structurize_select_into(region, attrs, kind, Ok(scrutinee), target_regions) } } }; @@ -1108,8 +1127,9 @@ impl<'a> Structurizer<'a> { DeferredTarget::Return => return Err(deferred), }; + let edge_bundle_attrs = edge_bundle.attrs; match self.try_claim_edge_bundle(edge_bundle) { - Ok(claimed_region) => Ok((condition, claimed_region)), + Ok(claimed_region) => Ok((edge_bundle_attrs, condition, claimed_region)), Err(new_edge_bundle) => { let new_target = DeferredTarget::Region(new_edge_bundle.target); @@ -1120,13 +1140,14 @@ impl<'a> Structurizer<'a> { } } }); - let Some((condition, then_region)) = claimed else { + let Some((branch_attrs, condition, then_region)) = claimed else { deferred_edges = else_deferred_edges; break; }; deferred_edges = self.structurize_select_into( region, + branch_attrs, SelectionKind::BoolCond, Err(&condition), [Ok(then_region), Err(else_deferred_edges)].into_iter().collect(), @@ -1168,6 +1189,8 @@ impl<'a> Structurizer<'a> { fn structurize_select_into( &mut self, parent_region: Region, + // FIXME(eddyb) semantically filter these, maybe focus on debuginfo? + attrs: AttrSet, kind: SelectionKind, scrutinee: Result, mut cases: SmallVec<[Result; 8]>, @@ -1194,7 +1217,7 @@ impl<'a> Structurizer<'a> { // "`Select` node insertion cursor" (into `parent_region`), and // stashing `convergent_case`'s deferred edges to return later. let deferred_edges = - self.structurize_select_into(parent_region, kind, scrutinee, cases); + self.structurize_select_into(parent_region, attrs, kind, scrutinee, cases); assert!(matches!(deferred_edges, DeferredEdgeBundleSet::Unreachable)); // The sole convergent case goes in the `parent_region`, and its @@ -1203,6 +1226,119 @@ impl<'a> Structurizer<'a> { return self.append_maybe_claimed_region(parent_region, convergent_case); } + // Extends a debug location "forward", from `initial_loc` (typically + // the location of the conditional branch/switch being structurized), + // to end at a later location in the same file, before any merge targets, + // but after all of the cases (i.e. returns `None` if the `Select` is + // made up of disjoint source ranges). + let extend_dbg_src_loc = + |this: &Self, + mut initial_loc: DbgSrcLoc, + cases: &[Result]| { + // HACK(eddyb) see comment on `if initial_start_line == start_line` below. + let mut shrink_initial_start_col = initial_loc.start_line_col.1; + + let mut relevant_dbg_src_loc = |attrs: AttrSet| { + attrs + .dbg_src_loc(this.cx) + .filter(|dbg_src_loc| { + // FIXME(eddyb) walk up `inlined_callee_name_and_call_site` + // in case `initial_loc` is e.g. next to some callsite. + dbg_src_loc.file_path == initial_loc.file_path + && dbg_src_loc.inlined_callee_name_and_call_site + == initial_loc.inlined_callee_name_and_call_site + }) + .filter(|dbg_src_loc| { + let (initial_start_line, initial_start_col) = + initial_loc.start_line_col; + let (start_line, start_col) = dbg_src_loc.start_line_col; + + if (initial_start_line, initial_start_col) <= (start_line, start_col) { + return true; + } + + // HACK(eddyb) this only exists because the debuginfo + // emited by Rust-GPU for `if cond { ... } else { ... }`'s + // conditional branch points to the start of `cond`, + // instead of at the `if`, but the merges *do* point + // at the whole `if` (or rather its start). + if initial_start_line == start_line { + shrink_initial_start_col = shrink_initial_start_col.min(start_col); + } + + false + }) + }; + + let max_cases_line_col = cases + .iter() + .filter_map(|case| { + let &ClaimedRegion { structured_body, .. } = case.as_ref().ok()?; + // FIXME(eddyb) maybe there should be a `FuncAt` + // helper for "debug locations from all `Block` `DataInst`s + // and non-`Block` `Node`"? (i.e. only flattening `Block`s) + this.func_def_body + .at(structured_body) + .at_children() + .into_iter() + .flat_map(|func_at_child| { + let child_def = func_at_child.def(); + if let NodeKind::Block { insts } = child_def.kind { + Either::Left( + func_at_child + .at(insts) + .into_iter() + .map(|func_at_inst| func_at_inst.def().attrs), + ) + } else { + Either::Right([child_def.attrs].into_iter()) + } + }) + .rev() + .find_map(&mut relevant_dbg_src_loc) + .map(|dbg_src_loc| dbg_src_loc.end_line_col) + }) + .max(); + let min_merges_line_col = cases + .iter() + .flat_map(|case| { + let case_deferred_edges = match case { + Ok(ClaimedRegion { deferred_edges, .. }) | Err(deferred_edges) => { + deferred_edges + } + }; + case_deferred_edges.iter_targets_with_edge_bundle().map(|(_, e)| e.attrs) + }) + .filter_map(&mut relevant_dbg_src_loc) + .map(|dbg_src_loc| dbg_src_loc.start_line_col) + .min(); + + // HACK(eddyb) see comment on `if initial_start_line == start_line` above. + initial_loc.start_line_col.1 = shrink_initial_start_col; + + // HACK(eddyb) prefers merges because otherwise the end location + // ends up pointing into one of the cases (e.g. at the end of + // `expr` in `if ... { ... } else { ... expr }`). + // FIXME(eddyb) this doesn't pan out because the merges point + // at the *start* of the `if` in Rust-GPU-emitted debuginfo + // currently (it's likely a range being shrunk to its start + // point - Rust-GPU's custom debuginfo could probably fix that). + let end_line_col = + min_merges_line_col.or(max_cases_line_col).unwrap_or(initial_loc.end_line_col); + + if let Some(max_cases_line_col) = max_cases_line_col { + // NOTE(eddyb) can only realistically be the case if + // some of the merges are inside a larger high-level + // control-flow construct that doesn't map to fully + // structured control-flow (e.g. `switch` fallthrough). + if max_cases_line_col > end_line_col { + return None; + } + } + + Some(DbgSrcLoc { end_line_col, ..initial_loc }) + }; + // Support lazily defining the `Select` node, as soon as it's necessary // (i.e. to plumb per-case dataflow through `Value::NodeOutput`s), // but also if any of the cases actually have non-empty regions, which @@ -1213,6 +1349,13 @@ impl<'a> Structurizer<'a> { let mut non_move_kind = Some(kind); let mut get_or_define_select_node = |this: &mut Self, cases: &[_]| { *cached_select_node.get_or_insert_with(|| { + let mut attrs = attrs; + if let Some(select_dbg_src_loc) = attrs.dbg_src_loc(this.cx) + && let Some(dbg_src_loc) = extend_dbg_src_loc(this, select_dbg_src_loc, cases) + { + attrs.set_dbg_src_loc(this.cx, dbg_src_loc); + } + let kind = non_move_kind.take().unwrap(); let cases = cases .iter() @@ -1235,6 +1378,7 @@ impl<'a> Structurizer<'a> { let select_node = this.func_def_body.nodes.define( this.cx, NodeDef { + attrs, kind: NodeKind::Select { kind, scrutinee, cases }, outputs: [].into_iter().collect(), } @@ -1443,6 +1587,15 @@ impl<'a> Structurizer<'a> { DeferredEdgeBundle { condition, edge_bundle: IncomingEdgeBundle { + // FIXME(eddyb) merge debug locations when attributes differ. + attrs: per_case_deferred + .iter() + .filter_map(|d| d.as_ref().ok()) + .map(|e| e.edge_bundle.attrs) + .unique() + .exactly_one() + .ok() + .unwrap_or_default(), target, accumulated_count: total_edge_count, target_inputs, @@ -1500,7 +1653,8 @@ impl<'a> Structurizer<'a> { .map(|cond| self.materialize_lazy_cond(cond)) .collect(); - let NodeDef { kind, outputs: output_decls } = &mut *self.func_def_body.nodes[node]; + let NodeDef { attrs: _, kind, outputs: output_decls } = + &mut *self.func_def_body.nodes[node]; let cases = match kind { NodeKind::Select { kind, scrutinee, cases } => { assert_eq!(cases.len(), per_case_conds.len()); diff --git a/src/lib.rs b/src/lib.rs index 97875ccb..36493ad2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -853,6 +853,8 @@ pub use context::Node; /// See [`Region`] docs for more on control-flow in SPIR-T. #[derive(Clone)] pub struct NodeDef { + pub attrs: AttrSet, + pub kind: NodeKind, /// Outputs from this [`Node`]: diff --git a/src/mem/analyze.rs b/src/mem/analyze.rs index f9b8b28b..4ef753f9 100644 --- a/src/mem/analyze.rs +++ b/src/mem/analyze.rs @@ -781,8 +781,16 @@ impl<'a> GatherAccesses<'a> { .attrs } Value::NodeOutput { node, output_idx } => { - &mut func_def_body.at_mut(node).def().outputs[output_idx as usize] - .attrs + let node_def = func_def_body.at_mut(node).def(); + + // HACK(eddyb) `NodeOutput { output_idx: !0, .. }` + // may be used to attach errors to a whole `Node`. + if output_idx == !0 { + assert!(accesses.is_err()); + &mut node_def.attrs + } else { + &mut node_def.outputs[output_idx as usize].attrs + } } Value::DataInstOutput(data_inst) => { &mut func_def_body.at_mut(data_inst).def().attrs diff --git a/src/print/mod.rs b/src/print/mod.rs index afe81f3d..ee7622ba 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -1499,9 +1499,9 @@ impl<'a> Printer<'a> { intra_region: DbgScopeDefPlaceInRegion { before_node: Some(node) }, }); - define(Use::AlignmentAnchorForNode(node), None); + let NodeDef { attrs, kind, outputs } = func_at_node.def(); - let NodeDef { kind, outputs } = func_at_node.def(); + define(Use::AlignmentAnchorForNode(node), Some(*attrs)); if let NodeKind::Block { insts } = *kind { for func_at_inst in func_def_body.at(insts) { @@ -3739,7 +3739,9 @@ impl Print for FuncAt<'_, Node> { type Output = pretty::Fragment; fn print(&self, printer: &Printer<'_>) -> pretty::Fragment { let node = self.position; - let NodeDef { kind, outputs } = self.def(); + let NodeDef { attrs, kind, outputs } = self.def(); + + let attrs = attrs.print(printer); let outputs_header = if !outputs.is_empty() { let mut outputs = outputs.iter().enumerate().map(|(output_idx, output)| { @@ -3870,11 +3872,11 @@ impl Print for FuncAt<'_, Node> { inputs.iter().map(|v| v.print(printer)), ), }; - pretty::Fragment::new([ + let def_without_name = pretty::Fragment::new([ Use::AlignmentAnchorForNode(self.position).print_as_def(printer), - outputs_header, node_body, - ]) + ]); + AttrsAndDef { attrs, def_without_name }.insert_name_before_def(outputs_header) } } diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 7e76e0c1..45c3439e 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -1661,6 +1661,7 @@ impl Module { let block_node = func_def_body.nodes.define( &cx, NodeDef { + attrs: AttrSet::default(), kind: NodeKind::Block { insts: EntityList::empty() }, outputs: SmallVec::new(), } diff --git a/src/transform.rs b/src/transform.rs index 77d7da98..2c92fe7e 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -194,6 +194,9 @@ pub trait Transformer: Sized { fn in_place_transform_func_decl(&mut self, func_decl: &mut FuncDecl) { func_decl.inner_in_place_transform_with(self); } + fn in_place_transform_region_def(&mut self, mut func_at_region: FuncAtMut<'_, Region>) { + func_at_region.inner_in_place_transform_with(self); + } fn in_place_transform_node_def(&mut self, mut func_at_node: FuncAtMut<'_, Node>) { func_at_node.inner_in_place_transform_with(self); } @@ -569,13 +572,13 @@ impl InnerTransform for FuncParam { impl InnerInPlaceTransform for FuncDefBody { fn inner_in_place_transform_with(&mut self, transformer: &mut impl Transformer) { match &self.unstructured_cfg { - None => self.at_mut_body().inner_in_place_transform_with(transformer), + None => transformer.in_place_transform_region_def(self.at_mut_body()), Some(cfg) => { // HACK(eddyb) have to compute this before borrowing any `self` fields. let rpo = cfg.rev_post_order(self); for region in rpo { - self.at_mut(region).inner_in_place_transform_with(transformer); + transformer.in_place_transform_region_def(self.at_mut(region)); let cfg = self.unstructured_cfg.as_mut().unwrap(); if let Some(control_inst) = cfg.control_inst_on_exit_from.get_mut(region) { @@ -641,9 +644,11 @@ impl FuncAtMut<'_, Node> { impl InnerInPlaceTransform for FuncAtMut<'_, Node> { fn inner_in_place_transform_with(&mut self, transformer: &mut impl Transformer) { - // HACK(eddyb) handle pre-child-regions parts of `kind` separately to + // HACK(eddyb) handle all pre-child-regions fields separately to // allow reborrowing `FuncAtMut` (for the child region recursion). - match &mut self.reborrow().def().kind { + let NodeDef { attrs, kind, outputs: _ } = self.reborrow().def(); + transformer.transform_attr_set_use(*attrs).apply_to(attrs); + match kind { &mut NodeKind::Block { insts } => { let mut func_at_inst_iter = self.reborrow().at(insts).into_iter(); while let Some(func_at_inst) = func_at_inst_iter.next() { @@ -669,10 +674,10 @@ impl InnerInPlaceTransform for FuncAtMut<'_, Node> { // in a `Vec` (or `SmallVec`), which requires workarounds like this. for child_region_idx in 0..self.child_regions().len() { let child_region = self.child_regions()[child_region_idx]; - self.reborrow().at(child_region).inner_in_place_transform_with(transformer); + transformer.in_place_transform_region_def(self.reborrow().at(child_region)); } - let NodeDef { kind, outputs } = self.reborrow().def(); + let NodeDef { attrs: _, kind, outputs } = self.reborrow().def(); match kind { // Fully handled above, before recursing into any child regions. diff --git a/src/visit.rs b/src/visit.rs index 63372f30..fb493f3f 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -475,8 +475,9 @@ impl<'a> FuncAt<'a, EntityListIter> { // requirement, whereas this has `'a` in `self: FuncAt<'a, Node>`. impl<'a> FuncAt<'a, Node> { pub fn inner_visit_with(self, visitor: &mut impl Visitor<'a>) { - let NodeDef { kind, outputs } = self.def(); + let NodeDef { attrs, kind, outputs } = self.def(); + visitor.visit_attr_set_use(*attrs); match kind { NodeKind::Block { insts } => { for func_at_inst in self.at(*insts) { From 546a50933e6844c8628249f53d3cd9e47003e982 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Tue, 9 Sep 2025 11:42:12 +0300 Subject: [PATCH 2/6] Factor out `Node` value inputs from `NodeKind` variants. --- src/cf/structurize.rs | 19 ++++++++++++------- src/lib.rs | 21 ++++++++------------- src/print/mod.rs | 19 ++++++++++--------- src/spv/lift.rs | 19 ++++++++++--------- src/spv/lower.rs | 1 + src/transform.rs | 24 ++++++++++-------------- src/visit.rs | 18 ++++++------------ 7 files changed, 57 insertions(+), 64 deletions(-) diff --git a/src/cf/structurize.rs b/src/cf/structurize.rs index aa654b6e..b4e5241b 100644 --- a/src/cf/structurize.rs +++ b/src/cf/structurize.rs @@ -914,7 +914,8 @@ impl<'a> Structurizer<'a> { // FIXME(eddyb) could it be possible to synthesize attrs // from `ControlInst`s' attrs and/or `OpLoopMerge`'s? attrs: AttrSet::default(), - kind: NodeKind::Loop { initial_inputs, body, repeat_condition }, + inputs: initial_inputs, + kind: NodeKind::Loop { body, repeat_condition }, outputs: [].into_iter().collect(), } .into(), @@ -1069,7 +1070,8 @@ impl<'a> Structurizer<'a> { self.cx, NodeDef { attrs, - kind: NodeKind::ExitInvocation { kind, inputs }, + inputs, + kind: NodeKind::ExitInvocation(kind), outputs: [].into_iter().collect(), } .into(), @@ -1379,7 +1381,8 @@ impl<'a> Structurizer<'a> { this.cx, NodeDef { attrs, - kind: NodeKind::Select { kind, scrutinee, cases }, + inputs: [scrutinee].into_iter().collect(), + kind: NodeKind::Select { kind, cases }, outputs: [].into_iter().collect(), } .into(), @@ -1653,23 +1656,25 @@ impl<'a> Structurizer<'a> { .map(|cond| self.materialize_lazy_cond(cond)) .collect(); - let NodeDef { attrs: _, kind, outputs: output_decls } = + let NodeDef { attrs: _, inputs, kind, outputs: output_decls } = &mut *self.func_def_body.nodes[node]; let cases = match kind { - NodeKind::Select { kind, scrutinee, cases } => { + NodeKind::Select { kind, cases } => { assert_eq!(cases.len(), per_case_conds.len()); if let SelectionKind::BoolCond = kind { + let cond = inputs[0]; + let [val_false, val_true] = [self.const_false, self.const_true].map(Value::Const); if per_case_conds[..] == [val_true, val_false] { - return *scrutinee; + return cond; } else if per_case_conds[..] == [val_false, val_true] { // FIXME(eddyb) this could also be special-cased, // at least when called from the topmost level, // where which side is `false`/`true` doesn't // matter (or we could even generate `!cond`?). - let _not_cond = *scrutinee; + let _not_cond = cond; } } diff --git a/src/lib.rs b/src/lib.rs index 36493ad2..20254a15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -855,6 +855,9 @@ pub use context::Node; pub struct NodeDef { pub attrs: AttrSet, + // FIXME(eddyb) change the inline size of this to fit most nodes. + pub inputs: SmallVec<[Value; 2]>, + pub kind: NodeKind, /// Outputs from this [`Node`]: @@ -885,16 +888,16 @@ pub enum NodeKind { }, /// Choose one [`Region`] out of `cases` to execute, based on a single - /// value input (`scrutinee`) interpreted according to [`SelectionKind`]. + /// value input (`input[0]`) interpreted according to [`SelectionKind`]. /// /// This corresponds to "gamma" (`γ`) nodes in (R)VSDG, though those are /// sometimes limited only to a two-way selection on a boolean condition. - Select { kind: cf::SelectionKind, scrutinee: Value, cases: SmallVec<[Region; 2]> }, + Select { kind: cf::SelectionKind, cases: SmallVec<[Region; 2]> }, /// Execute `body` repeatedly, until `repeat_condition` evaluates to `false`. /// - /// To represent "loop state", `body` can take `inputs`, getting values from: - /// * on the first iteration: `initial_inputs` + /// To represent "loop state", `body` can take inputs, getting values from: + /// * on the first iteration: initial `inputs` (from `NodeDef`) /// * on later iterations: `body`'s own `outputs` (from the last iteration) /// /// As the condition is checked only *after* the body, this type of loop is @@ -903,8 +906,6 @@ pub enum NodeKind { /// /// This corresponds to "theta" (`θ`) nodes in (R)VSDG. Loop { - initial_inputs: SmallVec<[Value; 2]>, - body: Region, // FIXME(eddyb) should this be kept in `body.outputs`? (that would not @@ -917,13 +918,7 @@ pub enum NodeKind { /// indicating a fatal error as well. // // FIXME(eddyb) make this less shader-controlflow-centric. - ExitInvocation { - kind: cf::ExitInvocationKind, - - // FIXME(eddyb) centralize `Value` inputs across `Node`s, - // and only use stricter types for building/traversing the IR. - inputs: SmallVec<[Value; 2]>, - }, + ExitInvocation(cf::ExitInvocationKind), } /// Entity handle for a [`DataInstDef`](crate::DataInstDef) (a leaf instruction). diff --git a/src/print/mod.rs b/src/print/mod.rs index ee7622ba..f1bf0a20 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -1499,7 +1499,7 @@ impl<'a> Printer<'a> { intra_region: DbgScopeDefPlaceInRegion { before_node: Some(node) }, }); - let NodeDef { attrs, kind, outputs } = func_at_node.def(); + let NodeDef { attrs, inputs: _, kind, outputs } = func_at_node.def(); define(Use::AlignmentAnchorForNode(node), Some(*attrs)); @@ -3739,7 +3739,7 @@ impl Print for FuncAt<'_, Node> { type Output = pretty::Fragment; fn print(&self, printer: &Printer<'_>) -> pretty::Fragment { let node = self.position; - let NodeDef { attrs, kind, outputs } = self.def(); + let NodeDef { attrs, inputs, kind, outputs } = self.def(); let attrs = attrs.print(printer); @@ -3775,15 +3775,16 @@ impl Print for FuncAt<'_, Node> { .flat_map(|entry| [pretty::Node::ForceLineSeparation.into(), entry]), ) } - NodeKind::Select { kind, scrutinee, cases } => kind.print_with_scrutinee_and_cases( + NodeKind::Select { kind, cases } => kind.print_with_scrutinee_and_cases( printer, kw_style, - *scrutinee, + inputs[0], cases.iter().map(|&case| self.at(case).print(printer)), ), - NodeKind::Loop { initial_inputs, body, repeat_condition } => { + NodeKind::Loop { body, repeat_condition } => { assert!(outputs.is_empty()); + let initial_inputs = inputs; let inputs = &self.at(*body).def().inputs; assert_eq!(initial_inputs.len(), inputs.len()); @@ -3862,10 +3863,10 @@ impl Print for FuncAt<'_, Node> { repeat_condition.print(printer), ]) } - NodeKind::ExitInvocation { - kind: cf::ExitInvocationKind::SpvInst(spv::Inst { opcode, imms }), - inputs, - } => printer.pretty_spv_inst( + NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(spv::Inst { + opcode, + imms, + })) => printer.pretty_spv_inst( kw_style, *opcode, imms, diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 7718c9e0..d95a48f3 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -575,11 +575,12 @@ impl<'a> FuncLifting<'a> { CfgPoint::RegionExit(_) => SmallVec::new(), CfgPoint::NodeEntry(node) => { - match &func_def_body.at(node).def().kind { + let node_def = func_def_body.at(node).def(); + match &node_def.kind { // The backedge of a SPIR-V structured loop points to // the "loop header", i.e. the `Entry` of the `Loop`, // so that's where `body` `inputs` phis have to go. - NodeKind::Loop { initial_inputs, body, .. } => { + NodeKind::Loop { body, .. } => { let loop_body_def = func_def_body.at(*body).def(); let loop_body_inputs = &loop_body_def.inputs; @@ -598,7 +599,7 @@ impl<'a> FuncLifting<'a> { result_id: alloc_id()?, cases: FxIndexMap::default(), - default_value: Some(initial_inputs[i]), + default_value: Some(node_def.inputs[i]), }) }) .collect::>()? @@ -687,12 +688,12 @@ impl<'a> FuncLifting<'a> { unreachable!() } - NodeKind::Select { kind, scrutinee, cases } => Terminator { + NodeKind::Select { kind, cases } => Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cf::unstructured::ControlInstKind::SelectBranch( kind.clone(), )), - inputs: [*scrutinee].into_iter().collect(), + inputs: [node_def.inputs[0]].into_iter().collect(), targets: cases .iter() .map(|&case| CfgPoint::RegionEntry(case)) @@ -701,7 +702,7 @@ impl<'a> FuncLifting<'a> { merge: Some(Merge::Selection(CfgPoint::NodeExit(node))), }, - NodeKind::Loop { initial_inputs: _, body, repeat_condition: _ } => { + NodeKind::Loop { body, repeat_condition: _ } => { Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cf::unstructured::ControlInstKind::Branch), @@ -722,12 +723,12 @@ impl<'a> FuncLifting<'a> { } } - NodeKind::ExitInvocation { kind, inputs } => Terminator { + NodeKind::ExitInvocation(kind) => Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cf::unstructured::ControlInstKind::ExitInvocation( kind.clone(), )), - inputs: inputs.clone(), + inputs: node_def.inputs.clone(), targets: [].into_iter().collect(), target_phi_values: FxIndexMap::default(), merge: None, @@ -763,7 +764,7 @@ impl<'a> FuncLifting<'a> { merge: None, }, - NodeKind::Loop { initial_inputs: _, body: _, repeat_condition } => { + NodeKind::Loop { body: _, repeat_condition } => { let backedge = CfgPoint::NodeEntry(parent_node); let target_phi_values = region_outputs .map(|outputs| (backedge, outputs)) diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 45c3439e..87c267cf 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -1662,6 +1662,7 @@ impl Module { &cx, NodeDef { attrs: AttrSet::default(), + inputs: SmallVec::new(), kind: NodeKind::Block { insts: EntityList::empty() }, outputs: SmallVec::new(), } diff --git a/src/transform.rs b/src/transform.rs index 2c92fe7e..e2209e1d 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -646,8 +646,11 @@ impl InnerInPlaceTransform for FuncAtMut<'_, Node> { fn inner_in_place_transform_with(&mut self, transformer: &mut impl Transformer) { // HACK(eddyb) handle all pre-child-regions fields separately to // allow reborrowing `FuncAtMut` (for the child region recursion). - let NodeDef { attrs, kind, outputs: _ } = self.reborrow().def(); + let NodeDef { attrs, inputs, kind, outputs: _ } = self.reborrow().def(); transformer.transform_attr_set_use(*attrs).apply_to(attrs); + for v in inputs { + transformer.transform_value_use(v).apply_to(v); + } match kind { &mut NodeKind::Block { insts } => { let mut func_at_inst_iter = self.reborrow().at(insts).into_iter(); @@ -657,17 +660,10 @@ impl InnerInPlaceTransform for FuncAtMut<'_, Node> { } NodeKind::Select { kind: SelectionKind::BoolCond | SelectionKind::SpvInst(_), - scrutinee, cases: _, - } => { - transformer.transform_value_use(scrutinee).apply_to(scrutinee); - } - NodeKind::Loop { initial_inputs: inputs, body: _, repeat_condition: _ } - | NodeKind::ExitInvocation { kind: cf::ExitInvocationKind::SpvInst(_), inputs } => { - for v in inputs { - transformer.transform_value_use(v).apply_to(v); - } } + | NodeKind::Loop { body: _, repeat_condition: _ } + | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) => {} } // FIXME(eddyb) represent the list of child regions without having them @@ -677,15 +673,15 @@ impl InnerInPlaceTransform for FuncAtMut<'_, Node> { transformer.in_place_transform_region_def(self.reborrow().at(child_region)); } - let NodeDef { attrs: _, kind, outputs } = self.reborrow().def(); + let NodeDef { attrs: _, inputs: _, kind, outputs } = self.reborrow().def(); match kind { // Fully handled above, before recursing into any child regions. NodeKind::Block { insts: _ } - | NodeKind::Select { kind: _, scrutinee: _, cases: _ } - | NodeKind::ExitInvocation { kind: cf::ExitInvocationKind::SpvInst(_), inputs: _ } => {} + | NodeKind::Select { kind: _, cases: _ } + | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) => {} - NodeKind::Loop { initial_inputs: _, body: _, repeat_condition } => { + NodeKind::Loop { body: _, repeat_condition } => { transformer.transform_value_use(repeat_condition).apply_to(repeat_condition); } }; diff --git a/src/visit.rs b/src/visit.rs index fb493f3f..683a5102 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -475,9 +475,12 @@ impl<'a> FuncAt<'a, EntityListIter> { // requirement, whereas this has `'a` in `self: FuncAt<'a, Node>`. impl<'a> FuncAt<'a, Node> { pub fn inner_visit_with(self, visitor: &mut impl Visitor<'a>) { - let NodeDef { attrs, kind, outputs } = self.def(); + let NodeDef { attrs, inputs, kind, outputs } = self.def(); visitor.visit_attr_set_use(*attrs); + for v in inputs { + visitor.visit_value_use(v); + } match kind { NodeKind::Block { insts } => { for func_at_inst in self.at(*insts) { @@ -486,26 +489,17 @@ impl<'a> FuncAt<'a, Node> { } NodeKind::Select { kind: SelectionKind::BoolCond | SelectionKind::SpvInst(_), - scrutinee, cases, } => { - visitor.visit_value_use(scrutinee); for &case in cases { visitor.visit_region_def(self.at(case)); } } - NodeKind::Loop { initial_inputs, body, repeat_condition } => { - for v in initial_inputs { - visitor.visit_value_use(v); - } + NodeKind::Loop { body, repeat_condition } => { visitor.visit_region_def(self.at(*body)); visitor.visit_value_use(repeat_condition); } - NodeKind::ExitInvocation { kind: cf::ExitInvocationKind::SpvInst(_), inputs } => { - for v in inputs { - visitor.visit_value_use(v); - } - } + NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) => {} } for output in outputs { output.inner_visit_with(visitor); From ca91c03b26045d2b693ad804fa7c3e6461848b3a Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Tue, 9 Sep 2025 11:42:12 +0300 Subject: [PATCH 3/6] Factor out child `Region`s from `NodeKind` variants. --- src/cf/structurize.rs | 23 +++++++++--------- src/lib.rs | 22 +++++++++-------- src/print/mod.rs | 18 +++++++------- src/spv/lift.rs | 31 +++++++++++------------- src/spv/lower.rs | 3 ++- src/transform.rs | 55 +++++++++++++------------------------------ src/visit.rs | 33 +++++++++++++------------- 7 files changed, 83 insertions(+), 102 deletions(-) diff --git a/src/cf/structurize.rs b/src/cf/structurize.rs index b4e5241b..e80cd0e7 100644 --- a/src/cf/structurize.rs +++ b/src/cf/structurize.rs @@ -914,8 +914,9 @@ impl<'a> Structurizer<'a> { // FIXME(eddyb) could it be possible to synthesize attrs // from `ControlInst`s' attrs and/or `OpLoopMerge`'s? attrs: AttrSet::default(), + kind: NodeKind::Loop { repeat_condition }, inputs: initial_inputs, - kind: NodeKind::Loop { body, repeat_condition }, + child_regions: [body].into_iter().collect(), outputs: [].into_iter().collect(), } .into(), @@ -1070,8 +1071,9 @@ impl<'a> Structurizer<'a> { self.cx, NodeDef { attrs, - inputs, kind: NodeKind::ExitInvocation(kind), + inputs, + child_regions: [].into_iter().collect(), outputs: [].into_iter().collect(), } .into(), @@ -1381,8 +1383,9 @@ impl<'a> Structurizer<'a> { this.cx, NodeDef { attrs, + kind: NodeKind::Select(kind), inputs: [scrutinee].into_iter().collect(), - kind: NodeKind::Select { kind, cases }, + child_regions: cases, outputs: [].into_iter().collect(), } .into(), @@ -1550,10 +1553,8 @@ impl<'a> Structurizer<'a> { for (case_idx, v) in per_case_target_input.enumerate() { let v = v.unwrap_or_else(|| Value::Const(self.const_undef(ty))); - let case_region = match &self.func_def_body.at(select_node).def().kind { - NodeKind::Select { cases, .. } => cases[case_idx], - _ => unreachable!(), - }; + let case_region = + self.func_def_body.at(select_node).def().child_regions[case_idx]; let outputs = &mut self.func_def_body.at_mut(case_region).def().outputs; assert_eq!(outputs.len(), output_idx); outputs.push(v); @@ -1656,11 +1657,11 @@ impl<'a> Structurizer<'a> { .map(|cond| self.materialize_lazy_cond(cond)) .collect(); - let NodeDef { attrs: _, inputs, kind, outputs: output_decls } = + let NodeDef { attrs: _, kind, inputs, child_regions, outputs: output_decls } = &mut *self.func_def_body.nodes[node]; let cases = match kind { - NodeKind::Select { kind, cases } => { - assert_eq!(cases.len(), per_case_conds.len()); + NodeKind::Select(kind) => { + assert_eq!(child_regions.len(), per_case_conds.len()); if let SelectionKind::BoolCond = kind { let cond = inputs[0]; @@ -1678,7 +1679,7 @@ impl<'a> Structurizer<'a> { } } - cases + child_regions } _ => unreachable!(), }; diff --git a/src/lib.rs b/src/lib.rs index 20254a15..01c8a789 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -855,10 +855,13 @@ pub use context::Node; pub struct NodeDef { pub attrs: AttrSet, + pub kind: NodeKind, + // FIXME(eddyb) change the inline size of this to fit most nodes. pub inputs: SmallVec<[Value; 2]>, - pub kind: NodeKind, + // HACK(eddyb) mostly separate to allow the above `kind`-before-`inputs` order. + pub child_regions: SmallVec<[Region; 2]>, /// Outputs from this [`Node`]: /// * accessed using [`Value::NodeOutput`] @@ -887,18 +890,19 @@ pub enum NodeKind { insts: EntityList, }, - /// Choose one [`Region`] out of `cases` to execute, based on a single + /// Choose one [`Region`] out of `child_regions` to execute, based on a single /// value input (`input[0]`) interpreted according to [`SelectionKind`]. /// /// This corresponds to "gamma" (`γ`) nodes in (R)VSDG, though those are /// sometimes limited only to a two-way selection on a boolean condition. - Select { kind: cf::SelectionKind, cases: SmallVec<[Region; 2]> }, + Select(cf::SelectionKind), - /// Execute `body` repeatedly, until `repeat_condition` evaluates to `false`. + /// Execute a "body" (`child_regions[0]`) repeatedly, until `repeat_condition` + /// evaluates to `false`. /// - /// To represent "loop state", `body` can take inputs, getting values from: + /// To represent "loop state", the body can take inputs, getting values from: /// * on the first iteration: initial `inputs` (from `NodeDef`) - /// * on later iterations: `body`'s own `outputs` (from the last iteration) + /// * on later iterations: the body's own `outputs` (from the last iteration) /// /// As the condition is checked only *after* the body, this type of loop is /// sometimes described as "tail-controlled", and is also equivalent to the @@ -906,10 +910,8 @@ pub enum NodeKind { /// /// This corresponds to "theta" (`θ`) nodes in (R)VSDG. Loop { - body: Region, - - // FIXME(eddyb) should this be kept in `body.outputs`? (that would not - // have any ambiguity as to whether it can see `body`-computed values) + // FIXME(eddyb) move this to body's `outputs`, removing any ambiguity as + // to whether it can see body-computed values, and simplifying traversals. repeat_condition: Value, }, diff --git a/src/print/mod.rs b/src/print/mod.rs index f1bf0a20..37173945 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -1499,7 +1499,8 @@ impl<'a> Printer<'a> { intra_region: DbgScopeDefPlaceInRegion { before_node: Some(node) }, }); - let NodeDef { attrs, inputs: _, kind, outputs } = func_at_node.def(); + let NodeDef { attrs, kind, inputs: _, child_regions: _, outputs } = + func_at_node.def(); define(Use::AlignmentAnchorForNode(node), Some(*attrs)); @@ -3739,7 +3740,7 @@ impl Print for FuncAt<'_, Node> { type Output = pretty::Fragment; fn print(&self, printer: &Printer<'_>) -> pretty::Fragment { let node = self.position; - let NodeDef { attrs, inputs, kind, outputs } = self.def(); + let NodeDef { attrs, kind, inputs, child_regions, outputs } = self.def(); let attrs = attrs.print(printer); @@ -3775,17 +3776,18 @@ impl Print for FuncAt<'_, Node> { .flat_map(|entry| [pretty::Node::ForceLineSeparation.into(), entry]), ) } - NodeKind::Select { kind, cases } => kind.print_with_scrutinee_and_cases( + NodeKind::Select(kind) => kind.print_with_scrutinee_and_cases( printer, kw_style, inputs[0], - cases.iter().map(|&case| self.at(case).print(printer)), + child_regions.iter().map(|&case| self.at(case).print(printer)), ), - NodeKind::Loop { body, repeat_condition } => { + NodeKind::Loop { repeat_condition } => { assert!(outputs.is_empty()); let initial_inputs = inputs; - let inputs = &self.at(*body).def().inputs; + let body = child_regions[0]; + let inputs = &self.at(body).def().inputs; assert_eq!(initial_inputs.len(), inputs.len()); // FIXME(eddyb) this avoids customizing how `body` is printed, @@ -3812,7 +3814,7 @@ impl Print for FuncAt<'_, Node> { ( input, Value::RegionInput { - region: *body, + region: body, input_idx: input_idx.try_into().unwrap(), }, ) @@ -3853,7 +3855,7 @@ impl Print for FuncAt<'_, Node> { inputs_header, " {".into(), pretty::Node::IndentedBlock(vec![pretty::Fragment::new([ - self.at(*body).print(printer), + self.at(body).print(printer), body_suffix, ])]) .into(), diff --git a/src/spv/lift.rs b/src/spv/lift.rs index d95a48f3..1561eae2 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -17,7 +17,7 @@ use std::borrow::Cow; use std::collections::{BTreeMap, BTreeSet}; use std::num::NonZeroU32; use std::path::Path; -use std::{io, iter, mem, slice}; +use std::{io, iter, mem}; impl spv::Dialect { fn capability_insts(&self) -> impl Iterator + '_ { @@ -496,16 +496,10 @@ impl FuncAt<'_, Node> { f: &mut impl FnMut(CfgCursor<'_>) -> Result<(), E>, parent: &CfgCursor<'_, ControlParent>, ) -> Result<(), E> { - let child_regions: &[_] = match &self.def().kind { - NodeKind::Block { .. } | NodeKind::ExitInvocation { .. } => &[], - NodeKind::Select { cases, .. } => cases, - NodeKind::Loop { body, .. } => slice::from_ref(body), - }; - let node = self.position; let parent = Some(parent); f(CfgCursor { point: CfgPoint::NodeEntry(node), parent })?; - for ®ion in child_regions { + for ®ion in &self.def().child_regions { self.at(region).rev_post_order_try_for_each_inner( f, Some(&CfgCursor { point: ControlParent::Node(node), parent }), @@ -580,13 +574,14 @@ impl<'a> FuncLifting<'a> { // The backedge of a SPIR-V structured loop points to // the "loop header", i.e. the `Entry` of the `Loop`, // so that's where `body` `inputs` phis have to go. - NodeKind::Loop { body, .. } => { - let loop_body_def = func_def_body.at(*body).def(); + NodeKind::Loop { .. } => { + let body = node_def.child_regions[0]; + let loop_body_def = func_def_body.at(body).def(); let loop_body_inputs = &loop_body_def.inputs; if !loop_body_inputs.is_empty() { region_inputs_source - .insert(*body, RegionInputsSource::LoopHeaderPhis(node)); + .insert(body, RegionInputsSource::LoopHeaderPhis(node)); } loop_body_inputs @@ -688,13 +683,14 @@ impl<'a> FuncLifting<'a> { unreachable!() } - NodeKind::Select { kind, cases } => Terminator { + NodeKind::Select(kind) => Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cf::unstructured::ControlInstKind::SelectBranch( kind.clone(), )), inputs: [node_def.inputs[0]].into_iter().collect(), - targets: cases + targets: node_def + .child_regions .iter() .map(|&case| CfgPoint::RegionEntry(case)) .collect(), @@ -702,12 +698,13 @@ impl<'a> FuncLifting<'a> { merge: Some(Merge::Selection(CfgPoint::NodeExit(node))), }, - NodeKind::Loop { body, repeat_condition: _ } => { + NodeKind::Loop { repeat_condition: _ } => { + let body = node_def.child_regions[0]; Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cf::unstructured::ControlInstKind::Branch), inputs: [].into_iter().collect(), - targets: [CfgPoint::RegionEntry(*body)].into_iter().collect(), + targets: [CfgPoint::RegionEntry(body)].into_iter().collect(), target_phi_values: FxIndexMap::default(), merge: Some(Merge::Loop { loop_merge: CfgPoint::NodeExit(node), @@ -718,7 +715,7 @@ impl<'a> FuncLifting<'a> { // and it should be valid *but* that had to be // reverted because it's only true in the absence // of divergence within the loop body itself! - loop_continue: CfgPoint::RegionExit(*body), + loop_continue: CfgPoint::RegionExit(body), }), } } @@ -764,7 +761,7 @@ impl<'a> FuncLifting<'a> { merge: None, }, - NodeKind::Loop { body: _, repeat_condition } => { + NodeKind::Loop { repeat_condition } => { let backedge = CfgPoint::NodeEntry(parent_node); let target_phi_values = region_outputs .map(|outputs| (backedge, outputs)) diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 87c267cf..7d5d2801 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -1662,8 +1662,9 @@ impl Module { &cx, NodeDef { attrs: AttrSet::default(), - inputs: SmallVec::new(), kind: NodeKind::Block { insts: EntityList::empty() }, + inputs: SmallVec::new(), + child_regions: SmallVec::new(), outputs: SmallVec::new(), } .into(), diff --git a/src/transform.rs b/src/transform.rs index e2209e1d..67b325ca 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -13,7 +13,6 @@ use crate::{ }; use std::cmp::Ordering; use std::rc::Rc; -use std::slice; /// The result of a transformation (which is not in-place). #[must_use] @@ -631,26 +630,12 @@ impl InnerInPlaceTransform for FuncAtMut<'_, EntityListIter> { } } -impl FuncAtMut<'_, Node> { - fn child_regions(&mut self) -> &mut [Region] { - match &mut self.reborrow().def().kind { - NodeKind::Block { .. } | NodeKind::ExitInvocation { .. } => &mut [][..], - - NodeKind::Select { cases, .. } => cases, - NodeKind::Loop { body, .. } => slice::from_mut(body), - } - } -} - impl InnerInPlaceTransform for FuncAtMut<'_, Node> { fn inner_in_place_transform_with(&mut self, transformer: &mut impl Transformer) { - // HACK(eddyb) handle all pre-child-regions fields separately to - // allow reborrowing `FuncAtMut` (for the child region recursion). - let NodeDef { attrs, inputs, kind, outputs: _ } = self.reborrow().def(); + let NodeDef { attrs, kind, inputs: _, child_regions: _, outputs: _ } = + self.reborrow().def(); + transformer.transform_attr_set_use(*attrs).apply_to(attrs); - for v in inputs { - transformer.transform_value_use(v).apply_to(v); - } match kind { &mut NodeKind::Block { insts } => { let mut func_at_inst_iter = self.reborrow().at(insts).into_iter(); @@ -658,33 +643,27 @@ impl InnerInPlaceTransform for FuncAtMut<'_, Node> { transformer.in_place_transform_data_inst_def(func_at_inst); } } - NodeKind::Select { - kind: SelectionKind::BoolCond | SelectionKind::SpvInst(_), - cases: _, - } - | NodeKind::Loop { body: _, repeat_condition: _ } + NodeKind::Select(SelectionKind::BoolCond | SelectionKind::SpvInst(_)) + | NodeKind::Loop { repeat_condition: _ } | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) => {} } - // FIXME(eddyb) represent the list of child regions without having them - // in a `Vec` (or `SmallVec`), which requires workarounds like this. - for child_region_idx in 0..self.child_regions().len() { - let child_region = self.child_regions()[child_region_idx]; + for v in &mut self.reborrow().def().inputs { + transformer.transform_value_use(v).apply_to(v); + } + + for child_region_idx in 0..self.reborrow().def().child_regions.len() { + let child_region = self.reborrow().def().child_regions[child_region_idx]; transformer.in_place_transform_region_def(self.reborrow().at(child_region)); } - let NodeDef { attrs: _, inputs: _, kind, outputs } = self.reborrow().def(); + let NodeDef { attrs: _, kind, inputs: _, child_regions: _, outputs } = + self.reborrow().def(); - match kind { - // Fully handled above, before recursing into any child regions. - NodeKind::Block { insts: _ } - | NodeKind::Select { kind: _, cases: _ } - | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) => {} - - NodeKind::Loop { body: _, repeat_condition } => { - transformer.transform_value_use(repeat_condition).apply_to(repeat_condition); - } - }; + // HACK(eddyb) semantically, `repeat_condition` is a body region output. + if let NodeKind::Loop { repeat_condition } = kind { + transformer.transform_value_use(repeat_condition).apply_to(repeat_condition); + } for output in outputs { output.inner_transform_with(transformer).apply_to(output); diff --git a/src/visit.rs b/src/visit.rs index 683a5102..ad529559 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -475,32 +475,31 @@ impl<'a> FuncAt<'a, EntityListIter> { // requirement, whereas this has `'a` in `self: FuncAt<'a, Node>`. impl<'a> FuncAt<'a, Node> { pub fn inner_visit_with(self, visitor: &mut impl Visitor<'a>) { - let NodeDef { attrs, inputs, kind, outputs } = self.def(); + let NodeDef { attrs, kind, inputs, child_regions, outputs } = self.def(); visitor.visit_attr_set_use(*attrs); - for v in inputs { - visitor.visit_value_use(v); - } match kind { NodeKind::Block { insts } => { for func_at_inst in self.at(*insts) { visitor.visit_data_inst_def(func_at_inst.def()); } } - NodeKind::Select { - kind: SelectionKind::BoolCond | SelectionKind::SpvInst(_), - cases, - } => { - for &case in cases { - visitor.visit_region_def(self.at(case)); - } - } - NodeKind::Loop { body, repeat_condition } => { - visitor.visit_region_def(self.at(*body)); - visitor.visit_value_use(repeat_condition); - } - NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) => {} + NodeKind::Select(SelectionKind::BoolCond | SelectionKind::SpvInst(_)) + | NodeKind::Loop { repeat_condition: _ } + | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) => {} } + for v in inputs { + visitor.visit_value_use(v); + } + for ®ion in child_regions { + visitor.visit_region_def(self.at(region)); + } + + // HACK(eddyb) semantically, `repeat_condition` is a body region output. + if let NodeKind::Loop { repeat_condition } = kind { + visitor.visit_value_use(repeat_condition); + } + for output in outputs { output.inner_visit_with(visitor); } From 18b72a1f526bd536d088b966976a625718b859e6 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Tue, 9 Sep 2025 11:42:12 +0300 Subject: [PATCH 4/6] [TODO(eddyb): printing/sifting attributes!] WIP: DataInstDef using NodeDef --- src/func_at.rs | 4 +- src/lib.rs | 30 +++++----- src/mem/analyze.rs | 45 ++++++++++----- src/print/mod.rs | 48 +++++++++++----- src/qptr/lift.rs | 133 +++++++++++++++++++++++---------------------- src/qptr/lower.rs | 31 +++++++---- src/spv/lift.rs | 27 ++++++--- src/spv/lower.rs | 22 +++++--- src/transform.rs | 9 +-- src/visit.rs | 9 +-- 10 files changed, 213 insertions(+), 145 deletions(-) diff --git a/src/func_at.rs b/src/func_at.rs index 00fc8718..43f78e3f 100644 --- a/src/func_at.rs +++ b/src/func_at.rs @@ -124,7 +124,9 @@ impl FuncAt<'_, Value> { Value::NodeOutput { node, output_idx } => { self.at(node).def().outputs[output_idx as usize].ty } - Value::DataInstOutput(inst) => self.at(inst).def().output_type.unwrap(), + Value::DataInstOutput { inst, output_idx } => { + self.at(inst).def().outputs[output_idx as usize].ty + } } } } diff --git a/src/lib.rs b/src/lib.rs index 01c8a789..0fef462c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -852,10 +852,13 @@ pub use context::Node; /// /// See [`Region`] docs for more on control-flow in SPIR-T. #[derive(Clone)] -pub struct NodeDef { +pub struct NodeDef< + // HACK(eddyb) generic so `DataInstDef` can reuse it, pre-merger. + K = NodeKind, +> { pub attrs: AttrSet, - pub kind: NodeKind, + pub kind: K, // FIXME(eddyb) change the inline size of this to fit most nodes. pub inputs: SmallVec<[Value; 2]>, @@ -928,19 +931,10 @@ pub use context::DataInst; /// Definition for a [`DataInst`]: a leaf (non-control-flow) instruction. // -// FIXME(eddyb) `DataInstKind::FuncCall` should probably be a `NodeKind`, -// but also `DataInst` vs `Node` is a purely artificial distinction. -#[derive(Clone)] -pub struct DataInstDef { - pub attrs: AttrSet, - - pub kind: DataInstKind, - - // FIXME(eddyb) change the inline size of this to fit most instructions. - pub inputs: SmallVec<[Value; 2]>, - - pub output_type: Option, -} +// HACK(eddyb) temporarily reusing `NodeDef` pre-merger, with: +// - `child_regions` always empty +// - `outputs.len` always <= 1 +pub type DataInstDef = NodeDef; #[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] pub enum DataInstKind { @@ -988,5 +982,9 @@ pub enum Value { }, /// The output value of a [`DataInst`]. - DataInstOutput(DataInst), + DataInstOutput { + inst: DataInst, + // HACK(eddyb) temporarily aligned with `NodeDef` pre-merger (always == 0). + output_idx: u32, + }, } diff --git a/src/mem/analyze.rs b/src/mem/analyze.rs index 4ef753f9..1f62a794 100644 --- a/src/mem/analyze.rs +++ b/src/mem/analyze.rs @@ -11,7 +11,7 @@ use crate::{ DeclDef, Diag, EntityList, ExportKey, Exportee, Func, FxIndexMap, GlobalVar, Module, Node, NodeKind, OrdAssertEq, Type, TypeKind, Value, }; -use itertools::Either; +use itertools::{Either, Itertools as _}; use rustc_hash::FxHashMap; use smallvec::SmallVec; use std::mem; @@ -792,8 +792,17 @@ impl<'a> GatherAccesses<'a> { &mut node_def.outputs[output_idx as usize].attrs } } - Value::DataInstOutput(data_inst) => { - &mut func_def_body.at_mut(data_inst).def().attrs + Value::DataInstOutput { inst, output_idx } => { + let inst_def = func_def_body.at_mut(inst).def(); + + // HACK(eddyb) `DataInstOutput { output_idx: !0, .. }` + // may be used to attach errors to a whole `DataInst`. + if output_idx == !0 { + assert!(accesses.is_err()); + &mut inst_def.attrs + } else { + &mut inst_def.outputs[output_idx as usize].attrs + } } }; match accesses { @@ -897,7 +906,9 @@ impl<'a> GatherAccesses<'a> { )); return; } - Value::DataInstOutput(ptr_inst) => { + Value::DataInstOutput { inst: ptr_inst, output_idx } => { + // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. + assert_eq!(output_idx, 0); data_inst_output_accesses.entry(ptr_inst).or_default() } }; @@ -924,25 +935,31 @@ impl<'a> GatherAccesses<'a> { } FuncGatherAccessesState::InProgress => { accesses_or_err_attrs_to_attach.push(( - Value::DataInstOutput(data_inst), + Value::DataInstOutput { inst: data_inst, output_idx: 0 }, Err(AnalysisError(Diag::bug([ "unsupported recursive call".into() ]))), )); } }; - if data_inst_def.output_type.is_some_and(is_qptr) + // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. + if (data_inst_def.outputs.iter().at_most_one().ok().unwrap()) + .is_some_and(|o| is_qptr(o.ty)) && let Some(accesses) = output_accesses { - accesses_or_err_attrs_to_attach - .push((Value::DataInstOutput(data_inst), accesses)); + accesses_or_err_attrs_to_attach.push(( + Value::DataInstOutput { inst: data_inst, output_idx: 0 }, + accesses, + )); } } DataInstKind::Mem(MemOp::FuncLocalVar(_)) => { if let Some(accesses) = output_accesses { - accesses_or_err_attrs_to_attach - .push((Value::DataInstOutput(data_inst), accesses)); + accesses_or_err_attrs_to_attach.push(( + Value::DataInstOutput { inst: data_inst, output_idx: 0 }, + accesses, + )); } } DataInstKind::QPtr(QPtrOp::HandleArrayIndex) => { @@ -1138,7 +1155,7 @@ impl<'a> GatherAccesses<'a> { // HACK(eddyb) `_` will match multiple variants soon. #[allow(clippy::match_wildcard_for_single_variants)] let (op_name, access_type) = match op { - MemOp::Load => ("Load", data_inst_def.output_type.unwrap()), + MemOp::Load => ("Load", data_inst_def.outputs[0].ty), MemOp::Store => { ("Store", func_at_inst.at(data_inst_def.inputs[1]).type_of(&cx)) } @@ -1270,8 +1287,10 @@ impl<'a> GatherAccesses<'a> { if has_from_spv_ptr_output_attr { // FIXME(eddyb) merge with `FromSpvPtrOutput`'s `pointee`. if let Some(accesses) = output_accesses { - accesses_or_err_attrs_to_attach - .push((Value::DataInstOutput(data_inst), accesses)); + accesses_or_err_attrs_to_attach.push(( + Value::DataInstOutput { inst: data_inst, output_idx: 0 }, + accesses, + )); } } } diff --git a/src/print/mod.rs b/src/print/mod.rs index 37173945..4374f2dc 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -258,7 +258,10 @@ enum Use { node: Node, output_idx: u32, }, - DataInstOutput(DataInst), + DataInstOutput { + inst: DataInst, + output_idx: u32, + }, // NOTE(eddyb) these overlap somewhat with other cases, but they're always // generated, even when there is no "use", for `multiversion` alignment. @@ -273,7 +276,7 @@ impl From for Use { Value::Const(ct) => Use::CxInterned(CxInterned::Const(ct)), Value::RegionInput { region, input_idx } => Use::RegionInput { region, input_idx }, Value::NodeOutput { node, output_idx } => Use::NodeOutput { node, output_idx }, - Value::DataInstOutput(inst) => Use::DataInstOutput(inst), + Value::DataInstOutput { inst, output_idx } => Use::DataInstOutput { inst, output_idx }, } } } @@ -292,7 +295,7 @@ impl Use { Self::DbgScope { .. } => ("", "d"), Self::RegionLabel(_) => ("label", "L"), - Self::RegionInput { .. } | Self::NodeOutput { .. } | Self::DataInstOutput(_) => { + Self::RegionInput { .. } | Self::NodeOutput { .. } | Self::DataInstOutput { .. } => { ("", "v") } @@ -1066,7 +1069,7 @@ impl<'a> Printer<'a> { if let Use::RegionLabel(_) | Use::RegionInput { .. } | Use::NodeOutput { .. } - | Use::DataInstOutput(_) = use_kind + | Use::DataInstOutput { .. } = use_kind { return (use_kind, UseStyle::Inline); } @@ -1101,7 +1104,7 @@ impl<'a> Printer<'a> { | Use::RegionLabel(_) | Use::RegionInput { .. } | Use::NodeOutput { .. } - | Use::DataInstOutput(_) + | Use::DataInstOutput { .. } | Use::AlignmentAnchorForRegion(_) | Use::AlignmentAnchorForNode(_) | Use::AlignmentAnchorForDataInst(_) => unreachable!(), @@ -1176,7 +1179,7 @@ impl<'a> Printer<'a> { | Use::RegionLabel(_) | Use::RegionInput { .. } | Use::NodeOutput { .. } - | Use::DataInstOutput(_) + | Use::DataInstOutput { .. } | Use::AlignmentAnchorForRegion(_) | Use::AlignmentAnchorForNode(_) | Use::AlignmentAnchorForDataInst(_) => { @@ -1202,7 +1205,7 @@ impl<'a> Printer<'a> { | Use::RegionLabel(_) | Use::RegionInput { .. } | Use::NodeOutput { .. } - | Use::DataInstOutput(_) + | Use::DataInstOutput { .. } | Use::AlignmentAnchorForRegion(_) | Use::AlignmentAnchorForNode(_) | Use::AlignmentAnchorForDataInst(_) => { @@ -1511,10 +1514,13 @@ impl<'a> Printer<'a> { None, ); let inst_def = func_at_inst.def(); - if inst_def.output_type.is_some() { + for (i, output_decl) in inst_def.outputs.iter().enumerate() { define( - Use::DataInstOutput(func_at_inst.position), - Some(inst_def.attrs), + Use::DataInstOutput { + inst: func_at_inst.position, + output_idx: i.try_into().unwrap(), + }, + Some(output_decl.attrs), ); } } @@ -1548,7 +1554,9 @@ impl<'a> Printer<'a> { (&mut region_label_counter, use_styles.get_mut(&use_kind)) } - Use::RegionInput { .. } | Use::NodeOutput { .. } | Use::DataInstOutput(_) => { + Use::RegionInput { .. } + | Use::NodeOutput { .. } + | Use::DataInstOutput { .. } => { (&mut value_counter, use_styles.get_mut(&use_kind)) } @@ -2178,7 +2186,7 @@ impl Use { | Self::RegionLabel(_) | Self::RegionInput { .. } | Self::NodeOutput { .. } - | Self::DataInstOutput(_) => "_".into(), + | Self::DataInstOutput { .. } => "_".into(), Self::AlignmentAnchorForRegion(_) | Self::AlignmentAnchorForNode(_) @@ -3910,14 +3918,24 @@ impl Print for NodeOutputDecl { impl Print for FuncAt<'_, DataInst> { type Output = pretty::Fragment; fn print(&self, printer: &Printer<'_>) -> pretty::Fragment { - let DataInstDef { attrs, kind, inputs, output_type } = self.def(); + let DataInstDef { attrs, kind, inputs, child_regions, outputs } = self.def(); + + assert_eq!(child_regions.len(), 0); let attrs = attrs.print(printer); + // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. + let output_type = if !outputs.is_empty() { + assert_eq!(outputs.len(), 1); + Some(outputs[0].ty) + } else { + None + }; + let mut output_use_to_print_as_lhs = - output_type.map(|_| Use::DataInstOutput(self.position)); + output_type.map(|_| Use::DataInstOutput { inst: self.position, output_idx: 0 }); - let mut output_type_to_print = *output_type; + let mut output_type_to_print = output_type; let def_without_type = match kind { &DataInstKind::FuncCall(func) => pretty::Fragment::new([ diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index 7aa5234b..83c81efc 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -7,8 +7,8 @@ use crate::transform::{InnerInPlaceTransform, InnerTransform, Transformed, Trans use crate::{ AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, DataInst, DataInstDef, DataInstKind, DeclDef, Diag, DiagLevel, EntityDefs, EntityOrientedDenseMap, Func, - FuncDecl, FxIndexMap, GlobalVar, GlobalVarDecl, Module, Node, NodeKind, Type, TypeDef, - TypeKind, TypeOrConst, Value, spv, + FuncDecl, FxIndexMap, GlobalVar, GlobalVarDecl, Module, Node, NodeKind, NodeOutputDecl, Type, + TypeDef, TypeKind, TypeOrConst, Value, spv, }; use smallvec::SmallVec; use std::cell::Cell; @@ -431,7 +431,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { // FIXME(eddyb) maybe all this data should be packaged up together in a // type with fields like those of `DeferredPtrNoop` (or even more). let type_of_val_as_spv_ptr_with_layout = |v: Value| { - if let Value::DataInstOutput(v_data_inst) = v + if let Value::DataInstOutput { inst: v_data_inst, output_idx: 0 } = v && let Some(ptr_noop) = self.deferred_ptr_noops.get(&v_data_inst) { return Ok(( @@ -460,22 +460,22 @@ impl LiftToSpvPtrInstsInFunc<'_> { } DataInstKind::Mem(MemOp::FuncLocalVar(_mem_layout)) => { - let mem_accesses = self.lifter.find_mem_accesses_attr(data_inst_def.attrs)?; + let mem_accesses = + self.lifter.find_mem_accesses_attr(data_inst_def.outputs[0].attrs)?; // FIXME(eddyb) validate against `mem_layout`! let pointee_type = self.lifter.pointee_type_for_accesses(mem_accesses)?; - DataInstDef { - attrs: self.lifter.strip_mem_accesses_attr(data_inst_def.attrs), - kind: DataInstKind::SpvInst(spv::Inst { - opcode: wk.OpVariable, - imms: [spv::Imm::Short(wk.StorageClass, wk.Function)].into_iter().collect(), - }), - inputs: data_inst_def.inputs.clone(), - output_type: Some( - self.lifter - .spv_ptr_type(AddrSpace::SpvStorageClass(wk.Function), pointee_type), - ), - } + + let mut data_inst_def = data_inst_def.clone(); + data_inst_def.kind = DataInstKind::SpvInst(spv::Inst { + opcode: wk.OpVariable, + imms: [spv::Imm::Short(wk.StorageClass, wk.Function)].into_iter().collect(), + }); + data_inst_def.outputs[0].attrs = + self.lifter.strip_mem_accesses_attr(data_inst_def.outputs[0].attrs); + data_inst_def.outputs[0].ty = + self.lifter.spv_ptr_type(AddrSpace::SpvStorageClass(wk.Function), pointee_type); + data_inst_def } DataInstKind::QPtr(QPtrOp::HandleArrayIndex) => { let (addr_space, layout) = @@ -496,12 +496,11 @@ impl LiftToSpvPtrInstsInFunc<'_> { shapes::Handle::Opaque(ty) => ty, shapes::Handle::Buffer(_, buf) => buf.original_type, }; - DataInstDef { - attrs: data_inst_def.attrs, - kind: DataInstKind::SpvInst(wk.OpAccessChain.into()), - inputs: data_inst_def.inputs.clone(), - output_type: Some(self.lifter.spv_ptr_type(addr_space, handle_type)), - } + + let mut data_inst_def = data_inst_def.clone(); + data_inst_def.kind = DataInstKind::SpvInst(wk.OpAccessChain.into()); + data_inst_def.outputs[0].ty = self.lifter.spv_ptr_type(addr_space, handle_type); + data_inst_def } DataInstKind::QPtr(QPtrOp::BufferData) => { let buf_ptr = data_inst_def.inputs[0]; @@ -522,13 +521,11 @@ impl LiftToSpvPtrInstsInFunc<'_> { }, ); - DataInstDef { - kind: QPtrOp::BufferData.into(), - // FIXME(eddyb) avoid the repeated call to `type_of_val`, - // maybe don't even replace the `QPtrOp::BufferData` instruction? - output_type: Some(type_of_val(buf_ptr)), - ..data_inst_def.clone() - } + // FIXME(eddyb) avoid the repeated call to `type_of_val`, + // maybe don't even replace the `QPtrOp::BufferData` instruction? + let mut data_inst_def = data_inst_def.clone(); + data_inst_def.outputs[0].ty = type_of_val(buf_ptr); + data_inst_def } &DataInstKind::QPtr(QPtrOp::BufferDynLen { fixed_base_size, dyn_unit_stride }) => { let buf_ptr = data_inst_def.inputs[0]; @@ -645,6 +642,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { } } + let mut data_inst_def = data_inst_def.clone(); if access_chain_inputs.len() == 1 { self.deferred_ptr_noops.insert( data_inst, @@ -658,21 +656,15 @@ impl LiftToSpvPtrInstsInFunc<'_> { // FIXME(eddyb) avoid the repeated call to `type_of_val`, // maybe don't even replace the `QPtrOp::Offset` instruction? - DataInstDef { - kind: QPtrOp::Offset(0).into(), - output_type: Some(type_of_val(base_ptr)), - ..data_inst_def.clone() - } + data_inst_def.kind = QPtrOp::Offset(0).into(); + data_inst_def.outputs[0].ty = type_of_val(base_ptr); } else { - DataInstDef { - attrs: data_inst_def.attrs, - kind: DataInstKind::SpvInst(wk.OpAccessChain.into()), - inputs: access_chain_inputs, - output_type: Some( - self.lifter.spv_ptr_type(addr_space, layout.original_type), - ), - } + data_inst_def.kind = DataInstKind::SpvInst(wk.OpAccessChain.into()); + data_inst_def.inputs = access_chain_inputs; + data_inst_def.outputs[0].ty = + self.lifter.spv_ptr_type(addr_space, layout.original_type); } + data_inst_def } DataInstKind::QPtr(QPtrOp::DynOffset { stride, index_bounds }) => { let base_ptr = data_inst_def.inputs[0]; @@ -745,18 +737,18 @@ impl LiftToSpvPtrInstsInFunc<'_> { Components::Fields { layouts, .. } => layouts[idx].clone(), }; } - DataInstDef { - attrs: data_inst_def.attrs, - kind: DataInstKind::SpvInst(wk.OpAccessChain.into()), - inputs: access_chain_inputs, - output_type: Some(self.lifter.spv_ptr_type(addr_space, layout.original_type)), - } + let mut data_inst_def = data_inst_def.clone(); + data_inst_def.kind = DataInstKind::SpvInst(wk.OpAccessChain.into()); + data_inst_def.inputs = access_chain_inputs; + data_inst_def.outputs[0].ty = + self.lifter.spv_ptr_type(addr_space, layout.original_type); + data_inst_def } DataInstKind::Mem(op @ (MemOp::Load | MemOp::Store)) => { // HACK(eddyb) `_` will match multiple variants soon. #[allow(clippy::match_wildcard_for_single_variants)] let (spv_opcode, access_type) = match op { - MemOp::Load => (wk.OpLoad, data_inst_def.output_type.unwrap()), + MemOp::Load => (wk.OpLoad, data_inst_def.outputs[0].ty), MemOp::Store => (wk.OpStore, type_of_val(data_inst_def.inputs[1])), _ => unreachable!(), }; @@ -808,7 +800,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { } new_data_inst_def.inputs[input_idx] = - Value::DataInstOutput(access_chain_data_inst); + Value::DataInstOutput { inst: access_chain_data_inst, output_idx: 0 }; } new_data_inst_def @@ -883,12 +875,12 @@ impl LiftToSpvPtrInstsInFunc<'_> { } new_data_inst_def.inputs[input_idx] = - Value::DataInstOutput(access_chain_data_inst); + Value::DataInstOutput { inst: access_chain_data_inst, output_idx: 0 }; } if let Some((addr_space, pointee_type)) = from_spv_ptr_output { - new_data_inst_def.output_type = - Some(self.lifter.spv_ptr_type(addr_space, pointee_type)); + new_data_inst_def.outputs[0].ty = + self.lifter.spv_ptr_type(addr_space, pointee_type); } new_data_inst_def @@ -1004,7 +996,13 @@ impl LiftToSpvPtrInstsInFunc<'_> { attrs: Default::default(), kind: DataInstKind::SpvInst(wk.OpAccessChain.into()), inputs: access_chain_inputs, - output_type: Some(self.lifter.spv_ptr_type(addr_space, access_type)), + child_regions: [].into_iter().collect(), + outputs: [NodeOutputDecl { + attrs: Default::default(), + ty: self.lifter.spv_ptr_type(addr_space, access_type), + }] + .into_iter() + .collect(), }) } else { None @@ -1020,8 +1018,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { for v in values { // FIXME(eddyb) the loop could theoretically be avoided, but that'd // make tracking use counts harder. - while let Value::DataInstOutput(data_inst) = *v { - match self.deferred_ptr_noops.get(&data_inst) { + while let Value::DataInstOutput { inst, output_idx: 0 } = *v { + match self.deferred_ptr_noops.get(&inst) { Some(ptr_noop) => { *v = ptr_noop.output_pointer; } @@ -1035,8 +1033,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { // encoded as `Option` for (dense) map entry reasons. fn add_value_uses(&mut self, values: &[Value]) { for &v in values { - if let Value::DataInstOutput(data_inst) = v { - let count = self.data_inst_use_counts.entry(data_inst); + if let Value::DataInstOutput { inst, .. } = v { + let count = self.data_inst_use_counts.entry(inst); *count = Some( NonZeroU32::new(count.map_or(0, |c| c.get()).checked_add(1).unwrap()).unwrap(), ); @@ -1045,8 +1043,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { } fn remove_value_uses(&mut self, values: &[Value]) { for &v in values { - if let Value::DataInstOutput(data_inst) = v { - let count = self.data_inst_use_counts.entry(data_inst); + if let Value::DataInstOutput { inst, .. } = v { + let count = self.data_inst_use_counts.entry(inst); *count = NonZeroU32::new(count.unwrap().get() - 1); } } @@ -1097,12 +1095,15 @@ impl Transformer for LiftToSpvPtrInstsInFunc<'_> { if let DataInstKind::QPtr(_) = data_inst_def.kind { lifted = Err(LiftError(Diag::bug(["unimplemented qptr instruction".into()]))); - } else if let Some(ty) = data_inst_def.output_type - && matches!(self.lifter.cx[ty].kind, TypeKind::QPtr) - { - lifted = Err(LiftError(Diag::bug([ - "unimplemented qptr-producing instruction".into(), - ]))); + } else { + for output in &data_inst_def.outputs { + if matches!(self.lifter.cx[output.ty].kind, TypeKind::QPtr) { + lifted = Err(LiftError(Diag::bug([ + "unimplemented qptr-producing instruction".into(), + ]))); + break; + } + } } } match lifted { diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index e452bed6..9c9f1ade 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -6,9 +6,10 @@ use crate::qptr::{QPtrAttr, QPtrOp}; use crate::transform::{InnerInPlaceTransform, Transformed, Transformer}; use crate::{ AddrSpace, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, DataInst, DataInstDef, - DataInstKind, Diag, FuncDecl, GlobalVarDecl, Node, NodeKind, OrdAssertEq, Type, TypeKind, - TypeOrConst, Value, spv, + DataInstKind, Diag, FuncDecl, GlobalVarDecl, Node, NodeKind, NodeOutputDecl, OrdAssertEq, Type, + TypeKind, TypeOrConst, Value, spv, }; +use itertools::Itertools as _; use smallvec::SmallVec; use std::cell::Cell; use std::num::NonZeroU32; @@ -405,17 +406,19 @@ impl LowerFromSpvPtrInstsInFunc<'_> { let func = func_at_data_inst_frozen.at(()); let mut attrs = data_inst_def.attrs; - let output_type = data_inst_def.output_type; let spv_inst = match &data_inst_def.kind { DataInstKind::SpvInst(spv_inst) => spv_inst, _ => return Ok(Transformed::Unchanged), }; + // FIXME(eddyb) wasteful clone? (needed due to borrowing issues) + let outputs = data_inst_def.outputs.clone(); + let replacement_kind_and_inputs = if spv_inst.opcode == wk.OpVariable { assert!(data_inst_def.inputs.len() <= 1); let (_, var_data_type) = - self.lowerer.as_spv_ptr_type(output_type.unwrap()).ok_or_else(|| { + self.lowerer.as_spv_ptr_type(outputs[0].ty).ok_or_else(|| { LowerError(Diag::bug(["output type not an `OpTypePointer`".into()])) })?; match self.lowerer.layout_of(var_data_type)? { @@ -547,7 +550,13 @@ impl LowerFromSpvPtrInstsInFunc<'_> { attrs: Default::default(), kind, inputs, - output_type: Some(self.lowerer.qptr_type()), + child_regions: [].into_iter().collect(), + outputs: [NodeOutputDecl { + attrs: Default::default(), + ty: self.lowerer.qptr_type(), + }] + .into_iter() + .collect(), } .into(), ); @@ -566,14 +575,14 @@ impl LowerFromSpvPtrInstsInFunc<'_> { _ => unreachable!(), } - ptr = Value::DataInstOutput(step_data_inst); + ptr = Value::DataInstOutput { inst: step_data_inst, output_idx: 0 }; } final_step.into_data_inst_kind_and_inputs(ptr) } else if spv_inst.opcode == wk.OpBitcast { let input = data_inst_def.inputs[0]; // Pointer-to-pointer casts are noops on `qptr`. if self.lowerer.as_spv_ptr_type(func.at(input).type_of(cx)).is_some() - && self.lowerer.as_spv_ptr_type(output_type.unwrap()).is_some() + && self.lowerer.as_spv_ptr_type(outputs[0].ty).is_some() { // HACK(eddyb) noop cases should not use any `DataInst`s at all, // but that would require the ability to replace all uses of a `Value`. @@ -596,7 +605,8 @@ impl LowerFromSpvPtrInstsInFunc<'_> { attrs, kind: new_kind, inputs: new_inputs, - output_type, + child_regions: [].into_iter().collect(), + outputs, })) } @@ -634,8 +644,9 @@ impl LowerFromSpvPtrInstsInFunc<'_> { ); } } - if let Some(output_type) = data_inst_def.output_type - && let Some((addr_space, pointee)) = self.lowerer.as_spv_ptr_type(output_type) + // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. + if let Some(output) = data_inst_def.outputs.iter().at_most_one().ok().unwrap() + && let Some((addr_space, pointee)) = self.lowerer.as_spv_ptr_type(output.ty) { old_and_new_attrs.get_or_insert_with(get_old_attrs).attrs.insert( QPtrAttr::FromSpvPtrOutput { diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 1561eae2..615ec344 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -11,6 +11,7 @@ use crate::{ ModuleDialect, Node, NodeKind, NodeOutputDecl, OrdAssertEq, Region, RegionInputDecl, Type, TypeDef, TypeKind, TypeOrConst, Value, }; +use itertools::Itertools; use rustc_hash::FxHashMap; use smallvec::SmallVec; use std::borrow::Cow; @@ -1022,7 +1023,7 @@ impl<'a> FuncLifting<'a> { .values() .flat_map(|block| block.insts.iter().copied()) .flat_map(|insts| func_def_body.at(insts)) - .filter(|&func_at_inst| func_at_inst.def().output_type.is_some()) + .filter(|&func_at_inst| !func_at_inst.def().outputs.is_empty()) .map(|func_at_inst| func_at_inst.position); Ok(Self { @@ -1161,7 +1162,11 @@ impl LazyInst<'_, '_> { [usize::try_from(output_idx).unwrap()] .result_id } - Value::DataInstOutput(inst) => parent_func.data_inst_output_ids[&inst], + Value::DataInstOutput { inst, output_idx } => { + // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. + assert_eq!(output_idx, 0); + parent_func.data_inst_output_ids[&inst] + } }; let (result_id, attrs, _) = self.result_id_attrs_and_import(module, ids); @@ -1318,9 +1323,9 @@ impl LazyInst<'_, '_> { }; spv::InstWithIds { without_ids: inst, - result_type_id: data_inst_def - .output_type - .map(|ty| ids.globals[&Global::Type(ty)]), + // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. + result_type_id: (data_inst_def.outputs.iter().at_most_one().ok().unwrap()) + .map(|o| ids.globals[&Global::Type(o.ty)]), result_id, ids: extra_initial_id_operand .into_iter() @@ -1499,10 +1504,14 @@ impl Module { let data_inst_def = func_at_inst.def(); LazyInst::DataInst { parent_func: func_lifting, - result_id: data_inst_def.output_type.map(|_| { - func_lifting.data_inst_output_ids - [&func_at_inst.position] - }), + // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. + result_id: (data_inst_def.outputs.iter().at_most_one()) + .ok() + .unwrap() + .map(|_| { + func_lifting.data_inst_output_ids + [&func_at_inst.position] + }), data_inst_def, } }), diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 7d5d2801..128c468e 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -7,8 +7,8 @@ use crate::{ AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, DataInstDef, DataInstKind, DbgSrcLoc, DeclDef, Diag, EntityDefs, EntityList, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, GlobalVarDefBody, Import, InternedStr, - Module, NodeDef, NodeKind, Region, RegionDef, RegionInputDecl, Type, TypeDef, TypeKind, - TypeOrConst, Value, print, + Module, NodeDef, NodeKind, NodeOutputDecl, Region, RegionDef, RegionInputDecl, Type, TypeDef, + TypeKind, TypeOrConst, Value, print, }; use rustc_hash::FxHashMap; use smallvec::SmallVec; @@ -1058,11 +1058,12 @@ impl Module { attrs: AttrSet::default(), kind: DataInstKind::SpvInst(wk.OpNop.into()), inputs: [].into_iter().collect(), - output_type: None, + child_regions: [].into_iter().collect(), + outputs: [].into_iter().collect(), } .into(), ); - LocalIdDef::Value(Value::DataInstOutput(inst)) + LocalIdDef::Value(Value::DataInstOutput { inst, output_idx: 0 }) } }; local_id_defs.insert(id, local_id_def); @@ -1625,7 +1626,8 @@ impl Module { } }) .collect::>()?, - output_type: result_id + child_regions: [].into_iter().collect(), + outputs: result_id .map(|_| { result_type.ok_or_else(|| { invalid( @@ -1634,11 +1636,17 @@ impl Module { ) }) }) - .transpose()?, + .transpose()? + .into_iter() + .map(|ty| { + // FIXME(eddyb) split attrs between output and inst. + NodeOutputDecl { attrs: AttrSet::default(), ty } + }) + .collect(), }; let inst = match result_id { Some(id) => match local_id_defs[&id] { - LocalIdDef::Value(Value::DataInstOutput(inst)) => { + LocalIdDef::Value(Value::DataInstOutput { inst, .. }) => { // A dummy was defined earlier, to be able to // have an entry in `local_id_defs`. func_def_body.data_insts[inst] = data_inst_def.into(); diff --git a/src/transform.rs b/src/transform.rs index 67b325ca..e8c8ec4d 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -687,15 +687,16 @@ impl InnerTransform for NodeOutputDecl { impl InnerInPlaceTransform for FuncAtMut<'_, DataInst> { fn inner_in_place_transform_with(&mut self, transformer: &mut impl Transformer) { - let DataInstDef { attrs, kind, inputs, output_type } = self.reborrow().def(); + let DataInstDef { attrs, kind, inputs, child_regions, outputs } = self.reborrow().def(); transformer.transform_attr_set_use(*attrs).apply_to(attrs); kind.inner_in_place_transform_with(transformer); for v in inputs { transformer.transform_value_use(v).apply_to(v); } - if let Some(output_type) = output_type { - transformer.transform_type_use(*output_type).apply_to(output_type); + assert_eq!(child_regions.len(), 0); + for output in outputs { + output.inner_transform_with(transformer).apply_to(output); } } } @@ -755,7 +756,7 @@ impl InnerTransform for Value { Self::RegionInput { region: _, input_idx: _ } | Self::NodeOutput { node: _, output_idx: _ } - | Self::DataInstOutput(_) => Transformed::Unchanged, + | Self::DataInstOutput { inst: _, output_idx: _ } => Transformed::Unchanged, } } } diff --git a/src/visit.rs b/src/visit.rs index ad529559..2aa961dc 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -517,15 +517,16 @@ impl InnerVisit for NodeOutputDecl { impl InnerVisit for DataInstDef { fn inner_visit_with<'a>(&'a self, visitor: &mut impl Visitor<'a>) { - let Self { attrs, kind, inputs, output_type } = self; + let Self { attrs, kind, inputs, child_regions, outputs } = self; visitor.visit_attr_set_use(*attrs); kind.inner_visit_with(visitor); for v in inputs { visitor.visit_value_use(v); } - if let Some(ty) = *output_type { - visitor.visit_type_use(ty); + assert_eq!(child_regions.len(), 0); + for output in outputs { + output.inner_visit_with(visitor); } } } @@ -582,7 +583,7 @@ impl InnerVisit for Value { Self::Const(ct) => visitor.visit_const_use(ct), Self::RegionInput { region: _, input_idx: _ } | Self::NodeOutput { node: _, output_idx: _ } - | Self::DataInstOutput(_) => {} + | Self::DataInstOutput { inst: _, output_idx: _ } => {} } } } From a7186bd4fc33d5f07aa1fa3ee005bf154f5a9158 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Tue, 9 Sep 2025 11:42:12 +0300 Subject: [PATCH 5/6] Merge `DataInst` and `Node` (while remaining disjoint in `Region` vs `Block`). --- src/cf/mod.rs | 4 +- src/context.rs | 5 +-- src/func_at.rs | 99 +++++----------------------------------------- src/lib.rs | 34 ++++++---------- src/mem/analyze.rs | 5 +++ src/print/mod.rs | 18 +++++++-- src/qptr/lift.rs | 37 +++++++++++------ src/qptr/lower.rs | 16 ++++++-- src/spv/lift.rs | 40 ++++++++++++++++--- src/spv/lower.rs | 15 +++---- src/transform.rs | 60 +++++++++------------------- src/visit.rs | 59 +++++++++------------------ 12 files changed, 162 insertions(+), 230 deletions(-) diff --git a/src/cf/mod.rs b/src/cf/mod.rs index f86d118a..f2131720 100644 --- a/src/cf/mod.rs +++ b/src/cf/mod.rs @@ -10,7 +10,7 @@ pub mod cfgssa; pub mod structurize; pub mod unstructured; -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq, Hash)] pub enum SelectionKind { /// Two-case selection based on boolean condition, i.e. `if`-`else`, with /// the two cases being "then" and "else" (in that order). @@ -19,7 +19,7 @@ pub enum SelectionKind { SpvInst(spv::Inst), } -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq, Hash)] pub enum ExitInvocationKind { SpvInst(spv::Inst), } diff --git a/src/context.rs b/src/context.rs index 26e0223f..c20e65f0 100644 --- a/src/context.rs +++ b/src/context.rs @@ -484,11 +484,11 @@ impl, V> std::ops::IndexMut for EntityOrientedDens /// [`EntityListNode`] (to hold the "previous/next node" links). /// /// Fields are private to avoid arbitrary user interactions. -#[derive(Copy, Clone)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] pub struct EntityList(Option>); // HACK(eddyb) this only exists to give field names to the non-empty case. -#[derive(Copy, Clone)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] struct FirstLast { first: F, last: L, @@ -957,5 +957,4 @@ entities! { Func => chunk_size(0x1_0000) crate::FuncDecl, Region => chunk_size(0x1000) crate::RegionDef, Node => chunk_size(0x1000) EntityListNode, - DataInst => chunk_size(0x1000) EntityListNode, } diff --git a/src/func_at.rs b/src/func_at.rs index 43f78e3f..44d62ca7 100644 --- a/src/func_at.rs +++ b/src/func_at.rs @@ -12,8 +12,8 @@ #![allow(clippy::should_implement_trait)] use crate::{ - Context, DataInst, DataInstDef, EntityDefs, EntityList, EntityListIter, FuncDefBody, Node, - NodeDef, Region, RegionDef, Type, Value, + Context, EntityDefs, EntityList, EntityListIter, FuncDefBody, Node, NodeDef, Region, RegionDef, + Type, Value, }; /// Immutable traversal (i.e. visiting) helper for intra-function entities. @@ -24,7 +24,6 @@ use crate::{ pub struct FuncAt<'a, P: Copy> { pub regions: &'a EntityDefs, pub nodes: &'a EntityDefs, - pub data_insts: &'a EntityDefs, pub position: P, } @@ -32,12 +31,7 @@ pub struct FuncAt<'a, P: Copy> { impl<'a, P: Copy> FuncAt<'a, P> { /// Reposition to `new_position`. pub fn at(self, new_position: P2) -> FuncAt<'a, P2> { - FuncAt { - regions: self.regions, - nodes: self.nodes, - data_insts: self.data_insts, - position: new_position, - } + FuncAt { regions: self.regions, nodes: self.nodes, position: new_position } } } @@ -82,37 +76,6 @@ impl<'a> FuncAt<'a, Node> { } } -impl<'a> IntoIterator for FuncAt<'a, EntityList> { - type IntoIter = FuncAt<'a, EntityListIter>; - type Item = FuncAt<'a, DataInst>; - fn into_iter(self) -> Self::IntoIter { - self.at(self.position.iter()) - } -} - -impl<'a> Iterator for FuncAt<'a, EntityListIter> { - type Item = FuncAt<'a, DataInst>; - fn next(&mut self) -> Option { - let (next, rest) = self.position.split_first(self.data_insts)?; - self.position = rest; - Some(self.at(next)) - } -} - -impl DoubleEndedIterator for FuncAt<'_, EntityListIter> { - fn next_back(&mut self) -> Option { - let (prev, rest) = self.position.split_last(self.data_insts)?; - self.position = rest; - Some(self.at(prev)) - } -} - -impl<'a> FuncAt<'a, DataInst> { - pub fn def(self) -> &'a DataInstDef { - &self.data_insts[self.position] - } -} - impl FuncAt<'_, Value> { /// Return the [`Type`] of this [`Value`] ([`Context`] used for [`Value::Const`]). pub fn type_of(self, cx: &Context) -> Type { @@ -138,7 +101,6 @@ impl FuncAt<'_, Value> { pub struct FuncAtMut<'a, P: Copy> { pub regions: &'a mut EntityDefs, pub nodes: &'a mut EntityDefs, - pub data_insts: &'a mut EntityDefs, pub position: P, } @@ -146,30 +108,20 @@ pub struct FuncAtMut<'a, P: Copy> { impl<'a, P: Copy> FuncAtMut<'a, P> { /// Emulate a "reborrow", which is automatic only for `&mut` types. pub fn reborrow(&mut self) -> FuncAtMut<'_, P> { - FuncAtMut { - regions: self.regions, - nodes: self.nodes, - data_insts: self.data_insts, - position: self.position, - } + FuncAtMut { regions: self.regions, nodes: self.nodes, position: self.position } } /// Reposition to `new_position`. pub fn at(self, new_position: P2) -> FuncAtMut<'a, P2> { - FuncAtMut { - regions: self.regions, - nodes: self.nodes, - data_insts: self.data_insts, - position: new_position, - } + FuncAtMut { regions: self.regions, nodes: self.nodes, position: new_position } } /// Demote to a `FuncAt`, with the same `position`. // // FIXME(eddyb) maybe find a better name for this? pub fn freeze(self) -> FuncAt<'a, P> { - let FuncAtMut { regions, nodes, data_insts, position } = self; - FuncAt { regions, nodes, data_insts, position } + let FuncAtMut { regions, nodes, position } = self; + FuncAt { regions, nodes, position } } } @@ -207,48 +159,15 @@ impl<'a> FuncAtMut<'a, Node> { } } -// HACK(eddyb) can't implement `IntoIterator` because `next` borrows `self`. -impl<'a> FuncAtMut<'a, EntityList> { - pub fn into_iter(self) -> FuncAtMut<'a, EntityListIter> { - let iter = self.position.iter(); - self.at(iter) - } -} - -// HACK(eddyb) can't implement `Iterator` because `next` borrows `self`. -impl FuncAtMut<'_, EntityListIter> { - pub fn next(&mut self) -> Option> { - let (next, rest) = self.position.split_first(self.data_insts)?; - self.position = rest; - Some(self.reborrow().at(next)) - } -} - -impl<'a> FuncAtMut<'a, DataInst> { - pub fn def(self) -> &'a mut DataInstDef { - &mut self.data_insts[self.position] - } -} - impl FuncDefBody { /// Start immutably traversing the function at `position`. pub fn at(&self, position: P) -> FuncAt<'_, P> { - FuncAt { - regions: &self.regions, - nodes: &self.nodes, - data_insts: &self.data_insts, - position, - } + FuncAt { regions: &self.regions, nodes: &self.nodes, position } } /// Start mutably traversing the function at `position`. pub fn at_mut(&mut self, position: P) -> FuncAtMut<'_, P> { - FuncAtMut { - regions: &mut self.regions, - nodes: &mut self.nodes, - data_insts: &mut self.data_insts, - position, - } + FuncAtMut { regions: &mut self.regions, nodes: &mut self.nodes, position } } /// Shorthand for `func_def_body.at(func_def_body.body)`. diff --git a/src/lib.rs b/src/lib.rs index 0fef462c..173d769b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -702,7 +702,6 @@ pub struct FuncParam { pub struct FuncDefBody { pub regions: EntityDefs, pub nodes: EntityDefs, - pub data_insts: EntityDefs, /// The [`Region`] representing the whole body of the function. /// @@ -852,13 +851,10 @@ pub use context::Node; /// /// See [`Region`] docs for more on control-flow in SPIR-T. #[derive(Clone)] -pub struct NodeDef< - // HACK(eddyb) generic so `DataInstDef` can reuse it, pre-merger. - K = NodeKind, -> { +pub struct NodeDef { pub attrs: AttrSet, - pub kind: K, + pub kind: NodeKind, // FIXME(eddyb) change the inline size of this to fit most nodes. pub inputs: SmallVec<[Value; 2]>, @@ -881,7 +877,7 @@ pub struct NodeOutputDecl { pub ty: Type, } -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] pub enum NodeKind { /// Linear chain of [`DataInst`]s, executing in sequence. /// @@ -924,22 +920,9 @@ pub enum NodeKind { // // FIXME(eddyb) make this less shader-controlflow-centric. ExitInvocation(cf::ExitInvocationKind), -} - -/// Entity handle for a [`DataInstDef`](crate::DataInstDef) (a leaf instruction). -pub use context::DataInst; -/// Definition for a [`DataInst`]: a leaf (non-control-flow) instruction. -// -// HACK(eddyb) temporarily reusing `NodeDef` pre-merger, with: -// - `child_regions` always empty -// - `outputs.len` always <= 1 -pub type DataInstDef = NodeDef; - -#[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] -pub enum DataInstKind { - // FIXME(eddyb) try to split this into recursive and non-recursive calls, - // to avoid needing special handling for recursion where it's impossible. + // NOTE(eddyb) all variants below used to be in `DataInstKind`. + // FuncCall(Func), /// Memory-specific operations (see [`mem::MemOp`]). @@ -958,6 +941,13 @@ pub enum DataInstKind { }, } +// HACK(eddyb) temporarily reusing `Node` pre-merger, with: +// - `child_regions` always empty +// - `outputs.len` always <= 1 +pub type DataInst = Node; +pub type DataInstDef = NodeDef; +pub type DataInstKind = NodeKind; + #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum Value { Const(Const), diff --git a/src/mem/analyze.rs b/src/mem/analyze.rs index 1f62a794..5d593efe 100644 --- a/src/mem/analyze.rs +++ b/src/mem/analyze.rs @@ -922,6 +922,11 @@ impl<'a> GatherAccesses<'a> { }); }; match &data_inst_def.kind { + NodeKind::Block { .. } + | NodeKind::Select(_) + | NodeKind::Loop { .. } + | NodeKind::ExitInvocation(_) => unreachable!(), + &DataInstKind::FuncCall(callee) => { match self.gather_accesses_in_func(module, callee) { FuncGatherAccessesState::Complete(callee_results) => { diff --git a/src/print/mod.rs b/src/print/mod.rs index 4374f2dc..45650348 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -3780,7 +3780,7 @@ impl Print for FuncAt<'_, Node> { pretty::Fragment::new( self.at(*insts) .into_iter() - .map(|func_at_inst| func_at_inst.print(printer)) + .map(|func_at_inst| func_at_inst.print_data_inst(printer)) .flat_map(|entry| [pretty::Node::ForceLineSeparation.into(), entry]), ) } @@ -3882,6 +3882,12 @@ impl Print for FuncAt<'_, Node> { imms, inputs.iter().map(|v| v.print(printer)), ), + + DataInstKind::FuncCall(_) + | DataInstKind::Mem(_) + | DataInstKind::QPtr(_) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => unreachable!(), }; let def_without_name = pretty::Fragment::new([ Use::AlignmentAnchorForNode(self.position).print_as_def(printer), @@ -3915,9 +3921,8 @@ impl Print for NodeOutputDecl { } } -impl Print for FuncAt<'_, DataInst> { - type Output = pretty::Fragment; - fn print(&self, printer: &Printer<'_>) -> pretty::Fragment { +impl FuncAt<'_, DataInst> { + fn print_data_inst(&self, printer: &Printer<'_>) -> pretty::Fragment { let DataInstDef { attrs, kind, inputs, child_regions, outputs } = self.def(); assert_eq!(child_regions.len(), 0); @@ -3938,6 +3943,11 @@ impl Print for FuncAt<'_, DataInst> { let mut output_type_to_print = output_type; let def_without_type = match kind { + NodeKind::Block { .. } + | NodeKind::Select(_) + | NodeKind::Loop { .. } + | NodeKind::ExitInvocation(_) => unreachable!(), + &DataInstKind::FuncCall(func) => pretty::Fragment::new([ printer.declarative_keyword_style().apply("call").into(), " ".into(), diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index 83c81efc..8f2e92b7 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -448,6 +448,11 @@ impl LiftToSpvPtrInstsInFunc<'_> { Ok((addr_space, self.lifter.layout_of(pointee_type)?)) }; let replacement_data_inst_def = match &data_inst_def.kind { + NodeKind::Block { .. } + | NodeKind::Select(_) + | NodeKind::Loop { .. } + | NodeKind::ExitInvocation(_) => unreachable!(), + &DataInstKind::FuncCall(_callee) => { for &v in &data_inst_def.inputs { if self.lifter.as_spv_ptr_type(type_of_val(v)).is_some() { @@ -781,9 +786,11 @@ impl LiftToSpvPtrInstsInFunc<'_> { let access_chain_data_inst = func_at_data_inst .reborrow() - .data_insts + .nodes .define(cx, access_chain_data_inst_def.into()); + // FIXME(eddyb) comment below should be about `nodes` vs `regions` + // (once `Block` and the `Node`-vs-`DataInst` split are gone). // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, // due to the need to borrow `nodes` and `data_insts` // at the same time - perhaps some kind of `FuncAtMut` position @@ -792,9 +799,10 @@ impl LiftToSpvPtrInstsInFunc<'_> { // an actual list entity of its own, should be considered. let data_inst = func_at_data_inst.position; let func = func_at_data_inst.reborrow().at(()); - match &mut func.nodes[parent_block].kind { - NodeKind::Block { insts } => { - insts.insert_before(access_chain_data_inst, data_inst, func.data_insts); + match func.nodes[parent_block].kind { + NodeKind::Block { mut insts } => { + insts.insert_before(access_chain_data_inst, data_inst, func.nodes); + func.nodes[parent_block].kind = NodeKind::Block { insts }; } _ => unreachable!(), } @@ -856,9 +864,11 @@ impl LiftToSpvPtrInstsInFunc<'_> { let access_chain_data_inst = func_at_data_inst .reborrow() - .data_insts + .nodes .define(cx, access_chain_data_inst_def.into()); + // FIXME(eddyb) comment below should be about `nodes` vs `regions` + // (once `Block` and the `Node`-vs-`DataInst` split are gone). // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, // due to the need to borrow `nodes` and `data_insts` // at the same time - perhaps some kind of `FuncAtMut` position @@ -867,9 +877,10 @@ impl LiftToSpvPtrInstsInFunc<'_> { // an actual list entity of its own, should be considered. let data_inst = func_at_data_inst.position; let func = func_at_data_inst.reborrow().at(()); - match &mut func.nodes[parent_block].kind { - NodeKind::Block { insts } => { - insts.insert_before(access_chain_data_inst, data_inst, func.data_insts); + match func.nodes[parent_block].kind { + NodeKind::Block { mut insts } => { + insts.insert_before(access_chain_data_inst, data_inst, func.nodes); + func.nodes[parent_block].kind = NodeKind::Block { insts }; } _ => unreachable!(), } @@ -1156,15 +1167,19 @@ impl Transformer for LiftToSpvPtrInstsInFunc<'_> { // use counts of an earlier definition, allowing further removal. for (inst, ptr_noop) in deferred_ptr_noops.into_iter().rev() { if self.data_inst_use_counts.get(inst).is_none() { + // FIXME(eddyb) comment below should be about `nodes` vs `regions` + // (once `Block` and the `Node`-vs-`DataInst` split are gone). // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, // due to the need to borrow `nodes` and `data_insts` // at the same time - perhaps some kind of `FuncAtMut` position // types for "where a list is in a parent entity" could be used // to make this more ergonomic, although the potential need for // an actual list entity of its own, should be considered. - match &mut func_def_body.nodes[ptr_noop.parent_block].kind { - NodeKind::Block { insts } => { - insts.remove(inst, &mut func_def_body.data_insts); + match func_def_body.nodes[ptr_noop.parent_block].kind { + NodeKind::Block { mut insts } => { + insts.remove(inst, &mut func_def_body.nodes); + func_def_body.nodes[ptr_noop.parent_block].kind = + NodeKind::Block { insts }; } _ => unreachable!(), } diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index 9c9f1ade..b99cac4a 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -544,7 +544,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { let mut ptr = base_ptr; for step in steps { let (kind, inputs) = step.into_data_inst_kind_and_inputs(ptr); - let step_data_inst = func_at_data_inst.reborrow().data_insts.define( + let step_data_inst = func_at_data_inst.reborrow().nodes.define( cx, DataInstDef { attrs: Default::default(), @@ -561,6 +561,8 @@ impl LowerFromSpvPtrInstsInFunc<'_> { .into(), ); + // FIXME(eddyb) comment below should be about `nodes` vs `regions` + // (once `Block` and the `Node`-vs-`DataInst` split are gone). // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, // due to the need to borrow `nodes` and `data_insts` // at the same time - perhaps some kind of `FuncAtMut` position @@ -568,9 +570,10 @@ impl LowerFromSpvPtrInstsInFunc<'_> { // to make this more ergonomic, although the potential need for // an actual list entity of its own, should be considered. let func = func_at_data_inst.reborrow().at(()); - match &mut func.nodes[parent_block].kind { - NodeKind::Block { insts } => { - insts.insert_before(step_data_inst, data_inst, func.data_insts); + match func.nodes[parent_block].kind { + NodeKind::Block { mut insts } => { + insts.insert_before(step_data_inst, data_inst, func.nodes); + func.nodes[parent_block].kind = NodeKind::Block { insts }; } _ => unreachable!(), } @@ -624,6 +627,11 @@ impl LowerFromSpvPtrInstsInFunc<'_> { let func = func_at_data_inst_frozen.at(()); match data_inst_def.kind { + NodeKind::Block { .. } + | NodeKind::Select(_) + | NodeKind::Loop { .. } + | NodeKind::ExitInvocation(_) => unreachable!(), + // Known semantics, no need to preserve SPIR-V pointer information. DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) => return, diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 615ec344..2212d634 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -211,9 +211,14 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { attr.inner_visit_with(self); } - fn visit_data_inst_def(&mut self, data_inst_def: &DataInstDef) { + fn visit_data_inst_def(&mut self, func_at_inst: FuncAt<'_, DataInst>) { #[allow(clippy::match_same_arms)] - match data_inst_def.kind { + match func_at_inst.def().kind { + NodeKind::Block { .. } + | NodeKind::Select(_) + | NodeKind::Loop { .. } + | NodeKind::ExitInvocation(_) => unreachable!(), + // FIXME(eddyb) this should be a proper `Result`-based error instead, // and/or `spv::lift` should mutate the module for legalization. DataInstKind::Mem(_) => { @@ -233,7 +238,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { self.ext_inst_imports.insert(&self.cx[ext_set]); } } - data_inst_def.inner_visit_with(self); + func_at_inst.inner_visit_with(self); } } @@ -432,6 +437,12 @@ impl<'p> FuncAt<'_, CfgCursor<'p>> { NodeKind::Select { .. } | NodeKind::Loop { .. } | NodeKind::ExitInvocation { .. } => None, + + DataInstKind::FuncCall(_) + | DataInstKind::Mem(_) + | DataInstKind::QPtr(_) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => unreachable!(), }, // Exiting a `Node` chains to a sibling/parent. @@ -731,6 +742,12 @@ impl<'a> FuncLifting<'a> { target_phi_values: FxIndexMap::default(), merge: None, }, + + DataInstKind::FuncCall(_) + | DataInstKind::Mem(_) + | DataInstKind::QPtr(_) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => unreachable!(), } } @@ -805,6 +822,12 @@ impl<'a> FuncLifting<'a> { } } } + + DataInstKind::FuncCall(_) + | DataInstKind::Mem(_) + | DataInstKind::QPtr(_) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => unreachable!(), } } @@ -1305,8 +1328,15 @@ impl LazyInst<'_, '_> { }, Self::DataInst { parent_func, result_id: _, data_inst_def } => { let (inst, extra_initial_id_operand) = match &data_inst_def.kind { - // Disallowed while visiting. - DataInstKind::Mem(_) | DataInstKind::QPtr(_) => unreachable!(), + NodeKind::Block { .. } + | NodeKind::Select(_) + | NodeKind::Loop { .. } + | NodeKind::ExitInvocation(_) => unreachable!(), + + DataInstKind::Mem(_) | DataInstKind::QPtr(_) => { + // Disallowed while visiting. + unreachable!() + } &DataInstKind::FuncCall(callee) => { (wk.OpFunctionCall.into(), Some(ids.funcs[&callee].func_id)) diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 128c468e..c6a3930b 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -817,7 +817,6 @@ impl Module { DeclDef::Present(FuncDefBody { regions, nodes: Default::default(), - data_insts: Default::default(), body, unstructured_cfg: Some(cf::unstructured::ControlFlowGraph::default()), }) @@ -1052,7 +1051,7 @@ impl Module { } else { // HACK(eddyb) can't get a `DataInst` without // defining it (as a dummy) first. - let inst = func_def_body.data_insts.define( + let inst = func_def_body.nodes.define( &cx, DataInstDef { attrs: AttrSet::default(), @@ -1649,13 +1648,13 @@ impl Module { LocalIdDef::Value(Value::DataInstOutput { inst, .. }) => { // A dummy was defined earlier, to be able to // have an entry in `local_id_defs`. - func_def_body.data_insts[inst] = data_inst_def.into(); + func_def_body.nodes[inst] = data_inst_def.into(); inst } _ => unreachable!(), }, - None => func_def_body.data_insts.define(&cx, data_inst_def.into()), + None => func_def_body.nodes.define(&cx, data_inst_def.into()), }; let current_block_node = current_block_region_def @@ -1682,9 +1681,11 @@ impl Module { .insert_last(block_node, &mut func_def_body.nodes); block_node }); - match &mut func_def_body.nodes[current_block_node].kind { - NodeKind::Block { insts } => { - insts.insert_last(inst, &mut func_def_body.data_insts); + match func_def_body.nodes[current_block_node].kind { + NodeKind::Block { mut insts } => { + insts.insert_last(inst, &mut func_def_body.nodes); + func_def_body.nodes[current_block_node].kind = + NodeKind::Block { insts }; } _ => unreachable!(), } diff --git a/src/transform.rs b/src/transform.rs index e8c8ec4d..8b796ed0 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -5,11 +5,11 @@ use crate::func_at::FuncAtMut; use crate::mem::{DataHapp, DataHappKind, MemAccesses, MemAttr, MemOp}; use crate::qptr::{QPtrAttr, QPtrOp}; use crate::{ - AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, DataInst, DataInstDef, - DataInstKind, DbgSrcLoc, DeclDef, EntityListIter, ExportKey, Exportee, Func, FuncDecl, - FuncDefBody, FuncParam, GlobalVar, GlobalVarDecl, GlobalVarDefBody, Import, Module, - ModuleDebugInfo, ModuleDialect, Node, NodeDef, NodeKind, NodeOutputDecl, OrdAssertEq, Region, - RegionDef, RegionInputDecl, Type, TypeDef, TypeKind, TypeOrConst, Value, spv, + AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, DataInst, DataInstKind, + DbgSrcLoc, DeclDef, EntityListIter, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, + FuncParam, GlobalVar, GlobalVarDecl, GlobalVarDefBody, Import, Module, ModuleDebugInfo, + ModuleDialect, Node, NodeDef, NodeKind, NodeOutputDecl, OrdAssertEq, Region, RegionDef, + RegionInputDecl, Type, TypeDef, TypeKind, TypeOrConst, Value, spv, }; use std::cmp::Ordering; use std::rc::Rc; @@ -643,9 +643,22 @@ impl InnerInPlaceTransform for FuncAtMut<'_, Node> { transformer.in_place_transform_data_inst_def(func_at_inst); } } + + DataInstKind::FuncCall(func) => transformer.transform_func_use(*func).apply_to(func), + NodeKind::Select(SelectionKind::BoolCond | SelectionKind::SpvInst(_)) | NodeKind::Loop { repeat_condition: _ } - | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) => {} + | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) + | DataInstKind::Mem(MemOp::FuncLocalVar(_) | MemOp::Load | MemOp::Store) + | DataInstKind::QPtr( + QPtrOp::HandleArrayIndex + | QPtrOp::BufferData + | QPtrOp::BufferDynLen { .. } + | QPtrOp::Offset(_) + | QPtrOp::DynOffset { .. }, + ) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => {} } for v in &mut self.reborrow().def().inputs { @@ -685,41 +698,6 @@ impl InnerTransform for NodeOutputDecl { } } -impl InnerInPlaceTransform for FuncAtMut<'_, DataInst> { - fn inner_in_place_transform_with(&mut self, transformer: &mut impl Transformer) { - let DataInstDef { attrs, kind, inputs, child_regions, outputs } = self.reborrow().def(); - - transformer.transform_attr_set_use(*attrs).apply_to(attrs); - kind.inner_in_place_transform_with(transformer); - for v in inputs { - transformer.transform_value_use(v).apply_to(v); - } - assert_eq!(child_regions.len(), 0); - for output in outputs { - output.inner_transform_with(transformer).apply_to(output); - } - } -} - -impl InnerInPlaceTransform for DataInstKind { - fn inner_in_place_transform_with(&mut self, transformer: &mut impl Transformer) { - match self { - DataInstKind::FuncCall(func) => transformer.transform_func_use(*func).apply_to(func), - DataInstKind::Mem(op) => match op { - MemOp::FuncLocalVar(_) | MemOp::Load | MemOp::Store => {} - }, - DataInstKind::QPtr(op) => match op { - QPtrOp::HandleArrayIndex - | QPtrOp::BufferData - | QPtrOp::BufferDynLen { .. } - | QPtrOp::Offset(_) - | QPtrOp::DynOffset { .. } => {} - }, - DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {} - } - } -} - impl InnerInPlaceTransform for cf::unstructured::ControlInst { fn inner_in_place_transform_with(&mut self, transformer: &mut impl Transformer) { let Self { attrs, kind, inputs, targets: _, target_inputs } = self; diff --git a/src/visit.rs b/src/visit.rs index 2aa961dc..25cfabfa 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -5,7 +5,7 @@ use crate::func_at::FuncAt; use crate::mem::{DataHapp, DataHappKind, MemAccesses, MemAttr, MemOp}; use crate::qptr::{QPtrAttr, QPtrOp}; use crate::{ - AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, DataInstDef, DataInstKind, + AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, DataInst, DataInstKind, DbgSrcLoc, DeclDef, DiagMsgPart, EntityListIter, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, GlobalVar, GlobalVarDecl, GlobalVarDefBody, Import, Module, ModuleDebugInfo, ModuleDialect, Node, NodeDef, NodeKind, NodeOutputDecl, OrdAssertEq, Region, @@ -65,8 +65,8 @@ pub trait Visitor<'a>: Sized { fn visit_node_def(&mut self, func_at_node: FuncAt<'a, Node>) { func_at_node.inner_visit_with(self); } - fn visit_data_inst_def(&mut self, data_inst_def: &'a DataInstDef) { - data_inst_def.inner_visit_with(self); + fn visit_data_inst_def(&mut self, func_at_inst: FuncAt<'a, DataInst>) { + func_at_inst.inner_visit_with(self); } fn visit_value_use(&mut self, v: &'a Value) { v.inner_visit_with(self); @@ -128,7 +128,6 @@ impl_visit! { visit_const_def(ConstDef), visit_global_var_decl(GlobalVarDecl), visit_func_decl(FuncDecl), - visit_data_inst_def(DataInstDef), visit_value_use(Value), } forward_to_inner_visit { @@ -481,12 +480,25 @@ impl<'a> FuncAt<'a, Node> { match kind { NodeKind::Block { insts } => { for func_at_inst in self.at(*insts) { - visitor.visit_data_inst_def(func_at_inst.def()); + visitor.visit_data_inst_def(func_at_inst); } } + + &DataInstKind::FuncCall(func) => visitor.visit_func_use(func), + NodeKind::Select(SelectionKind::BoolCond | SelectionKind::SpvInst(_)) | NodeKind::Loop { repeat_condition: _ } - | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) => {} + | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) + | DataInstKind::Mem(MemOp::FuncLocalVar(_) | MemOp::Load | MemOp::Store) + | DataInstKind::QPtr( + QPtrOp::HandleArrayIndex + | QPtrOp::BufferData + | QPtrOp::BufferDynLen { .. } + | QPtrOp::Offset(_) + | QPtrOp::DynOffset { .. }, + ) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => {} } for v in inputs { visitor.visit_value_use(v); @@ -515,41 +527,6 @@ impl InnerVisit for NodeOutputDecl { } } -impl InnerVisit for DataInstDef { - fn inner_visit_with<'a>(&'a self, visitor: &mut impl Visitor<'a>) { - let Self { attrs, kind, inputs, child_regions, outputs } = self; - - visitor.visit_attr_set_use(*attrs); - kind.inner_visit_with(visitor); - for v in inputs { - visitor.visit_value_use(v); - } - assert_eq!(child_regions.len(), 0); - for output in outputs { - output.inner_visit_with(visitor); - } - } -} - -impl InnerVisit for DataInstKind { - fn inner_visit_with<'a>(&'a self, visitor: &mut impl Visitor<'a>) { - match self { - &DataInstKind::FuncCall(func) => visitor.visit_func_use(func), - DataInstKind::Mem(op) => match *op { - MemOp::FuncLocalVar(_) | MemOp::Load | MemOp::Store => {} - }, - DataInstKind::QPtr(op) => match *op { - QPtrOp::HandleArrayIndex - | QPtrOp::BufferData - | QPtrOp::BufferDynLen { .. } - | QPtrOp::Offset(_) - | QPtrOp::DynOffset { .. } => {} - }, - DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {} - } - } -} - impl InnerVisit for cf::unstructured::ControlInst { fn inner_visit_with<'a>(&'a self, visitor: &mut impl Visitor<'a>) { let Self { attrs, kind, inputs, targets: _, target_inputs } = self; From 6ffd5a2501593f1a625b8fe4b8d3d0f0d4804513 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Tue, 9 Sep 2025 11:42:12 +0300 Subject: [PATCH 6/6] Remove `NodeKind::Block`, flattening "data insts" into `Region` children. --- src/cf/structurize.rs | 17 +- src/func_at.rs | 3 - src/lib.rs | 19 +- src/mem/analyze.rs | 824 +++++++++++++++++++++--------------------- src/print/mod.rs | 106 ++---- src/qptr/lift.rs | 190 +++++----- src/qptr/lower.rs | 72 ++-- src/spv/lift.rs | 156 ++++---- src/spv/lower.rs | 45 +-- src/transform.rs | 23 +- src/visit.rs | 22 +- 11 files changed, 652 insertions(+), 825 deletions(-) diff --git a/src/cf/structurize.rs b/src/cf/structurize.rs index e80cd0e7..d566595f 100644 --- a/src/cf/structurize.rs +++ b/src/cf/structurize.rs @@ -1278,26 +1278,11 @@ impl<'a> Structurizer<'a> { .iter() .filter_map(|case| { let &ClaimedRegion { structured_body, .. } = case.as_ref().ok()?; - // FIXME(eddyb) maybe there should be a `FuncAt` - // helper for "debug locations from all `Block` `DataInst`s - // and non-`Block` `Node`"? (i.e. only flattening `Block`s) this.func_def_body .at(structured_body) .at_children() .into_iter() - .flat_map(|func_at_child| { - let child_def = func_at_child.def(); - if let NodeKind::Block { insts } = child_def.kind { - Either::Left( - func_at_child - .at(insts) - .into_iter() - .map(|func_at_inst| func_at_inst.def().attrs), - ) - } else { - Either::Right([child_def.attrs].into_iter()) - } - }) + .map(|func_at_child| func_at_child.def().attrs) .rev() .find_map(&mut relevant_dbg_src_loc) .map(|dbg_src_loc| dbg_src_loc.end_line_col) diff --git a/src/func_at.rs b/src/func_at.rs index 44d62ca7..4d506fb2 100644 --- a/src/func_at.rs +++ b/src/func_at.rs @@ -87,9 +87,6 @@ impl FuncAt<'_, Value> { Value::NodeOutput { node, output_idx } => { self.at(node).def().outputs[output_idx as usize].ty } - Value::DataInstOutput { inst, output_idx } => { - self.at(inst).def().outputs[output_idx as usize].ty - } } } } diff --git a/src/lib.rs b/src/lib.rs index 173d769b..a73ad782 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -867,6 +867,7 @@ pub struct NodeDef { /// * values provided by `region.outputs`, where `region` is the executed /// child [`Region`]: /// * when this is a `Select`: the case that was chosen + // TODO(eddyb) include former `DataInst`s in above docs. pub outputs: SmallVec<[NodeOutputDecl; 2]>, } @@ -879,16 +880,6 @@ pub struct NodeOutputDecl { #[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] pub enum NodeKind { - /// Linear chain of [`DataInst`]s, executing in sequence. - /// - /// This is only an optimization over keeping [`DataInst`]s in [`Region`] - /// linear chains directly, or even merging [`DataInst`] with [`Node`]. - Block { - // FIXME(eddyb) should empty blocks be allowed? should `DataInst`s be - // linked directly into the `Region` `children` list? - insts: EntityList, - }, - /// Choose one [`Region`] out of `child_regions` to execute, based on a single /// value input (`input[0]`) interpreted according to [`SelectionKind`]. /// @@ -966,15 +957,9 @@ pub enum Value { /// * value provided by `region.outputs[output_idx]`, where `region` is the /// executed child [`Region`] (of `node`): /// * when `node` is a `Select`: the case that was chosen + // TODO(eddyb) include former `DataInst`s in above docs. NodeOutput { node: Node, output_idx: u32, }, - - /// The output value of a [`DataInst`]. - DataInstOutput { - inst: DataInst, - // HACK(eddyb) temporarily aligned with `NodeDef` pre-merger (always == 0). - output_idx: u32, - }, } diff --git a/src/mem/analyze.rs b/src/mem/analyze.rs index 5d593efe..83f74d9c 100644 --- a/src/mem/analyze.rs +++ b/src/mem/analyze.rs @@ -7,9 +7,9 @@ use crate::mem::{DataHapp, DataHappKind, MemAccesses, MemAttr, MemOp, shapes}; use crate::qptr::{QPtrAttr, QPtrOp}; use crate::visit::{InnerVisit, Visitor}; use crate::{ - AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstKind, Context, DataInst, DataInstKind, - DeclDef, Diag, EntityList, ExportKey, Exportee, Func, FxIndexMap, GlobalVar, Module, Node, - NodeKind, OrdAssertEq, Type, TypeKind, Value, + AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstKind, Context, DataInstKind, DeclDef, Diag, + ExportKey, Exportee, Func, FxIndexMap, GlobalVar, Module, Node, NodeKind, OrdAssertEq, Type, + TypeKind, Value, }; use itertools::{Either, Itertools as _}; use rustc_hash::FxHashMap; @@ -792,18 +792,6 @@ impl<'a> GatherAccesses<'a> { &mut node_def.outputs[output_idx as usize].attrs } } - Value::DataInstOutput { inst, output_idx } => { - let inst_def = func_def_body.at_mut(inst).def(); - - // HACK(eddyb) `DataInstOutput { output_idx: !0, .. }` - // may be used to attach errors to a whole `DataInst`. - if output_idx == !0 { - assert!(accesses.is_err()); - &mut inst_def.attrs - } else { - &mut inst_def.outputs[output_idx as usize].attrs - } - } }; match accesses { Ok(accesses) => { @@ -874,429 +862,437 @@ impl<'a> GatherAccesses<'a> { } }; - let mut all_data_insts = CollectAllDataInsts::default(); - func_def_body.inner_visit_with(&mut all_data_insts); - - let mut data_inst_output_accesses = FxHashMap::default(); - for insts in all_data_insts.0.into_iter().rev() { - for func_at_inst in func_def_body.at(insts).into_iter().rev() { - let data_inst = func_at_inst.position; - let data_inst_def = func_at_inst.def(); - let output_accesses = data_inst_output_accesses.remove(&data_inst).flatten(); - - let mut generate_accesses = |this: &mut Self, ptr: Value, new_accesses| { - let slot = match ptr { - Value::Const(ct) => match cx[ct].kind { - ConstKind::PtrToGlobalVar(gv) => { - this.global_var_accesses.entry(gv).or_default() - } - // FIXME(eddyb) may be relevant? - _ => unreachable!(), - }, - Value::RegionInput { region, input_idx } - if region == func_def_body.body => - { - &mut param_accesses[input_idx as usize] + let mut node_to_per_output_accesses: FxHashMap<_, SmallVec<[Option<_>; 2]>> = + FxHashMap::default(); + + // HACK(eddyb) reversing a post-order traversal to get RPO, which for + // structured control-flow means outside-in/top-down (just like pre-order), + // while post-order and reverse pre-order are inside-out/bottom-up. + let mut post_order_nodes = vec![]; + func_def_body.inner_visit_with(&mut VisitAllNodes { + before: |_| {}, + after: |node| post_order_nodes.push(node), + }); + for node in post_order_nodes.into_iter().rev() { + let per_output_accesses = node_to_per_output_accesses.remove(&node).unwrap_or_default(); + + let node_def = func_def_body.at(node).def(); + let mut generate_accesses = |this: &mut Self, ptr: Value, new_accesses| { + let slot = match ptr { + Value::Const(ct) => match cx[ct].kind { + ConstKind::PtrToGlobalVar(gv) => { + this.global_var_accesses.entry(gv).or_default() + } + // FIXME(eddyb) may be relevant? + _ => unreachable!(), + }, + Value::RegionInput { region, input_idx } if region == func_def_body.body => { + &mut param_accesses[input_idx as usize] + } + Value::RegionInput { .. } => { + // FIXME(eddyb) don't throw away `new_accesses`. + accesses_or_err_attrs_to_attach + .push((ptr, Err(AnalysisError(Diag::bug(["unsupported φ".into()]))))); + return; + } + Value::NodeOutput { node: ptr_node, output_idx } => { + let i = output_idx as usize; + let slots = node_to_per_output_accesses.entry(ptr_node).or_default(); + if i >= slots.len() { + slots.extend((slots.len()..=i).map(|_| None)); } - // FIXME(eddyb) implement - Value::RegionInput { .. } | Value::NodeOutput { .. } => { + &mut slots[i] + } + }; + *slot = Some(match slot.take() { + Some(old) => old.and_then(|old| { + AccessMerger { layout_cache: &this.layout_cache } + .merge(old, new_accesses?) + .into_result() + }), + None => new_accesses, + }); + }; + + match &node_def.kind { + NodeKind::Select(_) | NodeKind::Loop { .. } | NodeKind::ExitInvocation { .. } => { + for (i, accesses) in per_output_accesses.iter().enumerate() { + let output = Value::NodeOutput { node, output_idx: i.try_into().unwrap() }; + if let Some(_accesses) = accesses { + // FIXME(eddyb) don't throw away `accesses`. accesses_or_err_attrs_to_attach.push(( - ptr, + output, Err(AnalysisError(Diag::bug(["unsupported φ".into()]))), )); - return; - } - Value::DataInstOutput { inst: ptr_inst, output_idx } => { - // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. - assert_eq!(output_idx, 0); - data_inst_output_accesses.entry(ptr_inst).or_default() } - }; - *slot = Some(match slot.take() { - Some(old) => old.and_then(|old| { - AccessMerger { layout_cache: &this.layout_cache } - .merge(old, new_accesses?) - .into_result() - }), - None => new_accesses, - }); - }; - match &data_inst_def.kind { - NodeKind::Block { .. } - | NodeKind::Select(_) - | NodeKind::Loop { .. } - | NodeKind::ExitInvocation(_) => unreachable!(), - - &DataInstKind::FuncCall(callee) => { - match self.gather_accesses_in_func(module, callee) { - FuncGatherAccessesState::Complete(callee_results) => { - for (&arg, param_accesses) in - data_inst_def.inputs.iter().zip(&callee_results.param_accesses) - { - if let Some(param_accesses) = param_accesses { - generate_accesses(self, arg, param_accesses.clone()); - } + } + + continue; + } + + DataInstKind::FuncCall(_) + | DataInstKind::Mem(_) + | DataInstKind::QPtr(_) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => {} + } + + // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. + assert!(per_output_accesses.len() <= 1); + let output_accesses = per_output_accesses.into_iter().next().flatten(); + + // FIXME(eddyb) merge with `match &node_def.kind` above. + let data_inst_def = node_def; + match &data_inst_def.kind { + NodeKind::Select(_) | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) => { + unreachable!() + } + + &DataInstKind::FuncCall(callee) => { + match self.gather_accesses_in_func(module, callee) { + FuncGatherAccessesState::Complete(callee_results) => { + for (&arg, param_accesses) in + data_inst_def.inputs.iter().zip(&callee_results.param_accesses) + { + if let Some(param_accesses) = param_accesses { + generate_accesses(self, arg, param_accesses.clone()); } } - FuncGatherAccessesState::InProgress => { - accesses_or_err_attrs_to_attach.push(( - Value::DataInstOutput { inst: data_inst, output_idx: 0 }, - Err(AnalysisError(Diag::bug([ - "unsupported recursive call".into() - ]))), - )); - } - }; - // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. - if (data_inst_def.outputs.iter().at_most_one().ok().unwrap()) - .is_some_and(|o| is_qptr(o.ty)) - && let Some(accesses) = output_accesses - { + } + FuncGatherAccessesState::InProgress => { accesses_or_err_attrs_to_attach.push(( - Value::DataInstOutput { inst: data_inst, output_idx: 0 }, - accesses, + Value::NodeOutput { node, output_idx: 0 }, + Err(AnalysisError(Diag::bug( + ["unsupported recursive call".into()], + ))), )); } + }; + // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. + if (data_inst_def.outputs.iter().at_most_one().ok().unwrap()) + .is_some_and(|o| is_qptr(o.ty)) + && let Some(accesses) = output_accesses + { + accesses_or_err_attrs_to_attach + .push((Value::NodeOutput { node, output_idx: 0 }, accesses)); } + } - DataInstKind::Mem(MemOp::FuncLocalVar(_)) => { - if let Some(accesses) = output_accesses { - accesses_or_err_attrs_to_attach.push(( - Value::DataInstOutput { inst: data_inst, output_idx: 0 }, - accesses, - )); - } + DataInstKind::Mem(MemOp::FuncLocalVar(_)) => { + if let Some(accesses) = output_accesses { + accesses_or_err_attrs_to_attach + .push((Value::NodeOutput { node, output_idx: 0 }, accesses)); } - DataInstKind::QPtr(QPtrOp::HandleArrayIndex) => { - generate_accesses( - self, - data_inst_def.inputs[0], - output_accesses - .unwrap_or_else(|| { - Err(AnalysisError(Diag::bug([ - "HandleArrayIndex: unknown element".into(), - ]))) - }) - .and_then(|accesses| match accesses { - MemAccesses::Handles(handle) => { - Ok(MemAccesses::Handles(handle)) + } + DataInstKind::QPtr(QPtrOp::HandleArrayIndex) => { + generate_accesses( + self, + data_inst_def.inputs[0], + output_accesses + .unwrap_or_else(|| { + Err(AnalysisError(Diag::bug([ + "HandleArrayIndex: unknown element".into() + ]))) + }) + .and_then(|accesses| match accesses { + MemAccesses::Handles(handle) => Ok(MemAccesses::Handles(handle)), + MemAccesses::Data(_) => Err(AnalysisError(Diag::bug([ + "HandleArrayIndex: cannot be accessed as data".into(), + ]))), + }), + ); + } + DataInstKind::QPtr(QPtrOp::BufferData) => { + generate_accesses( + self, + data_inst_def.inputs[0], + output_accesses.unwrap_or(Ok(MemAccesses::Data(DataHapp::DEAD))).and_then( + |accesses| { + let happ = match accesses { + MemAccesses::Handles(_) => { + return Err(AnalysisError(Diag::bug([ + "BufferData: cannot be accessed as handles".into(), + ]))); } - MemAccesses::Data(_) => Err(AnalysisError(Diag::bug([ - "HandleArrayIndex: cannot be accessed as data".into(), - ]))), - }), - ); - } - DataInstKind::QPtr(QPtrOp::BufferData) => { - generate_accesses( - self, - data_inst_def.inputs[0], - output_accesses - .unwrap_or(Ok(MemAccesses::Data(DataHapp::DEAD))) - .and_then(|accesses| { - let happ = match accesses { - MemAccesses::Handles(_) => { - return Err(AnalysisError(Diag::bug([ - "BufferData: cannot be accessed as handles".into(), - ]))); - } - MemAccesses::Data(happ) => happ, - }; - Ok(MemAccesses::Handles(shapes::Handle::Buffer( - AddrSpace::Handles, - happ, - ))) - }), - ); - } - &DataInstKind::QPtr(QPtrOp::BufferDynLen { - fixed_base_size, - dyn_unit_stride, - }) => { - let array_happ = DataHapp { - max_size: None, - kind: DataHappKind::Repeated { - element: Rc::new(DataHapp::DEAD), - stride: dyn_unit_stride, + MemAccesses::Data(happ) => happ, + }; + Ok(MemAccesses::Handles(shapes::Handle::Buffer( + AddrSpace::Handles, + happ, + ))) }, - }; - let buf_data_happ = if fixed_base_size == 0 { - array_happ - } else { - DataHapp { - max_size: None, - kind: DataHappKind::Disjoint(Rc::new( - [(fixed_base_size, array_happ)].into(), - )), - } - }; - generate_accesses( - self, - data_inst_def.inputs[0], - Ok(MemAccesses::Handles(shapes::Handle::Buffer( - AddrSpace::Handles, - buf_data_happ, - ))), - ); - } - &DataInstKind::QPtr(QPtrOp::Offset(offset)) => { - generate_accesses( - self, - data_inst_def.inputs[0], - output_accesses - .unwrap_or(Ok(MemAccesses::Data(DataHapp::DEAD))) - .and_then(|accesses| { - let happ = match accesses { - MemAccesses::Handles(_) => { - return Err(AnalysisError(Diag::bug([format!( - "Offset({offset}): cannot offset in handle memory" - ) - .into()]))); - } - MemAccesses::Data(happ) => happ, - }; - let offset = u32::try_from(offset).ok().ok_or_else(|| { - AnalysisError(Diag::bug([format!( - "Offset({offset}): negative offset" + ), + ); + } + &DataInstKind::QPtr(QPtrOp::BufferDynLen { fixed_base_size, dyn_unit_stride }) => { + let array_happ = DataHapp { + max_size: None, + kind: DataHappKind::Repeated { + element: Rc::new(DataHapp::DEAD), + stride: dyn_unit_stride, + }, + }; + let buf_data_happ = if fixed_base_size == 0 { + array_happ + } else { + DataHapp { + max_size: None, + kind: DataHappKind::Disjoint(Rc::new( + [(fixed_base_size, array_happ)].into(), + )), + } + }; + generate_accesses( + self, + data_inst_def.inputs[0], + Ok(MemAccesses::Handles(shapes::Handle::Buffer( + AddrSpace::Handles, + buf_data_happ, + ))), + ); + } + &DataInstKind::QPtr(QPtrOp::Offset(offset)) => { + generate_accesses( + self, + data_inst_def.inputs[0], + output_accesses.unwrap_or(Ok(MemAccesses::Data(DataHapp::DEAD))).and_then( + |accesses| { + let happ = match accesses { + MemAccesses::Handles(_) => { + return Err(AnalysisError(Diag::bug([format!( + "Offset({offset}): cannot offset in handle memory" ) - .into()])) - })?; - - // FIXME(eddyb) these should be normalized - // (e.g. constant-folded) out of existence, - // but while they exist, they should be noops. - if offset == 0 { - return Ok(MemAccesses::Data(happ)); + .into()]))); } + MemAccesses::Data(happ) => happ, + }; + let offset = u32::try_from(offset).ok().ok_or_else(|| { + AnalysisError(Diag::bug([format!( + "Offset({offset}): negative offset" + ) + .into()])) + })?; + + // FIXME(eddyb) these should be normalized + // (e.g. constant-folded) out of existence, + // but while they exist, they should be noops. + if offset == 0 { + return Ok(MemAccesses::Data(happ)); + } - Ok(MemAccesses::Data(DataHapp { - max_size: happ - .max_size - .map(|max_size| { - offset.checked_add(max_size).ok_or_else(|| { - AnalysisError(Diag::bug([format!( - "Offset({offset}): size overflow \ - ({offset}+{max_size})" - ) - .into()])) - }) + Ok(MemAccesses::Data(DataHapp { + max_size: happ + .max_size + .map(|max_size| { + offset.checked_add(max_size).ok_or_else(|| { + AnalysisError(Diag::bug([format!( + "Offset({offset}): size overflow \ + ({offset}+{max_size})" + ) + .into()])) }) - .transpose()?, - // FIXME(eddyb) allocating `Rc>` - // to represent the one-element case, seems - // quite wasteful when it's likely consumed. - kind: DataHappKind::Disjoint(Rc::new( - [(offset, happ)].into(), - )), - })) - }), - ); - } - DataInstKind::QPtr(QPtrOp::DynOffset { stride, index_bounds }) => { - generate_accesses( - self, - data_inst_def.inputs[0], - output_accesses - .unwrap_or(Ok(MemAccesses::Data(DataHapp::DEAD))) - .and_then(|accesses| { - let happ = match accesses { - MemAccesses::Handles(_) => { - return Err(AnalysisError(Diag::bug([ - "DynOffset: cannot offset in handle memory".into(), - ]))); - } - MemAccesses::Data(happ) => happ, - }; - match happ.max_size { - None => { - return Err(AnalysisError(Diag::bug([ - "DynOffset: unsized element".into(), - ]))); - } - // FIXME(eddyb) support this by "folding" - // the HAPP onto itself (i.e. applying - // `%= stride` on all offsets inside). - Some(max_size) if max_size > stride.get() => { - return Err(AnalysisError(Diag::bug([ - "DynOffset: element max_size exceeds stride".into(), - ]))); - } - Some(_) => {} - } - Ok(MemAccesses::Data(DataHapp { - // FIXME(eddyb) does the `None` case allow - // for negative offsets? - max_size: index_bounds - .as_ref() - .map(|index_bounds| { - if index_bounds.start < 0 || index_bounds.end < 0 { - return Err(AnalysisError(Diag::bug([ - "DynOffset: potentially negative offset" - .into(), - ]))); - } - let index_bounds_end = - u32::try_from(index_bounds.end).unwrap(); - index_bounds_end - .checked_mul(stride.get()) - .ok_or_else(|| { - AnalysisError(Diag::bug([format!( - "DynOffset: size overflow \ - ({index_bounds_end}*{stride})" - ) - .into()])) - }) - }) - .transpose()?, - kind: DataHappKind::Repeated { - element: Rc::new(happ), - stride: *stride, - }, - })) - }), - ); - } - DataInstKind::Mem(op @ (MemOp::Load | MemOp::Store)) => { - // HACK(eddyb) `_` will match multiple variants soon. - #[allow(clippy::match_wildcard_for_single_variants)] - let (op_name, access_type) = match op { - MemOp::Load => ("Load", data_inst_def.outputs[0].ty), - MemOp::Store => { - ("Store", func_at_inst.at(data_inst_def.inputs[1]).type_of(&cx)) - } - _ => unreachable!(), - }; - generate_accesses( - self, - data_inst_def.inputs[0], - self.layout_cache - .layout_of(access_type) - .map_err(|LayoutError(e)| AnalysisError(e)) - .and_then(|layout| match layout { - TypeLayout::Handle(shapes::Handle::Opaque(ty)) => { - Ok(MemAccesses::Handles(shapes::Handle::Opaque(ty))) - } - TypeLayout::Handle(shapes::Handle::Buffer(..)) => { - Err(AnalysisError(Diag::bug([format!( - "{op_name}: cannot access whole Buffer" - ) - .into()]))) - } - TypeLayout::HandleArray(..) => { - Err(AnalysisError(Diag::bug([format!( - "{op_name}: cannot access whole HandleArray" - ) - .into()]))) + }) + .transpose()?, + // FIXME(eddyb) allocating `Rc>` + // to represent the one-element case, seems + // quite wasteful when it's likely consumed. + kind: DataHappKind::Disjoint(Rc::new([(offset, happ)].into())), + })) + }, + ), + ); + } + DataInstKind::QPtr(QPtrOp::DynOffset { stride, index_bounds }) => { + generate_accesses( + self, + data_inst_def.inputs[0], + output_accesses.unwrap_or(Ok(MemAccesses::Data(DataHapp::DEAD))).and_then( + |accesses| { + let happ = match accesses { + MemAccesses::Handles(_) => { + return Err(AnalysisError(Diag::bug([ + "DynOffset: cannot offset in handle memory".into(), + ]))); } - TypeLayout::Concrete(concrete) - if concrete.mem_layout.dyn_unit_stride.is_some() => - { - Err(AnalysisError(Diag::bug([format!( - "{op_name}: cannot access unsized type" - ) - .into()]))) + MemAccesses::Data(happ) => happ, + }; + match happ.max_size { + None => { + return Err(AnalysisError(Diag::bug([ + "DynOffset: unsized element".into(), + ]))); } - TypeLayout::Concrete(concrete) => { - Ok(MemAccesses::Data(DataHapp { - max_size: Some(concrete.mem_layout.fixed_base.size), - kind: DataHappKind::Direct(access_type), - })) + // FIXME(eddyb) support this by "folding" + // the HAPP onto itself (i.e. applying + // `%= stride` on all offsets inside). + Some(max_size) if max_size > stride.get() => { + return Err(AnalysisError(Diag::bug([ + "DynOffset: element max_size exceeds stride".into(), + ]))); } - }), - ); - } + Some(_) => {} + } + Ok(MemAccesses::Data(DataHapp { + // FIXME(eddyb) does the `None` case allow + // for negative offsets? + max_size: index_bounds + .as_ref() + .map(|index_bounds| { + if index_bounds.start < 0 || index_bounds.end < 0 { + return Err(AnalysisError(Diag::bug([ + "DynOffset: potentially negative offset".into(), + ]))); + } + let index_bounds_end = + u32::try_from(index_bounds.end).unwrap(); + index_bounds_end.checked_mul(stride.get()).ok_or_else( + || { + AnalysisError(Diag::bug([format!( + "DynOffset: size overflow \ + ({index_bounds_end}*{stride})" + ) + .into()])) + }, + ) + }) + .transpose()?, + kind: DataHappKind::Repeated { + element: Rc::new(happ), + stride: *stride, + }, + })) + }, + ), + ); + } + DataInstKind::Mem(op @ (MemOp::Load | MemOp::Store)) => { + // HACK(eddyb) `_` will match multiple variants soon. + #[allow(clippy::match_wildcard_for_single_variants)] + let (op_name, access_type) = match op { + MemOp::Load => ("Load", data_inst_def.outputs[0].ty), + MemOp::Store => { + ("Store", func_def_body.at(data_inst_def.inputs[1]).type_of(&cx)) + } + _ => unreachable!(), + }; + generate_accesses( + self, + data_inst_def.inputs[0], + self.layout_cache + .layout_of(access_type) + .map_err(|LayoutError(e)| AnalysisError(e)) + .and_then(|layout| match layout { + TypeLayout::Handle(shapes::Handle::Opaque(ty)) => { + Ok(MemAccesses::Handles(shapes::Handle::Opaque(ty))) + } + TypeLayout::Handle(shapes::Handle::Buffer(..)) => { + Err(AnalysisError(Diag::bug([format!( + "{op_name}: cannot access whole Buffer" + ) + .into()]))) + } + TypeLayout::HandleArray(..) => { + Err(AnalysisError(Diag::bug([format!( + "{op_name}: cannot access whole HandleArray" + ) + .into()]))) + } + TypeLayout::Concrete(concrete) + if concrete.mem_layout.dyn_unit_stride.is_some() => + { + Err(AnalysisError(Diag::bug([format!( + "{op_name}: cannot access unsized type" + ) + .into()]))) + } + TypeLayout::Concrete(concrete) => Ok(MemAccesses::Data(DataHapp { + max_size: Some(concrete.mem_layout.fixed_base.size), + kind: DataHappKind::Direct(access_type), + })), + }), + ); + } - DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => { - let mut has_from_spv_ptr_output_attr = false; - for attr in &cx[data_inst_def.attrs].attrs { - match *attr { - Attr::QPtr(QPtrAttr::ToSpvPtrInput { input_idx, pointee }) => { - let ty = pointee.0; - generate_accesses( - self, - data_inst_def.inputs[input_idx as usize], - self.layout_cache - .layout_of(ty) - .map_err(|LayoutError(e)| AnalysisError(e)) - .and_then(|layout| { - match layout { - TypeLayout::Handle(handle) => { - let handle = match handle { - shapes::Handle::Opaque(ty) => { - shapes::Handle::Opaque(ty) - } - // NOTE(eddyb) this error is important, - // as the `Block` annotation on the - // buffer type means the type is *not* - // usable anywhere inside buffer data, - // since it would conflict with our - // own `Block`-annotated wrapper. - shapes::Handle::Buffer(..) => { - return Err(AnalysisError( - Diag::bug(["ToSpvPtrInput: \ - whole Buffer ambiguous \ - (handle vs buffer data)" - .into()]), - )); - } - }; - Ok(MemAccesses::Handles(handle)) + DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => { + let mut has_from_spv_ptr_output_attr = false; + for attr in &cx[data_inst_def.attrs].attrs { + match *attr { + Attr::QPtr(QPtrAttr::ToSpvPtrInput { input_idx, pointee }) => { + let ty = pointee.0; + generate_accesses( + self, + data_inst_def.inputs[input_idx as usize], + self.layout_cache + .layout_of(ty) + .map_err(|LayoutError(e)| AnalysisError(e)) + .and_then(|layout| match layout { + TypeLayout::Handle(handle) => { + let handle = match handle { + shapes::Handle::Opaque(ty) => { + shapes::Handle::Opaque(ty) } - // NOTE(eddyb) because we can't represent - // the original type, in the same way we - // use `DataHappKind::StrictlyTyped` - // for non-handles, we can't guarantee - // a generated type that matches the - // desired `pointee` type. - TypeLayout::HandleArray(..) => { - Err(AnalysisError(Diag::bug([ + // NOTE(eddyb) this error is important, + // as the `Block` annotation on the + // buffer type means the type is *not* + // usable anywhere inside buffer data, + // since it would conflict with our + // own `Block`-annotated wrapper. + shapes::Handle::Buffer(..) => { + return Err(AnalysisError(Diag::bug([ "ToSpvPtrInput: \ - whole handle array \ - unrepresentable" + whole Buffer ambiguous \ + (handle vs buffer data)" .into(), - ]))) + ]))); } - TypeLayout::Concrete(concrete) => { - Ok(MemAccesses::Data(DataHapp { - max_size: if concrete - .mem_layout - .dyn_unit_stride - .is_some() - { - None - } else { - Some( - concrete - .mem_layout - .fixed_base - .size, - ) - }, - kind: DataHappKind::StrictlyTyped(ty), - })) - } - } - }), - ); - } - Attr::QPtr(QPtrAttr::FromSpvPtrOutput { - addr_space: _, - pointee: _, - }) => { - has_from_spv_ptr_output_attr = true; - } - _ => {} + }; + Ok(MemAccesses::Handles(handle)) + } + // NOTE(eddyb) because we can't represent + // the original type, in the same way we + // use `DataHappKind::StrictlyTyped` + // for non-handles, we can't guarantee + // a generated type that matches the + // desired `pointee` type. + TypeLayout::HandleArray(..) => { + Err(AnalysisError(Diag::bug([ + "ToSpvPtrInput: whole handle array \ + unrepresentable" + .into(), + ]))) + } + TypeLayout::Concrete(concrete) => { + Ok(MemAccesses::Data(DataHapp { + max_size: if concrete + .mem_layout + .dyn_unit_stride + .is_some() + { + None + } else { + Some(concrete.mem_layout.fixed_base.size) + }, + kind: DataHappKind::StrictlyTyped(ty), + })) + } + }), + ); + } + Attr::QPtr(QPtrAttr::FromSpvPtrOutput { + addr_space: _, + pointee: _, + }) => { + has_from_spv_ptr_output_attr = true; } + _ => {} } + } - if has_from_spv_ptr_output_attr { - // FIXME(eddyb) merge with `FromSpvPtrOutput`'s `pointee`. - if let Some(accesses) = output_accesses { - accesses_or_err_attrs_to_attach.push(( - Value::DataInstOutput { inst: data_inst, output_idx: 0 }, - accesses, - )); - } + if has_from_spv_ptr_output_attr { + // FIXME(eddyb) merge with `FromSpvPtrOutput`'s `pointee`. + if let Some(accesses) = output_accesses { + accesses_or_err_attrs_to_attach + .push((Value::NodeOutput { node, output_idx: 0 }, accesses)); } } } @@ -1309,9 +1305,12 @@ impl<'a> GatherAccesses<'a> { // HACK(eddyb) this is easier than implementing a proper reverse traversal. #[derive(Default)] -struct CollectAllDataInsts(Vec>); +struct VisitAllNodes { + before: B, + after: A, +} -impl Visitor<'_> for CollectAllDataInsts { +impl Visitor<'_> for VisitAllNodes { // FIXME(eddyb) this is excessive, maybe different kinds of // visitors should exist for module-level and func-level? fn visit_attr_set_use(&mut self, _: AttrSet) {} @@ -1321,9 +1320,8 @@ impl Visitor<'_> for CollectAllDataInsts { fn visit_func_use(&mut self, _: Func) {} fn visit_node_def(&mut self, func_at_node: FuncAt<'_, Node>) { - if let NodeKind::Block { insts } = func_at_node.def().kind { - self.0.push(insts); - } + (self.before)(func_at_node.position); func_at_node.inner_visit_with(self); + (self.after)(func_at_node.position); } } diff --git a/src/print/mod.rs b/src/print/mod.rs index 45650348..e5d968ab 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -258,16 +258,11 @@ enum Use { node: Node, output_idx: u32, }, - DataInstOutput { - inst: DataInst, - output_idx: u32, - }, // NOTE(eddyb) these overlap somewhat with other cases, but they're always // generated, even when there is no "use", for `multiversion` alignment. AlignmentAnchorForRegion(Region), AlignmentAnchorForNode(Node), - AlignmentAnchorForDataInst(DataInst), } impl From for Use { @@ -276,7 +271,6 @@ impl From for Use { Value::Const(ct) => Use::CxInterned(CxInterned::Const(ct)), Value::RegionInput { region, input_idx } => Use::RegionInput { region, input_idx }, Value::NodeOutput { node, output_idx } => Use::NodeOutput { node, output_idx }, - Value::DataInstOutput { inst, output_idx } => Use::DataInstOutput { inst, output_idx }, } } } @@ -295,13 +289,11 @@ impl Use { Self::DbgScope { .. } => ("", "d"), Self::RegionLabel(_) => ("label", "L"), - Self::RegionInput { .. } | Self::NodeOutput { .. } | Self::DataInstOutput { .. } => { - ("", "v") - } + Self::RegionInput { .. } | Self::NodeOutput { .. } => ("", "v"), - Self::AlignmentAnchorForRegion(_) - | Self::AlignmentAnchorForNode(_) - | Self::AlignmentAnchorForDataInst(_) => ("", Self::ANCHOR_ALIGNMENT_NAME_PREFIX), + Self::AlignmentAnchorForRegion(_) | Self::AlignmentAnchorForNode(_) => { + ("", Self::ANCHOR_ALIGNMENT_NAME_PREFIX) + } } } } @@ -1066,10 +1058,8 @@ impl<'a> Printer<'a> { .iter() .map(|(&use_kind, &use_count)| { // HACK(eddyb) these are assigned later. - if let Use::RegionLabel(_) - | Use::RegionInput { .. } - | Use::NodeOutput { .. } - | Use::DataInstOutput { .. } = use_kind + if let Use::RegionLabel(_) | Use::RegionInput { .. } | Use::NodeOutput { .. } = + use_kind { return (use_kind, UseStyle::Inline); } @@ -1104,10 +1094,8 @@ impl<'a> Printer<'a> { | Use::RegionLabel(_) | Use::RegionInput { .. } | Use::NodeOutput { .. } - | Use::DataInstOutput { .. } | Use::AlignmentAnchorForRegion(_) - | Use::AlignmentAnchorForNode(_) - | Use::AlignmentAnchorForDataInst(_) => unreachable!(), + | Use::AlignmentAnchorForNode(_) => unreachable!(), } if let Some(name) = @@ -1179,10 +1167,8 @@ impl<'a> Printer<'a> { | Use::RegionLabel(_) | Use::RegionInput { .. } | Use::NodeOutput { .. } - | Use::DataInstOutput { .. } | Use::AlignmentAnchorForRegion(_) - | Use::AlignmentAnchorForNode(_) - | Use::AlignmentAnchorForDataInst(_) => { + | Use::AlignmentAnchorForNode(_) => { unreachable!() } }; @@ -1205,10 +1191,8 @@ impl<'a> Printer<'a> { | Use::RegionLabel(_) | Use::RegionInput { .. } | Use::NodeOutput { .. } - | Use::DataInstOutput { .. } | Use::AlignmentAnchorForRegion(_) - | Use::AlignmentAnchorForNode(_) - | Use::AlignmentAnchorForDataInst(_) => { + | Use::AlignmentAnchorForNode(_) => { unreachable!() } }; @@ -1502,30 +1486,11 @@ impl<'a> Printer<'a> { intra_region: DbgScopeDefPlaceInRegion { before_node: Some(node) }, }); - let NodeDef { attrs, kind, inputs: _, child_regions: _, outputs } = + let NodeDef { attrs, kind: _, inputs: _, child_regions: _, outputs } = func_at_node.def(); define(Use::AlignmentAnchorForNode(node), Some(*attrs)); - if let NodeKind::Block { insts } = *kind { - for func_at_inst in func_def_body.at(insts) { - define( - Use::AlignmentAnchorForDataInst(func_at_inst.position), - None, - ); - let inst_def = func_at_inst.def(); - for (i, output_decl) in inst_def.outputs.iter().enumerate() { - define( - Use::DataInstOutput { - inst: func_at_inst.position, - output_idx: i.try_into().unwrap(), - }, - Some(output_decl.attrs), - ); - } - } - } - for (i, output_decl) in outputs.iter().enumerate() { define( Use::NodeOutput { node, output_idx: i.try_into().unwrap() }, @@ -1554,15 +1519,11 @@ impl<'a> Printer<'a> { (&mut region_label_counter, use_styles.get_mut(&use_kind)) } - Use::RegionInput { .. } - | Use::NodeOutput { .. } - | Use::DataInstOutput { .. } => { + Use::RegionInput { .. } | Use::NodeOutput { .. } => { (&mut value_counter, use_styles.get_mut(&use_kind)) } - Use::AlignmentAnchorForRegion(_) - | Use::AlignmentAnchorForNode(_) - | Use::AlignmentAnchorForDataInst(_) => ( + Use::AlignmentAnchorForRegion(_) | Use::AlignmentAnchorForNode(_) => ( &mut alignment_anchor_counter, Some(use_styles.entry(use_kind).or_insert(UseStyle::Inline)), ), @@ -2118,8 +2079,7 @@ impl Use { suffix.write_escaped_to(&mut anchor).unwrap(); let name = if let Self::AlignmentAnchorForRegion(_) - | Self::AlignmentAnchorForNode(_) - | Self::AlignmentAnchorForDataInst(_) = self + | Self::AlignmentAnchorForNode(_) = self { vec![] } else { @@ -2182,15 +2142,15 @@ impl Use { item.keyword_and_name_prefix().map_or_else(|s| s, |(s, _)| s) )) .into(), + Self::DbgScope { .. } | Self::RegionLabel(_) | Self::RegionInput { .. } - | Self::NodeOutput { .. } - | Self::DataInstOutput { .. } => "_".into(), + | Self::NodeOutput { .. } => "_".into(), - Self::AlignmentAnchorForRegion(_) - | Self::AlignmentAnchorForNode(_) - | Self::AlignmentAnchorForDataInst(_) => unreachable!(), + Self::AlignmentAnchorForRegion(_) | Self::AlignmentAnchorForNode(_) => { + unreachable!() + } }, } } @@ -3774,16 +3734,6 @@ impl Print for FuncAt<'_, Node> { let kw_style = printer.imperative_keyword_style(); let kw = |kw| kw_style.apply(kw).into(); let node_body = match kind { - NodeKind::Block { insts } => { - assert!(outputs.is_empty()); - - pretty::Fragment::new( - self.at(*insts) - .into_iter() - .map(|func_at_inst| func_at_inst.print_data_inst(printer)) - .flat_map(|entry| [pretty::Node::ForceLineSeparation.into(), entry]), - ) - } NodeKind::Select(kind) => kind.print_with_scrutinee_and_cases( printer, kw_style, @@ -3887,7 +3837,14 @@ impl Print for FuncAt<'_, Node> { | DataInstKind::Mem(_) | DataInstKind::QPtr(_) | DataInstKind::SpvInst(_) - | DataInstKind::SpvExtInst { .. } => unreachable!(), + | DataInstKind::SpvExtInst { .. } => { + // FIXME(eddyb) `outputs_header` is wastefully built even in + // this case (though ideally the logic would just be shared). + return pretty::Fragment::new([ + pretty::Node::ForceLineSeparation.into(), + self.print_data_inst(printer), + ]); + } }; let def_without_name = pretty::Fragment::new([ Use::AlignmentAnchorForNode(self.position).print_as_def(printer), @@ -3938,15 +3895,14 @@ impl FuncAt<'_, DataInst> { }; let mut output_use_to_print_as_lhs = - output_type.map(|_| Use::DataInstOutput { inst: self.position, output_idx: 0 }); + output_type.map(|_| Use::NodeOutput { node: self.position, output_idx: 0 }); let mut output_type_to_print = output_type; let def_without_type = match kind { - NodeKind::Block { .. } - | NodeKind::Select(_) - | NodeKind::Loop { .. } - | NodeKind::ExitInvocation(_) => unreachable!(), + NodeKind::Select(_) | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) => { + unreachable!() + } &DataInstKind::FuncCall(func) => pretty::Fragment::new([ printer.declarative_keyword_style().apply("call").into(), @@ -4296,7 +4252,7 @@ impl FuncAt<'_, DataInst> { // FIXME(eddyb) this is quite verbose for prepending. let def_without_name = pretty::Fragment::new([ - Use::AlignmentAnchorForDataInst(self.position).print_as_def(printer), + Use::AlignmentAnchorForNode(self.position).print_as_def(printer), def_without_name, ]); diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index 8f2e92b7..40ad4a61 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -7,8 +7,8 @@ use crate::transform::{InnerInPlaceTransform, InnerTransform, Transformed, Trans use crate::{ AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, DataInst, DataInstDef, DataInstKind, DeclDef, Diag, DiagLevel, EntityDefs, EntityOrientedDenseMap, Func, - FuncDecl, FxIndexMap, GlobalVar, GlobalVarDecl, Module, Node, NodeKind, NodeOutputDecl, Type, - TypeDef, TypeKind, TypeOrConst, Value, spv, + FuncDecl, FxIndexMap, GlobalVar, GlobalVarDecl, Module, Node, NodeKind, NodeOutputDecl, Region, + Type, TypeDef, TypeKind, TypeOrConst, Value, spv, }; use smallvec::SmallVec; use std::cell::Cell; @@ -66,6 +66,8 @@ impl<'a> LiftToSpvPtrs<'a> { lifter: self, global_vars: &module.global_vars, + parent_region: None, + deferred_ptr_noops: Default::default(), data_inst_use_counts: Default::default(), @@ -382,6 +384,8 @@ struct LiftToSpvPtrInstsInFunc<'a> { lifter: &'a LiftToSpvPtrs<'a>, global_vars: &'a EntityDefs, + parent_region: Option, + /// Some `QPtr`->`QPtr` `QPtrOp`s must be noops in SPIR-V, but because some /// of them have meaningful semantic differences in SPIR-T, replacement of /// their uses must be deferred until after `try_lift_data_inst_def` has had @@ -411,14 +415,13 @@ struct DeferredPtrNoop { /// except in the case of `QPtrOp::BufferData`. output_pointee_layout: TypeLayout, - parent_block: Node, + parent_region: Region, } impl LiftToSpvPtrInstsInFunc<'_> { fn try_lift_data_inst_def( &mut self, mut func_at_data_inst: FuncAtMut<'_, DataInst>, - parent_block: Node, ) -> Result, LiftError> { let wk = self.lifter.wk; let cx = &self.lifter.cx; @@ -431,7 +434,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { // FIXME(eddyb) maybe all this data should be packaged up together in a // type with fields like those of `DeferredPtrNoop` (or even more). let type_of_val_as_spv_ptr_with_layout = |v: Value| { - if let Value::DataInstOutput { inst: v_data_inst, output_idx: 0 } = v + if let Value::NodeOutput { node: v_data_inst, output_idx: 0 } = v && let Some(ptr_noop) = self.deferred_ptr_noops.get(&v_data_inst) { return Ok(( @@ -448,10 +451,9 @@ impl LiftToSpvPtrInstsInFunc<'_> { Ok((addr_space, self.lifter.layout_of(pointee_type)?)) }; let replacement_data_inst_def = match &data_inst_def.kind { - NodeKind::Block { .. } - | NodeKind::Select(_) - | NodeKind::Loop { .. } - | NodeKind::ExitInvocation(_) => unreachable!(), + NodeKind::Select(_) | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) => { + return Ok(Transformed::Unchanged); + } &DataInstKind::FuncCall(_callee) => { for &v in &data_inst_def.inputs { @@ -522,7 +524,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { output_pointer: buf_ptr, output_pointer_addr_space: addr_space, output_pointee_layout: buf_data_layout, - parent_block, + parent_region: self.parent_region.unwrap(), }, ); @@ -655,7 +657,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { output_pointer: base_ptr, output_pointer_addr_space: addr_space, output_pointee_layout: TypeLayout::Concrete(layout), - parent_block, + parent_region: self.parent_region.unwrap(), }, ); @@ -789,26 +791,22 @@ impl LiftToSpvPtrInstsInFunc<'_> { .nodes .define(cx, access_chain_data_inst_def.into()); - // FIXME(eddyb) comment below should be about `nodes` vs `regions` - // (once `Block` and the `Node`-vs-`DataInst` split are gone). // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, - // due to the need to borrow `nodes` and `data_insts` + // due to the need to borrow `regions` and `nodes` // at the same time - perhaps some kind of `FuncAtMut` position // types for "where a list is in a parent entity" could be used // to make this more ergonomic, although the potential need for // an actual list entity of its own, should be considered. let data_inst = func_at_data_inst.position; let func = func_at_data_inst.reborrow().at(()); - match func.nodes[parent_block].kind { - NodeKind::Block { mut insts } => { - insts.insert_before(access_chain_data_inst, data_inst, func.nodes); - func.nodes[parent_block].kind = NodeKind::Block { insts }; - } - _ => unreachable!(), - } + func.regions[self.parent_region.unwrap()].children.insert_before( + access_chain_data_inst, + data_inst, + func.nodes, + ); new_data_inst_def.inputs[input_idx] = - Value::DataInstOutput { inst: access_chain_data_inst, output_idx: 0 }; + Value::NodeOutput { node: access_chain_data_inst, output_idx: 0 }; } new_data_inst_def @@ -867,26 +865,22 @@ impl LiftToSpvPtrInstsInFunc<'_> { .nodes .define(cx, access_chain_data_inst_def.into()); - // FIXME(eddyb) comment below should be about `nodes` vs `regions` - // (once `Block` and the `Node`-vs-`DataInst` split are gone). // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, - // due to the need to borrow `nodes` and `data_insts` + // due to the need to borrow `regions` and `nodes` // at the same time - perhaps some kind of `FuncAtMut` position // types for "where a list is in a parent entity" could be used // to make this more ergonomic, although the potential need for // an actual list entity of its own, should be considered. let data_inst = func_at_data_inst.position; let func = func_at_data_inst.reborrow().at(()); - match func.nodes[parent_block].kind { - NodeKind::Block { mut insts } => { - insts.insert_before(access_chain_data_inst, data_inst, func.nodes); - func.nodes[parent_block].kind = NodeKind::Block { insts }; - } - _ => unreachable!(), - } + func.regions[self.parent_region.unwrap()].children.insert_before( + access_chain_data_inst, + data_inst, + func.nodes, + ); new_data_inst_def.inputs[input_idx] = - Value::DataInstOutput { inst: access_chain_data_inst, output_idx: 0 }; + Value::NodeOutput { node: access_chain_data_inst, output_idx: 0 }; } if let Some((addr_space, pointee_type)) = from_spv_ptr_output { @@ -1029,7 +1023,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { for v in values { // FIXME(eddyb) the loop could theoretically be avoided, but that'd // make tracking use counts harder. - while let Value::DataInstOutput { inst, output_idx: 0 } = *v { + while let Value::NodeOutput { node: inst, output_idx: 0 } = *v { match self.deferred_ptr_noops.get(&inst) { Some(ptr_noop) => { *v = ptr_noop.output_pointer; @@ -1044,7 +1038,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { // encoded as `Option` for (dense) map entry reasons. fn add_value_uses(&mut self, values: &[Value]) { for &v in values { - if let Value::DataInstOutput { inst, .. } = v { + if let Value::NodeOutput { node: inst, .. } = v { let count = self.data_inst_use_counts.entry(inst); *count = Some( NonZeroU32::new(count.map_or(0, |c| c.get()).checked_add(1).unwrap()).unwrap(), @@ -1054,7 +1048,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { } fn remove_value_uses(&mut self, values: &[Value]) { for &v in values { - if let Value::DataInstOutput { inst, .. } = v { + if let Value::NodeOutput { node: inst, .. } = v { let count = self.data_inst_use_counts.entry(inst); *count = NonZeroU32::new(count.unwrap().get() - 1); } @@ -1085,73 +1079,60 @@ impl Transformer for LiftToSpvPtrInstsInFunc<'_> { v.inner_transform_with(self) } - // HACK(eddyb) while we want to transform `DataInstDef`s, we can't inject - // adjacent instructions without access to the parent `NodeKind::Block`, - // and to fix this would likely require list nodes to carry some handle to - // the list they're part of, either the whole semantic parent, or something - // more contrived, where lists are actually allocated entities of their own, - // perhaps something where an `EntityListDefs` contains both: - // - an `EntityDefs>` (keyed by `DataInst`) - // - an `EntityDefs>` (keyed by `EntityList`) + fn in_place_transform_region_def(&mut self, mut func_at_region: FuncAtMut<'_, Region>) { + let outer_region = self.parent_region.replace(func_at_region.position); + func_at_region.inner_in_place_transform_with(self); + self.parent_region = outer_region; + } + fn in_place_transform_node_def(&mut self, mut func_at_node: FuncAtMut<'_, Node>) { func_at_node.reborrow().inner_in_place_transform_with(self); - let node = func_at_node.position; - if let NodeKind::Block { insts } = func_at_node.reborrow().def().kind { - let mut func_at_inst_iter = func_at_node.reborrow().at(insts).into_iter(); - while let Some(mut func_at_inst) = func_at_inst_iter.next() { - let mut lifted = self.try_lift_data_inst_def(func_at_inst.reborrow(), node); - if let Ok(Transformed::Unchanged) = lifted { - let data_inst_def = func_at_inst.reborrow().def(); - if let DataInstKind::QPtr(_) = data_inst_def.kind { - lifted = - Err(LiftError(Diag::bug(["unimplemented qptr instruction".into()]))); - } else { - for output in &data_inst_def.outputs { - if matches!(self.lifter.cx[output.ty].kind, TypeKind::QPtr) { - lifted = Err(LiftError(Diag::bug([ - "unimplemented qptr-producing instruction".into(), - ]))); - break; - } - } + let mut lifted = self.try_lift_data_inst_def(func_at_node.reborrow()); + if let Ok(Transformed::Unchanged) = lifted { + let data_inst_def = func_at_node.reborrow().def(); + if let DataInstKind::QPtr(_) = data_inst_def.kind { + lifted = Err(LiftError(Diag::bug(["unimplemented qptr instruction".into()]))); + } else { + for output in &data_inst_def.outputs { + if matches!(self.lifter.cx[output.ty].kind, TypeKind::QPtr) { + lifted = Err(LiftError(Diag::bug([ + "unimplemented qptr-producing instruction".into(), + ]))); + break; } } - match lifted { - Ok(Transformed::Unchanged) => {} - Ok(Transformed::Changed(new_def)) => { - // HACK(eddyb) this whole dance ensures that use counts - // remain accurate, no matter what rewrites occur. - let data_inst_def = func_at_inst.def(); - self.remove_value_uses(&data_inst_def.inputs); - *data_inst_def = new_def; - self.resolve_deferred_ptr_noop_uses(&mut data_inst_def.inputs); - self.add_value_uses(&data_inst_def.inputs); - } - Err(LiftError(e)) => { - let data_inst_def = func_at_inst.def(); - - // HACK(eddyb) do not add redundant errors to `mem::analyze` bugs. - self.func_has_mem_analysis_bug_diags = self.func_has_mem_analysis_bug_diags - || self.lifter.cx[data_inst_def.attrs].attrs.iter().any(|attr| { - match attr { - Attr::Diagnostics(diags) => { - diags.0.iter().any(|diag| match diag.level { - DiagLevel::Bug(loc) => { - loc.file().ends_with("mem/analyze.rs") - || loc.file().ends_with("mem\\analyze.rs") - } - _ => false, - }) - } - _ => false, - } - }); + } + } + match lifted { + Ok(Transformed::Unchanged) => {} + Ok(Transformed::Changed(new_def)) => { + // HACK(eddyb) this whole dance ensures that use counts + // remain accurate, no matter what rewrites occur. + let data_inst_def = func_at_node.def(); + self.remove_value_uses(&data_inst_def.inputs); + *data_inst_def = new_def; + self.resolve_deferred_ptr_noop_uses(&mut data_inst_def.inputs); + self.add_value_uses(&data_inst_def.inputs); + } + Err(LiftError(e)) => { + let data_inst_def = func_at_node.def(); + + // HACK(eddyb) do not add redundant errors to `mem::analyze` bugs. + self.func_has_mem_analysis_bug_diags = self.func_has_mem_analysis_bug_diags + || self.lifter.cx[data_inst_def.attrs].attrs.iter().any(|attr| match attr { + Attr::Diagnostics(diags) => diags.0.iter().any(|diag| match diag.level { + DiagLevel::Bug(loc) => { + loc.file().ends_with("mem/analyze.rs") + || loc.file().ends_with("mem\\analyze.rs") + } + _ => false, + }), + _ => false, + }); - if !self.func_has_mem_analysis_bug_diags { - data_inst_def.attrs.push_diag(&self.lifter.cx, e); - } - } + if !self.func_has_mem_analysis_bug_diags { + data_inst_def.attrs.push_diag(&self.lifter.cx, e); } } } @@ -1167,22 +1148,15 @@ impl Transformer for LiftToSpvPtrInstsInFunc<'_> { // use counts of an earlier definition, allowing further removal. for (inst, ptr_noop) in deferred_ptr_noops.into_iter().rev() { if self.data_inst_use_counts.get(inst).is_none() { - // FIXME(eddyb) comment below should be about `nodes` vs `regions` - // (once `Block` and the `Node`-vs-`DataInst` split are gone). // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, - // due to the need to borrow `nodes` and `data_insts` + // due to the need to borrow `regions` and `nodes` // at the same time - perhaps some kind of `FuncAtMut` position // types for "where a list is in a parent entity" could be used // to make this more ergonomic, although the potential need for // an actual list entity of its own, should be considered. - match func_def_body.nodes[ptr_noop.parent_block].kind { - NodeKind::Block { mut insts } => { - insts.remove(inst, &mut func_def_body.nodes); - func_def_body.nodes[ptr_noop.parent_block].kind = - NodeKind::Block { insts }; - } - _ => unreachable!(), - } + func_def_body.regions[ptr_noop.parent_region] + .children + .remove(inst, &mut func_def_body.nodes); self.remove_value_uses(&func_def_body.at(inst).def().inputs); } diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index b99cac4a..c5b27aa7 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -6,8 +6,8 @@ use crate::qptr::{QPtrAttr, QPtrOp}; use crate::transform::{InnerInPlaceTransform, Transformed, Transformer}; use crate::{ AddrSpace, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, DataInst, DataInstDef, - DataInstKind, Diag, FuncDecl, GlobalVarDecl, Node, NodeKind, NodeOutputDecl, OrdAssertEq, Type, - TypeKind, TypeOrConst, Value, spv, + DataInstKind, Diag, FuncDecl, GlobalVarDecl, Node, NodeKind, NodeOutputDecl, OrdAssertEq, + Region, Type, TypeKind, TypeOrConst, Value, spv, }; use itertools::Itertools as _; use smallvec::SmallVec; @@ -143,7 +143,8 @@ impl<'a> LowerFromSpvPtrs<'a> { // separately - so `LowerFromSpvPtrInstsInFunc` will leave all value defs // (including replaced instructions!) with unchanged `OpTypePointer` // types, that only `EraseSpvPtrs`, later, replaces with `QPtr`. - LowerFromSpvPtrInstsInFunc { lowerer: self }.in_place_transform_func_decl(func_decl); + LowerFromSpvPtrInstsInFunc { lowerer: self, parent_region: None } + .in_place_transform_func_decl(func_decl); EraseSpvPtrs { lowerer: self }.in_place_transform_func_decl(func_decl); } @@ -246,6 +247,8 @@ impl Transformer for EraseSpvPtrs<'_> { struct LowerFromSpvPtrInstsInFunc<'a> { lowerer: &'a LowerFromSpvPtrs<'a>, + + parent_region: Option, } /// One `QPtr`->`QPtr` step used in the lowering of `Op*AccessChain`. @@ -393,7 +396,6 @@ impl LowerFromSpvPtrInstsInFunc<'_> { fn try_lower_data_inst_def( &self, mut func_at_data_inst: FuncAtMut<'_, DataInst>, - parent_block: Node, ) -> Result, LowerError> { let cx = &self.lowerer.cx; let wk = self.lowerer.wk; @@ -561,24 +563,20 @@ impl LowerFromSpvPtrInstsInFunc<'_> { .into(), ); - // FIXME(eddyb) comment below should be about `nodes` vs `regions` - // (once `Block` and the `Node`-vs-`DataInst` split are gone). // HACK(eddyb) can't really use helpers like `FuncAtMut::def`, - // due to the need to borrow `nodes` and `data_insts` + // due to the need to borrow `regions` and `nodes` // at the same time - perhaps some kind of `FuncAtMut` position // types for "where a list is in a parent entity" could be used // to make this more ergonomic, although the potential need for // an actual list entity of its own, should be considered. let func = func_at_data_inst.reborrow().at(()); - match func.nodes[parent_block].kind { - NodeKind::Block { mut insts } => { - insts.insert_before(step_data_inst, data_inst, func.nodes); - func.nodes[parent_block].kind = NodeKind::Block { insts }; - } - _ => unreachable!(), - } + func.regions[self.parent_region.unwrap()].children.insert_before( + step_data_inst, + data_inst, + func.nodes, + ); - ptr = Value::DataInstOutput { inst: step_data_inst, output_idx: 0 }; + ptr = Value::NodeOutput { node: step_data_inst, output_idx: 0 }; } final_step.into_data_inst_kind_and_inputs(ptr) } else if spv_inst.opcode == wk.OpBitcast { @@ -627,13 +625,13 @@ impl LowerFromSpvPtrInstsInFunc<'_> { let func = func_at_data_inst_frozen.at(()); match data_inst_def.kind { - NodeKind::Block { .. } - | NodeKind::Select(_) - | NodeKind::Loop { .. } - | NodeKind::ExitInvocation(_) => unreachable!(), - // Known semantics, no need to preserve SPIR-V pointer information. - DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) => return, + NodeKind::Select(_) + | NodeKind::Loop { .. } + | NodeKind::ExitInvocation(_) + | DataInstKind::FuncCall(_) + | DataInstKind::Mem(_) + | DataInstKind::QPtr(_) => return, DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {} } @@ -676,29 +674,21 @@ impl LowerFromSpvPtrInstsInFunc<'_> { } impl Transformer for LowerFromSpvPtrInstsInFunc<'_> { - // HACK(eddyb) while we want to transform `DataInstDef`s, we can't inject - // adjacent instructions without access to the parent `NodeKind::Block`, - // and to fix this would likely require list nodes to carry some handle to - // the list they're part of, either the whole semantic parent, or something - // more contrived, where lists are actually allocated entities of their own, - // perhaps something where an `EntityListDefs` contains both: - // - an `EntityDefs>` (keyed by `DataInst`) - // - an `EntityDefs>` (keyed by `EntityList`) + fn in_place_transform_region_def(&mut self, mut func_at_region: FuncAtMut<'_, Region>) { + let outer_region = self.parent_region.replace(func_at_region.position); + func_at_region.inner_in_place_transform_with(self); + self.parent_region = outer_region; + } + fn in_place_transform_node_def(&mut self, mut func_at_node: FuncAtMut<'_, Node>) { func_at_node.reborrow().inner_in_place_transform_with(self); - let node = func_at_node.position; - if let NodeKind::Block { insts } = func_at_node.reborrow().def().kind { - let mut func_at_inst_iter = func_at_node.reborrow().at(insts).into_iter(); - while let Some(mut func_at_inst) = func_at_inst_iter.next() { - match self.try_lower_data_inst_def(func_at_inst.reborrow(), node) { - Ok(Transformed::Changed(new_def)) => { - *func_at_inst.def() = new_def; - } - result @ (Ok(Transformed::Unchanged) | Err(_)) => { - self.add_fallback_attrs_to_data_inst_def(func_at_inst, result.err()); - } - } + match self.try_lower_data_inst_def(func_at_node.reborrow()) { + Ok(Transformed::Changed(new_def)) => { + *func_at_node.def() = new_def; + } + result @ (Ok(Transformed::Unchanged) | Err(_)) => { + self.add_fallback_attrs_to_data_inst_def(func_at_node, result.err()); } } } diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 2212d634..b3728949 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -6,10 +6,10 @@ use crate::spv::{self, spec}; use crate::visit::{InnerVisit, Visitor}; use crate::{ AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, DataInst, DataInstDef, - DataInstKind, DbgSrcLoc, DeclDef, EntityList, ExportKey, Exportee, Func, FuncDecl, FuncParam, - FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDefBody, Import, Module, ModuleDebugInfo, - ModuleDialect, Node, NodeKind, NodeOutputDecl, OrdAssertEq, Region, RegionInputDecl, Type, - TypeDef, TypeKind, TypeOrConst, Value, + DataInstKind, DbgSrcLoc, DeclDef, ExportKey, Exportee, Func, FuncDecl, FuncParam, FxIndexMap, + FxIndexSet, GlobalVar, GlobalVarDefBody, Import, Module, ModuleDebugInfo, ModuleDialect, Node, + NodeKind, NodeOutputDecl, OrdAssertEq, Region, RegionInputDecl, Type, TypeDef, TypeKind, + TypeOrConst, Value, }; use itertools::Itertools; use rustc_hash::FxHashMap; @@ -211,13 +211,10 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { attr.inner_visit_with(self); } - fn visit_data_inst_def(&mut self, func_at_inst: FuncAt<'_, DataInst>) { + fn visit_node_def(&mut self, func_at_node: FuncAt<'_, Node>) { #[allow(clippy::match_same_arms)] - match func_at_inst.def().kind { - NodeKind::Block { .. } - | NodeKind::Select(_) - | NodeKind::Loop { .. } - | NodeKind::ExitInvocation(_) => unreachable!(), + match func_at_node.def().kind { + NodeKind::Select(_) | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) => {} // FIXME(eddyb) this should be a proper `Result`-based error instead, // and/or `spv::lift` should mutate the module for legalization. @@ -238,7 +235,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { self.ext_inst_imports.insert(&self.cx[ext_set]); } } - func_at_inst.inner_visit_with(self); + func_at_node.inner_visit_with(self); } } @@ -289,7 +286,7 @@ enum CfgPoint { struct BlockLifting<'a> { phis: SmallVec<[Phi; 2]>, - insts: SmallVec<[EntityList; 1]>, + insts: SmallVec<[DataInst; 4]>, terminator: Terminator<'a>, } @@ -430,10 +427,6 @@ impl<'p> FuncAt<'_, CfgCursor<'p>> { // Entering a `Node` depends entirely on the `NodeKind`. CfgPoint::NodeEntry(node) => match self.at(node).def().kind { - NodeKind::Block { .. } => { - Some(CfgCursor { point: CfgPoint::NodeExit(node), parent: cursor.parent }) - } - NodeKind::Select { .. } | NodeKind::Loop { .. } | NodeKind::ExitInvocation { .. } => None, @@ -442,7 +435,9 @@ impl<'p> FuncAt<'_, CfgCursor<'p>> { | DataInstKind::Mem(_) | DataInstKind::QPtr(_) | DataInstKind::SpvInst(_) - | DataInstKind::SpvExtInst { .. } => unreachable!(), + | DataInstKind::SpvExtInst { .. } => { + Some(CfgCursor { point: CfgPoint::NodeExit(node), parent: cursor.parent }) + } }, // Exiting a `Node` chains to a sibling/parent. @@ -614,28 +609,39 @@ impl<'a> FuncLifting<'a> { _ => SmallVec::new(), } } - CfgPoint::NodeExit(node) => func_def_body - .at(node) - .def() - .outputs - .iter() - .map(|&NodeOutputDecl { attrs, ty }| { - Ok(Phi { - attrs, - ty, + CfgPoint::NodeExit(node) => { + let node_def = func_def_body.at(node).def(); + match &node_def.kind { + NodeKind::Select(_) => node_def + .outputs + .iter() + .map(|&NodeOutputDecl { attrs, ty }| { + Ok(Phi { + attrs, + ty, - result_id: alloc_id()?, - cases: FxIndexMap::default(), - default_value: None, - }) - }) - .collect::>()?, + result_id: alloc_id()?, + cases: FxIndexMap::default(), + default_value: None, + }) + }) + .collect::>()?, + _ => SmallVec::new(), + } + } }; let insts = match point { CfgPoint::NodeEntry(node) => match func_def_body.at(node).def().kind { - NodeKind::Block { insts } => [insts].into_iter().collect(), - _ => SmallVec::new(), + NodeKind::Select(_) | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) => { + SmallVec::new() + } + + DataInstKind::FuncCall(_) + | DataInstKind::Mem(_) + | DataInstKind::QPtr(_) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => [node].into_iter().collect(), }, _ => SmallVec::new(), }; @@ -691,10 +697,6 @@ impl<'a> FuncLifting<'a> { (CfgPoint::NodeEntry(node), None) => { let node_def = func_def_body.at(node).def(); match &node_def.kind { - NodeKind::Block { .. } => { - unreachable!() - } - NodeKind::Select(kind) => Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cf::unstructured::ControlInstKind::SelectBranch( @@ -763,10 +765,6 @@ impl<'a> FuncLifting<'a> { }; match func_def_body.at(parent_node).def().kind { - NodeKind::Block { .. } | NodeKind::ExitInvocation { .. } => { - unreachable!() - } - NodeKind::Select { .. } => Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cf::unstructured::ControlInstKind::Branch), @@ -823,7 +821,8 @@ impl<'a> FuncLifting<'a> { } } - DataInstKind::FuncCall(_) + NodeKind::ExitInvocation { .. } + | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) | DataInstKind::SpvInst(_) @@ -832,7 +831,14 @@ impl<'a> FuncLifting<'a> { } // Siblings in the same `Region` (including the - // implied edge from a `Block`'s `Entry` to its `Exit`). + // implied edge from a `DataInst`'s `Entry` to its `Exit`). + // + // FIXME(eddyb) reduce the cost of generating then removing most + // "basic blocks" (as each former-`DataInst` gets *two*!), + // which should be pretty doable in the common case of getting + // `NodeEntry(a), NodeExit(a), NodeEntry(b), NodeExit(b), ...` + // from `rev_post_order_try_for_each` and/or introducing an + // `unique_predecessor` helper (just like `unique_successor`). (_, Some(succ_cursor)) => Terminator { attrs: AttrSet::default(), kind: Cow::Owned(cf::unstructured::ControlInstKind::Branch), @@ -936,7 +942,7 @@ impl<'a> FuncLifting<'a> { } = &blocks[&target]; (phis.is_empty() - && insts.iter().all(|insts| insts.is_empty()) + && insts.is_empty() && *attrs == AttrSet::default() && matches!(**kind, cf::unstructured::ControlInstKind::Branch) && inputs.is_empty() @@ -1045,9 +1051,7 @@ impl<'a> FuncLifting<'a> { let all_insts_with_output = blocks .values() .flat_map(|block| block.insts.iter().copied()) - .flat_map(|insts| func_def_body.at(insts)) - .filter(|&func_at_inst| !func_at_inst.def().outputs.is_empty()) - .map(|func_at_inst| func_at_inst.position); + .filter(|&inst| !func_def_body.at(inst).def().outputs.is_empty()); Ok(Self { func_id, @@ -1181,14 +1185,15 @@ impl LazyInst<'_, '_> { } } Value::NodeOutput { node, output_idx } => { - parent_func.blocks[&CfgPoint::NodeExit(node)].phis - [usize::try_from(output_idx).unwrap()] - .result_id - } - Value::DataInstOutput { inst, output_idx } => { - // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. - assert_eq!(output_idx, 0); - parent_func.data_inst_output_ids[&inst] + if let Some(&data_inst_output_id) = parent_func.data_inst_output_ids.get(&node) { + // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. + assert_eq!(output_idx, 0); + data_inst_output_id + } else { + parent_func.blocks[&CfgPoint::NodeExit(node)].phis + [usize::try_from(output_idx).unwrap()] + .result_id + } } }; @@ -1328,10 +1333,9 @@ impl LazyInst<'_, '_> { }, Self::DataInst { parent_func, result_id: _, data_inst_def } => { let (inst, extra_initial_id_operand) = match &data_inst_def.kind { - NodeKind::Block { .. } - | NodeKind::Select(_) - | NodeKind::Loop { .. } - | NodeKind::ExitInvocation(_) => unreachable!(), + NodeKind::Select(_) | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) => { + unreachable!() + } DataInstKind::Mem(_) | DataInstKind::QPtr(_) => { // Disallowed while visiting. @@ -1525,27 +1529,17 @@ impl Module { phis.iter() .map(|phi| LazyInst::OpPhi { parent_func: func_lifting, phi }), ) - .chain( - insts - .iter() - .copied() - .flat_map(move |insts| func_def_body.unwrap().at(insts)) - .map(move |func_at_inst| { - let data_inst_def = func_at_inst.def(); - LazyInst::DataInst { - parent_func: func_lifting, - // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. - result_id: (data_inst_def.outputs.iter().at_most_one()) - .ok() - .unwrap() - .map(|_| { - func_lifting.data_inst_output_ids - [&func_at_inst.position] - }), - data_inst_def, - } - }), - ) + .chain(insts.iter().copied().map(move |inst| { + let data_inst_def = func_def_body.unwrap().at(inst).def(); + LazyInst::DataInst { + parent_func: func_lifting, + // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. + result_id: (data_inst_def.outputs.iter().at_most_one().ok()) + .unwrap() + .map(|_| func_lifting.data_inst_output_ids[&inst]), + data_inst_def, + } + })) .chain(terminator.merge.map(|merge| { LazyInst::Merge(match merge { Merge::Selection(merge) => { diff --git a/src/spv/lower.rs b/src/spv/lower.rs index c6a3930b..bbed5e1e 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -5,10 +5,10 @@ use crate::spv::{self, spec}; // FIXME(eddyb) import more to avoid `crate::` everywhere. use crate::{ AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, DataInstDef, DataInstKind, - DbgSrcLoc, DeclDef, Diag, EntityDefs, EntityList, ExportKey, Exportee, Func, FuncDecl, - FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, GlobalVarDefBody, Import, InternedStr, - Module, NodeDef, NodeKind, NodeOutputDecl, Region, RegionDef, RegionInputDecl, Type, TypeDef, - TypeKind, TypeOrConst, Value, print, + DbgSrcLoc, DeclDef, Diag, EntityDefs, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, + FuncParam, FxIndexMap, GlobalVarDecl, GlobalVarDefBody, Import, InternedStr, Module, + NodeOutputDecl, Region, RegionDef, RegionInputDecl, Type, TypeDef, TypeKind, TypeOrConst, + Value, print, }; use rustc_hash::FxHashMap; use smallvec::SmallVec; @@ -1062,7 +1062,7 @@ impl Module { } .into(), ); - LocalIdDef::Value(Value::DataInstOutput { inst, output_idx: 0 }) + LocalIdDef::Value(Value::NodeOutput { node: inst, output_idx: 0 }) } }; local_id_defs.insert(id, local_id_def); @@ -1645,7 +1645,7 @@ impl Module { }; let inst = match result_id { Some(id) => match local_id_defs[&id] { - LocalIdDef::Value(Value::DataInstOutput { inst, .. }) => { + LocalIdDef::Value(Value::NodeOutput { node: inst, .. }) => { // A dummy was defined earlier, to be able to // have an entry in `local_id_defs`. func_def_body.nodes[inst] = data_inst_def.into(); @@ -1657,38 +1657,7 @@ impl Module { None => func_def_body.nodes.define(&cx, data_inst_def.into()), }; - let current_block_node = current_block_region_def - .children - .iter() - .last - .filter(|&last_node| { - matches!(func_def_body.nodes[last_node].kind, NodeKind::Block { .. }) - }) - .unwrap_or_else(|| { - let block_node = func_def_body.nodes.define( - &cx, - NodeDef { - attrs: AttrSet::default(), - kind: NodeKind::Block { insts: EntityList::empty() }, - inputs: SmallVec::new(), - child_regions: SmallVec::new(), - outputs: SmallVec::new(), - } - .into(), - ); - current_block_region_def - .children - .insert_last(block_node, &mut func_def_body.nodes); - block_node - }); - match func_def_body.nodes[current_block_node].kind { - NodeKind::Block { mut insts } => { - insts.insert_last(inst, &mut func_def_body.nodes); - func_def_body.nodes[current_block_node].kind = - NodeKind::Block { insts }; - } - _ => unreachable!(), - } + current_block_region_def.children.insert_last(inst, &mut func_def_body.nodes); } } diff --git a/src/transform.rs b/src/transform.rs index 8b796ed0..238fddef 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -5,11 +5,11 @@ use crate::func_at::FuncAtMut; use crate::mem::{DataHapp, DataHappKind, MemAccesses, MemAttr, MemOp}; use crate::qptr::{QPtrAttr, QPtrOp}; use crate::{ - AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, DataInst, DataInstKind, - DbgSrcLoc, DeclDef, EntityListIter, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, - FuncParam, GlobalVar, GlobalVarDecl, GlobalVarDefBody, Import, Module, ModuleDebugInfo, - ModuleDialect, Node, NodeDef, NodeKind, NodeOutputDecl, OrdAssertEq, Region, RegionDef, - RegionInputDecl, Type, TypeDef, TypeKind, TypeOrConst, Value, spv, + AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, DataInstKind, DbgSrcLoc, + DeclDef, EntityListIter, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, + GlobalVar, GlobalVarDecl, GlobalVarDefBody, Import, Module, ModuleDebugInfo, ModuleDialect, + Node, NodeDef, NodeKind, NodeOutputDecl, OrdAssertEq, Region, RegionDef, RegionInputDecl, Type, + TypeDef, TypeKind, TypeOrConst, Value, spv, }; use std::cmp::Ordering; use std::rc::Rc; @@ -199,9 +199,6 @@ pub trait Transformer: Sized { fn in_place_transform_node_def(&mut self, mut func_at_node: FuncAtMut<'_, Node>) { func_at_node.inner_in_place_transform_with(self); } - fn in_place_transform_data_inst_def(&mut self, mut func_at_data_inst: FuncAtMut<'_, DataInst>) { - func_at_data_inst.inner_in_place_transform_with(self); - } } /// Trait implemented on "transformable" types, to further "elaborate" a type by @@ -637,13 +634,6 @@ impl InnerInPlaceTransform for FuncAtMut<'_, Node> { transformer.transform_attr_set_use(*attrs).apply_to(attrs); match kind { - &mut NodeKind::Block { insts } => { - let mut func_at_inst_iter = self.reborrow().at(insts).into_iter(); - while let Some(func_at_inst) = func_at_inst_iter.next() { - transformer.in_place_transform_data_inst_def(func_at_inst); - } - } - DataInstKind::FuncCall(func) => transformer.transform_func_use(*func).apply_to(func), NodeKind::Select(SelectionKind::BoolCond | SelectionKind::SpvInst(_)) @@ -733,8 +723,7 @@ impl InnerTransform for Value { } => Self::Const(ct)), Self::RegionInput { region: _, input_idx: _ } - | Self::NodeOutput { node: _, output_idx: _ } - | Self::DataInstOutput { inst: _, output_idx: _ } => Transformed::Unchanged, + | Self::NodeOutput { node: _, output_idx: _ } => Transformed::Unchanged, } } } diff --git a/src/visit.rs b/src/visit.rs index 25cfabfa..77e45c69 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -5,11 +5,11 @@ use crate::func_at::FuncAt; use crate::mem::{DataHapp, DataHappKind, MemAccesses, MemAttr, MemOp}; use crate::qptr::{QPtrAttr, QPtrOp}; use crate::{ - AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, DataInst, DataInstKind, - DbgSrcLoc, DeclDef, DiagMsgPart, EntityListIter, ExportKey, Exportee, Func, FuncDecl, - FuncDefBody, FuncParam, GlobalVar, GlobalVarDecl, GlobalVarDefBody, Import, Module, - ModuleDebugInfo, ModuleDialect, Node, NodeDef, NodeKind, NodeOutputDecl, OrdAssertEq, Region, - RegionDef, RegionInputDecl, Type, TypeDef, TypeKind, TypeOrConst, Value, spv, + AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, DataInstKind, DbgSrcLoc, + DeclDef, DiagMsgPart, EntityListIter, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, + FuncParam, GlobalVar, GlobalVarDecl, GlobalVarDefBody, Import, Module, ModuleDebugInfo, + ModuleDialect, Node, NodeDef, NodeKind, NodeOutputDecl, OrdAssertEq, Region, RegionDef, + RegionInputDecl, Type, TypeDef, TypeKind, TypeOrConst, Value, spv, }; // FIXME(eddyb) `Sized` bound shouldn't be needed but removing it requires @@ -65,9 +65,6 @@ pub trait Visitor<'a>: Sized { fn visit_node_def(&mut self, func_at_node: FuncAt<'a, Node>) { func_at_node.inner_visit_with(self); } - fn visit_data_inst_def(&mut self, func_at_inst: FuncAt<'a, DataInst>) { - func_at_inst.inner_visit_with(self); - } fn visit_value_use(&mut self, v: &'a Value) { v.inner_visit_with(self); } @@ -478,12 +475,6 @@ impl<'a> FuncAt<'a, Node> { visitor.visit_attr_set_use(*attrs); match kind { - NodeKind::Block { insts } => { - for func_at_inst in self.at(*insts) { - visitor.visit_data_inst_def(func_at_inst); - } - } - &DataInstKind::FuncCall(func) => visitor.visit_func_use(func), NodeKind::Select(SelectionKind::BoolCond | SelectionKind::SpvInst(_)) @@ -559,8 +550,7 @@ impl InnerVisit for Value { match *self { Self::Const(ct) => visitor.visit_const_use(ct), Self::RegionInput { region: _, input_idx: _ } - | Self::NodeOutput { node: _, output_idx: _ } - | Self::DataInstOutput { inst: _, output_idx: _ } => {} + | Self::NodeOutput { node: _, output_idx: _ } => {} } } }