diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index c2931f7af41..603d726fb77 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3172,6 +3172,32 @@ impl Writer { Ok(check_written) } + fn is_root_workgroup_pointer( + &self, + chain: Handle, + context: &ExpressionContext, + ) -> bool { + // Check if the direct expression (without following access chains) is a workgroup pointer + match context.function.expressions[chain] { + crate::Expression::GlobalVariable(handle) => { + let var = &context.module.global_variables[handle]; + var.space == crate::AddressSpace::WorkGroup + } + crate::Expression::FunctionArgument(index) => { + let arg = &context.function.arguments[index as usize]; + let type_inner = &context.module.types[arg.ty].inner; + matches!( + type_inner, + crate::TypeInner::Pointer { + space: crate::AddressSpace::WorkGroup, + .. + } + ) + } + _ => false, + } + } + /// Write the access chain `chain`. /// /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions, @@ -3230,13 +3256,31 @@ impl Writer { // indexing a struct with an expression. match *base_ty { crate::TypeInner::Struct { .. } => { + let is_workgroup = { + let base_ty = context.resolve_type(base); + matches!( + base_ty, + crate::TypeInner::Pointer { + space: crate::AddressSpace::WorkGroup, + .. + } + ) + }; + let op = if is_workgroup { "->" } else { "." }; let base_ty = base_ty_handle.unwrap(); self.put_access_chain(base, policy, context)?; let name = &self.names[&NameKey::StructMember(base_ty, index)]; - write!(self.out, ".{name}")?; + write!(self.out, "{op}{name}")?; } crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => { + let is_workgroup_ptr = self.is_root_workgroup_pointer(base, context); + if is_workgroup_ptr { + write!(self.out, "(*")?; + } self.put_access_chain(base, policy, context)?; + if is_workgroup_ptr { + write!(self.out, ")")?; + } // Prior to Metal v2.1 component access for packed vectors wasn't available // however array indexing is if context.get_packed_vec_kind(base).is_some() { @@ -3296,9 +3340,35 @@ impl Writer { let accessing_wrapped_binding_array = matches!(*base_ty, crate::TypeInner::BindingArray { .. }); + let is_workgroup = { + let base_ty = context.resolve_type(base); + matches!( + base_ty, + crate::TypeInner::Pointer { + space: crate::AddressSpace::WorkGroup, + .. + } + ) + }; + + if is_workgroup && !accessing_wrapped_array { + write!(self.out, "(*")?; + } self.put_access_chain(base, policy, context)?; + if is_workgroup && !accessing_wrapped_array { + write!(self.out, ")")?; + } if accessing_wrapped_array { - write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?; + let is_direct_workgroup = if let crate::Expression::GlobalVariable(handle) = + &context.function.expressions[base] + { + let var = &context.module.global_variables[*handle]; + var.space == crate::AddressSpace::WorkGroup + } else { + false + }; + let op = if is_direct_workgroup { "->" } else { "." }; + write!(self.out, "{op}{WRAPPED_ARRAY_FIELD}")?; } write!(self.out, "[")?; @@ -3379,16 +3449,21 @@ impl Writer { .is_atomic_pointer(&context.module.types); if is_atomic_pointer { - write!( - self.out, - "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}" - )?; + write!(self.out, "{NAMESPACE}::atomic_load_explicit(")?; + let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, context); + if !is_workgroup_ptr { + write!(self.out, "{ATOMIC_REFERENCE}")?; + } self.put_access_chain(pointer, policy, context)?; write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; } else { // We don't do any dereferencing with `*` here as pointer arguments to functions // are done by `&` references and not `*` pointers. These do not need to be - // dereferenced. + // dereferenced, except for workgroups pointers. + let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, context); + if is_workgroup_ptr { + write!(self.out, "*")?; + } self.put_access_chain(pointer, policy, context)?; } @@ -4035,9 +4110,13 @@ impl Writer { } // Put the atomic function invocation. + let is_workgroup_atomic = self.is_root_workgroup_pointer(pointer, context); match *fun { crate::AtomicFunction::Exchange { compare: Some(cmp) } => { - write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?; + write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}(")?; + if !is_workgroup_atomic { + write!(self.out, "{ATOMIC_REFERENCE}")?; + } self.put_access_chain(pointer, policy, context)?; write!(self.out, ", ")?; self.put_expression(cmp, context, true)?; @@ -4046,10 +4125,10 @@ impl Writer { write!(self.out, ")")?; } _ => { - write!( - self.out, - "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}" - )?; + write!(self.out, "{NAMESPACE}::atomic_{fun_key}_explicit(")?; + if !is_workgroup_atomic { + write!(self.out, "{ATOMIC_REFERENCE}")?; + } self.put_access_chain(pointer, policy, context)?; write!(self.out, ", ")?; self.put_expression(value, context, true)?; @@ -4423,16 +4502,21 @@ impl Writer { .is_atomic_pointer(&context.expression.module.types); if is_atomic_pointer { - write!( - self.out, - "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" - )?; + write!(self.out, "{level}{NAMESPACE}::atomic_store_explicit(")?; + let is_workgroup_atomic = self.is_root_workgroup_pointer(pointer, &context.expression); + if !is_workgroup_atomic { + write!(self.out, "{ATOMIC_REFERENCE}")?; + } self.put_access_chain(pointer, policy, &context.expression)?; write!(self.out, ", ")?; self.put_expression(value, &context.expression, true)?; writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?; } else { write!(self.out, "{level}")?; + let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, &context.expression); + if is_workgroup_ptr { + write!(self.out, "*")?; + } self.put_access_chain(pointer, policy, &context.expression)?; write!(self.out, " = ")?; self.put_expression(value, &context.expression, true)?; @@ -8234,19 +8318,54 @@ mod workgroup_mem_init { } impl Access { + fn is_pointer_type(&self, module: &crate::Module) -> bool { + match *self { + Access::GlobalVariable(handle) => { + let var = &module.global_variables[handle]; + // workgroup variables are passed as pointers to the kernel + var.space == crate::AddressSpace::WorkGroup + || matches!(module.types[var.ty].inner, crate::TypeInner::Pointer { .. }) + } + Access::StructMember(struct_handle, member_index) => { + // check if the member's type is a pointer + if let crate::TypeInner::Struct { ref members, .. } = + module.types[struct_handle].inner + { + if let Some(member) = members.get(member_index as usize) { + matches!( + module.types[member.ty].inner, + crate::TypeInner::Pointer { .. } + ) + } else { + false + } + } else { + false + } + } + Access::Array(_) => false, + } + } + fn write( &self, writer: &mut W, names: &FastHashMap, + op: &str, ) -> Result<(), core::fmt::Error> { match *self { Access::GlobalVariable(handle) => { write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)]) } Access::StructMember(handle, index) => { - write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)]) + write!( + writer, + "{}{}", + op, + &names[&NameKey::StructMember(handle, index)] + ) } - Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"), + Access::Array(depth) => write!(writer, "{}{}[__i{depth}]", op, WRAPPED_ARRAY_FIELD), } } } @@ -8285,12 +8404,36 @@ mod workgroup_mem_init { &self, writer: &mut W, names: &FastHashMap, + module: &crate::Module, ) -> Result<(), core::fmt::Error> { - for next in self.stack.iter() { - next.write(writer, names)?; + for (i, next) in self.stack.iter().enumerate() { + let op = if i == 0 { + // root item doesn't get an operator prefix + "" + } else { + // check if the previous item is a pointer to determine the operator for this item + let prev = &self.stack[i - 1]; + if prev.is_pointer_type(module) { + "->" + } else { + "." + } + }; + next.write(writer, names, op)?; } Ok(()) } + + fn root_is_workgroup_pointer(&self, module: &crate::Module) -> bool { + if let Some(first) = self.stack.first() { + if let Access::GlobalVariable(handle) = first { + let var = &module.global_variables[*handle]; + // Workgroup variables are passed as pointers to the kernel, regardless of their type + return var.space == crate::AddressSpace::WorkGroup; + } + } + false + } } impl Writer { @@ -8365,16 +8508,31 @@ mod workgroup_mem_init { ) -> BackendResult { if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) { write!(self.out, "{level}")?; - access_stack.write(&mut self.out, &self.names)?; + // for workgroup pointers at the root level, only add * for scalar/vector/matrix types + // structs/arrays use -> operator which handles pointer dereference + // If we're nested in a struct, don't add * - the -> operator handles it + let is_root_workgroup = access_stack.root_is_workgroup_pointer(module); + let is_nested = access_stack.stack.len() > 1; + let skip_deref = is_nested + || match &module.types[ty].inner { + crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => true, + _ => false, + }; + if is_root_workgroup && !skip_deref { + write!(self.out, "*")?; + } + access_stack.write(&mut self.out, &self.names, module)?; writeln!(self.out, " = {{}};")?; } else { match module.types[ty].inner { crate::TypeInner::Atomic { .. } => { - write!( - self.out, - "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" - )?; - access_stack.write(&mut self.out, &self.names)?; + write!(self.out, "{level}{NAMESPACE}::atomic_store_explicit(")?; + // only skip & for direct access to workgroup atomic + let is_nested = access_stack.stack.len() > 1; + if !access_stack.root_is_workgroup_pointer(module) || is_nested { + write!(self.out, "{ATOMIC_REFERENCE}")?; + } + access_stack.write(&mut self.out, &self.names, module)?; writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?; } crate::TypeInner::Array { base, size, .. } => { diff --git a/naga/tests/out/msl/wgsl-8820-multiple-local-invocation-index-id.metal b/naga/tests/out/msl/wgsl-8820-multiple-local-invocation-index-id.metal index 90411265c45..449a4de4b51 100644 --- a/naga/tests/out/msl/wgsl-8820-multiple-local-invocation-index-id.metal +++ b/naga/tests/out/msl/wgsl-8820-multiple-local-invocation-index-id.metal @@ -14,15 +14,15 @@ struct compute1_Input { kernel void compute1_( metal::uint3 local_invocation_id [[thread_position_in_threadgroup]] , uint local_invocation_index [[thread_index_in_threadgroup]] -, threadgroup uint& wg_var +, threadgroup uint* wg_var ) { if (local_invocation_index == 0u) { wg_var = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); const Input input = { local_invocation_id, local_invocation_index }; - wg_var = input.local_invocation_index * 2u; - uint _e6 = wg_var; - wg_var = _e6 + input.local_invocation_id[0]; + *wg_var = input.local_invocation_index * 2u; + uint _e6 = *wg_var; + *wg_var = _e6 + input.local_invocation_id[0]; return; } diff --git a/naga/tests/out/msl/wgsl-abstract-types-operators.metal b/naga/tests/out/msl/wgsl-abstract-types-operators.metal index bed023c0480..9fa28a553e1 100644 --- a/naga/tests/out/msl/wgsl-abstract-types-operators.metal +++ b/naga/tests/out/msl/wgsl-abstract-types-operators.metal @@ -100,9 +100,9 @@ void wgpu_4445_( } void wgpu_4435_( - threadgroup type_3& a + threadgroup type_3* a ) { - uint y = a.inner[as_type(as_type(1) - as_type(1))]; + uint y = a->inner[as_type(as_type(1) - as_type(1))]; return; } diff --git a/naga/tests/out/msl/wgsl-atomicOps.metal b/naga/tests/out/msl/wgsl-atomicOps.metal index 26f384cf11f..a92e513ee9d 100644 --- a/naga/tests/out/msl/wgsl-atomicOps.metal +++ b/naga/tests/out/msl/wgsl-atomicOps.metal @@ -80,18 +80,18 @@ kernel void cs_main( , device metal::atomic_uint& storage_atomic_scalar [[user(fake0)]] , device type_4& storage_atomic_arr [[user(fake0)]] , device Struct& storage_struct [[user(fake0)]] -, threadgroup metal::atomic_uint& workgroup_atomic_scalar -, threadgroup type_4& workgroup_atomic_arr -, threadgroup Struct& workgroup_struct +, threadgroup metal::atomic_uint* workgroup_atomic_scalar +, threadgroup type_4* workgroup_atomic_arr +, threadgroup Struct* workgroup_struct ) { if (__local_invocation_index == 0u) { metal::atomic_store_explicit(&workgroup_atomic_scalar, 0, metal::memory_order_relaxed); for (int __i0 = 0; __i0 < 2; __i0++) { - metal::atomic_store_explicit(&workgroup_atomic_arr.inner[__i0], 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_atomic_arr->inner[__i0], 0, metal::memory_order_relaxed); } - metal::atomic_store_explicit(&workgroup_struct.atomic_scalar, 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_struct->atomic_scalar, 0, metal::memory_order_relaxed); for (int __i0 = 0; __i0 < 2; __i0++) { - metal::atomic_store_explicit(&workgroup_struct.atomic_arr.inner[__i0], 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_struct->atomic_arr.inner[__i0], 0, metal::memory_order_relaxed); } } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); @@ -99,97 +99,97 @@ kernel void cs_main( metal::atomic_store_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::atomic_store_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); metal::atomic_store_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - metal::atomic_store_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - metal::atomic_store_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - metal::atomic_store_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - metal::atomic_store_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + metal::atomic_store_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint l0_ = metal::atomic_load_explicit(&storage_atomic_scalar, metal::memory_order_relaxed); int l1_ = metal::atomic_load_explicit(&storage_atomic_arr.inner[1], metal::memory_order_relaxed); uint l2_ = metal::atomic_load_explicit(&storage_struct.atomic_scalar, metal::memory_order_relaxed); int l3_ = metal::atomic_load_explicit(&storage_struct.atomic_arr.inner[1], metal::memory_order_relaxed); - uint l4_ = metal::atomic_load_explicit(&workgroup_atomic_scalar, metal::memory_order_relaxed); - int l5_ = metal::atomic_load_explicit(&workgroup_atomic_arr.inner[1], metal::memory_order_relaxed); - uint l6_ = metal::atomic_load_explicit(&workgroup_struct.atomic_scalar, metal::memory_order_relaxed); - int l7_ = metal::atomic_load_explicit(&workgroup_struct.atomic_arr.inner[1], metal::memory_order_relaxed); + uint l4_ = metal::atomic_load_explicit(workgroup_atomic_scalar, metal::memory_order_relaxed); + int l5_ = metal::atomic_load_explicit(&workgroup_atomic_arr->inner[1], metal::memory_order_relaxed); + uint l6_ = metal::atomic_load_explicit(&workgroup_struct->atomic_scalar, metal::memory_order_relaxed); + int l7_ = metal::atomic_load_explicit(&workgroup_struct->atomic_arr.inner[1], metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e51 = metal::atomic_fetch_add_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e55 = metal::atomic_fetch_add_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e59 = metal::atomic_fetch_add_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e64 = metal::atomic_fetch_add_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e67 = metal::atomic_fetch_add_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e71 = metal::atomic_fetch_add_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e75 = metal::atomic_fetch_add_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e80 = metal::atomic_fetch_add_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e67 = metal::atomic_fetch_add_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e71 = metal::atomic_fetch_add_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e75 = metal::atomic_fetch_add_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e80 = metal::atomic_fetch_add_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e83 = metal::atomic_fetch_sub_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e87 = metal::atomic_fetch_sub_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e91 = metal::atomic_fetch_sub_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e96 = metal::atomic_fetch_sub_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e99 = metal::atomic_fetch_sub_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e103 = metal::atomic_fetch_sub_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e107 = metal::atomic_fetch_sub_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e112 = metal::atomic_fetch_sub_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e99 = metal::atomic_fetch_sub_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e103 = metal::atomic_fetch_sub_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e107 = metal::atomic_fetch_sub_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e112 = metal::atomic_fetch_sub_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e115 = metal::atomic_fetch_max_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e119 = metal::atomic_fetch_max_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e123 = metal::atomic_fetch_max_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e128 = metal::atomic_fetch_max_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e131 = metal::atomic_fetch_max_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e135 = metal::atomic_fetch_max_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e139 = metal::atomic_fetch_max_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e144 = metal::atomic_fetch_max_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e131 = metal::atomic_fetch_max_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e135 = metal::atomic_fetch_max_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e139 = metal::atomic_fetch_max_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e144 = metal::atomic_fetch_max_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e147 = metal::atomic_fetch_min_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e151 = metal::atomic_fetch_min_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e155 = metal::atomic_fetch_min_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e160 = metal::atomic_fetch_min_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e163 = metal::atomic_fetch_min_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e167 = metal::atomic_fetch_min_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e171 = metal::atomic_fetch_min_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e176 = metal::atomic_fetch_min_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e163 = metal::atomic_fetch_min_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e167 = metal::atomic_fetch_min_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e171 = metal::atomic_fetch_min_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e176 = metal::atomic_fetch_min_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e179 = metal::atomic_fetch_and_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e183 = metal::atomic_fetch_and_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e187 = metal::atomic_fetch_and_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e192 = metal::atomic_fetch_and_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e195 = metal::atomic_fetch_and_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e199 = metal::atomic_fetch_and_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e203 = metal::atomic_fetch_and_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e208 = metal::atomic_fetch_and_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e195 = metal::atomic_fetch_and_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e199 = metal::atomic_fetch_and_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e203 = metal::atomic_fetch_and_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e208 = metal::atomic_fetch_and_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e211 = metal::atomic_fetch_or_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e215 = metal::atomic_fetch_or_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e219 = metal::atomic_fetch_or_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e224 = metal::atomic_fetch_or_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e227 = metal::atomic_fetch_or_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e231 = metal::atomic_fetch_or_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e235 = metal::atomic_fetch_or_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e240 = metal::atomic_fetch_or_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e227 = metal::atomic_fetch_or_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e231 = metal::atomic_fetch_or_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e235 = metal::atomic_fetch_or_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e240 = metal::atomic_fetch_or_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e243 = metal::atomic_fetch_xor_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e247 = metal::atomic_fetch_xor_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e251 = metal::atomic_fetch_xor_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e256 = metal::atomic_fetch_xor_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e259 = metal::atomic_fetch_xor_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e263 = metal::atomic_fetch_xor_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e267 = metal::atomic_fetch_xor_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e272 = metal::atomic_fetch_xor_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e259 = metal::atomic_fetch_xor_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e263 = metal::atomic_fetch_xor_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e267 = metal::atomic_fetch_xor_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e272 = metal::atomic_fetch_xor_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e275 = metal::atomic_exchange_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e279 = metal::atomic_exchange_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e283 = metal::atomic_exchange_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e288 = metal::atomic_exchange_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e291 = metal::atomic_exchange_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e295 = metal::atomic_exchange_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e299 = metal::atomic_exchange_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e304 = metal::atomic_exchange_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e291 = metal::atomic_exchange_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e295 = metal::atomic_exchange_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e299 = metal::atomic_exchange_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e304 = metal::atomic_exchange_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); _atomic_compare_exchange_result_Uint_4_ _e308 = naga_atomic_compare_exchange_weak_explicit(&storage_atomic_scalar, 1u, 2u); _atomic_compare_exchange_result_Sint_4_ _e313 = naga_atomic_compare_exchange_weak_explicit(&storage_atomic_arr.inner[1], 1, 2); _atomic_compare_exchange_result_Uint_4_ _e318 = naga_atomic_compare_exchange_weak_explicit(&storage_struct.atomic_scalar, 1u, 2u); _atomic_compare_exchange_result_Sint_4_ _e324 = naga_atomic_compare_exchange_weak_explicit(&storage_struct.atomic_arr.inner[1], 1, 2); - _atomic_compare_exchange_result_Uint_4_ _e328 = naga_atomic_compare_exchange_weak_explicit(&workgroup_atomic_scalar, 1u, 2u); - _atomic_compare_exchange_result_Sint_4_ _e333 = naga_atomic_compare_exchange_weak_explicit(&workgroup_atomic_arr.inner[1], 1, 2); - _atomic_compare_exchange_result_Uint_4_ _e338 = naga_atomic_compare_exchange_weak_explicit(&workgroup_struct.atomic_scalar, 1u, 2u); - _atomic_compare_exchange_result_Sint_4_ _e344 = naga_atomic_compare_exchange_weak_explicit(&workgroup_struct.atomic_arr.inner[1], 1, 2); + _atomic_compare_exchange_result_Uint_4_ _e328 = naga_atomic_compare_exchange_weak_explicit(workgroup_atomic_scalar, 1u, 2u); + _atomic_compare_exchange_result_Sint_4_ _e333 = naga_atomic_compare_exchange_weak_explicit(&workgroup_atomic_arr->inner[1], 1, 2); + _atomic_compare_exchange_result_Uint_4_ _e338 = naga_atomic_compare_exchange_weak_explicit(&workgroup_struct->atomic_scalar, 1u, 2u); + _atomic_compare_exchange_result_Sint_4_ _e344 = naga_atomic_compare_exchange_weak_explicit(&workgroup_struct->atomic_arr.inner[1], 1, 2); return; } diff --git a/naga/tests/out/msl/wgsl-globals.metal b/naga/tests/out/msl/wgsl-globals.metal index 8bd1a5cd377..2fbb7b69965 100644 --- a/naga/tests/out/msl/wgsl-globals.metal +++ b/naga/tests/out/msl/wgsl-globals.metal @@ -74,7 +74,7 @@ kernel void main_( ) { if (__local_invocation_index == 0u) { wg = {}; - metal::atomic_store_explicit(&at_1, 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(at_1, 0, metal::memory_order_relaxed); } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); float Foo = 1.0; @@ -82,20 +82,20 @@ kernel void main_( test_msl_packed_vec3_(alignment); metal::float4x2 _e5 = global_nested_arrays_of_matrices_4x2_.inner[0].inner[0]; metal::float4 _e10 = global_nested_arrays_of_matrices_2x4_.inner[0].inner[0][0]; - wg.inner[7] = (_e5 * _e10).x; + wg->inner[7] = (_e5 * _e10).x; metal::float3x2 _e16 = global_mat; metal::float3 _e18 = global_vec; - wg.inner[6] = (_e16 * _e18).x; + wg->inner[6] = (_e16 * _e18).x; float _e26 = dummy[1].y; - wg.inner[5] = _e26; + wg->inner[5] = _e26; float _e32 = float_vecs.inner[0].w; - wg.inner[4] = _e32; + wg->inner[4] = _e32; float _e37 = alignment.v1_; - wg.inner[3] = _e37; + wg->inner[3] = _e37; float _e43 = alignment.v3_[0]; - wg.inner[2] = _e43; + wg->inner[2] = _e43; alignment.v1_ = 4.0; - wg.inner[1] = static_cast(1 + (_buffer_sizes.size3 - 0 - 8) / 8); - metal::atomic_store_explicit(&at_1, 2u, metal::memory_order_relaxed); + wg->inner[1] = static_cast(1 + (_buffer_sizes.size3 - 0 - 8) / 8); + metal::atomic_store_explicit(at_1, 2u, metal::memory_order_relaxed); return; } diff --git a/naga/tests/out/msl/wgsl-interface.metal b/naga/tests/out/msl/wgsl-interface.metal index c4f63ca9ae7..522e222ba6d 100644 --- a/naga/tests/out/msl/wgsl-interface.metal +++ b/naga/tests/out/msl/wgsl-interface.metal @@ -75,13 +75,13 @@ struct computeInput { , uint local_index [[thread_index_in_threadgroup]] , metal::uint3 wg_id [[threadgroup_position_in_grid]] , metal::uint3 num_wgs [[threadgroups_per_grid]] -, threadgroup type_4& output +, threadgroup type_4* output ) { if (local_index == 0u) { output = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - output.inner[0] = (((global_id.x + local_id.x) + local_index) + wg_id.x) + num_wgs.x; + output->inner[0] = (((global_id.x + local_id.x) + local_index) + wg_id.x) + num_wgs.x; return; } diff --git a/naga/tests/out/msl/wgsl-overrides-atomicCompareExchangeWeak.metal b/naga/tests/out/msl/wgsl-overrides-atomicCompareExchangeWeak.metal index 8380883e370..4d7e12252af 100644 --- a/naga/tests/out/msl/wgsl-overrides-atomicCompareExchangeWeak.metal +++ b/naga/tests/out/msl/wgsl-overrides-atomicCompareExchangeWeak.metal @@ -44,6 +44,6 @@ kernel void f( metal::atomic_store_explicit(&a, 0, metal::memory_order_relaxed); } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - _atomic_compare_exchange_result_Uint_4_ _e5 = naga_atomic_compare_exchange_weak_explicit(&a, 2u, 1u); + _atomic_compare_exchange_result_Uint_4_ _e5 = naga_atomic_compare_exchange_weak_explicit(a, 2u, 1u); return; } diff --git a/naga/tests/out/msl/wgsl-policy-mix.metal b/naga/tests/out/msl/wgsl-policy-mix.metal index a31d80398f7..ecbc2dee5e2 100644 --- a/naga/tests/out/msl/wgsl-policy-mix.metal +++ b/naga/tests/out/msl/wgsl-policy-mix.metal @@ -39,14 +39,14 @@ metal::float4 mock_function( device InStorage const& in_storage, constant InUniform& in_uniform, metal::texture2d_array image_2d_array, - threadgroup type_5& in_workgroup, + threadgroup type_5* in_workgroup, thread type_6& in_private ) { type_9 in_function = type_9 {{metal::float4(0.707, 0.0, 0.0, 1.0), metal::float4(0.0, 0.707, 0.0, 1.0)}}; metal::float4 _e18 = in_storage.a.inner[i]; metal::float4 _e22 = in_uniform.a.inner[i]; metal::float4 _e25 = (uint(l) < image_2d_array.get_num_mip_levels() && uint(i) < image_2d_array.get_array_size() && metal::all(metal::uint2(c) < metal::uint2(image_2d_array.get_width(l), image_2d_array.get_height(l))) ? image_2d_array.read(metal::uint2(c), i, l): DefaultConstructible()); - float _e29 = in_workgroup.inner[metal::min(unsigned(i), 29u)]; + float _e29 = in_workgroup->inner[metal::min(unsigned(i), 29u)]; float _e34 = in_private.inner[metal::min(unsigned(i), 39u)]; metal::float4 _e38 = in_function.inner[metal::min(unsigned(i), 1u)]; return ((((_e18 + _e22) + _e25) + metal::float4(_e29)) + metal::float4(_e34)) + _e38; @@ -57,7 +57,7 @@ kernel void main_( , device InStorage const& in_storage [[user(fake0)]] , constant InUniform& in_uniform [[user(fake0)]] , metal::texture2d_array image_2d_array [[user(fake0)]] -, threadgroup type_5& in_workgroup +, threadgroup type_5* in_workgroup ) { type_6 in_private = {}; if (__local_invocation_index == 0u) { diff --git a/naga/tests/out/msl/wgsl-workgroup-uniform-load-atomic.metal b/naga/tests/out/msl/wgsl-workgroup-uniform-load-atomic.metal index ed4b600967d..345d444c40e 100644 --- a/naga/tests/out/msl/wgsl-workgroup-uniform-load-atomic.metal +++ b/naga/tests/out/msl/wgsl-workgroup-uniform-load-atomic.metal @@ -27,7 +27,7 @@ kernel void test_atomic_workgroup_uniform_load( metal::atomic_store_explicit(&wg_signed, 0, metal::memory_order_relaxed); metal::atomic_store_explicit(&wg_struct.atomic_scalar, 0, metal::memory_order_relaxed); for (int __i0 = 0; __i0 < 2; __i0++) { - metal::atomic_store_explicit(&wg_struct.atomic_arr.inner[__i0], 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(&wg_struct->atomic_arr.inner[__i0], 0, metal::memory_order_relaxed); } } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); @@ -35,22 +35,22 @@ kernel void test_atomic_workgroup_uniform_load( bool local_1 = {}; bool local_2 = {}; uint active_tile_index = workgroup_id.x + (workgroup_id.y * 32768u); - uint _e11 = metal::atomic_fetch_or_explicit(&wg_scalar, static_cast(active_tile_index >= 64u), metal::memory_order_relaxed); - int _e14 = metal::atomic_fetch_add_explicit(&wg_signed, 1, metal::memory_order_relaxed); - metal::atomic_store_explicit(&wg_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e22 = metal::atomic_fetch_add_explicit(&wg_struct.atomic_arr.inner[0], 1, metal::memory_order_relaxed); + uint _e11 = metal::atomic_fetch_or_explicit(wg_scalar, static_cast(active_tile_index >= 64u), metal::memory_order_relaxed); + int _e14 = metal::atomic_fetch_add_explicit(wg_signed, 1, metal::memory_order_relaxed); + metal::atomic_store_explicit(&wg_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e22 = metal::atomic_fetch_add_explicit(&wg_struct->atomic_arr.inner[0], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - uint unnamed = metal::atomic_load_explicit(&wg_scalar, metal::memory_order_relaxed); + uint unnamed = metal::atomic_load_explicit(wg_scalar, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - int unnamed_1 = metal::atomic_load_explicit(&wg_signed, metal::memory_order_relaxed); + int unnamed_1 = metal::atomic_load_explicit(wg_signed, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - uint unnamed_2 = metal::atomic_load_explicit(&wg_struct.atomic_scalar, metal::memory_order_relaxed); + uint unnamed_2 = metal::atomic_load_explicit(&wg_struct->atomic_scalar, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - int unnamed_3 = metal::atomic_load_explicit(&wg_struct.atomic_arr.inner[0], metal::memory_order_relaxed); + int unnamed_3 = metal::atomic_load_explicit(&wg_struct->atomic_arr.inner[0], metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); if (unnamed == 0u) { local = unnamed_1 > 0; diff --git a/naga/tests/out/msl/wgsl-workgroup-uniform-load.metal b/naga/tests/out/msl/wgsl-workgroup-uniform-load.metal index 5b8b513c36d..1280b570717 100644 --- a/naga/tests/out/msl/wgsl-workgroup-uniform-load.metal +++ b/naga/tests/out/msl/wgsl-workgroup-uniform-load.metal @@ -21,7 +21,7 @@ kernel void test_workgroupUniformLoad( } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - int unnamed = arr_i32_.inner[workgroup_id.x]; + int unnamed = arr_i32_->inner[workgroup_id.x]; metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); if (unnamed > 10) { metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); diff --git a/naga/tests/out/msl/wgsl-workgroup-var-init.metal b/naga/tests/out/msl/wgsl-workgroup-var-init.metal index 6bb6bf96f6e..31475ea7198 100644 --- a/naga/tests/out/msl/wgsl-workgroup-var-init.metal +++ b/naga/tests/out/msl/wgsl-workgroup-var-init.metal @@ -29,12 +29,12 @@ kernel void main_( metal::atomic_store_explicit(&w_mem.atom, 0, metal::memory_order_relaxed); for (int __i0 = 0; __i0 < 8; __i0++) { for (int __i1 = 0; __i1 < 8; __i1++) { - metal::atomic_store_explicit(&w_mem.atom_arr.inner[__i0].inner[__i1], 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(&w_mem->atom_arr.inner[__i0].inner[__i1], 0, metal::memory_order_relaxed); } } } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - type_1 _e3 = w_mem.arr; + type_1 _e3 = w_mem->arr; output = _e3; return; }