diff --git a/Cargo.lock b/Cargo.lock index d8fb8986..010d6d86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,12 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + [[package]] name = "bytemuck" version = "1.18.0" @@ -195,6 +201,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc_apfloat" +version = "0.2.3+llvm-462a31f5a5ab" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "486c2179b4796f65bfe2ee33679acf0927ac83ecf583ad6c91c3b4570911b9ad" +dependencies = [ + "bitflags", + "smallvec", +] + [[package]] name = "rustc_version" version = "0.4.1" @@ -271,6 +287,7 @@ dependencies = [ "lazy_static", "longest-increasing-subsequence", "rustc-hash", + "rustc_apfloat", "serde", "serde_json", "smallvec", diff --git a/Cargo.toml b/Cargo.toml index 58eb221d..f03c058a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ itertools = "0.10.3" lazy_static = "1.4.0" longest-increasing-subsequence = "0.1.0" rustc-hash = "1.1.0" +rustc_apfloat = "0.2.3" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" smallvec = { version = "1.7.0", features = ["serde", "union"] } diff --git a/README.md b/README.md index 9db56ee4..6628eaae 100644 --- a/README.md +++ b/README.md @@ -137,10 +137,10 @@ global_var GV0 in spv.StorageClass.Output: s32 func F0() -> spv.OpTypeVoid { (_: s32, _: s32, v0: s32) = loop(v1: s32 <- 1s32, v2: s32 <- 1s32, _: s32 <- undef: s32) { - v3 = spv.OpSLessThan(v2, 10s32): bool + v3 = s.lt(v2, 10s32): bool (v4: s32, v5: s32) = if v3 { - v6 = spv.OpIMul(v1, v2): s32 - v7 = spv.OpIAdd(v2, 1s32): s32 + v6 = i.mul(v1, v2): s32 + v7 = i.add(v2, 1s32): s32 (v6, v7) } else { (undef: s32, undef: s32) diff --git a/deny.toml b/deny.toml index 26f4d0d2..af60bf9c 100644 --- a/deny.toml +++ b/deny.toml @@ -20,6 +20,7 @@ allow = [ "MIT", "Apache-2.0", "Unicode-DFS-2016", + "Apache-2.0 WITH LLVM-exception", ] # This section is considered when running `cargo deny check bans`. diff --git a/src/cf/mod.rs b/src/cf/mod.rs index f2131720..406660e5 100644 --- a/src/cf/mod.rs +++ b/src/cf/mod.rs @@ -2,7 +2,7 @@ // // FIXME(eddyb) consider moving more definitions into this module. -use crate::spv; +use crate::{scalar, spv}; // NOTE(eddyb) all the modules are declared here, but they're documented "inside" // (i.e. using inner doc comments). @@ -10,13 +10,24 @@ pub mod cfgssa; pub mod structurize; pub mod unstructured; +// FIXME(eddyb) consider interning this. #[derive(Clone, PartialEq, Eq, Hash)] pub enum SelectionKind { /// Two-case selection based on boolean condition, i.e. `if`-`else`, with /// the two cases being "then" and "else" (in that order). BoolCond, - SpvInst(spv::Inst), + /// `N+1`-case selection based on comparing an integer scrutinee against + /// `N` constants, i.e. `switch`, with the last case being the "default" + /// (making it the only case without a matching entry in `case_consts`). + Switch { + // FIXME(eddyb) avoid some of the `scalar::Const` overhead here, as there + // is only a single type and we shouldn't need to store more bits per case, + // than the actual width of the integer type. + // FIXME(eddyb) consider storing this more like sorted compressed keyset, + // as there can be no duplicates, and in many cases it may be contiguous. + case_consts: Vec, + }, } #[derive(Clone, PartialEq, Eq, Hash)] diff --git a/src/cf/structurize.rs b/src/cf/structurize.rs index 14a12674..082956cf 100644 --- a/src/cf/structurize.rs +++ b/src/cf/structurize.rs @@ -9,7 +9,7 @@ use crate::transform::{InnerInPlaceTransform as _, Transformed, Transformer}; use crate::{ AttrSet, Const, ConstDef, ConstKind, Context, DbgSrcLoc, EntityOrientedDenseMap, FuncDefBody, FxIndexMap, FxIndexSet, Node, NodeDef, NodeKind, Region, RegionDef, Type, TypeKind, Value, Var, - VarDecl, VarKind, spv, + VarDecl, VarKind, scalar, }; use itertools::{Either, Itertools}; use smallvec::SmallVec; @@ -555,32 +555,9 @@ impl<'a> Structurizer<'a> { unreachable!(); }; - // FIXME(eddyb) SPIR-T should have native booleans itself. - let wk = &spv::spec::Spec::get().well_known; - let type_bool = cx.intern(TypeKind::SpvInst { - spv_inst: wk.OpTypeBool.into(), - type_and_const_inputs: [].into_iter().collect(), - }); - let const_true = cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: type_bool, - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new(( - wk.OpConstantTrue.into(), - [].into_iter().collect(), - )), - }, - }); - let const_false = cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: type_bool, - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new(( - wk.OpConstantFalse.into(), - [].into_iter().collect(), - )), - }, - }); + let type_bool = cx.intern(scalar::Type::Bool); + let const_true = cx.intern(scalar::Const::TRUE); + let const_false = cx.intern(scalar::Const::FALSE); let (loop_header_to_exit_targets, incoming_edge_counts_including_loop_exits) = func_def_body @@ -625,7 +602,9 @@ impl<'a> Structurizer<'a> { func_ret_types: { let is_void = match &cx[func_decl.ret_type].kind { - TypeKind::SpvInst { spv_inst, .. } => spv_inst.opcode == wk.OpTypeVoid, + TypeKind::SpvInst { spv_inst, .. } => { + spv_inst.opcode == crate::spv::spec::Spec::get().well_known.OpTypeVoid + } _ => false, }; if is_void { &[][..] } else { std::slice::from_ref(&func_decl.ret_type) } diff --git a/src/context.rs b/src/context.rs index 1e1a28dd..a0679232 100644 --- a/src/context.rs +++ b/src/context.rs @@ -608,6 +608,49 @@ impl>, D> EntityList { } } + /// Insert `new_node` (defined in `defs`) into `self`, after `prev`. + // + // FIXME(eddyb) unify this with the other insert methods, maybe with a new + // "insert position" type? + #[track_caller] + pub fn insert_after(&mut self, new_node: E, prev: E, defs: &mut EntityDefs) { + let next = defs[prev].next.replace(new_node); + + let new_node_def = &mut defs[new_node]; + assert!( + new_node_def.next.is_none() && new_node_def.prev.is_none(), + "EntityList::insert_before: new node already linked into a (different?) list" + ); + + new_node_def.next = next; + new_node_def.prev = Some(prev); + + match next { + Some(next) => { + let old_next_prev = defs[next].prev.replace(new_node); + + // FIXME(eddyb) this situation should be impossible anyway, as it + // involves the `EntityListNode`s links, which should be unforgeable. + assert!( + old_next_prev == Some(prev), + "invalid EntityListNode: `node->next->prev != node`" + ); + } + None => { + // FIXME(eddyb) this situation should be impossible anyway, as it + // involves the `EntityListNode`s links, which should be unforgeable, + // but it's still possible to keep around outdated `EntityList`s + // (should `EntityList` not implement `Copy`/`Clone` *at all*?) + assert!( + self.0.map(|this| this.last) == Some(prev), + "invalid EntityList: `node->next == None` but `node != last`" + ); + + self.0.as_mut().unwrap().last = new_node; + } + } + } + /// Insert all of `list_to_prepend`'s nodes at the start of `self`. #[track_caller] pub fn prepend(&mut self, list_to_prepend: Self, defs: &mut EntityDefs) { diff --git a/src/lib.rs b/src/lib.rs index c5b3a120..76731bb4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -169,6 +169,7 @@ pub mod passes { } pub mod mem; pub mod qptr; +pub mod scalar; pub mod spv; use smallvec::SmallVec; @@ -526,16 +527,23 @@ impl Ord for OrdAssertEq { pub use context::Type; /// Definition for a [`Type`]. -// -// FIXME(eddyb) maybe special-case some basic types like integers. #[derive(PartialEq, Eq, Hash)] pub struct TypeDef { pub attrs: AttrSet, pub kind: TypeKind, } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] pub enum TypeKind { + /// Scalar (`bool`, integer, and floating-point) type, with limitations + /// on the supported bit-widths (power-of-two multiples of a byte). + /// + /// **Note**: pointers are never scalars (like SPIR-V, but unlike other IRs). + /// + /// See also the [`scalar`] module for more documentation and definitions. + #[from] + Scalar(scalar::Type), + /// "Quasi-pointer", an untyped pointer-like abstract scalar that can represent /// both memory locations (in any address space) and other kinds of locations /// (e.g. SPIR-V `OpVariable`s in non-memory "storage classes"). @@ -566,12 +574,18 @@ pub enum TypeKind { SpvStringLiteralForExtInst, } -// HACK(eddyb) this behaves like an implicit conversion for `cx.intern(...)`. -impl context::InternInCx for TypeKind { - fn intern_in_cx(self, cx: &Context) -> Type { - cx.intern(TypeDef { attrs: Default::default(), kind: self }) +// HACK(eddyb) this behaves like an implicit conversion for `cx.intern(...)`, +// and the macro is only used because coherence bans `impl>`. +macro_rules! impl_intern_type_kind { + ($($kind:ty),+ $(,)?) => { + $(impl context::InternInCx for $kind { + fn intern_in_cx(self, cx: &Context) -> Type { + cx.intern(TypeDef { attrs: Default::default(), kind: self.into() }) + } + })+ } } +impl_intern_type_kind!(TypeKind, scalar::Type); // HACK(eddyb) this is like `Either`, only used in `TypeKind::SpvInst`, // and only because SPIR-V type definitions can references both types and consts. @@ -581,6 +595,16 @@ pub enum TypeOrConst { Const(Const), } +// HACK(eddyb) on `Type` instead of `TypeDef` for ergonomics reasons. +impl Type { + pub fn as_scalar(self, cx: &Context) -> Option { + match cx[self].kind { + TypeKind::Scalar(ty) => Some(ty), + _ => None, + } + } +} + /// Interned handle for a [`ConstDef`](crate::ConstDef) (a constant value). pub use context::Const; @@ -594,7 +618,7 @@ pub struct ConstDef { pub kind: ConstKind, } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] pub enum ConstKind { /// Undeterminate value (i.e. SPIR-V `OpUndef`, LLVM `undef`). // @@ -602,6 +626,18 @@ pub enum ConstKind { // model, without being forced to never lift back to `OpUndef`? Undef, + /// Scalar (`bool`, integer, and floating-point) constant, which must have + /// a type of [`TypeKind::Scalar`] (of the same [`scalar::Type`]). + /// + /// See also the [`scalar`] module for more documentation and definitions. + // + // FIXME(eddyb) maybe document the 128-bit limitation?. + // FIXME(eddyb) this technically makes the `scalar::Type` redundant, could + // it get out of sync? (perhaps "forced canonicalization" could be used to + // enforce that interning simply doesn't allow such scenarios?). + #[from] + Scalar(scalar::Const), + // FIXME(eddyb) maybe merge these? however, their connection is somewhat // tenuous (being one of the LLVM-isms SPIR-V inherited, among other things), // there's still the need to rename "global variable" post-`Var`-refactor, @@ -622,6 +658,34 @@ pub enum ConstKind { SpvStringLiteralForExtInst(InternedStr), } +// HACK(eddyb) this behaves like an implicit conversion for `cx.intern(...)`, +// like the `TypeKind` one, but this one is even weirder because it also interns +// the inherent type of the constant, as a `Type` (with empty attributes). +macro_rules! impl_intern_const_kind { + ($($kind:ty),+ $(,)?) => { + $(impl context::InternInCx for $kind { + fn intern_in_cx(self, cx: &Context) -> Const { + cx.intern(ConstDef { + attrs: Default::default(), + ty: cx.intern(self.ty()), + kind: self.into(), + }) + } + })+ + } +} +impl_intern_const_kind!(scalar::Const); + +// HACK(eddyb) on `Const` instead of `ConstDef` for ergonomics reasons. +impl Const { + pub fn as_scalar(self, cx: &Context) -> Option<&scalar::Const> { + match &cx[self].kind { + ConstKind::Scalar(ct) => Some(ct), + _ => None, + } + } +} + /// Declarations ([`GlobalVarDecl`], [`FuncDecl`]) can contain a full definition, /// or only be an import of a definition (e.g. from another module). #[derive(Clone)] @@ -898,6 +962,12 @@ pub enum NodeKind { // NOTE(eddyb) all variants below used to be in `DataInstKind`. // + /// Scalar (`bool`, integer, and floating-point) pure operations. + /// + /// See also the [`scalar`] module for more documentation and definitions. + #[from] + Scalar(scalar::Op), + FuncCall(Func), /// Memory-specific operations (see [`mem::MemOp`]). diff --git a/src/mem/analyze.rs b/src/mem/analyze.rs index 762ff2a8..e1551fbf 100644 --- a/src/mem/analyze.rs +++ b/src/mem/analyze.rs @@ -931,7 +931,8 @@ impl<'a> GatherAccesses<'a> { continue; } - DataInstKind::FuncCall(_) + DataInstKind::Scalar(_) + | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) | DataInstKind::ThunkBind(_) @@ -950,6 +951,8 @@ impl<'a> GatherAccesses<'a> { unreachable!() } + DataInstKind::Scalar(_) => {} + &DataInstKind::FuncCall(callee) => { match self.gather_accesses_in_func(module, callee) { FuncGatherAccessesState::Complete(callee_results) => { diff --git a/src/mem/layout.rs b/src/mem/layout.rs index 9687afab..43c68c17 100644 --- a/src/mem/layout.rs +++ b/src/mem/layout.rs @@ -2,7 +2,7 @@ use crate::mem::shapes; use crate::{ - AddrSpace, Attr, Const, ConstKind, Context, Diag, FxIndexMap, Type, TypeKind, TypeOrConst, spv, + AddrSpace, Attr, Const, Context, Diag, FxIndexMap, Type, TypeKind, TypeOrConst, scalar, spv, }; use itertools::Either; use smallvec::SmallVec; @@ -182,18 +182,10 @@ impl<'a> LayoutCache<'a> { Self { cx, wk: &spv::spec::Spec::get().well_known, config, cache: Default::default() } } - // FIXME(eddyb) properly distinguish between zero-extension and sign-extension. fn const_as_u32(&self, ct: Const) -> Option { - if let ConstKind::SpvInst { spv_inst_and_const_inputs } = &self.cx[ct].kind { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - if spv_inst.opcode == self.wk.OpConstant && spv_inst.imms.len() == 1 { - match spv_inst.imms[..] { - [spv::Imm::Short(_, x)] => return Some(x), - _ => unreachable!(), - } - } - } - None + // HACK(eddyb) lossless roundtrip through `i32` is most conservative + // option (only `0..=i32::MAX`, i.e. `0 <= x < 2**32, is allowed). + u32::try_from(ct.as_scalar(&self.cx)?.int_as_i32()?).ok() } /// Attempt to compute a `TypeLayout` for a given (SPIR-V) `Type`. @@ -202,29 +194,16 @@ impl<'a> LayoutCache<'a> { return Ok(cached); } + let layout = self.layout_of_uncached(ty)?; + self.cache.borrow_mut().insert(ty, layout.clone()); + Ok(layout) + } + + fn layout_of_uncached(&self, ty: Type) -> Result { let cx = &self.cx; let wk = self.wk; let ty_def = &cx[ty]; - let (spv_inst, type_and_const_inputs) = match &ty_def.kind { - // FIXME(eddyb) treat `QPtr`s as scalars. - TypeKind::QPtr => { - return Err(LayoutError(Diag::bug( - ["`layout_of(qptr)` (already lowered?)".into()], - ))); - } - TypeKind::Thunk => { - return Err(LayoutError(Diag::bug(["`layout_of(thunk)`".into()]))); - } - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { - (spv_inst, type_and_const_inputs) - } - TypeKind::SpvStringLiteralForExtInst => { - return Err(LayoutError(Diag::bug([ - "`layout_of(type_of(OpString<\"...\">))`".into() - ]))); - } - }; let scalar_with_size_and_align = |(size, align)| { TypeLayout::Concrete(Rc::new(MemTypeLayout { @@ -343,25 +322,46 @@ impl<'a> LayoutCache<'a> { } } }; - let short_imm_at = |i| match spv_inst.imms[i] { - spv::Imm::Short(_, x) => x, - _ => unreachable!(), - }; // FIXME(eddyb) !!! what if... types had a min/max size and then... // that would allow surrounding offsets to limit their size... but... ugh... // ugh this doesn't make any sense. maybe if the front-end specifies // offsets with "abstract types", it must configure `mem::layout`? - let layout = if spv_inst.opcode == wk.OpTypeBool { - // FIXME(eddyb) make this properly abstract instead of only configurable. - scalar_with_size_and_align(self.config.abstract_bool_size_align) - } else if spv_inst.opcode == wk.OpTypePointer { + + let (spv_inst, type_and_const_inputs) = match &ty_def.kind { + TypeKind::Scalar(scalar::Type::Bool) => { + // FIXME(eddyb) make this properly abstract instead of only configurable. + return Ok(scalar_with_size_and_align(self.config.abstract_bool_size_align)); + } + TypeKind::Scalar(ty) => return Ok(scalar(ty.bit_width())), + + // FIXME(eddyb) treat `QPtr`s as scalars. + TypeKind::QPtr => { + return Err(LayoutError(Diag::bug( + ["`layout_of(qptr)` (already lowered?)".into()], + ))); + } + TypeKind::Thunk => { + return Err(LayoutError(Diag::bug(["`layout_of(thunk)`".into()]))); + } + TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { + (spv_inst, type_and_const_inputs) + } + TypeKind::SpvStringLiteralForExtInst => { + return Err(LayoutError(Diag::bug([ + "`layout_of(type_of(OpString<\"...\">))`".into() + ]))); + } + }; + let short_imm_at = |i| match spv_inst.imms[i] { + spv::Imm::Short(_, x) => x, + _ => unreachable!(), + }; + Ok(if spv_inst.opcode == wk.OpTypePointer { // FIXME(eddyb) make this properly abstract instead of only configurable. // FIXME(eddyb) categorize `OpTypePointer` by storage class and split on // logical vs physical here. scalar_with_size_and_align(self.config.logical_ptr_size_align) - } else if [wk.OpTypeInt, wk.OpTypeFloat].contains(&spv_inst.opcode) { - scalar(short_imm_at(0)) } else if [wk.OpTypeVector, wk.OpTypeMatrix].contains(&spv_inst.opcode) { let len = short_imm_at(0); let (min_legacy_align, legacy_align_multiplier) = if spv_inst.opcode == wk.OpTypeVector @@ -646,8 +646,6 @@ impl<'a> LayoutCache<'a> { spv_inst.opcode.name() ) .into()]))); - }; - self.cache.borrow_mut().insert(ty, layout.clone()); - Ok(layout) + }) } } diff --git a/src/print/mod.rs b/src/print/mod.rs index c7358129..c868763e 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -31,7 +31,7 @@ use crate::{ EntityOrientedDenseMap, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDecl, GlobalVarDefBody, Import, InternedStr, Module, ModuleDebugInfo, ModuleDialect, Node, NodeDef, NodeKind, OrdAssertEq, Region, - RegionDef, Type, TypeDef, TypeKind, TypeOrConst, Value, Var, VarDecl, spv, + RegionDef, Type, TypeDef, TypeKind, TypeOrConst, Value, Var, VarDecl, scalar, spv, }; use arrayvec::ArrayVec; use itertools::Either; @@ -1091,17 +1091,12 @@ impl<'a> Printer<'a> { // here and `TypeDef`'s `Print` impl. let has_compact_print_or_is_leaf = match &ty_def.kind { TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { - [ - wk.OpTypeBool, - wk.OpTypeInt, - wk.OpTypeFloat, - wk.OpTypeVector, - ] - .contains(&spv_inst.opcode) + spv_inst.opcode == wk.OpTypeVector || type_and_const_inputs.is_empty() } - TypeKind::QPtr + TypeKind::Scalar(_) + | TypeKind::QPtr | TypeKind::Thunk | TypeKind::SpvStringLiteralForExtInst => true, }; @@ -1112,28 +1107,16 @@ impl<'a> Printer<'a> { CxInterned::Const(ct) => { let ct_def = &cx[ct]; - // FIXME(eddyb) remove the duplication between - // here and `ConstDef`'s `Print` impl. - let (has_compact_print, has_nested_consts) = match &ct_def.kind - { + let has_nested_consts = match &ct_def.kind { ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, const_inputs) = + let (_spv_inst, const_inputs) = &**spv_inst_and_const_inputs; - ( - [ - wk.OpConstantFalse, - wk.OpConstantTrue, - wk.OpConstant, - ] - .contains(&spv_inst.opcode), - !const_inputs.is_empty(), - ) + !const_inputs.is_empty() } - _ => (false, false), + _ => false, }; - ct_def.attrs == AttrSet::default() - && (has_compact_print || !has_nested_consts) + ct_def.attrs == AttrSet::default() && !has_nested_consts } } } @@ -3066,30 +3049,13 @@ impl Print for TypeDef { // FIXME(eddyb) should this be done by lowering SPIR-V types to SPIR-T? let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); + #[allow(irrefutable_let_patterns)] let compact_def = if let &TypeKind::SpvInst { spv_inst: spv::Inst { opcode, ref imms }, ref type_and_const_inputs, } = kind { - if opcode == wk.OpTypeBool { - Some(kw("bool".into())) - } else if opcode == wk.OpTypeInt { - let (width, signed) = match imms[..] { - [spv::Imm::Short(_, width), spv::Imm::Short(_, signedness)] => { - (width, signedness != 0) - } - _ => unreachable!(), - }; - - Some(if signed { kw(format!("s{width}")) } else { kw(format!("u{width}")) }) - } else if opcode == wk.OpTypeFloat { - let width = match imms[..] { - [spv::Imm::Short(_, width)] => width, - _ => unreachable!(), - }; - - Some(kw(format!("f{width}"))) - } else if opcode == wk.OpTypeVector { + if opcode == wk.OpTypeVector { let (elem_ty, elem_count) = match (&imms[..], &type_and_const_inputs[..]) { (&[spv::Imm::Short(_, elem_count)], &[TypeOrConst::Type(elem_ty)]) => { (elem_ty, elem_count) @@ -3115,6 +3081,16 @@ impl Print for TypeDef { def } else { match kind { + TypeKind::Scalar(ty) => { + let width = ty.bit_width(); + kw(match ty { + scalar::Type::Bool => "bool".into(), + scalar::Type::SInt(_) => format!("s{width}"), + scalar::Type::UInt(_) => format!("u{width}"), + scalar::Type::Float(_) => format!("f{width}"), + }) + } + // FIXME(eddyb) should this be shortened to `qtr`? TypeKind::QPtr => printer.declarative_keyword_style().apply("qptr").into(), @@ -3150,87 +3126,32 @@ impl Print for ConstDef { let wk = &spv::spec::Spec::get().well_known; let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); - let literal_ty_suffix = |ty| { - pretty::Styles { - // HACK(eddyb) the exact type detracts from the value. - color_opacity: Some(0.4), - subscript: true, - ..printer.declarative_keyword_style() - } - .apply(ty) - }; - let compact_def = if let ConstKind::SpvInst { spv_inst_and_const_inputs } = kind { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - let &spv::Inst { opcode, ref imms } = spv_inst; - - if opcode == wk.OpConstantFalse { - Some(kw("false")) - } else if opcode == wk.OpConstantTrue { - Some(kw("true")) - } else if opcode == wk.OpConstant { - // HACK(eddyb) it's simpler to only handle a limited subset of - // integer/float bit-widths, for now. - let raw_bits = match imms[..] { - [spv::Imm::Short(_, x)] => Some(u64::from(x)), - [spv::Imm::LongStart(_, lo), spv::Imm::LongCont(_, hi)] => { - Some(u64::from(lo) | (u64::from(hi) << 32)) - } - _ => None, - }; - - if let ( - Some(raw_bits), - &TypeKind::SpvInst { - spv_inst: spv::Inst { opcode: ty_opcode, imms: ref ty_imms }, - .. - }, - ) = (raw_bits, &printer.cx[*ty].kind) - { - if ty_opcode == wk.OpTypeInt { - let (width, signed) = match ty_imms[..] { - [spv::Imm::Short(_, width), spv::Imm::Short(_, signedness)] => { - (width, signedness != 0) - } - _ => unreachable!(), - }; - - if width <= 64 { - let (printed_value, ty) = if signed { - let sext_raw_bits = - (raw_bits as u128 as i128) << (128 - width) >> (128 - width); - // FIXME(eddyb) consider supporting negative hex. - ( - if sext_raw_bits >= 0 { - printer.pretty_numeric_literal_as_dec_or_hex( - sext_raw_bits as u128, - ) - } else { - printer - .numeric_literal_style() - .apply(format!("{sext_raw_bits}")) - .into() - }, - format!("s{width}"), - ) - } else { - ( - printer.pretty_numeric_literal_as_dec_or_hex(raw_bits.into()), - format!("u{width}"), - ) - }; - Some(pretty::Fragment::new([ - printed_value, - literal_ty_suffix(ty).into(), - ])) - } else { - None - } - } else if ty_opcode == wk.OpTypeFloat { - let width = match ty_imms[..] { - [spv::Imm::Short(_, width)] => width, - _ => unreachable!(), - }; + let def_without_name = match kind { + ConstKind::Undef => pretty::Fragment::new([ + printer.imperative_keyword_style().apply("undef").into(), + printer.pretty_type_ascription_suffix(*ty), + ]), + ConstKind::Scalar(scalar::Const::FALSE) => kw("false"), + ConstKind::Scalar(scalar::Const::TRUE) => kw("true"), + ConstKind::Scalar(ct) => { + let ty = ct.ty(); + let width = ty.bit_width(); + let (maybe_printed_value, ty_prefix) = match ty { + scalar::Type::Bool => unreachable!(), + scalar::Type::SInt(_) => ( + // FIXME(eddyb) consider supporting negative hex. + ct.int_as_i128().map(|x| match u128::try_from(x) { + Ok(x) => printer.pretty_numeric_literal_as_dec_or_hex(x), + Err(_) => printer.numeric_literal_style().apply(x.to_string()).into(), + }), + 's', + ), + scalar::Type::UInt(_) => ( + ct.int_as_u128().map(|x| printer.pretty_numeric_literal_as_dec_or_hex(x)), + 'u', + ), + scalar::Type::Float(_) => { /// Check that parsing the result of printing produces /// the original bits of the floating-point value, and /// only return `Some` if that is the case. @@ -3250,72 +3171,80 @@ impl Print for ConstDef { }) } - let printed_value = match width { - 32 => bitwise_roundtrip_float_print( - raw_bits as u32, - f32::from_bits, - f32::to_bits, - ), - 64 => bitwise_roundtrip_float_print( - raw_bits, - f64::from_bits, - f64::to_bits, - ), - _ => None, - }; - printed_value.map(|s| { - pretty::Fragment::new([ - printer.numeric_literal_style().apply(s), - literal_ty_suffix(format!("f{width}")), - ]) - }) - } else { - None + ( + match width { + 32 => bitwise_roundtrip_float_print( + ct.bits() as u32, + f32::from_bits, + f32::to_bits, + ), + 64 => bitwise_roundtrip_float_print( + ct.bits() as u64, + f64::from_bits, + f64::to_bits, + ), + _ => None, + } + .map(|s| printer.numeric_literal_style().apply(s).into()), + 'f', + ) } - } else { - None + }; + match maybe_printed_value { + Some(printed_value) => { + let literal_ty_suffix = pretty::Styles { + // HACK(eddyb) the exact type detracts from the value. + color_opacity: Some(0.4), + subscript: true, + ..printer.declarative_keyword_style() + } + .apply(format!("{ty_prefix}{width}")); + pretty::Fragment::new([printed_value, literal_ty_suffix.into()]) + } + // HACK(eddyb) fallback using the bitwise representation. + None => pretty::Fragment::new([ + printer + .demote_style_for_namespace_prefix(printer.declarative_keyword_style()) + .apply(format!("{ty_prefix}{width}.")) + .into(), + printer.declarative_keyword_style().apply("from_bits").into(), + pretty::join_comma_sep( + "(", + [ + // FIXME(eddyb) consider padding this with enough + // leading zeroes for its respective width. + printer.numeric_literal_style().apply(format!("0x{:x}", ct.bits())), + ], + ")", + ), + ]), } - } else { - None } - } else { - None - }; + &ConstKind::PtrToGlobalVar(gv) => { + pretty::Fragment::new(["&".into(), gv.print(printer)]) + } + &ConstKind::PtrToFunc(func) => pretty::Fragment::new(["&".into(), func.print(printer)]), - AttrsAndDef { - attrs: attrs.print(printer), - def_without_name: compact_def.unwrap_or_else(|| match kind { - ConstKind::Undef => pretty::Fragment::new([ - printer.imperative_keyword_style().apply("undef").into(), + ConstKind::SpvInst { spv_inst_and_const_inputs } => { + let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; + pretty::Fragment::new([ + printer.pretty_spv_inst( + printer.spv_op_style(), + spv_inst.opcode, + &spv_inst.imms, + const_inputs.iter().map(|ct| ct.print(printer)), + ), printer.pretty_type_ascription_suffix(*ty), - ]), - &ConstKind::PtrToGlobalVar(gv) => { - pretty::Fragment::new(["&".into(), gv.print(printer)]) - } - &ConstKind::PtrToFunc(func) => { - pretty::Fragment::new(["&".into(), func.print(printer)]) - } - - ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; - pretty::Fragment::new([ - printer.pretty_spv_inst( - printer.spv_op_style(), - spv_inst.opcode, - &spv_inst.imms, - const_inputs.iter().map(|ct| ct.print(printer)), - ), - printer.pretty_type_ascription_suffix(*ty), - ]) - } - &ConstKind::SpvStringLiteralForExtInst(s) => pretty::Fragment::new([ - printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), - "(".into(), - printer.pretty_string_literal(&printer.cx[s]), - ")".into(), - ]), - }), - } + ]) + } + &ConstKind::SpvStringLiteralForExtInst(s) => pretty::Fragment::new([ + printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), + "(".into(), + printer.pretty_string_literal(&printer.cx[s]), + ")".into(), + ]), + }; + AttrsAndDef { attrs: attrs.print(printer), def_without_name } } } @@ -3783,7 +3712,7 @@ impl Print for FuncAt<'_, Node> { "(", self.at(Either::Left(body)) .print_var_defs(printer) - .zip(initial_inputs) + .zip_eq(initial_inputs) .map(|(lhs, initial)| { pretty::Fragment::new([ lhs, @@ -3834,7 +3763,8 @@ impl Print for FuncAt<'_, Node> { inputs.iter().map(|v| v.print(printer)), ), - DataInstKind::FuncCall(_) + DataInstKind::Scalar(_) + | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) | DataInstKind::ThunkBind(_) @@ -3881,6 +3811,19 @@ impl FuncAt<'_, DataInst> { unreachable!() } + &DataInstKind::Scalar(op) => { + let name = op.name(); + let (namespace_prefix, name) = name.split_at(name.find('.').unwrap() + 1); + pretty::Fragment::new([ + printer + .demote_style_for_namespace_prefix(printer.declarative_keyword_style()) + .apply(namespace_prefix) + .into(), + printer.declarative_keyword_style().apply(name).into(), + pretty::join_comma_sep("(", inputs.iter().map(|v| v.print(printer)), ")"), + ]) + } + &DataInstKind::FuncCall(func) => pretty::Fragment::new([ printer.declarative_keyword_style().apply("call").into(), " ".into(), @@ -4110,21 +4053,16 @@ impl FuncAt<'_, DataInst> { match &printer.cx[ct].kind { ConstKind::Undef | ConstKind::PtrToGlobalVar(_) - | ConstKind::PtrToFunc(_) => {} + | ConstKind::PtrToFunc(_) + | ConstKind::SpvInst { .. } => {} &ConstKind::SpvStringLiteralForExtInst(s) => { return Some(PseudoImm::Str(&printer.cx[s])); } - ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - if spv_inst.opcode == wk.OpConstant - && let [spv::Imm::Short(_, x)] = spv_inst.imms[..] - { - // HACK(eddyb) only allow unambiguously positive values. - if i32::try_from(x).and_then(u32::try_from) == Ok(x) { - return Some(PseudoImm::U32(x)); - } - } + // HACK(eddyb) lossless roundtrip through `i32` is most conservative + // option (only `0..=i32::MAX`, i.e. `0 <= x < 2**32, is allowed). + ConstKind::Scalar(ct) => { + return Some(PseudoImm::U32(u32::try_from(ct.int_as_i32()?).ok()?)); } } } @@ -4278,7 +4216,7 @@ impl SelectionKind { mut cases: impl ExactSizeIterator, ) -> pretty::Fragment { let kw = |kw| kw_style.apply(kw).into(); - match *self { + match self { SelectionKind::BoolCond => { assert_eq!(cases.len(), 2); let [then_case, else_case] = [cases.next().unwrap(), cases.next().unwrap()]; @@ -4295,27 +4233,36 @@ impl SelectionKind { "}".into(), ]) } - SelectionKind::SpvInst(spv::Inst { opcode, ref imms }) => { - let header = printer.pretty_spv_inst( - kw_style, - opcode, - imms, - [Some(scrutinee.print(printer))] - .into_iter() - .chain((0..cases.len()).map(|_| None)), - ); + SelectionKind::Switch { case_consts } => { + assert_eq!(cases.len(), case_consts.len() + 1); + + let case_patterns = case_consts + .iter() + .map(|&ct| { + let int_to_string = (ct.int_as_u128().map(|x| x.to_string())) + .or_else(|| ct.int_as_i128().map(|x| x.to_string())); + match int_to_string { + Some(v) => printer.numeric_literal_style().apply(v).into(), + None => { + let ct: Const = printer.cx.intern(ct); + ct.print(printer) + } + } + }) + .chain(["_".into()]); pretty::Fragment::new([ - header, + kw("switch"), + " ".into(), + scrutinee.print(printer), " {".into(), pretty::Node::IndentedBlock( - cases - .map(|case| { + case_patterns + .zip_eq(cases) + .map(|(case_pattern, case)| { pretty::Fragment::new([ pretty::Node::ForceLineSeparation.into(), - // FIXME(eddyb) this should pull information out - // of the instruction to be more precise. - kw("case"), + case_pattern, " => {".into(), pretty::Node::IndentedBlock(vec![case]).into(), "}".into(), diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index 48eba658..804bfe0a 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -8,11 +8,10 @@ use crate::{ AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, DataInst, DataInstDef, DataInstKind, DeclDef, Diag, DiagLevel, EntityDefs, EntityOrientedDenseMap, Func, FuncDecl, FxIndexMap, GlobalVar, GlobalVarDecl, Module, Node, NodeKind, Region, Type, TypeDef, - TypeKind, TypeOrConst, Value, Var, VarDecl, spv, + TypeKind, TypeOrConst, Value, Var, VarDecl, scalar, spv, }; use itertools::Either; use smallvec::SmallVec; -use std::cell::Cell; use std::mem; use std::num::NonZeroU32; use std::rc::Rc; @@ -30,8 +29,6 @@ pub struct LiftToSpvPtrs<'a> { cx: Rc, wk: &'static spv::spec::WellKnown, layout_cache: LayoutCache<'a>, - - cached_u32_type: Cell>, } impl<'a> LiftToSpvPtrs<'a> { @@ -40,7 +37,6 @@ impl<'a> LiftToSpvPtrs<'a> { cx: cx.clone(), wk: &spv::spec::Spec::get().well_known, layout_cache: LayoutCache::new(cx, layout_config), - cached_u32_type: Default::default(), } } @@ -295,7 +291,9 @@ impl<'a> LiftToSpvPtrs<'a> { spv_inst: spv_opcode.into(), type_and_const_inputs: [TypeOrConst::Type(element_type)] .into_iter() - .chain(fixed_len.map(|len| TypeOrConst::Const(self.const_u32(len)))) + .chain(fixed_len.map(|len| { + TypeOrConst::Const(self.cx.intern(scalar::Const::from_u32(len))) + })) .collect(), }, })) @@ -333,48 +331,6 @@ impl<'a> LiftToSpvPtrs<'a> { })) } - /// Get the (likely cached) `u32` type. - fn u32_type(&self) -> Type { - if let Some(cached) = self.cached_u32_type.get() { - return cached; - } - let wk = self.wk; - let ty = self.cx.intern(TypeKind::SpvInst { - spv_inst: spv::Inst { - opcode: wk.OpTypeInt, - imms: [ - spv::Imm::Short(wk.LiteralInteger, 32), - spv::Imm::Short(wk.LiteralInteger, 0), - ] - .into_iter() - .collect(), - }, - type_and_const_inputs: [].into_iter().collect(), - }); - self.cached_u32_type.set(Some(ty)); - ty - } - - fn const_u32(&self, x: u32) -> Const { - let wk = self.wk; - - self.cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: self.u32_type(), - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new(( - spv::Inst { - opcode: wk.OpConstant, - imms: [spv::Imm::Short(wk.LiteralContextDependentNumber, x)] - .into_iter() - .collect(), - }, - [].into_iter().collect(), - )), - }, - }) - } - /// Attempt to compute a `TypeLayout` for a given (SPIR-V) `Type`. fn layout_of(&self, ty: Type) -> Result { self.layout_cache.layout_of(ty).map_err(|LayoutError(err)| LiftError(err)) @@ -455,6 +411,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { return Ok(Transformed::Unchanged); } + DataInstKind::Scalar(_) => return Ok(Transformed::Unchanged), + &DataInstKind::FuncCall(_callee) => { for &v in &data_inst_def.inputs { if self.lifter.as_spv_ptr_type(type_of_val(v)).is_some() { @@ -640,7 +598,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { ])) })?; access_chain_inputs - .push(Value::Const(self.lifter.const_u32(idx_as_i32 as u32))); + .push(Value::Const(cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)))); match &layout.components { Components::Scalar => unreachable!(), @@ -746,7 +704,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { ])) })?; access_chain_inputs - .push(Value::Const(self.lifter.const_u32(idx_as_i32 as u32))); + .push(Value::Const(cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)))); layout = match &layout.components { Components::Scalar => unreachable!(), @@ -960,7 +918,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { let mut access_chain_inputs: SmallVec<_> = [ptr].into_iter().collect(); if let TypeLayout::HandleArray(handle, _) = pointee_layout { - access_chain_inputs.push(Value::Const(self.lifter.const_u32(0))); + access_chain_inputs + .push(Value::Const(self.lifter.cx.intern(scalar::Const::from_u32(0)))); pointee_layout = TypeLayout::Handle(handle); } match (pointee_layout, access_layout) { @@ -1029,8 +988,9 @@ impl LiftToSpvPtrInstsInFunc<'_> { format!("{idx} not representable as a positive s32").into() ])) })?; - access_chain_inputs - .push(Value::Const(self.lifter.const_u32(idx_as_i32 as u32))); + access_chain_inputs.push(Value::Const( + self.lifter.cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)), + )); pointee_layout = match &pointee_layout.components { Components::Scalar => unreachable!(), diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index 01d893fb..f8a2e1b4 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -182,18 +182,10 @@ impl<'a> LowerFromSpvPtrs<'a> { } } - // FIXME(eddyb) properly distinguish between zero-extension and sign-extension. fn const_as_u32(&self, ct: Const) -> Option { - if let ConstKind::SpvInst { spv_inst_and_const_inputs } = &self.cx[ct].kind { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - if spv_inst.opcode == self.wk.OpConstant && spv_inst.imms.len() == 1 { - match spv_inst.imms[..] { - [spv::Imm::Short(_, x)] => return Some(x), - _ => unreachable!(), - } - } - } - None + // HACK(eddyb) lossless roundtrip through `i32` is most conservative + // option (only `0..=i32::MAX`, i.e. `0 <= x < 2**32, is allowed). + u32::try_from(ct.as_scalar(&self.cx)?.int_as_i32()?).ok() } /// Get the (likely cached) `QPtr` type. @@ -635,6 +627,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { NodeKind::Select(_) | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) + | DataInstKind::Scalar(_) | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) diff --git a/src/scalar.rs b/src/scalar.rs new file mode 100644 index 00000000..814dfb7e --- /dev/null +++ b/src/scalar.rs @@ -0,0 +1,960 @@ +//! Scalar (`bool`, integer, and floating-point) types and associated functionality. +//! +//! **Note**: pointers are never scalars (like SPIR-V, but unlike other IRs). + +use arrayvec::ArrayVec; +use itertools::Itertools; + +// HACK(eddyb) this could be some `struct` with private fields, but this `enum` +// is only 2 bytes in size, and has better ergonomics overall. +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum Type { + Bool, + SInt(IntWidth), + UInt(IntWidth), + + // FIXME(eddyb) SPIR-V added a "Floating Point Encoding" optional operand + // to `OpTypeFloat`, for non-IEEE floating-point formats, find a way to + // also support those here (maybe replacing `FloatWidth` entirely?). + Float(FloatWidth), +} + +impl Type { + // HACK(eddyb) only common widths, as a convenience, expand as-needed. + pub const S32: Type = Type::SInt(IntWidth::I32); + pub const U32: Type = Type::UInt(IntWidth::I32); + pub const F16: Type = Type::Float(FloatWidth::F16); + pub const F32: Type = Type::Float(FloatWidth::F32); + pub const F64: Type = Type::Float(FloatWidth::F64); + + pub const fn bit_width(self) -> u32 { + match self { + Type::Bool => 1, + Type::SInt(w) | Type::UInt(w) => w.bits(), + Type::Float(w) => w.bits(), + } + } +} + +/// Bit-width of a supported integer type (only power-of-two multiples of a byte). +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct IntWidth { + // HACK(eddyb) this is so compact that only 3 bits of this byte are used + // to encode integer types from `i8` to `i128`, and so `Type` could all fit + // in one byte, but that'd need a new `enum` for `Bool`/`{S,U}Int`/`Float`. + log2_bytes: u8, +} + +impl IntWidth { + pub const I8: Self = Self::try_from_bits_unwrap(8); + pub const I16: Self = Self::try_from_bits_unwrap(16); + pub const I32: Self = Self::try_from_bits_unwrap(32); + pub const I64: Self = Self::try_from_bits_unwrap(64); + pub const I128: Self = Self::try_from_bits_unwrap(128); + + // FIXME(eddyb) remove when `Option::unwrap` is stabilized. + const fn try_from_bits_unwrap(bits: u32) -> Self { + match Self::try_from_bits(bits) { + Some(w) => w, + None => unreachable!(), + } + } + + pub const fn try_from_bits(bits: u32) -> Option { + if !bits.is_multiple_of(8) { + return None; + } + let bytes = bits / 8; + match bytes.checked_ilog2() { + Some(log2_bytes_u32) => { + let log2_bytes = log2_bytes_u32 as u8; + assert!(log2_bytes as u32 == log2_bytes_u32); + Some(Self { log2_bytes }) + } + None => None, + } + } + + pub const fn bits(self) -> u32 { + 8 * (1 << self.log2_bytes) + } +} + +/// Bit-width of a supported floating-point type (only power-of-two multiples of a byte). +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct FloatWidth(IntWidth); + +impl FloatWidth { + pub const F16: Self = Self::try_from_bits_unwrap(16); + pub const F32: Self = Self::try_from_bits_unwrap(32); + pub const F64: Self = Self::try_from_bits_unwrap(64); + + // FIXME(eddyb) remove when `Option::unwrap` is stabilized. + const fn try_from_bits_unwrap(bits: u32) -> Self { + match Self::try_from_bits(bits) { + Some(w) => w, + None => unreachable!(), + } + } + + pub const fn try_from_bits(bits: u32) -> Option { + match IntWidth::try_from_bits(bits) { + Some(w) => Some(Self(w)), + None => None, + } + } + + pub const fn bits(self) -> u32 { + self.0.bits() + } +} + +// FIXME(eddyb) document the 128-bit limitations. +// HACK(eddyb) `(Type, u128)` would waste almost half its size on padding, and +// packing will only impact accessing the `bits`, while allowing e.g. being +// wrapped in an outer `enum`, before reaching the same size as `(u128, u128)`. +#[repr(Rust, packed)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct Const { + ty: Type, + bits: u128, +} + +impl Const { + pub const FALSE: Const = Const::from_bool(false); + pub const TRUE: Const = Const::from_bool(true); + + // FIXME(eddyb) document the panic conditions. + // FIXME(eddyb) make this public? + const fn from_bits_trunc(ty: Type, bits: u128) -> Const { + // FIXME(eddyb) this ensures `Const`s cannot be created when that could + // potentially need more than 128 bits for e.g. constant-folding. + let width = ty.bit_width(); + assert!(width <= 128); + + Const { ty, bits: bits & (!0u128 >> (128 - width)) } + } + + // FIXME(eddyb) document the panic conditions. + pub const fn from_bits(ty: Type, bits: u128) -> Const { + let ct_trunc = Const::from_bits_trunc(ty, bits); + assert!(ct_trunc.bits == bits); + ct_trunc + } + + pub const fn try_from_bits(ty: Type, bits: u128) -> Option { + let ct_trunc = Const::from_bits_trunc(ty, bits); + if ct_trunc.bits == bits { Some(ct_trunc) } else { None } + } + + pub const fn from_bool(v: bool) -> Const { + Const::from_bits(Type::Bool, v as u128) + } + + pub const fn from_u32(v: u32) -> Const { + Const::from_bits(Type::U32, v as u128) + } + + /// Returns `Some(ct)` iff `ty` is `{S,U}Int` and can represent `v: i128` + /// (i.e. `ct` has the same sign and absolute value as `v` does). + pub fn int_try_from_i128(ty: Type, v: i128) -> Option { + let ct_trunc = Const::from_bits_trunc(ty, v as u128); + (ct_trunc.int_as_i128() == Some(v)).then_some(ct_trunc) + } + + pub const fn ty(&self) -> Type { + self.ty + } + + pub const fn bits(&self) -> u128 { + self.bits + } + + // FIXME(eddyb) make this public? + fn try_bit_cast_to(&self, ty: Type) -> Option { + (self.ty.bit_width() == ty.bit_width()).then_some(Const { ty, ..*self }) + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: i128` + /// (i.e. `self` has the same sign and absolute value as `v` does). + pub fn int_as_i128(&self) -> Option { + match self.ty { + Type::Bool | Type::Float(_) => None, + Type::SInt(_) => { + let width = self.ty.bit_width(); + Some((self.bits as i128) << (128 - width) >> (128 - width)) + } + Type::UInt(_) => self.bits.try_into().ok(), + } + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: u128` + /// (i.e. `self` is positive and has the same absolute value as `v` does). + pub fn int_as_u128(&self) -> Option { + match self.ty { + Type::Bool | Type::Float(_) => None, + Type::SInt(_) => self.int_as_i128()?.try_into().ok(), + Type::UInt(_) => Some(self.bits), + } + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: i32` + /// (i.e. `self` has the same sign and absolute value as `v` does). + pub fn int_as_i32(&self) -> Option { + self.int_as_i128()?.try_into().ok() + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: u32` + /// (i.e. `self` is positive and has the same absolute value as `v` does). + pub fn int_as_u32(&self) -> Option { + self.int_as_u128()?.try_into().ok() + } +} + +/// Pure operations with scalar inputs and outputs. +// +// FIXME(eddyb) these are not some "perfect" grouping, but allow for more +// flexibility in users of this `enum` (and its component `enum`s). +#[derive(Copy, Clone, PartialEq, Eq, Hash, derive_more::From)] +pub enum Op { + BoolUnary(BoolUnOp), + BoolBinary(BoolBinOp), + + IntUnary(IntUnOp), + IntBinary(IntBinOp), + + FloatUnary(FloatUnOp), + FloatBinary(FloatBinOp), +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum BoolUnOp { + Not, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum BoolBinOp { + Eq, + // FIXME(eddyb) should this be `Xor` instead? + Ne, + Or, + And, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum IntUnOp { + Neg, + Not, + CountOnes, + + // FIXME(eddyb) ideally `Trunc` should be separated and common. + TruncOrZeroExtend, + TruncOrSignExtend, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum IntBinOp { + // I×I→I + Add, + Sub, + Mul, + DivU, + DivS, + ModU, + RemS, + ModS, + ShrU, + ShrS, + Shl, + Or, + Xor, + And, + + // I×I→I×I + CarryingAdd, + BorrowingSub, + WideningMulU, + WideningMulS, + + // I×I→B + Eq, + Ne, + // FIXME(eddyb) deduplicate between signed and unsigned. + GtU, + GtS, + GeU, + GeS, + LtU, + LtS, + LeU, + LeS, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum FloatUnOp { + // F→F + Neg, + + // F→B + IsNan, + IsInf, + + // FIXME(eddyb) these are a complicated mix of signatures. + FromUInt, + FromSInt, + ToUInt, + ToSInt, + Convert, + QuantizeAsF16, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum FloatBinOp { + // F×F→F + Add, + Sub, + Mul, + Div, + Rem, + Mod, + + // F×F→B + Cmp(FloatCmp), + // FIXME(eddyb) this doesn't properly convey that this is effectively the + // boolean flip of the opposite comparison, e.g. `CmpOrUnord(Ge)` is really + // a fused version of `Not(Cmp(Lt))`, because `x < y` is never `true` for + // unordered `x` and `y` (i.e. `PartialOrd::partial_cmp(x, y) == None`), + // but that maps to `!(x < y)` always being `true` for unordered `x` and `y`, + // and thus `x >= y` is only equivalent for the ordered cases. + CmpOrUnord(FloatCmp), +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum FloatCmp { + Eq, + Ne, + Lt, + Gt, + Le, + Ge, +} + +pub enum EvalError { + // FIXME(eddyb) provide more detail. + OpSignatureMismatch, + + UnsupportedFloatWidth(FloatWidth), + + // FIXME(eddyb) is there a better name for this? + FloatException, + + // FIXME(eddyb) not exactly an error, and can be replaced with `undef`. + PoisonOutput, + + UndefinedBehavior { cause: &'static str }, +} + +impl Op { + pub fn output_count(self) -> usize { + match self { + Op::IntBinary(op) => op.output_count(), + _ => 1, + } + } + + pub fn name(self) -> &'static str { + match self { + Op::BoolUnary(op) => op.name(), + Op::BoolBinary(op) => op.name(), + + Op::IntUnary(op) => op.name(), + Op::IntBinary(op) => op.name(), + + Op::FloatUnary(op) => op.name(), + Op::FloatBinary(op) => op.name(), + } + } + + pub fn try_eval( + self, + inputs: &[Const], + output_types: &[Type], + ) -> Result, EvalError> { + let single_output = match (self, inputs, output_types) { + (Op::BoolUnary(op), &[Const { ty: Type::Bool, bits: x @ (0..=1) }], &[Type::Bool]) => { + Const::from_bool(op.eval(x != 0)) + } + ( + Op::BoolBinary(op), + &[ + Const { ty: Type::Bool, bits: a @ (0..=1) }, + Const { ty: Type::Bool, bits: b @ (0..=1) }, + ], + &[Type::Bool], + ) => Const::from_bool(op.eval(a != 0, b != 0)), + (Op::IntUnary(op), &[x], &[output_type]) => op.try_eval(x, output_type)?, + (Op::IntBinary(op), &[a, b], _) => return op.try_eval(a, b, output_types), + (Op::FloatUnary(op), &[x], &[output_type]) => op.try_eval(x, output_type)?, + (Op::FloatBinary(op), &[a, b], &[output_type]) => op.try_eval(a, b, output_type)?, + _ => return Err(EvalError::OpSignatureMismatch), + }; + Ok([single_output].into_iter().collect()) + } +} + +impl BoolUnOp { + pub fn name(self) -> &'static str { + match self { + BoolUnOp::Not => "bool.not", + } + } + + pub fn eval(self, x: bool) -> bool { + match self { + BoolUnOp::Not => !x, + } + } +} + +impl BoolBinOp { + pub fn name(self) -> &'static str { + match self { + BoolBinOp::Eq => "bool.eq", + BoolBinOp::Ne => "bool.ne", + BoolBinOp::Or => "bool.or", + BoolBinOp::And => "bool.and", + } + } + + pub fn eval(self, a: bool, b: bool) -> bool { + match self { + BoolBinOp::Eq => a == b, + BoolBinOp::Ne => a != b, + BoolBinOp::Or => a | b, + BoolBinOp::And => a & b, + } + } +} + +impl IntUnOp { + pub fn name(self) -> &'static str { + match self { + IntUnOp::Neg => "i.neg", + IntUnOp::Not => "i.not", + IntUnOp::CountOnes => "i.count_ones", + + IntUnOp::TruncOrZeroExtend => "u.trunc_or_zext", + IntUnOp::TruncOrSignExtend => "s.trunc_or_sext", + } + } + + pub fn try_eval(self, x: Const, output_type: Type) -> Result { + // FIXME(eddyb) try to dedup these helpers with `IntBinOp`. + let int_width = |ty| match ty { + Type::UInt(w) | Type::SInt(w) => Ok(w), + _ => Err(EvalError::OpSignatureMismatch), + }; + let output_width = int_width(output_type)?; + + let x_width = int_width(x.ty())?; + let (x, x_s) = + (x.bits(), x.try_bit_cast_to(Type::SInt(x_width)).unwrap().int_as_i128().unwrap()); + + let valid_widths = output_width == x_width + || matches!( + self, + IntUnOp::CountOnes | IntUnOp::TruncOrZeroExtend | IntUnOp::TruncOrSignExtend + ); + if !valid_widths { + return Err(EvalError::OpSignatureMismatch); + } + + let output_bits = match self { + IntUnOp::Neg => x_s.wrapping_neg() as u128, + IntUnOp::Not => !x, + IntUnOp::CountOnes => x.count_ones().into(), + IntUnOp::TruncOrZeroExtend => x, + IntUnOp::TruncOrSignExtend => x_s as u128, + }; + Ok(Const::from_bits_trunc(output_type, output_bits)) + } +} + +impl IntBinOp { + pub fn output_count(self) -> usize { + // FIXME(eddyb) should these 4 go into a different `enum`? + match self { + IntBinOp::CarryingAdd + | IntBinOp::BorrowingSub + | IntBinOp::WideningMulU + | IntBinOp::WideningMulS => 2, + _ => 1, + } + } + + pub fn name(self) -> &'static str { + match self { + IntBinOp::Add => "i.add", + IntBinOp::Sub => "i.sub", + IntBinOp::Mul => "i.mul", + IntBinOp::DivU => "u.div", + IntBinOp::DivS => "s.div", + IntBinOp::ModU => "u.mod", + IntBinOp::RemS => "s.rem", + IntBinOp::ModS => "s.mod", + IntBinOp::ShrU => "u.shr", + IntBinOp::ShrS => "s.shr", + IntBinOp::Shl => "i.shl", + IntBinOp::Or => "i.or", + IntBinOp::Xor => "i.xor", + IntBinOp::And => "i.and", + IntBinOp::CarryingAdd => "i.carrying_add", + IntBinOp::BorrowingSub => "i.borrowing_sub", + IntBinOp::WideningMulU => "u.widening_mul", + IntBinOp::WideningMulS => "s.widening_mul", + IntBinOp::Eq => "i.eq", + IntBinOp::Ne => "i.ne", + IntBinOp::GtU => "u.gt", + IntBinOp::GtS => "s.gt", + IntBinOp::GeU => "u.ge", + IntBinOp::GeS => "s.ge", + IntBinOp::LtU => "u.lt", + IntBinOp::LtS => "s.lt", + IntBinOp::LeU => "u.le", + IntBinOp::LeS => "s.le", + } + } + + pub fn try_eval( + self, + a: Const, + b: Const, + output_types: &[Type], + ) -> Result, EvalError> { + let output_type = output_types + .iter() + .copied() + .dedup() + .exactly_one() + .ok() + .filter(|_| output_types.len() == self.output_count()) + .ok_or(EvalError::OpSignatureMismatch)?; + + // FIXME(eddyb) try to dedup these helpers with `IntUnOp`. + let int_width = |ty| match ty { + Type::UInt(w) | Type::SInt(w) => Ok(w), + _ => Err(EvalError::OpSignatureMismatch), + }; + let output_width = match self { + // FIXME(eddyb) should comparisons be handled separately? + IntBinOp::Eq + | IntBinOp::Ne + | IntBinOp::GtU + | IntBinOp::GtS + | IntBinOp::GeU + | IntBinOp::GeS + | IntBinOp::LtU + | IntBinOp::LtS + | IntBinOp::LeU + | IntBinOp::LeS => None, + + _ => Some(int_width(output_type)?), + }; + + let as_u128_i128 = |x: Const| { + let x_width = int_width(x.ty())?; + Ok(( + x_width, + x.bits(), + x.try_bit_cast_to(Type::SInt(x_width)).unwrap().int_as_i128().unwrap(), + )) + }; + let (a_width, a, a_s) = as_u128_i128(a)?; + let (b_width, b, b_s) = as_u128_i128(b)?; + + let valid_widths = output_width.is_none_or(|w| w == a_width) + && (a_width == b_width + || matches!(self, IntBinOp::ShrU | IntBinOp::ShrS | IntBinOp::Shl)); + if !valid_widths { + return Err(EvalError::OpSignatureMismatch); + } + + let div_ub_err = EvalError::UndefinedBehavior { + cause: if b_s == 0 { "division by 0" } else { "signed division overflow" }, + }; + let b_as_shift_amount = + || u32::try_from(b).ok().filter(|&b| b < a_width.bits()).ok_or(EvalError::PoisonOutput); + + // FIXME(eddyb) replace with `u128::widening_mul` when it stabilizes. + fn u128_widening_mul(a: u128, b: u128) -> (u128, u128) { + // HACK(eddyb) the code below extracts `lo` and `hi`, + // such that `lo + 2¹²⁸hi` is equal to this expansion of `a · b`: + // `(al + 2⁶⁴ah) · (bl + 2⁶⁴bh) = al·bl + 2⁶⁴(al·bh + ah·bl) + 2¹²⁸(ah·bh)` + let [(al, ah), (bl, bh)] = [a, b].map(|x| (x as u64 as u128, x >> 64)); + let [[al_bl, al_bh], [ah_bl, ah_bh]] = + [al, ah].map(|a| [bl, bh].map(|b| a.checked_mul(b).unwrap())); + + let (mid, mid_carry) = al_bh.overflowing_add(ah_bl); + let (lo, lo_carry) = al_bl.overflowing_add(mid << 64); + let hi = [ah_bh, mid >> 64, (mid_carry as u128) << 64, lo_carry as u128] + .into_iter() + .reduce(|a, b| a.checked_add(b).unwrap()) + .unwrap(); + + assert_eq!(lo, a.wrapping_mul(b)); + + (lo, hi) + } + + // FIXME(eddyb) replace with `i128::widening_mul` when it stabilizes. + fn i128_widening_mul(a: i128, b: i128) -> (u128, i128) { + // HACK(eddyb) to avoid duplication and signedness subtleties, + // the sign is handled on top of the unsigned implementation above. + let (abs_lo, abs_hi) = u128_widening_mul(a.unsigned_abs(), b.unsigned_abs()); + if a.signum() * b.signum() == -1 { + // HACK(eddyb) `-x` is equivalent to `(!x).wrapping_add(1)`, + // which can be directly applied to a double-width integer. + let (lo, lo_carry) = (!abs_lo).overflowing_add(1); + (lo, (!abs_hi).wrapping_add(lo_carry as u128) as i128) + } else { + (abs_lo, abs_hi as i128) + } + } + + let wide_result = |[lo, hi]: [u128; 2]| { + // HACK(eddyb) `lo + 2¹²⁸hi` form a 256-bit result, but the true + // result for an N-bit operation will only match those two halves + // for N=128, for smaller N both halves can be found in `lo`. + let width = output_width.unwrap().bits(); + let hi = if width == 128 || { + // HACK(eddyb) because subtraction overflow is centered around `0`, + // and not `2^N`, the 128-bit `hi` is already the correct top half, + // and it's not obvious how to otherwise get that correct value, + // without this (otherwise quite annoying) special-case. + self == IntBinOp::BorrowingSub + } { + hi + } else { + lo.checked_shr(width).unwrap() + }; + + Ok([lo, hi].map(|x| Const::from_bits_trunc(output_type, x)).into_iter().collect()) + }; + + // HACK(eddyb) can't trust `checked_{div,rem}` to handle the "MIN" part + // correctly, because `iN::MIN as i128 != i128::MIN` for `N < 128`. + if let IntBinOp::DivS | IntBinOp::RemS | IntBinOp::ModS = self + && a_s == -1 << (a_width.bits() - 1) + && b_s == -1 + { + return Err(div_ub_err); + } + + let output_bits = match self { + IntBinOp::Add => a.wrapping_add(b), + IntBinOp::Sub => a.wrapping_sub(b), + IntBinOp::Mul => a.wrapping_mul(b), + IntBinOp::DivU => a.checked_div(b).ok_or(div_ub_err)?, + IntBinOp::DivS => a_s.checked_div(b_s).ok_or(div_ub_err)? as u128, + IntBinOp::ModU => a.checked_rem(b).ok_or(div_ub_err)?, + IntBinOp::RemS => a_s.checked_rem(b_s).ok_or(div_ub_err)? as u128, + IntBinOp::ModS => { + let rem_s = a_s.checked_rem(b_s).ok_or(div_ub_err)?; + let mod_s = if rem_s.signum() * b_s.signum() == -1 { + // |b_s| > |rem_s|, so |b_s + rem_s| = |b_s| - |rem_s|, and + // the sum will have sign of `b_s` (as required by SPIR-V). + rem_s.checked_add(b_s).unwrap() + } else { + rem_s + }; + mod_s as u128 + } + IntBinOp::ShrU => a.checked_shr(b_as_shift_amount()?).unwrap(), + IntBinOp::ShrS => a_s.checked_shr(b_as_shift_amount()?).unwrap() as u128, + IntBinOp::Shl => a.checked_shl(b_as_shift_amount()?).unwrap(), + IntBinOp::Or => a | b, + IntBinOp::Xor => a ^ b, + IntBinOp::And => a & b, + IntBinOp::CarryingAdd => { + let (lo, hi) = a.overflowing_add(b); + return wide_result([lo, hi as u128]); + } + IntBinOp::BorrowingSub => { + let (lo, hi) = a.overflowing_sub(b); + return wide_result([lo, hi as u128]); + } + IntBinOp::WideningMulU => { + let (lo, hi) = u128_widening_mul(a, b); + return wide_result([lo, hi]); + } + IntBinOp::WideningMulS => { + let (lo, hi) = i128_widening_mul(a_s, b_s); + return wide_result([lo, hi as u128]); + } + IntBinOp::Eq => (a == b) as u128, + IntBinOp::Ne => (a != b) as u128, + IntBinOp::GtU => (a > b) as u128, + IntBinOp::GtS => (a_s > b_s) as u128, + IntBinOp::GeU => (a >= b) as u128, + IntBinOp::GeS => (a_s >= b_s) as u128, + IntBinOp::LtU => (a < b) as u128, + IntBinOp::LtS => (a_s < b_s) as u128, + IntBinOp::LeU => (a <= b) as u128, + IntBinOp::LeS => (a_s <= b_s) as u128, + }; + Ok([Const::from_bits_trunc(output_type, output_bits)].into_iter().collect()) + } +} + +impl FloatUnOp { + pub fn name(self) -> &'static str { + match self { + FloatUnOp::Neg => "f.neg", + + FloatUnOp::IsNan => "f.is_nan", + FloatUnOp::IsInf => "f.is_inf", + + FloatUnOp::FromUInt => "f.from_uint", + FloatUnOp::FromSInt => "f.from_sint", + FloatUnOp::ToUInt => "f.to_uint", + FloatUnOp::ToSInt => "f.to_sint", + FloatUnOp::Convert => "f.convert", + FloatUnOp::QuantizeAsF16 => "f.quantize_as_f16", + } + } + + pub fn try_eval(self, x: Const, output_type: Type) -> Result { + let float_type = match self { + FloatUnOp::Neg + | FloatUnOp::IsNan + | FloatUnOp::IsInf + | FloatUnOp::ToUInt + | FloatUnOp::ToSInt + | FloatUnOp::Convert => x.ty(), + FloatUnOp::FromUInt | FloatUnOp::FromSInt => output_type, + FloatUnOp::QuantizeAsF16 => Type::F32, + }; + + match float_type { + Type::F16 => self.try_eval_specialized::(x, output_type), + Type::F32 => self.try_eval_specialized::(x, output_type), + Type::F64 => self.try_eval_specialized::(x, output_type), + Type::Float(w) => Err(EvalError::UnsupportedFloatWidth(w)), + _ => Err(EvalError::OpSignatureMismatch), + } + } + + fn try_eval_specialized(self, x: Const, output_type: Type) -> Result + where + F: rustc_apfloat::Float + + rustc_apfloat::FloatConvert + + rustc_apfloat::FloatConvert + + rustc_apfloat::FloatConvert, + rustc_apfloat::ieee::Half: rustc_apfloat::FloatConvert, + { + use rustc_apfloat::{Float, FloatConvert, Status, StatusAnd}; + + // HACK(eddyb) more convenient conversion helper. + fn convert, U: Float>(x: T) -> StatusAnd { + x.convert(&mut false) + } + + let int_width = |ty| match ty { + Type::UInt(w) | Type::SInt(w) => Ok(w), + _ => Err(EvalError::OpSignatureMismatch), + }; + + // FIXME(eddyb) try to dedup these helpers with `FloatBinOp`. + let expected_float_type = + Type::Float(FloatWidth::try_from_bits(F::BITS.try_into().unwrap()).unwrap()); + let f_from_const = |x: Const| { + if x.ty() != expected_float_type { + return Err(EvalError::OpSignatureMismatch); + } + Ok(F::from_bits(x.bits())) + }; + let const_f = |x: F| Const::from_bits(expected_float_type, x.to_bits()); + let const_bool = |x: bool| Const::from_bits(Type::Bool, x as u128); + + let status_and_output = match self { + FloatUnOp::Neg => Status::OK.and(-f_from_const(x)?).map(const_f), + FloatUnOp::IsNan => Status::OK.and(f_from_const(x)?.is_nan()).map(const_bool), + FloatUnOp::IsInf => Status::OK.and(f_from_const(x)?.is_infinite()).map(const_bool), + FloatUnOp::FromUInt => { + F::from_u128(x.int_as_u128().ok_or(EvalError::OpSignatureMismatch)?).map(const_f) + } + FloatUnOp::FromSInt => { + F::from_i128(x.int_as_i128().ok_or(EvalError::OpSignatureMismatch)?).map(const_f) + } + FloatUnOp::ToUInt => { + let width = int_width(output_type)?; + f_from_const(x)? + .to_u128(width.bits() as usize) + .map(|r| Const::from_bits(Type::UInt(width), r)) + } + FloatUnOp::ToSInt => { + let width = int_width(output_type)?; + f_from_const(x)? + .to_i128(width.bits() as usize) + .map(|r| Const::int_try_from_i128(Type::SInt(width), r).unwrap()) + } + FloatUnOp::Convert => { + let x = f_from_const(x)?; + let status_and_output_bits = match output_type { + Type::F16 => convert::<_, rustc_apfloat::ieee::Half>(x).map(|r| r.to_bits()), + Type::F32 => convert::<_, rustc_apfloat::ieee::Single>(x).map(|r| r.to_bits()), + Type::F64 => convert::<_, rustc_apfloat::ieee::Double>(x).map(|r| r.to_bits()), + Type::Float(w) => return Err(EvalError::UnsupportedFloatWidth(w)), + _ => return Err(EvalError::OpSignatureMismatch), + }; + status_and_output_bits.map(|output_bits| Const::from_bits(output_type, output_bits)) + } + FloatUnOp::QuantizeAsF16 => convert::<_, rustc_apfloat::ieee::Half>(f_from_const(x)?) + .map(|x_f16| convert::<_, F>(x_f16).value) + .map(const_f), + }; + + if status_and_output.status.intersects(Status::INVALID_OP | Status::DIV_BY_ZERO) { + return Err(EvalError::FloatException); + } + + let output = status_and_output.value; + if output.ty() != output_type { + return Err(EvalError::OpSignatureMismatch); + } + Ok(output) + } +} + +impl FloatBinOp { + pub fn name(self) -> &'static str { + match self { + FloatBinOp::Add => "f.add", + FloatBinOp::Sub => "f.sub", + FloatBinOp::Mul => "f.mul", + FloatBinOp::Div => "f.div", + FloatBinOp::Rem => "f.rem", + FloatBinOp::Mod => "f.mod", + FloatBinOp::Cmp(FloatCmp::Eq) => "f.eq", + FloatBinOp::Cmp(FloatCmp::Ne) => "f.ne", + FloatBinOp::Cmp(FloatCmp::Lt) => "f.lt", + FloatBinOp::Cmp(FloatCmp::Gt) => "f.gt", + FloatBinOp::Cmp(FloatCmp::Le) => "f.le", + FloatBinOp::Cmp(FloatCmp::Ge) => "f.ge", + FloatBinOp::CmpOrUnord(FloatCmp::Eq) => "f.eq_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Ne) => "f.ne_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Lt) => "f.lt_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Gt) => "f.gt_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Le) => "f.le_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Ge) => "f.ge_or_unord", + } + } + + pub fn try_eval(self, a: Const, b: Const, output_type: Type) -> Result { + if a.ty() != b.ty() { + return Err(EvalError::OpSignatureMismatch); + } + + match a.ty() { + Type::F16 => self.try_eval_specialized::(a, b, output_type), + Type::F32 => { + self.try_eval_specialized::(a, b, output_type) + } + Type::F64 => { + self.try_eval_specialized::(a, b, output_type) + } + Type::Float(w) => Err(EvalError::UnsupportedFloatWidth(w)), + _ => Err(EvalError::OpSignatureMismatch), + } + } + + fn try_eval_specialized( + self, + a: Const, + b: Const, + output_type: Type, + ) -> Result { + use rustc_apfloat::Status; + + // FIXME(eddyb) try to dedup these helpers with `FloatBinOp`. + let expected_float_type = + Type::Float(FloatWidth::try_from_bits(F::BITS.try_into().unwrap()).unwrap()); + let f_from_const = |x: Const| { + if x.ty() != expected_float_type { + return Err(EvalError::OpSignatureMismatch); + } + Ok(F::from_bits(x.bits())) + }; + let const_f = |x: F| Const::from_bits(expected_float_type, x.to_bits()); + let const_bool = |x: bool| Const::from_bits(Type::Bool, x as u128); + + let status_and_output = match self { + FloatBinOp::Add => (f_from_const(a)? + f_from_const(b)?).map(const_f), + FloatBinOp::Sub => (f_from_const(a)? - f_from_const(b)?).map(const_f), + FloatBinOp::Mul => (f_from_const(a)? * f_from_const(b)?).map(const_f), + FloatBinOp::Div => (f_from_const(a)? / f_from_const(b)?).map(const_f), + FloatBinOp::Rem => (f_from_const(a)? % f_from_const(b)?).map(const_f), + FloatBinOp::Mod => { + let (a, b) = (f_from_const(a)?, f_from_const(b)?); + (a % b) + .map(|rem| { + if !rem.is_zero() && rem.is_negative() != b.is_negative() { + // |b| > |rem|, so |b + rem| = |b| - |rem|, and the sum + // will have sign of `b` (as required by SPIR-V). + (rem + b).value + } else { + rem + } + }) + .map(const_f) + } + FloatBinOp::Cmp(cmp) => { + Status::OK.and(cmp.eval(&f_from_const(a)?, &f_from_const(b)?)).map(const_bool) + } + // HACK(eddyb) see comment on `FloatBinOp::CmpOrUnord` for an explanation. + FloatBinOp::CmpOrUnord(cmp) => Status::OK + .and((!cmp).eval(&f_from_const(a)?, &f_from_const(b)?)) + .map(|r| const_bool(!r)), + }; + + if status_and_output.status.intersects(Status::INVALID_OP | Status::DIV_BY_ZERO) { + return Err(EvalError::FloatException); + } + + let output = status_and_output.value; + if output.ty() != output_type { + return Err(EvalError::OpSignatureMismatch); + } + Ok(output) + } +} + +// HACK(eddyb) see comment on `FloatBinOp::CmpOrUnord` for why this "flipping" +// is useful - i.e. `FloatBinOp::CmpOrUnord(cmp)` is equivalent to first applying +// `FloatBinOp::Cmp(!cmp)` then passing its result to `BoolUnOp::Not`. +impl std::ops::Not for FloatCmp { + type Output = FloatCmp; + fn not(self) -> FloatCmp { + match self { + FloatCmp::Eq => FloatCmp::Ne, + FloatCmp::Ne => FloatCmp::Eq, + FloatCmp::Lt => FloatCmp::Ge, + FloatCmp::Gt => FloatCmp::Le, + FloatCmp::Le => FloatCmp::Gt, + FloatCmp::Ge => FloatCmp::Lt, + } + } +} + +impl FloatCmp { + fn eval(self, a: &T, b: &T) -> bool { + match self { + FloatCmp::Eq => *a == *b, + FloatCmp::Ne => *a != *b, + FloatCmp::Lt => *a < *b, + FloatCmp::Gt => *a > *b, + FloatCmp::Le => *a <= *b, + FloatCmp::Ge => *a >= *b, + } + } +} diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs index 45a0424a..efd8e95f 100644 --- a/src/spv/canonical.rs +++ b/src/spv/canonical.rs @@ -7,16 +7,22 @@ // // FIXME(eddyb) should interning attempts check/apply these canonicalizations? -use crate::ConstKind; use crate::spv::{self, spec}; +use crate::{ConstKind, Context, NodeKind, Type, TypeKind, scalar}; use lazy_static::lazy_static; // FIXME(eddyb) these ones could maybe make use of build script generation. macro_rules! def_mappable_ops { - ($($op:ident),+ $(,)?) => { + ( + type { $($ty_op:ident),+ $(,)? } + const { $($ct_op:ident),+ $(,)? } + $($enum_path:path { $($variant_op:ident <=> $variant:ident$(($($variant_args:tt)*))?),+ $(,)? })* + ) => { #[allow(non_snake_case)] struct MappableOps { - $($op: spec::Opcode,)+ + $($ty_op: spec::Opcode,)+ + $($ct_op: spec::Opcode,)+ + $($($variant_op: spec::Opcode,)+)* } impl MappableOps { #[inline(always)] @@ -26,33 +32,302 @@ macro_rules! def_mappable_ops { static ref MAPPABLE_OPS: MappableOps = { let spv_spec = spec::Spec::get(); MappableOps { - $($op: spv_spec.instructions.lookup(stringify!($op)).unwrap(),)+ + $($ty_op: spv_spec.instructions.lookup(stringify!($ty_op)).unwrap(),)+ + $($ct_op: spv_spec.instructions.lookup(stringify!($ct_op)).unwrap(),)+ + $($($variant_op: spv_spec.instructions.lookup(stringify!($variant_op)).unwrap(),)+)* } }; } &MAPPABLE_OPS } } + // NOTE(eddyb) these should stay private, hence not implementing `TryFrom`. + $(impl $enum_path { + fn try_from_opcode(opcode: spec::Opcode) -> Option { + let mo = MappableOps::get(); + $(if opcode == mo.$variant_op { + return Some(Self::$variant$(($($variant_args)*))?); + })+ + None + } + fn to_opcode(self) -> spec::Opcode { + let mo = MappableOps::get(); + match self { + $(Self::$variant$(($($variant_args)*))? => mo.$variant_op,)+ + } + } + })* }; } def_mappable_ops! { - OpUndef, + // FIXME(eddyb) these categories don't actually do anything right now + type { + OpTypeBool, + OpTypeInt, + OpTypeFloat, + } + const { + OpUndef, + OpConstantFalse, + OpConstantTrue, + OpConstant, + } + scalar::BoolUnOp { + OpLogicalNot <=> Not, + } + scalar::BoolBinOp { + OpLogicalEqual <=> Eq, + OpLogicalNotEqual <=> Ne, + OpLogicalOr <=> Or, + OpLogicalAnd <=> And, + } + scalar::IntUnOp { + OpSNegate <=> Neg, + OpNot <=> Not, + OpBitCount <=> CountOnes, + + OpUConvert <=> TruncOrZeroExtend, + OpSConvert <=> TruncOrSignExtend, + } + scalar::IntBinOp { + // I×I→I + OpIAdd <=> Add, + OpISub <=> Sub, + OpIMul <=> Mul, + OpUDiv <=> DivU, + OpSDiv <=> DivS, + OpUMod <=> ModU, + OpSRem <=> RemS, + OpSMod <=> ModS, + OpShiftRightLogical <=> ShrU, + OpShiftRightArithmetic <=> ShrS, + OpShiftLeftLogical <=> Shl, + OpBitwiseOr <=> Or, + OpBitwiseXor <=> Xor, + OpBitwiseAnd <=> And, + + // I×I→I×I + OpIAddCarry <=> CarryingAdd, + OpISubBorrow <=> BorrowingSub, + OpUMulExtended <=> WideningMulU, + OpSMulExtended <=> WideningMulS, + + // I×I→B + OpIEqual <=> Eq, + OpINotEqual <=> Ne, + OpUGreaterThan <=> GtU, + OpSGreaterThan <=> GtS, + OpUGreaterThanEqual <=> GeU, + OpSGreaterThanEqual <=> GeS, + OpULessThan <=> LtU, + OpSLessThan <=> LtS, + OpULessThanEqual <=> LeU, + OpSLessThanEqual <=> LeS, + } + scalar::FloatUnOp { + // F→F + OpFNegate <=> Neg, + + // F→B + OpIsNan <=> IsNan, + OpIsInf <=> IsInf, + + OpConvertUToF <=> FromUInt, + OpConvertSToF <=> FromSInt, + OpConvertFToU <=> ToUInt, + OpConvertFToS <=> ToSInt, + OpFConvert <=> Convert, + OpQuantizeToF16 <=> QuantizeAsF16, + } + scalar::FloatBinOp { + // F×F→F + OpFAdd <=> Add, + OpFSub <=> Sub, + OpFMul <=> Mul, + OpFDiv <=> Div, + OpFRem <=> Rem, + OpFMod <=> Mod, + + // F×F→B + OpFOrdEqual <=> Cmp(scalar::FloatCmp::Eq), + OpFOrdNotEqual <=> Cmp(scalar::FloatCmp::Ne), + OpFOrdLessThan <=> Cmp(scalar::FloatCmp::Lt), + OpFOrdGreaterThan <=> Cmp(scalar::FloatCmp::Gt), + OpFOrdLessThanEqual <=> Cmp(scalar::FloatCmp::Le), + OpFOrdGreaterThanEqual <=> Cmp(scalar::FloatCmp::Ge), + OpFUnordEqual <=> CmpOrUnord(scalar::FloatCmp::Eq), + OpFUnordNotEqual <=> CmpOrUnord(scalar::FloatCmp::Ne), + OpFUnordLessThan <=> CmpOrUnord(scalar::FloatCmp::Lt), + OpFUnordGreaterThan <=> CmpOrUnord(scalar::FloatCmp::Gt), + OpFUnordLessThanEqual <=> CmpOrUnord(scalar::FloatCmp::Le), + OpFUnordGreaterThanEqual <=> CmpOrUnord(scalar::FloatCmp::Ge), + } +} + +impl scalar::Const { + // HACK(eddyb) this is not private so `spv::lower` can use it for `OpSwitch`. + pub(super) fn try_decode_from_spv_imms( + ty: scalar::Type, + imms: &[spv::Imm], + ) -> Option { + // FIXME(eddyb) don't hardcode the 128-bit limitation, + // but query `scalar::Const` somehow instead. + if ty.bit_width() > 128 { + return None; + } + let imm_words = usize::try_from(ty.bit_width().div_ceil(32)).unwrap(); + if imms.len() != imm_words { + return None; + } + let mut bits = 0; + for (i, &imm) in imms.iter().enumerate() { + let w = match imm { + spv::Imm::Short(_, w) if imm_words == 1 => w, + spv::Imm::LongStart(_, w) if i == 0 && imm_words > 1 => w, + spv::Imm::LongCont(_, w) if i > 0 => w, + _ => return None, + }; + bits |= (w as u128) << (i * 32); + } + + // HACK(eddyb) signed integers are encoded sign-extended into immediates. + if let scalar::Type::SInt(_) = ty { + let imm_width = imm_words * 32; + scalar::Const::int_try_from_i128( + ty, + (bits as i128) << (128 - imm_width) >> (128 - imm_width), + ) + } else { + scalar::Const::try_from_bits(ty, bits) + } + } + + // HACK(eddyb) this is not private so `spv::lift` can use it for `OpSwitch`. + pub(super) fn encode_as_spv_imms(&self) -> impl Iterator { + let wk = &spec::Spec::get().well_known; + + let ty = self.ty(); + let imm_words = ty.bit_width().div_ceil(32); + + let bits = self.bits(); + + // HACK(eddyb) signed integers are encoded sign-extended into immediates. + let bits = if let scalar::Type::SInt(_) = ty { + let imm_width = imm_words * 32; + (self.int_as_i128().unwrap() as u128) & (!0 >> (128 - imm_width)) + } else { + bits + }; + + (0..imm_words).map(move |i| { + let k = wk.LiteralContextDependentNumber; + let w = (bits >> (i * 32)) as u32; + if imm_words == 1 { + spv::Imm::Short(k, w) + } else if i == 0 { + spv::Imm::LongStart(k, w) + } else { + spv::Imm::LongCont(k, w) + } + }) + } } // FIXME(eddyb) decide on a visibility scope - `pub(super)` avoids some mistakes // (using these methods outside of `spv::{lower,lift}`), but may be too restrictive. impl spv::Inst { - pub(super) fn as_canonical_const(&self) -> Option { + // HACK(eddyb) exported only for `spv::read`/`LiteralContextDependentNumber`. + pub(super) fn int_or_float_type_bit_width(&self) -> Option { + let mo = MappableOps::get(); + + match self.imms[..] { + [spv::Imm::Short(_, bit_width), _] if self.opcode == mo.OpTypeInt => Some(bit_width), + [spv::Imm::Short(_, bit_width)] if self.opcode == mo.OpTypeFloat => Some(bit_width), + _ => None, + } + } + + // FIXME(eddyb) automate bidirectional mappings more (although the need + // for conditional, i.e. "partial", mappings, adds a lot of complexity). + pub(super) fn as_canonical_type(&self) -> Option { let Self { opcode, imms } = self; let (&opcode, imms) = (opcode, &imms[..]); let mo = MappableOps::get(); - if opcode == mo.OpUndef { - assert_eq!(imms.len(), 0); - Some(ConstKind::Undef) - } else { - None + let int_width = || scalar::IntWidth::try_from_bits(self.int_or_float_type_bit_width()?); + match imms { + [] if opcode == mo.OpTypeBool => Some(scalar::Type::Bool.into()), + &[_, spv::Imm::Short(_, 0)] if opcode == mo.OpTypeInt => { + Some(scalar::Type::UInt(int_width()?).into()) + } + &[_, spv::Imm::Short(_, 1)] if opcode == mo.OpTypeInt => { + Some(scalar::Type::SInt(int_width()?).into()) + } + [_] if opcode == mo.OpTypeFloat => Some( + scalar::Type::Float(scalar::FloatWidth::try_from_bits( + self.int_or_float_type_bit_width()?, + )?) + .into(), + ), + _ => None, + } + } + + pub(super) fn from_canonical_type(type_kind: &TypeKind) -> Option { + let wk = &spec::Spec::get().well_known; + let mo = MappableOps::get(); + + match type_kind { + &TypeKind::Scalar(ty) => match ty { + scalar::Type::Bool => Some(mo.OpTypeBool.into()), + scalar::Type::SInt(w) | scalar::Type::UInt(w) => Some(spv::Inst { + opcode: mo.OpTypeInt, + imms: [ + spv::Imm::Short(wk.LiteralInteger, w.bits()), + spv::Imm::Short( + wk.LiteralInteger, + matches!(ty, scalar::Type::SInt(_)) as u32, + ), + ] + .into_iter() + .collect(), + }), + scalar::Type::Float(w) => Some(spv::Inst { + opcode: mo.OpTypeFloat, + imms: [spv::Imm::Short(wk.LiteralInteger, w.bits())].into_iter().collect(), + }), + }, + + TypeKind::QPtr + | TypeKind::Thunk + | TypeKind::SpvInst { .. } + | TypeKind::SpvStringLiteralForExtInst => None, + } + } + + // HACK(eddyb) this only exists as a helper for `spv::lower`. + pub(super) fn always_lower_as_const(&self) -> bool { + let mo = MappableOps::get(); + mo.OpUndef == self.opcode + } + + // FIXME(eddyb) automate bidirectional mappings more (although the need + // for conditional, i.e. "partial", mappings, adds a lot of complexity). + pub(super) fn as_canonical_const(&self, cx: &Context, ty: Type) -> Option { + let Self { opcode, imms } = self; + let (&opcode, imms) = (opcode, &imms[..]); + + let mo = MappableOps::get(); + + match imms { + [] if opcode == mo.OpUndef => Some(ConstKind::Undef), + [] if opcode == mo.OpConstantFalse => Some(scalar::Const::FALSE.into()), + [] if opcode == mo.OpConstantTrue => Some(scalar::Const::TRUE.into()), + _ if opcode == mo.OpConstant => { + Some(scalar::Const::try_decode_from_spv_imms(ty.as_scalar(cx)?, imms)?.into()) + } + _ => None, } } @@ -61,6 +336,11 @@ impl spv::Inst { match const_kind { ConstKind::Undef => Some(mo.OpUndef.into()), + ConstKind::Scalar(scalar::Const::FALSE) => Some(mo.OpConstantFalse.into()), + ConstKind::Scalar(scalar::Const::TRUE) => Some(mo.OpConstantTrue.into()), + ConstKind::Scalar(ct) => { + Some(spv::Inst { opcode: mo.OpConstant, imms: ct.encode_as_spv_imms().collect() }) + } ConstKind::PtrToGlobalVar(_) | ConstKind::PtrToFunc(_) @@ -68,4 +348,45 @@ impl spv::Inst { | ConstKind::SpvStringLiteralForExtInst(_) => None, } } + + // HACK(eddyb) exported to facilitate `OpSpecConstantOp` handling elsewhere. + pub fn as_canonical_node_kind(&self, cx: &Context, output_types: &[Type]) -> Option { + let Self { opcode, imms } = self; + let (&opcode, imms) = (opcode, &imms[..]); + + let scalar_op = (scalar::BoolUnOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::BoolBinOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::IntUnOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::IntBinOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::FloatUnOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::FloatBinOp::try_from_opcode(opcode).map(scalar::Op::from)); + if let Some(op) = scalar_op { + assert_eq!(imms.len(), 0); + + // FIXME(eddyb) support vector versions of these ops as well. + if output_types.len() == op.output_count() + && output_types.iter().all(|ty| ty.as_scalar(cx).is_some()) + { + Some(op.into()) + } else { + None + } + } else { + None + } + } + + pub(super) fn from_canonical_node_kind(node_kind: &NodeKind) -> Option { + match node_kind { + &NodeKind::Scalar(op) => Some(match op { + scalar::Op::BoolUnary(op) => op.to_opcode().into(), + scalar::Op::BoolBinary(op) => op.to_opcode().into(), + scalar::Op::IntUnary(op) => op.to_opcode().into(), + scalar::Op::IntBinary(op) => op.to_opcode().into(), + scalar::Op::FloatUnary(op) => op.to_opcode().into(), + scalar::Op::FloatBinary(op) => op.to_opcode().into(), + }), + _ => None, + } + } } diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 06e32150..ba0955f8 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -9,7 +9,7 @@ use crate::{ DataInstKind, DbgSrcLoc, DeclDef, ExportKey, Exportee, Func, FuncDecl, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDefBody, Import, Module, ModuleDebugInfo, ModuleDialect, Node, NodeDef, NodeKind, OrdAssertEq, Region, Type, TypeDef, TypeKind, TypeOrConst, Value, Var, - VarDecl, VarKind, + VarDecl, VarKind, scalar, }; use itertools::Itertools; use rustc_hash::FxHashMap; @@ -121,6 +121,8 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { } let ty_def = &self.cx[ty]; match ty_def.kind { + TypeKind::Scalar(_) | TypeKind::SpvInst { .. } => {} + // FIXME(eddyb) this should be a proper `Result`-based error instead, // and/or `spv::lift` should mutate the module for legalization. TypeKind::QPtr => { @@ -132,7 +134,6 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { return; } - TypeKind::SpvInst { .. } => {} TypeKind::SpvStringLiteralForExtInst => { unreachable!( "`TypeKind::SpvStringLiteralForExtInst` should not be used \ @@ -155,6 +156,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { } ConstKind::Undef + | ConstKind::Scalar(_) | ConstKind::PtrToGlobalVar(_) | ConstKind::PtrToFunc(_) | ConstKind::SpvInst { .. } => { @@ -223,11 +225,14 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { } fn visit_node_def(&mut self, func_at_node: FuncAt<'_, Node>) { - #[allow(clippy::match_same_arms)] match func_at_node.def().kind { - NodeKind::Select(_) | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) => {} - - DataInstKind::FuncCall(_) => {} + NodeKind::Select(_) + | NodeKind::Loop { .. } + | NodeKind::ExitInvocation(_) + | DataInstKind::Scalar(_) + | DataInstKind::FuncCall(_) + | DataInstKind::ThunkBind(_) + | DataInstKind::SpvInst(_) => {} // FIXME(eddyb) this should be a proper `Result`-based error instead, // and/or `spv::lift` should mutate the module for legalization. @@ -241,9 +246,6 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { unreachable!("`DataInstKind::QPtr` should be legalized away before lifting"); } - DataInstKind::ThunkBind(_) => {} - - DataInstKind::SpvInst(_) => {} DataInstKind::SpvExtInst { ext_set, .. } => { self.ext_inst_imports.insert(&self.cx[ext_set]); } @@ -490,7 +492,8 @@ impl<'p> FuncAt<'_, CfgCursor<'p>> { | NodeKind::Loop { .. } | NodeKind::ExitInvocation { .. } => None, - DataInstKind::FuncCall(_) + DataInstKind::Scalar(_) + | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) | DataInstKind::ThunkBind(_) @@ -582,8 +585,6 @@ impl<'a> FuncLifting<'a> { func_decl: &'a FuncDecl, mut alloc_id: impl FnMut() -> Result, ) -> Result { - let wk = &spec::Spec::get().well_known; - let func_id = alloc_id()?; let param_ids = func_decl.params.iter().map(|_| alloc_id()).collect::>()?; @@ -705,7 +706,8 @@ impl<'a> FuncLifting<'a> { SmallVec::new() } - DataInstKind::FuncCall(_) + DataInstKind::Scalar(_) + | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) | DataInstKind::SpvInst(_) @@ -875,7 +877,8 @@ impl<'a> FuncLifting<'a> { merge: None, }, - DataInstKind::FuncCall(_) + DataInstKind::Scalar(_) + | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) | DataInstKind::ThunkBind(_) @@ -916,14 +919,9 @@ impl<'a> FuncLifting<'a> { .collect(); let is_infinite_loop = match repeat_condition { - Value::Const(cond) => match &cx[cond].kind { - ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, _const_inputs) = - &**spv_inst_and_const_inputs; - spv_inst.opcode == wk.OpConstantTrue - } - _ => false, - }, + Value::Const(cond) => { + matches!(cx[cond].kind, ConstKind::Scalar(scalar::Const::TRUE)) + } Value::Var(_) => false, }; if is_infinite_loop { @@ -956,6 +954,7 @@ impl<'a> FuncLifting<'a> { } NodeKind::ExitInvocation { .. } + | DataInstKind::Scalar(_) | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) @@ -1276,6 +1275,7 @@ impl LazyInst<'_, '_> { ConstKind::Undef | ConstKind::PtrToFunc(_) + | ConstKind::Scalar(_) | ConstKind::SpvInst { .. } => (ct_def.attrs, None), // Not inserted into `globals` while visiting. @@ -1351,27 +1351,45 @@ impl LazyInst<'_, '_> { let (result_id, attrs, _) = self.result_id_attrs_and_import(module, ids); let inst = match self { Self::Global(global) => match global { - Global::Type(ty) => match &cx[ty].kind { - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => spv::InstWithIds { - without_ids: spv_inst.clone(), - result_type_id: None, - result_id, - ids: type_and_const_inputs - .iter() - .map(|&ty_or_ct| { - ids.globals[&match ty_or_ct { - TypeOrConst::Type(ty) => Global::Type(ty), - TypeOrConst::Const(ct) => Global::Const(ct), - }] - }) - .collect(), - }, + Global::Type(ty) => { + let ty_def = &cx[ty]; + match spv::Inst::from_canonical_type(&ty_def.kind).ok_or(&ty_def.kind) { + Ok(spv_inst) => spv::InstWithIds { + without_ids: spv_inst, + result_type_id: None, + result_id, + ids: [].into_iter().collect(), + }, - // Not inserted into `globals` while visiting. - TypeKind::QPtr | TypeKind::Thunk | TypeKind::SpvStringLiteralForExtInst => { - unreachable!() + Err(TypeKind::Scalar(_)) => { + unreachable!("should've been handled as canonical") + } + + Err(TypeKind::SpvInst { spv_inst, type_and_const_inputs }) => { + spv::InstWithIds { + without_ids: spv_inst.clone(), + result_type_id: None, + result_id, + ids: type_and_const_inputs + .iter() + .map(|&ty_or_ct| { + ids.globals[&match ty_or_ct { + TypeOrConst::Type(ty) => Global::Type(ty), + TypeOrConst::Const(ct) => Global::Const(ct), + }] + }) + .collect(), + } + } + + // Not inserted into `globals` while visiting. + Err( + TypeKind::QPtr | TypeKind::Thunk | TypeKind::SpvStringLiteralForExtInst, + ) => { + unreachable!() + } } - }, + } Global::Const(ct) => { let ct_def = &cx[ct]; match spv::Inst::from_canonical_const(&ct_def.kind).ok_or(&ct_def.kind) { @@ -1382,7 +1400,7 @@ impl LazyInst<'_, '_> { ids: [].into_iter().collect(), }, - Err(ConstKind::Undef) => { + Err(ConstKind::Undef | ConstKind::Scalar(_)) => { unreachable!("should've been handled as canonical") } @@ -1496,29 +1514,43 @@ impl LazyInst<'_, '_> { .collect(), }, Self::DataInst { parent_func, result_id: _, data_inst_def } => { - let (inst, extra_initial_id_operand) = match &data_inst_def.kind { - NodeKind::Select(_) | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) => { - unreachable!() - } + let kind = &data_inst_def.kind; + let (inst, extra_initial_id_operand) = + match spv::Inst::from_canonical_node_kind(kind).ok_or(kind) { + Ok(spv_inst) => (spv_inst, None), + + Err( + NodeKind::Select(_) + | NodeKind::Loop { .. } + | NodeKind::ExitInvocation(_), + ) => unreachable!(), + + Err(DataInstKind::Scalar(_)) => { + unreachable!("should've been handled as canonical") + } - DataInstKind::Mem(_) | DataInstKind::QPtr(_) | DataInstKind::ThunkBind(_) => { - // Disallowed while visiting. - unreachable!() - } + Err( + DataInstKind::Mem(_) + | DataInstKind::QPtr(_) + | DataInstKind::ThunkBind(_), + ) => { + // Disallowed while visiting. + unreachable!() + } - &DataInstKind::FuncCall(callee) => { - (wk.OpFunctionCall.into(), Some(ids.funcs[&callee].func_id)) - } - DataInstKind::SpvInst(inst) => (inst.clone(), None), - &DataInstKind::SpvExtInst { ext_set, inst } => ( - spv::Inst { - opcode: wk.OpExtInst, - imms: iter::once(spv::Imm::Short(wk.LiteralExtInstInteger, inst)) - .collect(), - }, - Some(ids.ext_inst_imports[&cx[ext_set]]), - ), - }; + Err(&DataInstKind::FuncCall(callee)) => { + (wk.OpFunctionCall.into(), Some(ids.funcs[&callee].func_id)) + } + Err(DataInstKind::SpvInst(inst)) => (inst.clone(), None), + Err(&DataInstKind::SpvExtInst { ext_set, inst }) => ( + spv::Inst { + opcode: wk.OpExtInst, + imms: iter::once(spv::Imm::Short(wk.LiteralExtInstInteger, inst)) + .collect(), + }, + Some(ids.ext_inst_imports[&cx[ext_set]]), + ), + }; spv::InstWithIds { without_ids: inst, // HACK(eddyb) multi-output instructions don't exist pre-disaggregate. @@ -1553,6 +1585,14 @@ impl LazyInst<'_, '_> { ids: [merge_label_id, continue_label_id].into_iter().collect(), }, Self::Terminator { parent_func, terminator } => { + let mut ids: SmallVec<[_; 4]> = terminator + .inputs + .iter() + .map(|&v| value_to_id(parent_func, v)) + .chain(terminator.targets.iter().map(|&target| parent_func.label_ids[&target])) + .collect(); + + // FIXME(eddyb) move some of this to `spv::canonical`. let inst = match terminator.kind { TerminatorKind::Unreachable => wk.OpUnreachable.into(), TerminatorKind::Return => { @@ -1562,28 +1602,30 @@ impl LazyInst<'_, '_> { wk.OpReturnValue.into() } } - TerminatorKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(inst)) - | TerminatorKind::SelectBranch(SelectionKind::SpvInst(inst)) => inst.clone(), + TerminatorKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(inst)) => { + inst.clone() + } TerminatorKind::Branch => wk.OpBranch.into(), TerminatorKind::SelectBranch(SelectionKind::BoolCond) => { wk.OpBranchConditional.into() } + TerminatorKind::SelectBranch(SelectionKind::Switch { case_consts }) => { + // HACK(eddyb) move the default case from last back to first. + let default_target = ids.pop().unwrap(); + ids.insert(1, default_target); + + spv::Inst { + opcode: wk.OpSwitch, + imms: case_consts + .iter() + .flat_map(|ct| ct.encode_as_spv_imms()) + .collect(), + } + } }; - spv::InstWithIds { - without_ids: inst, - result_type_id: None, - result_id: None, - ids: terminator - .inputs - .iter() - .map(|&v| value_to_id(parent_func, v)) - .chain( - terminator.targets.iter().map(|&target| parent_func.label_ids[&target]), - ) - .collect(), - } + spv::InstWithIds { without_ids: inst, result_type_id: None, result_id: None, ids } } Self::OpFunctionEnd => spv::InstWithIds { without_ids: wk.OpFunctionEnd.into(), diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 48137c0c..72572f41 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -9,6 +9,7 @@ use crate::{ DbgSrcLoc, DeclDef, Diag, EntityDefs, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, GlobalVarDefBody, Import, InternedStr, Module, NodeDef, NodeKind, Region, RegionDef, Type, TypeDef, TypeKind, TypeOrConst, Value, VarDecl, print, + scalar, }; use itertools::{Either, Itertools as _}; use rustc_hash::{FxHashMap, FxHashSet}; @@ -92,6 +93,20 @@ fn invalid(reason: &str) -> io::Error { io::Error::new(io::ErrorKind::InvalidData, format!("malformed SPIR-V ({reason})")) } +fn invalid_factory_for_spv_inst( + inst: &spv::Inst, + result_id: Option, + ids: &[spv::Id], +) -> impl Fn(&str) -> io::Error + use<> { + let opcode = inst.opcode; + let first_id_operand = ids.first().copied(); + move |msg: &str| { + let result_prefix = result_id.map(|id| format!("%{id} = ")).unwrap_or_default(); + let operand_suffix = first_id_operand.map(|id| format!(" %{id} ...")).unwrap_or_default(); + invalid(&format!("in {result_prefix}{}{operand_suffix}: {msg}", opcode.name())) + } +} + // FIXME(eddyb) provide more information about any normalization that happened: // * stats about deduplication that occured through interning // * sets of unused global vars and functions (and types+consts only they use) @@ -233,7 +248,7 @@ impl Module { while let Some(mut inst) = spv_insts.next().transpose()? { let opcode = inst.opcode; - let invalid = |msg: &str| invalid(&format!("in {}: {}", opcode.name(), msg)); + let invalid = invalid_factory_for_spv_inst(&inst, inst.result_id, &inst.ids); // Handle line debuginfo early, as it doesn't have its own section, // but rather can go almost anywhere among globals and functions. @@ -598,7 +613,7 @@ impl Module { } else if inst_category == spec::InstructionCategory::Type { assert!(inst.result_type_id.is_none()); let id = inst.result_id.unwrap(); - let type_and_const_inputs = inst + let type_and_const_inputs: SmallVec<_> = inst .ids .iter() .map(|&id| match id_defs.get(&id) { @@ -616,7 +631,15 @@ impl Module { let ty = cx.intern(TypeDef { attrs: mem::take(&mut attrs), - kind: TypeKind::SpvInst { spv_inst: inst.without_ids, type_and_const_inputs }, + kind: match inst.as_canonical_type() { + Some(type_kind) => { + assert_eq!(type_and_const_inputs.len(), 0); + type_kind + } + None => { + TypeKind::SpvInst { spv_inst: inst.without_ids, type_and_const_inputs } + } + }, }); id_defs.insert(id, IdDef::Type(ty)); @@ -653,31 +676,13 @@ impl Module { id_defs.insert(id, IdDef::Const(ct)); Seq::TypeConstOrGlobalVar - } else if let Some(const_kind) = inst.as_canonical_const() { + } else if inst_category == spec::InstructionCategory::Const + || inst.always_lower_as_const() + { let id = inst.result_id.unwrap(); - assert_eq!(inst.ids.len(), 0); - - // FIXME(eddyb) this is used below for sequencing, so maybe it - // may be useful to still have some access here to `wk.OpUndef`. - let is_op_undef = matches!(const_kind, ConstKind::Undef); + let ty = result_type.unwrap(); - let ct = cx.intern(ConstDef { - attrs: mem::take(&mut attrs), - ty: result_type.unwrap(), - kind: const_kind, - }); - id_defs.insert(id, IdDef::Const(ct)); - - if is_op_undef { - // `OpUndef` can appear either among constants, or in a - // function, so at most advance `seq` to globals. - seq.max(Some(Seq::TypeConstOrGlobalVar)).unwrap() - } else { - Seq::TypeConstOrGlobalVar - } - } else if inst_category == spec::InstructionCategory::Const { - let id = inst.result_id.unwrap(); - let const_inputs = inst + let const_inputs: SmallVec<_> = inst .ids .iter() .map(|&id| match id_defs.get(&id) { @@ -694,14 +699,26 @@ impl Module { let ct = cx.intern(ConstDef { attrs: mem::take(&mut attrs), - ty: result_type.unwrap(), - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new((inst.without_ids, const_inputs)), + ty, + kind: match inst.as_canonical_const(&cx, ty) { + Some(const_kind) => { + assert_eq!(const_inputs.len(), 0); + const_kind + } + None => ConstKind::SpvInst { + spv_inst_and_const_inputs: Rc::new((inst.without_ids, const_inputs)), + }, }, }); id_defs.insert(id, IdDef::Const(ct)); - Seq::TypeConstOrGlobalVar + if inst_category != spec::InstructionCategory::Const { + // `OpUndef` can appear either among constants, or in a + // function, so at most advance `seq` to globals. + seq.max(Some(Seq::TypeConstOrGlobalVar)).unwrap() + } else { + Seq::TypeConstOrGlobalVar + } } else if opcode == wk.OpVariable && current_func_body.is_none() { let global_var_id = inst.result_id.unwrap(); let type_of_ptr_to_global_var = result_type.unwrap(); @@ -1116,7 +1133,7 @@ impl Module { #[derive(Copy, Clone)] enum LocalIdDef { - Value(Value), + Value(Type, Value), BlockLabel(Region), } @@ -1162,7 +1179,7 @@ impl Module { if opcode == wk.OpLabel { current_block = match local_id_defs[&result_id.unwrap()] { LocalIdDef::BlockLabel(region) => region, - LocalIdDef::Value(_) => unreachable!(), + LocalIdDef::Value(..) => unreachable!(), }; continue; } @@ -1256,53 +1273,55 @@ impl Module { ref ids, } = *raw_inst; - let invalid = |msg: &str| invalid(&format!("in {}: {}", opcode.name(), msg)); + let invalid = invalid_factory_for_spv_inst(&raw_inst.without_ids, result_id, ids); // FIXME(eddyb) find a more compact name and/or make this a method. // FIXME(eddyb) this returns `LocalIdDef` even for global values. - let lookup_global_or_local_id_for_data_or_control_inst_input = - |id| match id_defs.get(&id) { - Some(&IdDef::Const(ct)) => Ok(LocalIdDef::Value(Value::Const(ct))), - Some(id_def @ IdDef::Type(_)) => Err(invalid(&format!( - "unsupported use of {} as an operand for \ + let lookup_global_or_local_id_for_data_or_control_inst_input = |id| match id_defs + .get(&id) + { + Some(&IdDef::Const(ct)) => Ok(LocalIdDef::Value(cx[ct].ty, Value::Const(ct))), + Some(id_def @ IdDef::Type(_)) => Err(invalid(&format!( + "unsupported use of {} as an operand for \ an instruction in a function", - id_def.descr(&cx), - ))), - Some(id_def @ IdDef::Func(_)) => Err(invalid(&format!( - "unsupported use of {} outside `OpFunctionCall`", - id_def.descr(&cx), - ))), - Some(id_def @ IdDef::SpvDebugString(s)) => { - if opcode == wk.OpExtInst { - // HACK(eddyb) intern `OpString`s as `Const`s on - // the fly, as it's a less likely usage than the - // `OpLine` one. - let ct = cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: cx.intern(TypeKind::SpvStringLiteralForExtInst), - kind: ConstKind::SpvStringLiteralForExtInst(*s), - }); - Ok(LocalIdDef::Value(Value::Const(ct))) - } else { - Err(invalid(&format!( - "unsupported use of {} outside `OpSource`, \ + id_def.descr(&cx), + ))), + Some(id_def @ IdDef::Func(_)) => Err(invalid(&format!( + "unsupported use of {} outside `OpFunctionCall`", + id_def.descr(&cx), + ))), + Some(id_def @ IdDef::SpvDebugString(s)) => { + if opcode == wk.OpExtInst { + // HACK(eddyb) intern `OpString`s as `Const`s on + // the fly, as it's a less likely usage than the + // `OpLine` one. + let ty = cx.intern(TypeKind::SpvStringLiteralForExtInst); + let ct = cx.intern(ConstDef { + attrs: AttrSet::default(), + ty, + kind: ConstKind::SpvStringLiteralForExtInst(*s), + }); + Ok(LocalIdDef::Value(ty, Value::Const(ct))) + } else { + Err(invalid(&format!( + "unsupported use of {} outside `OpSource`, \ `OpLine`, or `OpExtInst`", - id_def.descr(&cx), - ))) - } + id_def.descr(&cx), + ))) } - Some(id_def @ IdDef::SpvExtInstImport(_)) => Err(invalid(&format!( - "unsupported use of {} outside `OpExtInst`", - id_def.descr(&cx), - ))), - // FIXME(eddyb) scan the rest of the function for any - // instructions returning this ID, to report an invalid - // forward reference (use before def). - None | Some(IdDef::FuncForwardRef(_)) => local_id_defs - .get(&id) - .copied() - .ok_or_else(|| invalid(&format!("undefined ID %{id}",))), - }; + } + Some(id_def @ IdDef::SpvExtInstImport(_)) => Err(invalid(&format!( + "unsupported use of {} outside `OpExtInst`", + id_def.descr(&cx), + ))), + // FIXME(eddyb) scan the rest of the function for any + // instructions returning this ID, to report an invalid + // forward reference (use before def). + None | Some(IdDef::FuncForwardRef(_)) => local_id_defs + .get(&id) + .copied() + .ok_or_else(|| invalid(&format!("undefined ID %{id}",))), + }; if opcode == wk.OpFunctionParameter { if current_block.is_some() { @@ -1331,8 +1350,10 @@ impl Module { ); body_inputs.push(input_var); - local_id_defs - .insert(result_id.unwrap(), LocalIdDef::Value(Value::Var(input_var))); + local_id_defs.insert( + result_id.unwrap(), + LocalIdDef::Value(ty, Value::Var(input_var)), + ); } continue; @@ -1351,7 +1372,7 @@ impl Module { // to be able to have an entry in `local_id_defs`. let region = match local_id_defs[&result_id.unwrap()] { LocalIdDef::BlockLabel(region) => region, - LocalIdDef::Value(_) => unreachable!(), + LocalIdDef::Value(..) => unreachable!(), }; let details = &block_details[®ion]; assert_eq!(details.label_id, result_id.unwrap()); @@ -1405,7 +1426,7 @@ impl Module { ); current_block_region_def.inputs.push(input_var); - (used_id, LocalIdDef::Value(Value::Var(input_var))) + (used_id, LocalIdDef::Value(ty, Value::Var(input_var))) }, ), ); @@ -1440,7 +1461,7 @@ impl Module { }; let phi_value_id_to_value = |phi_key: &PhiKey, id| { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => Ok(v), + LocalIdDef::Value(_, v) => Ok(v), LocalIdDef::BlockLabel { .. } => Err(invalid(&format!( "unsupported use of block label as the value for {}", descr_phi_case(phi_key) @@ -1490,7 +1511,7 @@ impl Module { match lookup_global_or_local_id_for_data_or_control_inst_input( used_id, )? { - LocalIdDef::Value(v) => Ok(v), + LocalIdDef::Value(_, v) => Ok(v), LocalIdDef::BlockLabel(_) => unreachable!(), } }), @@ -1503,10 +1524,11 @@ impl Module { // Split the operands into value inputs (e.g. a branch's // condition or an `OpSwitch`'s selector) and target blocks. let mut inputs = SmallVec::new(); + let mut input_types = SmallVec::<[_; 2]>::new(); let mut targets = SmallVec::<[_; 4]>::new(); for &id in ids { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => { + LocalIdDef::Value(ty, v) => { if !targets.is_empty() { return Err(invalid( "out of order: value operand \ @@ -1514,6 +1536,7 @@ impl Module { )); } inputs.push(v); + input_types.push(ty); } LocalIdDef::BlockLabel(target) => { record_cfg_edge(target)?; @@ -1567,13 +1590,75 @@ impl Module { Value::Var(thunk_var) }; - let selection_kind = if opcode == wk.OpBranchConditional { + let (selection_kind, targets_with_inputs) = if opcode == wk.OpBranchConditional + { assert_eq!((targets_with_inputs.len(), inputs.len()), (2, 1)); - Some(SelectionKind::BoolCond) + (Some(SelectionKind::BoolCond), Either::Left(targets_with_inputs)) } else if opcode == wk.OpSwitch { - Some(SelectionKind::SpvInst(raw_inst.without_ids.clone())) + assert_eq!(inputs.len(), 1); + + // HACK(eddyb) `spv::read` has to "redundantly" validate + // that such a type is `OpTypeInt`/`OpTypeFloat`, but + // there is still a limitation when it comes to `scalar::Const`. + // FIXME(eddyb) don't hardcode the 128-bit limitation, + // but query `scalar::Const` somehow instead. + let scrutinee_type = input_types[0]; + let scrutinee_type = scrutinee_type + .as_scalar(&cx) + .filter(|ty| { + matches!(ty, scalar::Type::UInt(_) | scalar::Type::SInt(_)) + && ty.bit_width() <= 128 + }) + .ok_or_else(|| { + invalid( + &print::Plan::for_root( + &cx, + &Diag::err([ + "unsupported `OpSwitch` scrutinee type `".into(), + scrutinee_type.into(), + "`".into(), + ]) + .message, + ) + .pretty_print() + .to_string(), + ) + })?; + + // FIXME(eddyb) move some of this to `spv::canonical`. + let imm_words_per_case = + usize::try_from(scrutinee_type.bit_width().div_ceil(32)).unwrap(); + + // NOTE(eddyb) these sanity-checks are redundant with `spv::read`. + assert_eq!(imms.len() % imm_words_per_case, 0); + assert_eq!(targets_with_inputs.len(), 1 + imms.len() / imm_words_per_case); + + let case_consts = imms + .chunks(imm_words_per_case) + .map(|case_imms| { + scalar::Const::try_decode_from_spv_imms(scrutinee_type, case_imms) + .ok_or_else(|| { + invalid(&format!( + "invalid {}-bit `OpSwitch` case constant", + scrutinee_type.bit_width() + )) + }) + }) + .collect::>()?; + + // HACK(eddyb) move the default case from first to last. + let targets_with_inputs = { + let mut original_targets = targets_with_inputs; + let default_target = original_targets.next().unwrap(); + original_targets.chain([default_target]) + }; + + ( + Some(SelectionKind::Switch { case_consts }), + Either::Right(targets_with_inputs), + ) } else { - None + (None, Either::Left(targets_with_inputs)) }; // HACK(eddyb) see comment on `whole_func_merge`. @@ -1594,6 +1679,7 @@ impl Module { }) .collect(); + // FIXME(eddyb) move some of this to `spv::canonical`. let select_node = func_def_body.nodes.define( &cx, NodeDef { @@ -1625,12 +1711,21 @@ impl Module { } else if [wk.OpReturn, wk.OpReturnValue].contains(&opcode) && !treat_return_as_exit_invocation { - assert!(targets_with_inputs.len() == 0 && inputs.len() <= 1); + assert!(targets_with_inputs.count() == 0 && inputs.len() <= 1); build_thunk( func_def_body.at_mut(current_block.region), (cf::unstructured::ControlTarget::Return, mem::take(&mut inputs)), ) - } else if targets_with_inputs.len() == 0 { + } else if opcode == wk.OpBranch { + build_thunk( + func_def_body.at_mut(current_block.region), + targets_with_inputs.exactly_one().ok().unwrap(), + ) + } else { + if targets_with_inputs.count() > 0 { + return Err(invalid("unsupported control-flow instruction")); + } + if opcode != wk.OpUnreachable { let node = func_def_body.nodes.define( &cx, @@ -1658,13 +1753,6 @@ impl Module { ty: thunk_ty, kind: ConstKind::Undef, })) - } else if opcode == wk.OpBranch { - build_thunk( - func_def_body.at_mut(current_block.region), - targets_with_inputs.exactly_one().ok().unwrap(), - ) - } else { - return Err(invalid("unsupported control-flow instruction")); }; assert_eq!(inputs.len(), 0); @@ -1694,7 +1782,7 @@ impl Module { current_block_region_def.inputs.push(input_var); local_id_defs - .insert(result_id.unwrap(), LocalIdDef::Value(Value::Var(input_var))); + .insert(result_id.unwrap(), LocalIdDef::Value(ty, Value::Var(input_var))); } else if [wk.OpSelectionMerge, wk.OpLoopMerge].contains(&opcode) { let is_second_to_last_in_block = lookahead_raw_inst(2) .is_none_or(|next_raw_inst| next_raw_inst.without_ids.opcode == wk.OpLabel); @@ -1713,7 +1801,7 @@ impl Module { let loop_merge_target = match lookup_global_or_local_id_for_data_or_control_inst_input(ids[0])? { - LocalIdDef::Value(_) => return Err(invalid("expected label ID")), + LocalIdDef::Value(..) => return Err(invalid("expected label ID")), LocalIdDef::BlockLabel(target) => target, }; @@ -1731,7 +1819,13 @@ impl Module { // some "structured regions" replacement for the CFG. } else { let mut ids = &ids[..]; - let kind = if opcode == wk.OpFunctionCall { + let kind = if let Some(kind) = raw_inst.without_ids.as_canonical_node_kind( + &cx, + result_type.map(|ty| [ty]).as_ref().map_or(&[][..], |tys| &tys[..]), + ) { + // FIXME(eddyb) sanity-check the number/types of inputs. + kind + } else if opcode == wk.OpFunctionCall { assert!(imms.is_empty()); let callee_id = ids[0]; let maybe_callee = id_defs @@ -1794,7 +1888,7 @@ impl Module { .map(|&id| { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => Ok(v), + LocalIdDef::Value(_, v) => Ok(v), LocalIdDef::BlockLabel { .. } => Err(invalid( "unsupported use of block label as a value, \ in non-terminator instruction", @@ -1829,7 +1923,10 @@ impl Module { ); outputs.push(output_var); - local_id_defs.insert(result_id, LocalIdDef::Value(Value::Var(output_var))); + local_id_defs.insert( + result_id, + LocalIdDef::Value(result_type.unwrap(), Value::Var(output_var)), + ); } current_block_region_def.children.insert_last(inst, &mut func_def_body.nodes); diff --git a/src/spv/read.rs b/src/spv/read.rs index eee77bdd..567a0a24 100644 --- a/src/spv/read.rs +++ b/src/spv/read.rs @@ -12,15 +12,14 @@ use std::{fs, io, iter, slice}; /// /// Used currently only to help parsing `LiteralContextDependentNumber`. enum KnownIdDef { - TypeInt(NonZeroU32), - TypeFloat(NonZeroU32), + TypeIntOrFloat(NonZeroU32), Uncategorized { opcode: spec::Opcode, result_type_id: Option }, } impl KnownIdDef { fn result_type_id(&self) -> Option { match *self { - Self::TypeInt(_) | Self::TypeFloat(_) => None, + Self::TypeIntOrFloat(_) => None, Self::Uncategorized { result_type_id, .. } => result_type_id, } } @@ -183,11 +182,8 @@ impl InstParser<'_> { .and_then(|id| self.known_ids.get(&id)) .ok_or(Error::MissingContextSensitiveLiteralType)?; - let extra_word_count = match *contextual_type { - KnownIdDef::TypeInt(width) | KnownIdDef::TypeFloat(width) => { - // HACK(eddyb) `(width + 31) / 32 - 1` but without overflow. - (width.get() - 1) / 32 - } + let word_count = match *contextual_type { + KnownIdDef::TypeIntOrFloat(width) => width.get().div_ceil(32), KnownIdDef::Uncategorized { opcode, .. } => { return Err(Error::UnsupportedContextSensitiveLiteralType { type_opcode: opcode, @@ -195,11 +191,11 @@ impl InstParser<'_> { } }; - if extra_word_count == 0 { + if word_count == 1 { self.inst.imms.push(spv::Imm::Short(kind, word)); } else { self.inst.imms.push(spv::Imm::LongStart(kind, word)); - for _ in 0..extra_word_count { + for _ in 1..word_count { let word = self.words.next().ok_or(Error::NotEnoughWords)?; self.inst.imms.push(spv::Imm::LongCont(kind, word)); } @@ -329,9 +325,6 @@ impl ModuleParser { impl Iterator for ModuleParser { type Item = io::Result; fn next(&mut self) -> Option { - let spv_spec = spec::Spec::get(); - let wk = &spv_spec.well_known; - let words = &bytemuck::cast_slice::(&self.word_bytes)[self.next_word..]; let &opcode = words.first()?; @@ -367,24 +360,11 @@ impl Iterator for ModuleParser { // HACK(eddyb) `Option::map` allows using `?` for `Result` in the closure. let maybe_known_id_result = inst.result_id.map(|id| { - let known_id_def = if opcode == wk.OpTypeInt { - KnownIdDef::TypeInt(match inst.imms[0] { - spv::Imm::Short(kind, n) => { - assert_eq!(kind, wk.LiteralInteger); - n.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))? - } - _ => unreachable!(), - }) - } else if opcode == wk.OpTypeFloat { - KnownIdDef::TypeFloat(match inst.imms[0] { - spv::Imm::Short(kind, n) => { - assert_eq!(kind, wk.LiteralInteger); - n.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))? - } - _ => unreachable!(), - }) - } else { - KnownIdDef::Uncategorized { opcode, result_type_id: inst.result_type_id } + let known_id_def = match inst.int_or_float_type_bit_width() { + Some(w) => KnownIdDef::TypeIntOrFloat( + w.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))?, + ), + None => KnownIdDef::Uncategorized { opcode, result_type_id: inst.result_type_id }, }; let old = self.known_ids.insert(id, known_id_def); diff --git a/src/spv/spec.rs b/src/spv/spec.rs index 6daff6c1..c1e80a40 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -117,9 +117,6 @@ def_well_known! { OpNoLine, OpTypeVoid, - OpTypeBool, - OpTypeInt, - OpTypeFloat, OpTypeVector, OpTypeMatrix, OpTypeArray, @@ -133,9 +130,6 @@ def_well_known! { OpTypeSampledImage, OpTypeAccelerationStructureKHR, - OpConstantFalse, - OpConstantTrue, - OpConstant, OpConstantFunctionPointerINTEL, OpVariable, diff --git a/src/transform.rs b/src/transform.rs index 148669ad..00b44a2e 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -438,7 +438,10 @@ impl InnerTransform for TypeDef { transform!({ attrs -> transformer.transform_attr_set_use(*attrs), kind -> match kind { - TypeKind::QPtr | TypeKind::Thunk | TypeKind::SpvStringLiteralForExtInst => Transformed::Unchanged, + TypeKind::Scalar(_) + | TypeKind::QPtr + | TypeKind::Thunk + | TypeKind::SpvStringLiteralForExtInst => Transformed::Unchanged, TypeKind::SpvInst { spv_inst, type_and_const_inputs } => Transformed::map_iter( type_and_const_inputs.iter(), @@ -472,6 +475,7 @@ impl InnerTransform for ConstDef { ty -> transformer.transform_type_use(*ty), kind -> match kind { ConstKind::Undef + | ConstKind::Scalar(_) | ConstKind::SpvStringLiteralForExtInst(_) => Transformed::Unchanged, ConstKind::PtrToGlobalVar(gv) => transform!({ @@ -625,9 +629,12 @@ impl InnerInPlaceTransform for FuncAtMut<'_, Node> { match kind { DataInstKind::FuncCall(func) => transformer.transform_func_use(*func).apply_to(func), - NodeKind::Select(SelectionKind::BoolCond | SelectionKind::SpvInst(_)) + NodeKind::Select( + SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, + ) | NodeKind::Loop { repeat_condition: _ } | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) + | DataInstKind::Scalar(_) | DataInstKind::Mem(MemOp::FuncLocalVar(_) | MemOp::Load | MemOp::Store) | DataInstKind::QPtr( QPtrOp::HandleArrayIndex diff --git a/src/visit.rs b/src/visit.rs index b25d4cbc..24e4a666 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -321,7 +321,10 @@ impl InnerVisit for TypeDef { visitor.visit_attr_set_use(*attrs); match kind { - TypeKind::QPtr | TypeKind::Thunk | TypeKind::SpvStringLiteralForExtInst => {} + TypeKind::Scalar(_) + | TypeKind::QPtr + | TypeKind::Thunk + | TypeKind::SpvStringLiteralForExtInst => {} TypeKind::SpvInst { spv_inst: _, type_and_const_inputs } => { for &ty_or_ct in type_and_const_inputs { @@ -342,7 +345,7 @@ impl InnerVisit for ConstDef { visitor.visit_attr_set_use(*attrs); visitor.visit_type_use(*ty); match kind { - ConstKind::Undef | ConstKind::SpvStringLiteralForExtInst(_) => {} + ConstKind::Undef | ConstKind::Scalar(_) | ConstKind::SpvStringLiteralForExtInst(_) => {} &ConstKind::PtrToGlobalVar(gv) => visitor.visit_global_var_use(gv), &ConstKind::PtrToFunc(func) => visitor.visit_func_use(func), @@ -468,9 +471,12 @@ impl<'a> FuncAt<'a, Node> { match kind { &DataInstKind::FuncCall(func) => visitor.visit_func_use(func), - NodeKind::Select(SelectionKind::BoolCond | SelectionKind::SpvInst(_)) + NodeKind::Select( + SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, + ) | NodeKind::Loop { repeat_condition: _ } | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) + | DataInstKind::Scalar(_) | DataInstKind::Mem(MemOp::FuncLocalVar(_) | MemOp::Load | MemOp::Store) | DataInstKind::QPtr( QPtrOp::HandleArrayIndex