From 8fcc984a0adc5aafeb7d3e03079c23cb66168dfc Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Tue, 9 Sep 2025 11:42:12 +0300 Subject: [PATCH 1/2] Add `TypeKind::Vector`&`ConstKind::Vector` for vector types&consts. --- src/lib.rs | 36 ++++++++++- src/mem/layout.rs | 31 +++++---- src/print/mod.rs | 149 +++++++++++++++++++------------------------ src/spv/canonical.rs | 140 ++++++++++++++++++++++++++++------------ src/spv/lift.rs | 87 ++++++++++++++++--------- src/spv/lower.rs | 24 ++----- src/spv/spec.rs | 2 +- src/transform.rs | 2 + src/vector.rs | 123 +++++++++++++++++++++++++++++++++++ src/visit.rs | 6 +- 10 files changed, 415 insertions(+), 185 deletions(-) create mode 100644 src/vector.rs diff --git a/src/lib.rs b/src/lib.rs index 76731bb4..bfbc465d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -171,6 +171,7 @@ pub mod mem; pub mod qptr; pub mod scalar; pub mod spv; +pub mod vector; use smallvec::SmallVec; use std::borrow::Cow; @@ -544,6 +545,13 @@ pub enum TypeKind { #[from] Scalar(scalar::Type), + /// Vector (small array of [`scalar`]s) type, with some limitations on the + /// supported component counts (but all standard ones should be included). + /// + /// See also the [`vector`] module for more documentation and definitions. + #[from] + Vector(vector::Type), + /// "Quasi-pointer", an untyped pointer-like abstract scalar that can represent /// both memory locations (in any address space) and other kinds of locations /// (e.g. SPIR-V `OpVariable`s in non-memory "storage classes"). @@ -585,7 +593,7 @@ macro_rules! impl_intern_type_kind { })+ } } -impl_intern_type_kind!(TypeKind, scalar::Type); +impl_intern_type_kind!(TypeKind, scalar::Type, vector::Type); // HACK(eddyb) this is like `Either`, only used in `TypeKind::SpvInst`, // and only because SPIR-V type definitions can references both types and consts. @@ -603,6 +611,12 @@ impl Type { _ => None, } } + pub fn as_vector(self, cx: &Context) -> Option { + match cx[self].kind { + TypeKind::Vector(ty) => Some(ty), + _ => None, + } + } } /// Interned handle for a [`ConstDef`](crate::ConstDef) (a constant value). @@ -638,6 +652,18 @@ pub enum ConstKind { #[from] Scalar(scalar::Const), + /// Vector (small array of [`scalar`]s) constant, which must have + /// a type of [`TypeKind::Vector`] (of the same [`vector::Type`]). + /// + /// See also the [`vector`] module for more documentation and definitions. + // + // FIXME(eddyb) maybe document the 128-bit limitation inherited from `scalar::Const`? + // FIXME(eddyb) this technically makes the `vector::Type` redundant, could + // it get out of sync? (perhaps "forced canonicalization" could be used to + // enforce that interning simply doesn't allow such scenarios?). + #[from] + Vector(vector::Const), + // FIXME(eddyb) maybe merge these? however, their connection is somewhat // tenuous (being one of the LLVM-isms SPIR-V inherited, among other things), // there's still the need to rename "global variable" post-`Var`-refactor, @@ -674,7 +700,7 @@ macro_rules! impl_intern_const_kind { })+ } } -impl_intern_const_kind!(scalar::Const); +impl_intern_const_kind!(scalar::Const, vector::Const); // HACK(eddyb) on `Const` instead of `ConstDef` for ergonomics reasons. impl Const { @@ -684,6 +710,12 @@ impl Const { _ => None, } } + pub fn as_vector(self, cx: &Context) -> Option<&vector::Const> { + match &cx[self].kind { + ConstKind::Vector(ct) => Some(ct), + _ => None, + } + } } /// Declarations ([`GlobalVarDecl`], [`FuncDecl`]) can contain a full definition, diff --git a/src/mem/layout.rs b/src/mem/layout.rs index 43c68c17..1a4721c4 100644 --- a/src/mem/layout.rs +++ b/src/mem/layout.rs @@ -335,6 +335,21 @@ impl<'a> LayoutCache<'a> { } TypeKind::Scalar(ty) => return Ok(scalar(ty.bit_width())), + TypeKind::Vector(ty) => { + let len = u32::from(ty.elem_count.get()); + return array( + cx.intern(ty.elem), + ArrayParams { + fixed_len: Some(len), + known_stride: None, + + // NOTE(eddyb) this is specifically Vulkan "base alignment". + min_legacy_align: 1, + legacy_align_multiplier: if len <= 2 { 2 } else { 4 }, + }, + ); + } + // FIXME(eddyb) treat `QPtr`s as scalars. TypeKind::QPtr => { return Err(LayoutError(Diag::bug( @@ -362,15 +377,7 @@ impl<'a> LayoutCache<'a> { // FIXME(eddyb) categorize `OpTypePointer` by storage class and split on // logical vs physical here. scalar_with_size_and_align(self.config.logical_ptr_size_align) - } else if [wk.OpTypeVector, wk.OpTypeMatrix].contains(&spv_inst.opcode) { - let len = short_imm_at(0); - let (min_legacy_align, legacy_align_multiplier) = if spv_inst.opcode == wk.OpTypeVector - { - // NOTE(eddyb) this is specifically Vulkan "base alignment". - (1, if len <= 2 { 2 } else { 4 }) - } else { - (self.config.min_aggregate_legacy_align, 1) - }; + } else if spv_inst.opcode == wk.OpTypeMatrix { // NOTE(eddyb) `RowMajor` is disallowed on `OpTypeStruct` members below. array( match type_and_const_inputs[..] { @@ -378,10 +385,10 @@ impl<'a> LayoutCache<'a> { _ => unreachable!(), }, ArrayParams { - fixed_len: Some(len), + fixed_len: Some(short_imm_at(0)), known_stride: None, - min_legacy_align, - legacy_align_multiplier, + min_legacy_align: self.config.min_aggregate_legacy_align, + legacy_align_multiplier: 1, }, )? } else if [wk.OpTypeArray, wk.OpTypeRuntimeArray].contains(&spv_inst.opcode) { diff --git a/src/print/mod.rs b/src/print/mod.rs index c868763e..5ef5ea4a 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -991,7 +991,6 @@ impl<'a, VRN: FnMut(FuncAt<'a, Either>)> Visitor<'a> impl<'a> Printer<'a> { fn new(plan: &Plan<'a>) -> Self { let cx = plan.cx; - let wk = &spv::spec::Spec::get().well_known; let mut attrs_with_spv_name_in_use = FxHashMap::default(); let mut per_region_dbg_scope_def_placements: EntityOrientedDenseMap< @@ -1087,22 +1086,19 @@ impl<'a> Printer<'a> { CxInterned::Type(ty) => { let ty_def = &cx[ty]; - // FIXME(eddyb) remove the duplication between - // here and `TypeDef`'s `Print` impl. - let has_compact_print_or_is_leaf = match &ty_def.kind { - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { - spv_inst.opcode == wk.OpTypeVector - || type_and_const_inputs.is_empty() + let is_leaf = match &ty_def.kind { + TypeKind::SpvInst { type_and_const_inputs, .. } => { + type_and_const_inputs.is_empty() } TypeKind::Scalar(_) + | TypeKind::Vector(_) | TypeKind::QPtr | TypeKind::Thunk | TypeKind::SpvStringLiteralForExtInst => true, }; - ty_def.attrs == AttrSet::default() - && has_compact_print_or_is_leaf + ty_def.attrs == AttrSet::default() && is_leaf } CxInterned::Const(ct) => { let ct_def = &cx[ct]; @@ -3047,72 +3043,45 @@ impl Print for TypeDef { let wk = &spv::spec::Spec::get().well_known; - // FIXME(eddyb) should this be done by lowering SPIR-V types to SPIR-T? let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); - #[allow(irrefutable_let_patterns)] - let compact_def = if let &TypeKind::SpvInst { - spv_inst: spv::Inst { opcode, ref imms }, - ref type_and_const_inputs, - } = kind - { - if opcode == wk.OpTypeVector { - let (elem_ty, elem_count) = match (&imms[..], &type_and_const_inputs[..]) { - (&[spv::Imm::Short(_, elem_count)], &[TypeOrConst::Type(elem_ty)]) => { - (elem_ty, elem_count) - } - _ => unreachable!(), - }; - Some(pretty::Fragment::new([ - elem_ty.print(printer), - "×".into(), - printer.numeric_literal_style().apply(format!("{elem_count}")).into(), - ])) - } else { - None + // FIXME(eddyb) should this just be `fmt::Display` on `scalar::Type`? + let print_scalar = |ty: scalar::Type| { + let width = ty.bit_width(); + match ty { + scalar::Type::Bool => "bool".into(), + scalar::Type::SInt(_) => format!("s{width}"), + scalar::Type::UInt(_) => format!("u{width}"), + scalar::Type::Float(_) => format!("f{width}"), } - } else { - None }; AttrsAndDef { attrs: attrs.print(printer), - def_without_name: if let Some(def) = compact_def { - def - } else { - match kind { - TypeKind::Scalar(ty) => { - let width = ty.bit_width(); - kw(match ty { - scalar::Type::Bool => "bool".into(), - scalar::Type::SInt(_) => format!("s{width}"), - scalar::Type::UInt(_) => format!("u{width}"), - scalar::Type::Float(_) => format!("f{width}"), - }) - } - - // FIXME(eddyb) should this be shortened to `qtr`? - TypeKind::QPtr => printer.declarative_keyword_style().apply("qptr").into(), - - TypeKind::Thunk => printer.imperative_keyword_style().apply("thunk").into(), - - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => printer - .pretty_spv_inst( - printer.spv_op_style(), - spv_inst.opcode, - &spv_inst.imms, - type_and_const_inputs.iter().map(|&ty_or_ct| match ty_or_ct { - TypeOrConst::Type(ty) => ty.print(printer), - TypeOrConst::Const(ct) => ct.print(printer), - }), - ), - TypeKind::SpvStringLiteralForExtInst => pretty::Fragment::new([ - printer.error_style().apply("type_of").into(), - "(".into(), - printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), - ")".into(), - ]), - } + def_without_name: match kind { + &TypeKind::Scalar(ty) => kw(print_scalar(ty)), + &TypeKind::Vector(ty) => kw(format!("{}×{}", print_scalar(ty.elem), ty.elem_count)), + + // FIXME(eddyb) should this be shortened to `qtr`? + TypeKind::QPtr => printer.declarative_keyword_style().apply("qptr").into(), + + TypeKind::Thunk => printer.imperative_keyword_style().apply("thunk").into(), + + TypeKind::SpvInst { spv_inst, type_and_const_inputs } => printer.pretty_spv_inst( + printer.spv_op_style(), + spv_inst.opcode, + &spv_inst.imms, + type_and_const_inputs.iter().map(|&ty_or_ct| match ty_or_ct { + TypeOrConst::Type(ty) => ty.print(printer), + TypeOrConst::Const(ct) => ct.print(printer), + }), + ), + TypeKind::SpvStringLiteralForExtInst => pretty::Fragment::new([ + printer.error_style().apply("type_of").into(), + "(".into(), + printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), + ")".into(), + ]), }, } } @@ -3127,14 +3096,11 @@ impl Print for ConstDef { let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); - let def_without_name = match kind { - ConstKind::Undef => pretty::Fragment::new([ - printer.imperative_keyword_style().apply("undef").into(), - printer.pretty_type_ascription_suffix(*ty), - ]), - ConstKind::Scalar(scalar::Const::FALSE) => kw("false"), - ConstKind::Scalar(scalar::Const::TRUE) => kw("true"), - ConstKind::Scalar(ct) => { + // FIXME(eddyb) should this just a method on `scalar::Const` instead? + let print_scalar = |ct: scalar::Const, include_type_suffix: bool| match ct { + scalar::Const::FALSE => kw("false"), + scalar::Const::TRUE => kw("true"), + _ => { let ty = ct.ty(); let width = ty.bit_width(); let (maybe_printed_value, ty_prefix) = match ty { @@ -3192,14 +3158,18 @@ impl Print for ConstDef { }; match maybe_printed_value { Some(printed_value) => { - let literal_ty_suffix = pretty::Styles { - // HACK(eddyb) the exact type detracts from the value. - color_opacity: Some(0.4), - subscript: true, - ..printer.declarative_keyword_style() + if include_type_suffix { + let literal_ty_suffix = pretty::Styles { + // HACK(eddyb) the exact type detracts from the value. + color_opacity: Some(0.4), + subscript: true, + ..printer.declarative_keyword_style() + } + .apply(format!("{ty_prefix}{width}")); + pretty::Fragment::new([printed_value, literal_ty_suffix.into()]) + } else { + printed_value } - .apply(format!("{ty_prefix}{width}")); - pretty::Fragment::new([printed_value, literal_ty_suffix.into()]) } // HACK(eddyb) fallback using the bitwise representation. None => pretty::Fragment::new([ @@ -3220,6 +3190,18 @@ impl Print for ConstDef { ]), } } + }; + + let def_without_name = match kind { + ConstKind::Undef => pretty::Fragment::new([ + printer.imperative_keyword_style().apply("undef").into(), + printer.pretty_type_ascription_suffix(*ty), + ]), + &ConstKind::Scalar(ct) => print_scalar(ct, true), + ConstKind::Vector(ct) => pretty::Fragment::new([ + ty.print(printer), + pretty::join_comma_sep("(", ct.elems().map(|elem| print_scalar(elem, false)), ")"), + ]), &ConstKind::PtrToGlobalVar(gv) => { pretty::Fragment::new(["&".into(), gv.print(printer)]) } @@ -4052,6 +4034,7 @@ impl FuncAt<'_, DataInst> { if let Value::Const(ct) = v { match &printer.cx[ct].kind { ConstKind::Undef + | ConstKind::Vector(_) | ConstKind::PtrToGlobalVar(_) | ConstKind::PtrToFunc(_) | ConstKind::SpvInst { .. } => {} diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs index efd8e95f..45b1fe6f 100644 --- a/src/spv/canonical.rs +++ b/src/spv/canonical.rs @@ -8,8 +8,9 @@ // FIXME(eddyb) should interning attempts check/apply these canonicalizations? use crate::spv::{self, spec}; -use crate::{ConstKind, Context, NodeKind, Type, TypeKind, scalar}; +use crate::{Const, ConstKind, Context, NodeKind, Type, TypeKind, TypeOrConst, scalar, vector}; use lazy_static::lazy_static; +use smallvec::SmallVec; // FIXME(eddyb) these ones could maybe make use of build script generation. macro_rules! def_mappable_ops { @@ -65,6 +66,7 @@ def_mappable_ops! { OpTypeBool, OpTypeInt, OpTypeFloat, + OpTypeVector, } const { OpUndef, @@ -249,55 +251,86 @@ impl spv::Inst { // FIXME(eddyb) automate bidirectional mappings more (although the need // for conditional, i.e. "partial", mappings, adds a lot of complexity). - pub(super) fn as_canonical_type(&self) -> Option { + pub(super) fn as_canonical_type( + &self, + cx: &Context, + type_and_const_inputs: &[TypeOrConst], + ) -> Option { let Self { opcode, imms } = self; let (&opcode, imms) = (opcode, &imms[..]); let mo = MappableOps::get(); let int_width = || scalar::IntWidth::try_from_bits(self.int_or_float_type_bit_width()?); - match imms { - [] if opcode == mo.OpTypeBool => Some(scalar::Type::Bool.into()), - &[_, spv::Imm::Short(_, 0)] if opcode == mo.OpTypeInt => { + match (imms, type_and_const_inputs) { + ([], []) if opcode == mo.OpTypeBool => Some(scalar::Type::Bool.into()), + (&[_, spv::Imm::Short(_, 0)], []) if opcode == mo.OpTypeInt => { Some(scalar::Type::UInt(int_width()?).into()) } - &[_, spv::Imm::Short(_, 1)] if opcode == mo.OpTypeInt => { + (&[_, spv::Imm::Short(_, 1)], []) if opcode == mo.OpTypeInt => { Some(scalar::Type::SInt(int_width()?).into()) } - [_] if opcode == mo.OpTypeFloat => Some( + ([_], []) if opcode == mo.OpTypeFloat => Some( scalar::Type::Float(scalar::FloatWidth::try_from_bits( self.int_or_float_type_bit_width()?, )?) .into(), ), + (&[spv::Imm::Short(_, elem_count)], &[TypeOrConst::Type(elem_type)]) + if opcode == mo.OpTypeVector => + { + Some( + vector::Type { + elem: elem_type.as_scalar(cx)?, + elem_count: u8::try_from(elem_count).ok()?.try_into().ok()?, + } + .into(), + ) + } _ => None, } } - pub(super) fn from_canonical_type(type_kind: &TypeKind) -> Option { + pub(super) fn from_canonical_type( + cx: &Context, + type_kind: &TypeKind, + ) -> Option<(Self, SmallVec<[TypeOrConst; 2]>)> { let wk = &spec::Spec::get().well_known; let mo = MappableOps::get(); match type_kind { - &TypeKind::Scalar(ty) => match ty { - scalar::Type::Bool => Some(mo.OpTypeBool.into()), - scalar::Type::SInt(w) | scalar::Type::UInt(w) => Some(spv::Inst { - opcode: mo.OpTypeInt, - imms: [ - spv::Imm::Short(wk.LiteralInteger, w.bits()), - spv::Imm::Short( - wk.LiteralInteger, - matches!(ty, scalar::Type::SInt(_)) as u32, - ), - ] - .into_iter() - .collect(), - }), - scalar::Type::Float(w) => Some(spv::Inst { - opcode: mo.OpTypeFloat, - imms: [spv::Imm::Short(wk.LiteralInteger, w.bits())].into_iter().collect(), - }), - }, + &TypeKind::Scalar(ty) => Some(( + match ty { + scalar::Type::Bool => mo.OpTypeBool.into(), + scalar::Type::SInt(w) | scalar::Type::UInt(w) => spv::Inst { + opcode: mo.OpTypeInt, + imms: [ + spv::Imm::Short(wk.LiteralInteger, w.bits()), + spv::Imm::Short( + wk.LiteralInteger, + matches!(ty, scalar::Type::SInt(_)) as u32, + ), + ] + .into_iter() + .collect(), + }, + scalar::Type::Float(w) => spv::Inst { + opcode: mo.OpTypeFloat, + imms: [spv::Imm::Short(wk.LiteralInteger, w.bits())].into_iter().collect(), + }, + }, + [].into_iter().collect(), + )), + + TypeKind::Vector(ty) => Some(( + spv::Inst { + opcode: mo.OpTypeVector, + imms: [spv::Imm::Short(wk.LiteralInteger, ty.elem_count.get().into())] + .into_iter() + .collect(), + }, + [TypeOrConst::Type(cx.intern(ty.elem))].into_iter().collect(), + )), TypeKind::QPtr | TypeKind::Thunk @@ -314,33 +347,60 @@ impl spv::Inst { // FIXME(eddyb) automate bidirectional mappings more (although the need // for conditional, i.e. "partial", mappings, adds a lot of complexity). - pub(super) fn as_canonical_const(&self, cx: &Context, ty: Type) -> Option { + pub(super) fn as_canonical_const( + &self, + cx: &Context, + ty: Type, + const_inputs: &[Const], + ) -> Option { let Self { opcode, imms } = self; let (&opcode, imms) = (opcode, &imms[..]); + let wk = &spec::Spec::get().well_known; let mo = MappableOps::get(); - match imms { - [] if opcode == mo.OpUndef => Some(ConstKind::Undef), - [] if opcode == mo.OpConstantFalse => Some(scalar::Const::FALSE.into()), - [] if opcode == mo.OpConstantTrue => Some(scalar::Const::TRUE.into()), - _ if opcode == mo.OpConstant => { + match (imms, const_inputs) { + ([], []) if opcode == mo.OpUndef => Some(ConstKind::Undef), + ([], []) if opcode == mo.OpConstantFalse => Some(scalar::Const::FALSE.into()), + ([], []) if opcode == mo.OpConstantTrue => Some(scalar::Const::TRUE.into()), + (_, []) if opcode == mo.OpConstant => { Some(scalar::Const::try_decode_from_spv_imms(ty.as_scalar(cx)?, imms)?.into()) } + _ if opcode == wk.OpConstantComposite => { + let ty = ty.as_vector(cx)?; + let elems = (const_inputs.len() == usize::from(ty.elem_count.get()) + && const_inputs.iter().all(|ct| ct.as_scalar(cx).is_some())) + .then(|| const_inputs.iter().map(|ct| *ct.as_scalar(cx).unwrap()))?; + Some(vector::Const::from_elems(ty, elems).into()) + } _ => None, } } - pub(super) fn from_canonical_const(const_kind: &ConstKind) -> Option { + pub(super) fn from_canonical_const( + cx: &Context, + const_kind: &ConstKind, + ) -> Option<(Self, SmallVec<[Const; 4]>)> { + let wk = &spec::Spec::get().well_known; let mo = MappableOps::get(); match const_kind { - ConstKind::Undef => Some(mo.OpUndef.into()), - ConstKind::Scalar(scalar::Const::FALSE) => Some(mo.OpConstantFalse.into()), - ConstKind::Scalar(scalar::Const::TRUE) => Some(mo.OpConstantTrue.into()), - ConstKind::Scalar(ct) => { - Some(spv::Inst { opcode: mo.OpConstant, imms: ct.encode_as_spv_imms().collect() }) - } + ConstKind::Undef => Some((mo.OpUndef.into(), [].into_iter().collect())), + &ConstKind::Scalar(ct) => Some(( + match ct { + scalar::Const::FALSE => mo.OpConstantFalse.into(), + scalar::Const::TRUE => mo.OpConstantTrue.into(), + _ => { + spv::Inst { opcode: mo.OpConstant, imms: ct.encode_as_spv_imms().collect() } + } + }, + [].into_iter().collect(), + )), + + ConstKind::Vector(ct) => Some(( + wk.OpConstantComposite.into(), + ct.elems().map(|elem| cx.intern(elem)).collect(), + )), ConstKind::PtrToGlobalVar(_) | ConstKind::PtrToFunc(_) diff --git a/src/spv/lift.rs b/src/spv/lift.rs index ba0955f8..3c5830f9 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -120,8 +120,22 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { return; } let ty_def = &self.cx[ty]; + + // HACK(eddyb) there isn't a great way to handle canonical types, but + // perhaps this result should be recorded in `self.globals`? + if let Some((_spv_inst, type_and_const_inputs)) = + spv::Inst::from_canonical_type(self.cx, &ty_def.kind) + { + for ty_or_ct in type_and_const_inputs { + match ty_or_ct { + TypeOrConst::Type(ty) => self.visit_type_use(ty), + TypeOrConst::Const(ct) => self.visit_const_use(ct), + } + } + } + match ty_def.kind { - TypeKind::Scalar(_) | TypeKind::SpvInst { .. } => {} + TypeKind::Scalar(_) | TypeKind::Vector(_) | TypeKind::SpvInst { .. } => {} // FIXME(eddyb) this should be a proper `Result`-based error instead, // and/or `spv::lift` should mutate the module for legalization. @@ -141,6 +155,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { ); } } + self.visit_type_def(ty_def); self.globals.insert(global); } @@ -150,6 +165,17 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { return; } let ct_def = &self.cx[ct]; + + // HACK(eddyb) there isn't a great way to handle canonical consts, but + // perhaps this result should be recorded in `self.globals`? + if let Some((_spv_inst, const_inputs)) = + spv::Inst::from_canonical_const(self.cx, &ct_def.kind) + { + for ct in const_inputs { + self.visit_const_use(ct); + } + } + match ct_def.kind { ConstKind::Undef if matches!(self.cx[ct_def.ty].kind, TypeKind::Thunk) => { // HACK(eddyb) unstructured control-flow may use `undef` thunks. @@ -157,6 +183,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { ConstKind::Undef | ConstKind::Scalar(_) + | ConstKind::Vector(_) | ConstKind::PtrToGlobalVar(_) | ConstKind::PtrToFunc(_) | ConstKind::SpvInst { .. } => { @@ -1274,8 +1301,9 @@ impl LazyInst<'_, '_> { } ConstKind::Undef - | ConstKind::PtrToFunc(_) | ConstKind::Scalar(_) + | ConstKind::Vector(_) + | ConstKind::PtrToFunc(_) | ConstKind::SpvInst { .. } => (ct_def.attrs, None), // Not inserted into `globals` while visiting. @@ -1353,19 +1381,16 @@ impl LazyInst<'_, '_> { Self::Global(global) => match global { Global::Type(ty) => { let ty_def = &cx[ty]; - match spv::Inst::from_canonical_type(&ty_def.kind).ok_or(&ty_def.kind) { - Ok(spv_inst) => spv::InstWithIds { - without_ids: spv_inst, - result_type_id: None, - result_id, - ids: [].into_iter().collect(), - }, - - Err(TypeKind::Scalar(_)) => { + match spv::Inst::from_canonical_type(cx, &ty_def.kind) + .as_ref() + .ok_or(&ty_def.kind) + { + Err(TypeKind::Scalar(_) | TypeKind::Vector(_)) => { unreachable!("should've been handled as canonical") } - Err(TypeKind::SpvInst { spv_inst, type_and_const_inputs }) => { + Ok((spv_inst, type_and_const_inputs)) + | Err(TypeKind::SpvInst { spv_inst, type_and_const_inputs }) => { spv::InstWithIds { without_ids: spv_inst.clone(), result_type_id: None, @@ -1392,15 +1417,32 @@ impl LazyInst<'_, '_> { } Global::Const(ct) => { let ct_def = &cx[ct]; - match spv::Inst::from_canonical_const(&ct_def.kind).ok_or(&ct_def.kind) { - Ok(spv_inst) => spv::InstWithIds { + match spv::Inst::from_canonical_const(cx, &ct_def.kind).ok_or(&ct_def.kind) { + // FIXME(eddyb) this duplicates the `ConstKind::SpvInst` + // case, only due to an inability to pattern-match `Rc`. + Ok((spv_inst, const_inputs)) => spv::InstWithIds { without_ids: spv_inst, result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]), result_id, - ids: [].into_iter().collect(), + ids: const_inputs + .iter() + .map(|&ct| ids.globals[&Global::Const(ct)]) + .collect(), }, + Err(ConstKind::SpvInst { spv_inst_and_const_inputs }) => { + let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; + spv::InstWithIds { + without_ids: spv_inst.clone(), + result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]), + result_id, + ids: const_inputs + .iter() + .map(|&ct| ids.globals[&Global::Const(ct)]) + .collect(), + } + } - Err(ConstKind::Undef | ConstKind::Scalar(_)) => { + Err(ConstKind::Undef | ConstKind::Scalar(_) | ConstKind::Vector(_)) => { unreachable!("should've been handled as canonical") } @@ -1444,19 +1486,6 @@ impl LazyInst<'_, '_> { ids: [ids.funcs[&func].func_id].into_iter().collect(), }, - Err(ConstKind::SpvInst { spv_inst_and_const_inputs }) => { - let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; - spv::InstWithIds { - without_ids: spv_inst.clone(), - result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]), - result_id, - ids: const_inputs - .iter() - .map(|&ct| ids.globals[&Global::Const(ct)]) - .collect(), - } - } - // Not inserted into `globals` while visiting. Err(ConstKind::SpvStringLiteralForExtInst(_)) => unreachable!(), } diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 72572f41..31c37a6e 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -631,15 +631,9 @@ impl Module { let ty = cx.intern(TypeDef { attrs: mem::take(&mut attrs), - kind: match inst.as_canonical_type() { - Some(type_kind) => { - assert_eq!(type_and_const_inputs.len(), 0); - type_kind - } - None => { - TypeKind::SpvInst { spv_inst: inst.without_ids, type_and_const_inputs } - } - }, + kind: inst.as_canonical_type(&cx, &type_and_const_inputs).unwrap_or( + TypeKind::SpvInst { spv_inst: inst.without_ids, type_and_const_inputs }, + ), }); id_defs.insert(id, IdDef::Type(ty)); @@ -700,15 +694,11 @@ impl Module { let ct = cx.intern(ConstDef { attrs: mem::take(&mut attrs), ty, - kind: match inst.as_canonical_const(&cx, ty) { - Some(const_kind) => { - assert_eq!(const_inputs.len(), 0); - const_kind - } - None => ConstKind::SpvInst { + kind: inst.as_canonical_const(&cx, ty, &const_inputs).unwrap_or_else(|| { + ConstKind::SpvInst { spv_inst_and_const_inputs: Rc::new((inst.without_ids, const_inputs)), - }, - }, + } + }), }); id_defs.insert(id, IdDef::Const(ct)); diff --git a/src/spv/spec.rs b/src/spv/spec.rs index c1e80a40..483e1b0a 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -117,7 +117,6 @@ def_well_known! { OpNoLine, OpTypeVoid, - OpTypeVector, OpTypeMatrix, OpTypeArray, OpTypeRuntimeArray, @@ -130,6 +129,7 @@ def_well_known! { OpTypeSampledImage, OpTypeAccelerationStructureKHR, + OpConstantComposite, OpConstantFunctionPointerINTEL, OpVariable, diff --git a/src/transform.rs b/src/transform.rs index 00b44a2e..ee1ba7cb 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -439,6 +439,7 @@ impl InnerTransform for TypeDef { attrs -> transformer.transform_attr_set_use(*attrs), kind -> match kind { TypeKind::Scalar(_) + | TypeKind::Vector(_) | TypeKind::QPtr | TypeKind::Thunk | TypeKind::SpvStringLiteralForExtInst => Transformed::Unchanged, @@ -476,6 +477,7 @@ impl InnerTransform for ConstDef { kind -> match kind { ConstKind::Undef | ConstKind::Scalar(_) + | ConstKind::Vector(_) | ConstKind::SpvStringLiteralForExtInst(_) => Transformed::Unchanged, ConstKind::PtrToGlobalVar(gv) => transform!({ diff --git a/src/vector.rs b/src/vector.rs new file mode 100644 index 00000000..b5b60fb2 --- /dev/null +++ b/src/vector.rs @@ -0,0 +1,123 @@ +//! Vector types (small arrays of [`scalar`](crate::scalar)s) and associated functionality. +//! +//! **Note**: these are similar to SIMD types in other IRs, but SPIR-V often uses +//! its `OpTypeVector` to represent geometrical vectors, colors, etc. without any +//! expectation of SIMD execution (which most GPU execution models use implicitly, +//! i.e. one non-uniform scalar becomes a hardware SIMD vector, while a high-level +//! "vector" of N "lanes", becomes N separate hardware SIMD vectors). + +use crate::scalar; +use smallvec::SmallVec; +use std::num::NonZeroU8; +use std::rc::Rc; + +// FIXME(eddyb) this entire module shorthands "element" as "elem", is that good? + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct Type { + pub elem: scalar::Type, + // FIXME(eddyb) maybe wrap this in a type that abstracts away the encoding? + pub elem_count: NonZeroU8, +} + +// FIXME(eddyb) document the 128-bit limitations inherited from `scalar::Const`. +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct Const(ConstRepr); + +// HACK(eddyb) `#[repr(packed)]` not allowed on `enum`s themselves. +#[repr(Rust, packed)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +struct Packed(T); + +// FIXME(eddyb) maybe build an abstraction for "N-dimensional" bit arrays? +#[derive(Clone, PartialEq, Eq, Hash)] +#[repr(u8)] +enum ConstRepr { + // HACK(eddyb) `(Type, u128)` would waste almost half its size on padding, and + // packing will only impact accessing the bits, while allowing e.g. being + // wrapped in an outer `enum`, before reaching the same size as `(u128, u128)`. + Inline(Type, Packed), + + // HACK(eddyb) this does raise the alignment, but the size and alignment are + // kept at one pointer (so likely half of `u128`) - `Packed>` is sadly + // not an option because `#[derive(...)]` + `#[repr(packed)]` often requires + // `Copy` in order to be able to safely take references (to a copy of a field). + Boxed(Type, Rc>), +} + +impl Const { + pub const fn ty(&self) -> Type { + match self.0 { + ConstRepr::Inline(ty, _) | ConstRepr::Boxed(ty, _) => ty, + } + } + + pub fn from_elems(ty: Type, elems: impl IntoIterator) -> Const { + let elem_width = ty.elem.bit_width(); + assert!(elem_width <= 128); + + let expected_elem_count = u32::from(ty.elem_count.get()); + + let num_limbs = elem_width.checked_mul(expected_elem_count).unwrap().div_ceil(128); + assert_ne!(num_limbs, 0); + let mut limbs = SmallVec::<[u128; 1]>::from_elem(0, usize::try_from(num_limbs).unwrap()); + + let mut found_elem_count = 0; + for ct in elems { + let i: u32 = found_elem_count; + found_elem_count = found_elem_count.checked_add(1).unwrap(); + if i >= expected_elem_count { + continue; + } + + // FIXME(eddyb) get better names (perhaps from miri-like memory?). + let first_bit_idx = i.checked_mul(elem_width).unwrap(); + let limb_idx = first_bit_idx / 128; + let intra_limb_first_bit_idx = first_bit_idx % 128; + assert!(intra_limb_first_bit_idx + elem_width <= 128); + + limbs[usize::try_from(limb_idx).unwrap()] |= ct.bits() << intra_limb_first_bit_idx; + } + assert_eq!(found_elem_count, expected_elem_count); + + match limbs.into_inner() { + Ok([limb]) => Const(ConstRepr::Inline(ty, Packed(limb))), + Err(limbs) => Const(ConstRepr::Boxed(ty, Rc::new(limbs.into_vec()))), + } + } + + pub fn get_elem(&self, i: usize) -> Option { + let ty = self.ty(); + if i >= usize::from(ty.elem_count.get()) { + return None; + } + let i = u32::try_from(i).unwrap(); + let elem_width = ty.elem.bit_width(); + assert!(elem_width <= 128); + + // FIXME(eddyb) get better names (perhaps from miri-like memory?). + let first_bit_idx = i.checked_mul(elem_width).unwrap(); + let limb_idx = first_bit_idx / 128; + let intra_limb_first_bit_idx = first_bit_idx % 128; + assert!(intra_limb_first_bit_idx + elem_width <= 128); + + let limb = match &self.0 { + ConstRepr::Inline(_, limb) => { + assert_eq!(limb_idx, 0); + limb.0 + } + ConstRepr::Boxed(_, limbs) => limbs[usize::try_from(limb_idx).unwrap()], + }; + + Some(scalar::Const::from_bits( + ty.elem, + (limb >> intra_limb_first_bit_idx) & (!0 >> (128 - elem_width)), + )) + } + + pub fn elems(&self) -> impl Iterator + '_ { + let ty = self.ty(); + // FIXME(eddyb) there should be a more efficient way to do this. + (0..usize::from(ty.elem_count.get())).map(|i| self.get_elem(i).unwrap()) + } +} diff --git a/src/visit.rs b/src/visit.rs index 24e4a666..059a6ce8 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -322,6 +322,7 @@ impl InnerVisit for TypeDef { visitor.visit_attr_set_use(*attrs); match kind { TypeKind::Scalar(_) + | TypeKind::Vector(_) | TypeKind::QPtr | TypeKind::Thunk | TypeKind::SpvStringLiteralForExtInst => {} @@ -345,7 +346,10 @@ impl InnerVisit for ConstDef { visitor.visit_attr_set_use(*attrs); visitor.visit_type_use(*ty); match kind { - ConstKind::Undef | ConstKind::Scalar(_) | ConstKind::SpvStringLiteralForExtInst(_) => {} + ConstKind::Undef + | ConstKind::Scalar(_) + | ConstKind::Vector(_) + | ConstKind::SpvStringLiteralForExtInst(_) => {} &ConstKind::PtrToGlobalVar(gv) => visitor.visit_global_var_use(gv), &ConstKind::PtrToFunc(func) => visitor.visit_func_use(func), From 9d78488b9d08d3389abf22d50197f58e00faa695 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Tue, 9 Sep 2025 11:42:12 +0300 Subject: [PATCH 2/2] Add `NodeKind::Vector` for pure vector ops. --- src/lib.rs | 6 ++ src/mem/analyze.rs | 3 +- src/print/mod.rs | 64 ++++++++++++++++--- src/qptr/lift.rs | 2 +- src/qptr/lower.rs | 1 + src/spv/canonical.rs | 146 ++++++++++++++++++++++++++++++++++++++++--- src/spv/lift.rs | 7 ++- src/spv/lower.rs | 32 +++++++--- src/spv/spec.rs | 6 ++ src/transform.rs | 1 + src/vector.rs | 57 +++++++++++++++++ src/visit.rs | 1 + 12 files changed, 297 insertions(+), 29 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index bfbc465d..beecdc7d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1000,6 +1000,12 @@ pub enum NodeKind { #[from] Scalar(scalar::Op), + /// Vector (small array of [`scalar`]s) pure operations. + /// + /// See also the [`vector`] module for more documentation and definitions. + #[from] + Vector(vector::Op), + FuncCall(Func), /// Memory-specific operations (see [`mem::MemOp`]). diff --git a/src/mem/analyze.rs b/src/mem/analyze.rs index e1551fbf..d9390e34 100644 --- a/src/mem/analyze.rs +++ b/src/mem/analyze.rs @@ -932,6 +932,7 @@ impl<'a> GatherAccesses<'a> { } DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) @@ -951,7 +952,7 @@ impl<'a> GatherAccesses<'a> { unreachable!() } - DataInstKind::Scalar(_) => {} + DataInstKind::Scalar(_) | DataInstKind::Vector(_) => {} &DataInstKind::FuncCall(callee) => { match self.gather_accesses_in_func(module, callee) { diff --git a/src/print/mod.rs b/src/print/mod.rs index 5ef5ea4a..dafb022f 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -31,7 +31,7 @@ use crate::{ EntityOrientedDenseMap, ExportKey, Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDecl, GlobalVarDefBody, Import, InternedStr, Module, ModuleDebugInfo, ModuleDialect, Node, NodeDef, NodeKind, OrdAssertEq, Region, - RegionDef, Type, TypeDef, TypeKind, TypeOrConst, Value, Var, VarDecl, scalar, spv, + RegionDef, Type, TypeDef, TypeKind, TypeOrConst, Value, Var, VarDecl, scalar, spv, vector, }; use arrayvec::ArrayVec; use itertools::Either; @@ -3096,7 +3096,7 @@ impl Print for ConstDef { let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); - // FIXME(eddyb) should this just a method on `scalar::Const` instead? + // FIXME(eddyb) should this be a method on `scalar::Const` instead? let print_scalar = |ct: scalar::Const, include_type_suffix: bool| match ct { scalar::Const::FALSE => kw("false"), scalar::Const::TRUE => kw("true"), @@ -3746,6 +3746,7 @@ impl Print for FuncAt<'_, Node> { ), DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) @@ -3788,21 +3789,66 @@ impl FuncAt<'_, DataInst> { let mut output_type_to_print = output_type; + // FIXME(eddyb) should this be a method on `scalar::Op` instead? + let print_scalar = |op: scalar::Op| { + let name = op.name(); + let (namespace_prefix, name) = name.split_at(name.find('.').unwrap() + 1); + pretty::Fragment::new([ + printer + .demote_style_for_namespace_prefix(printer.declarative_keyword_style()) + .apply(namespace_prefix), + printer.declarative_keyword_style().apply(name), + ]) + }; + let def_without_type = match kind { NodeKind::Select(_) | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) => { unreachable!() } - &DataInstKind::Scalar(op) => { - let name = op.name(); + &DataInstKind::Scalar(op) => pretty::Fragment::new([ + print_scalar(op), + pretty::join_comma_sep("(", inputs.iter().map(|v| v.print(printer)), ")"), + ]), + + &DataInstKind::Vector(op) => { + let (name, extra_last_input) = match op { + vector::Op::Distribute(_) => ("vec.distribute", None), + vector::Op::Reduce(op) => (op.name(), None), + vector::Op::Whole(op) => ( + op.name(), + match op { + vector::WholeOp::Extract { elem_idx } + | vector::WholeOp::Insert { elem_idx } => Some( + printer.numeric_literal_style().apply(elem_idx.to_string()).into(), + ), + vector::WholeOp::New + | vector::WholeOp::DynExtract + | vector::WholeOp::DynInsert + | vector::WholeOp::Mul => None, + }, + ), + }; let (namespace_prefix, name) = name.split_at(name.find('.').unwrap() + 1); - pretty::Fragment::new([ + let mut pretty_name = pretty::Fragment::new([ printer .demote_style_for_namespace_prefix(printer.declarative_keyword_style()) - .apply(namespace_prefix) - .into(), - printer.declarative_keyword_style().apply(name).into(), - pretty::join_comma_sep("(", inputs.iter().map(|v| v.print(printer)), ")"), + .apply(namespace_prefix), + printer.declarative_keyword_style().apply(name), + ]); + if let vector::Op::Distribute(op) = op { + pretty_name = pretty::Fragment::new([ + pretty_name, + pretty::join_comma_sep("(", [print_scalar(op)], ")"), + ]); + } + pretty::Fragment::new([ + pretty_name, + pretty::join_comma_sep( + "(", + inputs.iter().map(|v| v.print(printer)).chain(extra_last_input), + ")", + ), ]) } diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index 804bfe0a..68381b6c 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -411,7 +411,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { return Ok(Transformed::Unchanged); } - DataInstKind::Scalar(_) => return Ok(Transformed::Unchanged), + DataInstKind::Scalar(_) | DataInstKind::Vector(_) => return Ok(Transformed::Unchanged), &DataInstKind::FuncCall(_callee) => { for &v in &data_inst_def.inputs { diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index f8a2e1b4..4eb97d2d 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -628,6 +628,7 @@ impl LowerFromSpvPtrInstsInFunc<'_> { | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) | DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs index 45b1fe6f..079bb15e 100644 --- a/src/spv/canonical.rs +++ b/src/spv/canonical.rs @@ -9,6 +9,7 @@ use crate::spv::{self, spec}; use crate::{Const, ConstKind, Context, NodeKind, Type, TypeKind, TypeOrConst, scalar, vector}; +use itertools::Itertools; use lazy_static::lazy_static; use smallvec::SmallVec; @@ -17,12 +18,14 @@ macro_rules! def_mappable_ops { ( type { $($ty_op:ident),+ $(,)? } const { $($ct_op:ident),+ $(,)? } + node { $($di_op:ident),+ $(,)? } $($enum_path:path { $($variant_op:ident <=> $variant:ident$(($($variant_args:tt)*))?),+ $(,)? })* ) => { #[allow(non_snake_case)] struct MappableOps { $($ty_op: spec::Opcode,)+ $($ct_op: spec::Opcode,)+ + $($di_op: spec::Opcode,)+ $($($variant_op: spec::Opcode,)+)* } impl MappableOps { @@ -35,6 +38,7 @@ macro_rules! def_mappable_ops { MappableOps { $($ty_op: spv_spec.instructions.lookup(stringify!($ty_op)).unwrap(),)+ $($ct_op: spv_spec.instructions.lookup(stringify!($ct_op)).unwrap(),)+ + $($di_op: spv_spec.instructions.lookup(stringify!($di_op)).unwrap(),)+ $($($variant_op: spv_spec.instructions.lookup(stringify!($variant_op)).unwrap(),)+)* } }; @@ -74,6 +78,11 @@ def_mappable_ops! { OpConstantTrue, OpConstant, } + node { + OpVectorExtractDynamic, + OpVectorInsertDynamic, + OpVectorTimesScalar, + } scalar::BoolUnOp { OpLogicalNot <=> Not, } @@ -164,6 +173,11 @@ def_mappable_ops! { OpFUnordLessThanEqual <=> CmpOrUnord(scalar::FloatCmp::Le), OpFUnordGreaterThanEqual <=> CmpOrUnord(scalar::FloatCmp::Ge), } + vector::ReduceOp { + OpDot <=> Dot, + OpAny <=> Any, + OpAll <=> All, + } } impl scalar::Const { @@ -410,7 +424,12 @@ impl spv::Inst { } // HACK(eddyb) exported to facilitate `OpSpecConstantOp` handling elsewhere. - pub fn as_canonical_node_kind(&self, cx: &Context, output_types: &[Type]) -> Option { + pub fn as_canonical_node_kind( + &self, + cx: &Context, + output_types: impl ExactSizeIterator, + input_types: impl ExactSizeIterator, + ) -> Option { let Self { opcode, imms } = self; let (&opcode, imms) = (opcode, &imms[..]); @@ -423,16 +442,84 @@ impl spv::Inst { if let Some(op) = scalar_op { assert_eq!(imms.len(), 0); - // FIXME(eddyb) support vector versions of these ops as well. - if output_types.len() == op.output_count() - && output_types.iter().all(|ty| ty.as_scalar(cx).is_some()) - { - Some(op.into()) + let (_scalar_type, vec_elem_count) = (output_types.len() == op.output_count()) + .then(|| { + output_types.map(|ty| match cx[ty].kind { + TypeKind::Scalar(ty) => Some((ty, None)), + TypeKind::Vector(ty) => Some((ty.elem, Some(ty.elem_count))), + _ => None, + }) + }) + .and_then(|outputs| outputs.dedup().exactly_one().ok()?)?; + + Some(if vec_elem_count.is_some() { + vector::Op::Distribute(op).into() } else { - None - } + op.into() + }) + } else if let Some(op) = vector::ReduceOp::try_from_opcode(opcode).map(vector::Op::from) { + assert_eq!(imms.len(), 0); + Some(op.into()) } else { - None + let wk = &spec::Spec::get().well_known; + let mo = MappableOps::get(); + + // FIXME(eddyb) automate this by supporting immediates in the macro. + let v_whole = |op| Some(vector::Op::Whole(op).into()); + match imms { + // FIXME(eddyb) should these kind of checks be done here? + // (if so, other ops above don't check anywhere near as much) + [] if opcode == wk.OpCompositeConstruct => { + output_types.exactly_one().ok()?.as_vector(cx).filter(|vec_ty| { + input_types.len() == usize::from(vec_ty.elem_count.get()) + && input_types + .map(|ty| ty.as_scalar(cx)) + .dedup() + .exactly_one() + .ok() + .flatten() + == Some(vec_ty.elem) + })?; + v_whole(vector::WholeOp::New) + } + &[spv::Imm::Short(_, elem_idx)] if opcode == wk.OpCompositeExtract => { + let (vec_ty, _extracted_ty) = input_types + .exactly_one() + .ok() + .and_then(|vec_ty| { + Some(( + vec_ty.as_vector(cx)?, + output_types.exactly_one().ok()?.as_scalar(cx)?, + )) + }) + .filter(|&(vec_ty, extracted_ty)| vec_ty.elem == extracted_ty)?; + v_whole(vector::WholeOp::Extract { + elem_idx: elem_idx + .try_into() + .ok() + .filter(|&elem_idx| elem_idx < vec_ty.elem_count.get())?, + }) + } + &[spv::Imm::Short(_, elem_idx)] if opcode == wk.OpCompositeInsert => { + let (vec_ty, _inserted_ty) = input_types + .collect_tuple::<(_, _)>() + .filter(|&(vec_ty, _)| Some(vec_ty) == output_types.exactly_one().ok()) + .and_then(|(vec_ty, inserted_ty)| { + Some((vec_ty.as_vector(cx)?, inserted_ty.as_scalar(cx)?)) + }) + .filter(|&(vec_ty, inserted_ty)| vec_ty.elem == inserted_ty)?; + v_whole(vector::WholeOp::Insert { + elem_idx: elem_idx + .try_into() + .ok() + .filter(|&elem_idx| elem_idx < vec_ty.elem_count.get())?, + }) + } + [] if opcode == mo.OpVectorExtractDynamic => v_whole(vector::WholeOp::DynExtract), + [] if opcode == mo.OpVectorInsertDynamic => v_whole(vector::WholeOp::DynInsert), + [] if opcode == mo.OpVectorTimesScalar => v_whole(vector::WholeOp::Mul), + _ => None, + } } } @@ -446,7 +533,46 @@ impl spv::Inst { scalar::Op::FloatUnary(op) => op.to_opcode().into(), scalar::Op::FloatBinary(op) => op.to_opcode().into(), }), - _ => None, + &NodeKind::Vector(op) => Some(match op { + vector::Op::Distribute(op) => { + Self::from_canonical_node_kind(&NodeKind::Scalar(op)).unwrap() + } + vector::Op::Reduce(op) => op.to_opcode().into(), + vector::Op::Whole(op) => { + let wk = &spec::Spec::get().well_known; + let mo = MappableOps::get(); + + // FIXME(eddyb) automate this by supporting immediates in the macro. + match op { + vector::WholeOp::New => wk.OpCompositeConstruct.into(), + vector::WholeOp::Extract { elem_idx } => spv::Inst { + opcode: wk.OpCompositeExtract, + imms: [spv::Imm::Short(wk.LiteralInteger, elem_idx.into())] + .into_iter() + .collect(), + }, + vector::WholeOp::Insert { elem_idx } => spv::Inst { + opcode: wk.OpCompositeInsert, + imms: [spv::Imm::Short(wk.LiteralInteger, elem_idx.into())] + .into_iter() + .collect(), + }, + vector::WholeOp::DynExtract => mo.OpVectorExtractDynamic.into(), + vector::WholeOp::DynInsert => mo.OpVectorInsertDynamic.into(), + vector::WholeOp::Mul => mo.OpVectorTimesScalar.into(), + } + } + }), + + NodeKind::Select(_) + | NodeKind::Loop { .. } + | NodeKind::ExitInvocation(_) + | NodeKind::FuncCall(_) + | NodeKind::Mem(_) + | NodeKind::QPtr(_) + | NodeKind::ThunkBind(_) + | NodeKind::SpvInst(..) + | NodeKind::SpvExtInst { .. } => None, } } } diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 3c5830f9..f94dced6 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -257,6 +257,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { | NodeKind::Loop { .. } | NodeKind::ExitInvocation(_) | DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::FuncCall(_) | DataInstKind::ThunkBind(_) | DataInstKind::SpvInst(_) => {} @@ -520,6 +521,7 @@ impl<'p> FuncAt<'_, CfgCursor<'p>> { | NodeKind::ExitInvocation { .. } => None, DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) @@ -734,6 +736,7 @@ impl<'a> FuncLifting<'a> { } DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) @@ -905,6 +908,7 @@ impl<'a> FuncLifting<'a> { }, DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) @@ -982,6 +986,7 @@ impl<'a> FuncLifting<'a> { NodeKind::ExitInvocation { .. } | DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::FuncCall(_) | DataInstKind::Mem(_) | DataInstKind::QPtr(_) @@ -1554,7 +1559,7 @@ impl LazyInst<'_, '_> { | NodeKind::ExitInvocation(_), ) => unreachable!(), - Err(DataInstKind::Scalar(_)) => { + Err(DataInstKind::Scalar(_) | DataInstKind::Vector(_)) => { unreachable!("should've been handled as canonical") } diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 31c37a6e..0a42db1b 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -1809,13 +1809,7 @@ impl Module { // some "structured regions" replacement for the CFG. } else { let mut ids = &ids[..]; - let kind = if let Some(kind) = raw_inst.without_ids.as_canonical_node_kind( - &cx, - result_type.map(|ty| [ty]).as_ref().map_or(&[][..], |tys| &tys[..]), - ) { - // FIXME(eddyb) sanity-check the number/types of inputs. - kind - } else if opcode == wk.OpFunctionCall { + let kind = if opcode == wk.OpFunctionCall { assert!(imms.is_empty()); let callee_id = ids[0]; let maybe_callee = id_defs @@ -1920,6 +1914,30 @@ impl Module { } current_block_region_def.children.insert_last(inst, &mut func_def_body.nodes); + + // HACK(eddyb) doing this after defining the maybe-uncanonical + // node, just to keep the iterators simpler. + let node_def = &mut func_def_body.nodes[inst]; + if let DataInstKind::SpvInst(spv_inst) = &node_def.kind + && let Some(canonical_kind) = spv_inst.as_canonical_node_kind( + &cx, + node_def + .outputs + .iter() + .map(|&output_var| func_def_body.vars[output_var].ty), + node_def.inputs.iter().map(|&v| { + // HACK(eddyb) `func_def_body.at(v).type_of(cx)` + // equivalent, without running into borrow issues. + match v { + Value::Const(ct) => cx[ct].ty, + Value::Var(var) => func_def_body.vars[var].ty, + } + }), + ) + { + // FIXME(eddyb) sanity-check the number/types of inputs. + node_def.kind = canonical_kind; + } } } diff --git a/src/spv/spec.rs b/src/spv/spec.rs index 483e1b0a..ef7a5d7b 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -129,6 +129,7 @@ def_well_known! { OpTypeSampledImage, OpTypeAccelerationStructureKHR, + // FIXME(eddyb) hide these from code, lowering should handle most cases. OpConstantComposite, OpConstantFunctionPointerINTEL, @@ -160,6 +161,11 @@ def_well_known! { OpPtrAccessChain, OpInBoundsPtrAccessChain, OpBitcast, + + // FIXME(eddyb) hide these from code, lowering should handle most cases. + OpCompositeInsert, + OpCompositeExtract, + OpCompositeConstruct, ], operand_kind: OperandKind = [ Capability, diff --git a/src/transform.rs b/src/transform.rs index ee1ba7cb..d6ed049f 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -637,6 +637,7 @@ impl InnerInPlaceTransform for FuncAtMut<'_, Node> { | NodeKind::Loop { repeat_condition: _ } | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) | DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::Mem(MemOp::FuncLocalVar(_) | MemOp::Load | MemOp::Store) | DataInstKind::QPtr( QPtrOp::HandleArrayIndex diff --git a/src/vector.rs b/src/vector.rs index b5b60fb2..3158dad8 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -121,3 +121,60 @@ impl Const { (0..usize::from(ty.elem_count.get())).map(|i| self.get_elem(i).unwrap()) } } + +/// Pure operations with vector inputs and/or outputs. +#[derive(Copy, Clone, PartialEq, Eq, Hash, derive_more::From)] +pub enum Op { + Distribute(scalar::Op), + Reduce(ReduceOp), + + // FIXME(eddyb) find a better name for this category of ops. + Whole(WholeOp), +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum ReduceOp { + // FIXME(eddyb) also support all the new integer dot product instructions. + Dot, + // FIXME(eddyb) model these using their respective `BoolBinOp`s? + Any, + All, +} + +impl ReduceOp { + pub fn name(self) -> &'static str { + match self { + ReduceOp::Dot => "vec.dot", + ReduceOp::Any => "vec.any", + ReduceOp::All => "vec.all", + } + } +} + +// FIXME(eddyb) find a better name for this category of ops. +// FIXME(eddyb) also support `OpVectorShuffle`. +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum WholeOp { + // FIXME(eddyb) better name for this (pack? make? "construct" is too long). + New, + Extract { elem_idx: u8 }, + Insert { elem_idx: u8 }, + DynExtract, + DynInsert, + + // FIXME(eddyb) may need a better name to indicate "scalar product". + Mul, +} + +impl WholeOp { + pub fn name(self) -> &'static str { + match self { + WholeOp::New => "vec.new", + WholeOp::Extract { .. } => "vec.extract", + WholeOp::Insert { .. } => "vec.insert", + WholeOp::DynExtract => "vec.dyn_extract", + WholeOp::DynInsert => "vec.dyn_insert", + WholeOp::Mul => "vec.mul", + } + } +} diff --git a/src/visit.rs b/src/visit.rs index 059a6ce8..2dbabd71 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -481,6 +481,7 @@ impl<'a> FuncAt<'a, Node> { | NodeKind::Loop { repeat_condition: _ } | NodeKind::ExitInvocation(cf::ExitInvocationKind::SpvInst(_)) | DataInstKind::Scalar(_) + | DataInstKind::Vector(_) | DataInstKind::Mem(MemOp::FuncLocalVar(_) | MemOp::Load | MemOp::Store) | DataInstKind::QPtr( QPtrOp::HandleArrayIndex