Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 184 additions & 26 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,32 @@ impl<W: Write> Writer<W> {
Ok(check_written)
}

fn is_root_workgroup_pointer(
&self,
chain: Handle<crate::Expression>,
context: &ExpressionContext,
) -> bool {
// Check if the direct expression (without following access chains) is a workgroup pointer
match context.function.expressions[chain] {
crate::Expression::GlobalVariable(handle) => {
let var = &context.module.global_variables[handle];
var.space == crate::AddressSpace::WorkGroup
}
crate::Expression::FunctionArgument(index) => {
let arg = &context.function.arguments[index as usize];
let type_inner = &context.module.types[arg.ty].inner;
matches!(
type_inner,
crate::TypeInner::Pointer {
space: crate::AddressSpace::WorkGroup,
..
}
)
}
_ => false,
}
}

/// Write the access chain `chain`.
///
/// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
Expand Down Expand Up @@ -3230,13 +3256,31 @@ impl<W: Write> Writer<W> {
// indexing a struct with an expression.
match *base_ty {
crate::TypeInner::Struct { .. } => {
let is_workgroup = {
let base_ty = context.resolve_type(base);
matches!(
base_ty,
crate::TypeInner::Pointer {
space: crate::AddressSpace::WorkGroup,
..
}
)
};
let op = if is_workgroup { "->" } else { "." };
let base_ty = base_ty_handle.unwrap();
self.put_access_chain(base, policy, context)?;
let name = &self.names[&NameKey::StructMember(base_ty, index)];
write!(self.out, ".{name}")?;
write!(self.out, "{op}{name}")?;
}
crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
let is_workgroup_ptr = self.is_root_workgroup_pointer(base, context);
if is_workgroup_ptr {
write!(self.out, "(*")?;
}
self.put_access_chain(base, policy, context)?;
if is_workgroup_ptr {
write!(self.out, ")")?;
}
// Prior to Metal v2.1 component access for packed vectors wasn't available
// however array indexing is
if context.get_packed_vec_kind(base).is_some() {
Expand Down Expand Up @@ -3296,9 +3340,35 @@ impl<W: Write> Writer<W> {
let accessing_wrapped_binding_array =
matches!(*base_ty, crate::TypeInner::BindingArray { .. });

let is_workgroup = {
let base_ty = context.resolve_type(base);
matches!(
base_ty,
crate::TypeInner::Pointer {
space: crate::AddressSpace::WorkGroup,
..
}
)
};

if is_workgroup && !accessing_wrapped_array {
write!(self.out, "(*")?;
}
self.put_access_chain(base, policy, context)?;
if is_workgroup && !accessing_wrapped_array {
write!(self.out, ")")?;
}
if accessing_wrapped_array {
write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
let is_direct_workgroup = if let crate::Expression::GlobalVariable(handle) =
&context.function.expressions[base]
{
let var = &context.module.global_variables[*handle];
var.space == crate::AddressSpace::WorkGroup
} else {
false
};
let op = if is_direct_workgroup { "->" } else { "." };
write!(self.out, "{op}{WRAPPED_ARRAY_FIELD}")?;
}
write!(self.out, "[")?;

Expand Down Expand Up @@ -3379,16 +3449,21 @@ impl<W: Write> Writer<W> {
.is_atomic_pointer(&context.module.types);

if is_atomic_pointer {
write!(
self.out,
"{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
)?;
write!(self.out, "{NAMESPACE}::atomic_load_explicit(")?;
let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, context);
if !is_workgroup_ptr {
write!(self.out, "{ATOMIC_REFERENCE}")?;
}
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
} else {
// We don't do any dereferencing with `*` here as pointer arguments to functions
// are done by `&` references and not `*` pointers. These do not need to be
// dereferenced.
// dereferenced, except for workgroups pointers.
let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, context);
if is_workgroup_ptr {
write!(self.out, "*")?;
}
self.put_access_chain(pointer, policy, context)?;
}

Expand Down Expand Up @@ -4035,9 +4110,13 @@ impl<W: Write> Writer<W> {
}

// Put the atomic function invocation.
let is_workgroup_atomic = self.is_root_workgroup_pointer(pointer, context);
match *fun {
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}(")?;
if !is_workgroup_atomic {
write!(self.out, "{ATOMIC_REFERENCE}")?;
}
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(cmp, context, true)?;
Expand All @@ -4046,10 +4125,10 @@ impl<W: Write> Writer<W> {
write!(self.out, ")")?;
}
_ => {
write!(
self.out,
"{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
)?;
write!(self.out, "{NAMESPACE}::atomic_{fun_key}_explicit(")?;
if !is_workgroup_atomic {
write!(self.out, "{ATOMIC_REFERENCE}")?;
}
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
Expand Down Expand Up @@ -4423,16 +4502,21 @@ impl<W: Write> Writer<W> {
.is_atomic_pointer(&context.expression.module.types);

if is_atomic_pointer {
write!(
self.out,
"{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
)?;
write!(self.out, "{level}{NAMESPACE}::atomic_store_explicit(")?;
let is_workgroup_atomic = self.is_root_workgroup_pointer(pointer, &context.expression);
if !is_workgroup_atomic {
write!(self.out, "{ATOMIC_REFERENCE}")?;
}
self.put_access_chain(pointer, policy, &context.expression)?;
write!(self.out, ", ")?;
self.put_expression(value, &context.expression, true)?;
writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
} else {
write!(self.out, "{level}")?;
let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, &context.expression);
if is_workgroup_ptr {
write!(self.out, "*")?;
}
self.put_access_chain(pointer, policy, &context.expression)?;
write!(self.out, " = ")?;
self.put_expression(value, &context.expression, true)?;
Expand Down Expand Up @@ -8234,19 +8318,54 @@ mod workgroup_mem_init {
}

impl Access {
fn is_pointer_type(&self, module: &crate::Module) -> bool {
match *self {
Access::GlobalVariable(handle) => {
let var = &module.global_variables[handle];
// workgroup variables are passed as pointers to the kernel
var.space == crate::AddressSpace::WorkGroup
|| matches!(module.types[var.ty].inner, crate::TypeInner::Pointer { .. })
}
Access::StructMember(struct_handle, member_index) => {
// check if the member's type is a pointer
if let crate::TypeInner::Struct { ref members, .. } =
module.types[struct_handle].inner
{
if let Some(member) = members.get(member_index as usize) {
matches!(
module.types[member.ty].inner,
crate::TypeInner::Pointer { .. }
)
} else {
false
}
} else {
false
}
}
Access::Array(_) => false,
}
}

fn write<W: Write>(
&self,
writer: &mut W,
names: &FastHashMap<NameKey, String>,
op: &str,
) -> Result<(), core::fmt::Error> {
match *self {
Access::GlobalVariable(handle) => {
write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
}
Access::StructMember(handle, index) => {
write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
write!(
writer,
"{}{}",
op,
&names[&NameKey::StructMember(handle, index)]
)
}
Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
Access::Array(depth) => write!(writer, "{}{}[__i{depth}]", op, WRAPPED_ARRAY_FIELD),
}
}
}
Expand Down Expand Up @@ -8285,12 +8404,36 @@ mod workgroup_mem_init {
&self,
writer: &mut W,
names: &FastHashMap<NameKey, String>,
module: &crate::Module,
) -> Result<(), core::fmt::Error> {
for next in self.stack.iter() {
next.write(writer, names)?;
for (i, next) in self.stack.iter().enumerate() {
let op = if i == 0 {
// root item doesn't get an operator prefix
""
} else {
// check if the previous item is a pointer to determine the operator for this item
let prev = &self.stack[i - 1];
if prev.is_pointer_type(module) {
"->"
} else {
"."
}
};
next.write(writer, names, op)?;
}
Ok(())
}

fn root_is_workgroup_pointer(&self, module: &crate::Module) -> bool {
if let Some(first) = self.stack.first() {
if let Access::GlobalVariable(handle) = first {
let var = &module.global_variables[*handle];
// Workgroup variables are passed as pointers to the kernel, regardless of their type
return var.space == crate::AddressSpace::WorkGroup;
}
}
false
}
}

impl<W: Write> Writer<W> {
Expand Down Expand Up @@ -8365,16 +8508,31 @@ mod workgroup_mem_init {
) -> BackendResult {
if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
write!(self.out, "{level}")?;
access_stack.write(&mut self.out, &self.names)?;
// for workgroup pointers at the root level, only add * for scalar/vector/matrix types
// structs/arrays use -> operator which handles pointer dereference
// If we're nested in a struct, don't add * - the -> operator handles it
let is_root_workgroup = access_stack.root_is_workgroup_pointer(module);
let is_nested = access_stack.stack.len() > 1;
let skip_deref = is_nested
|| match &module.types[ty].inner {
crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => true,
_ => false,
};
if is_root_workgroup && !skip_deref {
write!(self.out, "*")?;
}
access_stack.write(&mut self.out, &self.names, module)?;
writeln!(self.out, " = {{}};")?;
} else {
match module.types[ty].inner {
crate::TypeInner::Atomic { .. } => {
write!(
self.out,
"{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
)?;
access_stack.write(&mut self.out, &self.names)?;
write!(self.out, "{level}{NAMESPACE}::atomic_store_explicit(")?;
// only skip & for direct access to workgroup atomic
let is_nested = access_stack.stack.len() > 1;
if !access_stack.root_is_workgroup_pointer(module) || is_nested {
write!(self.out, "{ATOMIC_REFERENCE}")?;
}
access_stack.write(&mut self.out, &self.names, module)?;
writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
}
crate::TypeInner::Array { base, size, .. } => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ struct compute1_Input {
kernel void compute1_(
metal::uint3 local_invocation_id [[thread_position_in_threadgroup]]
, uint local_invocation_index [[thread_index_in_threadgroup]]
, threadgroup uint& wg_var
, threadgroup uint* wg_var
) {
if (local_invocation_index == 0u) {
wg_var = {};
}
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
const Input input = { local_invocation_id, local_invocation_index };
wg_var = input.local_invocation_index * 2u;
uint _e6 = wg_var;
wg_var = _e6 + input.local_invocation_id[0];
*wg_var = input.local_invocation_index * 2u;
uint _e6 = *wg_var;
*wg_var = _e6 + input.local_invocation_id[0];
return;
}
4 changes: 2 additions & 2 deletions naga/tests/out/msl/wgsl-abstract-types-operators.metal
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ void wgpu_4445_(
}

void wgpu_4435_(
threadgroup type_3& a
threadgroup type_3* a
) {
uint y = a.inner[as_type<int>(as_type<uint>(1) - as_type<uint>(1))];
uint y = a->inner[as_type<int>(as_type<uint>(1) - as_type<uint>(1))];
return;
}

Expand Down
Loading
Loading