diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 9111a098..48137c0c 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -11,7 +11,7 @@ use crate::{ NodeKind, Region, RegionDef, Type, TypeDef, TypeKind, TypeOrConst, Value, VarDecl, print, }; use itertools::{Either, Itertools as _}; -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use smallvec::SmallVec; use std::collections::{BTreeMap, BTreeSet}; use std::num::NonZeroU32; @@ -933,10 +933,20 @@ impl Module { return Err(invalid("OpFunction without matching OpFunctionEnd")); } + let entry_points: FxHashSet<_> = pending_exports + .iter() + .filter_map(|export| match *export { + Export::Linkage { .. } => None, + Export::EntryPoint { func_id, .. } => Some(func_id), + }) + .collect(); + // Process function bodies, having seen the whole module. for func_body in pending_func_bodies { let FuncBody { func_id, func, insts: raw_insts } = func_body; + let func_is_entry_point = entry_points.contains(&func_id); + let func_decl = &mut module.funcs[func]; #[derive(PartialEq, Eq, Hash)] @@ -1119,6 +1129,19 @@ impl Module { .map(|(®ion, details)| (details.label_id, LocalIdDef::BlockLabel(region))), ); + // HACK(eddyb) in order to reduce restructurization costs with lots + // of conditional returns, wherever possible (i.e. in entry-points), + // `OpReturn` gets turned into an `ExitInvocation`, but to avoid + // doing that to the "true return" at the "end" of the function, + // SPIR-V structured merge annotations are used to try to find that + // one block which should end in an `OpReturn` (notably, all of this + // is for SPIR-V that already has some structured control-flow, and + // not the kind of SPIR-V that e.g. Rust-GPU might generate). + // FIXME(eddyb) use a "structured control-flow recovery" analysis to + // make this more principled. + let mut whole_func_merge = + func_def_body.as_ref().map(|func_def_body| func_def_body.body); + // HACK(eddyb) an entire separate traversal is required to find // all inter-block uses, before any blocks get lowered to SPIR-T. let mut cfgssa_use_accumulator = cfgssa_def_map @@ -1169,6 +1192,12 @@ impl Module { // closest dominator of a merge, that merge could contain // uses that don't belong/are illegal in `current_block`. if [wk.OpSelectionMerge, wk.OpLoopMerge].contains(&opcode) { + if whole_func_merge == Some(current_block) + && let Some(&LocalIdDef::BlockLabel(merge_block)) = + local_id_defs.get(&raw_inst.ids[0]) + { + whole_func_merge = Some(merge_block); + } continue; } @@ -1179,6 +1208,14 @@ impl Module { // (which are already special-cased above). if let Some(&LocalIdDef::BlockLabel(target_block)) = local_id_defs.get(&id) { + if whole_func_merge == Some(current_block) { + // HACK(eddyb) always replacing `whole_func_merge` + // detects unstructured control-flow, and avoids + // keeping around some intermediary block which + // happened to be e.g. in an `OpBranch` chain, + // but doesn't end in `OpReturn`/`ExitInvocation`. + whole_func_merge = (opcode == wk.OpBranch).then_some(target_block); + } use_acc.add_edge(current_block, target_block); } else { // HACK(eddyb) this heavily relies on `add_use(_, id)` @@ -1539,6 +1576,13 @@ impl Module { None }; + // HACK(eddyb) see comment on `whole_func_merge`. + let treat_return_as_exit_invocation = opcode == wk.OpReturn + && func_is_entry_point + && whole_func_merge.is_some_and(|whole_func_merge| { + whole_func_merge != current_block.region + }); + let target_thunk = if let Some(selection_kind) = selection_kind { let cases = targets_with_inputs .map(|target_with_inputs| { @@ -1578,7 +1622,9 @@ impl Module { func_def_body.nodes[select_node].outputs.push(select_thunk_var); Value::Var(select_thunk_var) - } else if [wk.OpReturn, wk.OpReturnValue].contains(&opcode) { + } else if [wk.OpReturn, wk.OpReturnValue].contains(&opcode) + && !treat_return_as_exit_invocation + { assert!(targets_with_inputs.len() == 0 && inputs.len() <= 1); build_thunk( func_def_body.at_mut(current_block.region),