diff --git a/CHANGELOG.md b/CHANGELOG.md index ebdf156f2b5..23be5f053d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ Bottom level categories: - Make `wgpu_types::texture::format::TextureChannel` accessible as `wgpu::TextureChannel`. By @TornaxO7 in [#9394](https://github.com/gfx-rs/wgpu/pull/9349). - Add support for `per_vertex` in Metal and DX12, as well as some validation for `per_vertex`, and a new enable extension, `wgpu_per_vertex`. By @inner-daemons in [#9219](https://github.com/gfx-rs/wgpu/pull/9219). - Add `ComputePass` version of `CommandEncoder::transition_resources` that allows intra-pass transitions. By @wingertge in [#9371](https://github.com/gfx-rs/wgpu/pull/9371). +- Add support for shader `debugPrintf` in Metal and Vulkan, behind a wgsl extension. By @39ali [#9389](https://github.com/gfx-rs/wgpu/pull/9389). #### Metal diff --git a/Cargo.toml b/Cargo.toml index 2b4c6083734..194a45faa22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -255,6 +255,7 @@ objc2-metal = { version = "0.3.2", default-features = false, features = [ "MTLAccelerationStructureTypes", "MTLAccelerationStructureCommandEncoder", "MTLResidencySet", + "MTLLogState", ] } objc2-quartz-core = { version = "0.3.2", default-features = false, features = [ "std", diff --git a/examples/features/src/debug_printf/README.md b/examples/features/src/debug_printf/README.md new file mode 100644 index 00000000000..5e7309a8dd2 --- /dev/null +++ b/examples/features/src/debug_printf/README.md @@ -0,0 +1,9 @@ +# debugPrintf + +This example shows how to printf in shaders + +## To Run + +``` +cargo run --bin wgpu-examples debug_printf +``` diff --git a/examples/features/src/debug_printf/mod.rs b/examples/features/src/debug_printf/mod.rs new file mode 100644 index 00000000000..90d0892a8aa --- /dev/null +++ b/examples/features/src/debug_printf/mod.rs @@ -0,0 +1,84 @@ +struct Example {} + +impl crate::framework::Example for Example { + fn required_features() -> wgpu::Features { + wgpu::Features::DEBUG_PRINTF + } + + fn required_downlevel_capabilities() -> wgpu::DownlevelCapabilities { + wgpu::DownlevelCapabilities { + flags: wgpu::DownlevelFlags::COMPUTE_SHADERS, + ..Default::default() + } + } + + fn required_limits() -> wgpu::Limits { + wgpu::Limits::default() + } + + fn init( + _config: &wgpu::SurfaceConfiguration, + _adapter: &wgpu::Adapter, + device: &wgpu::Device, + queue: &wgpu::Queue, + ) -> Self { + let shader_source = r#" + enable wgpu_debug_printf; + + @compute @workgroup_size(8, 1, 1) + fn main(@builtin(local_invocation_index) idx: u32) { + debugPrintf("WGSL_LOG: Thread index is %u", idx); + } + "#; + + let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("DebugPrintfShader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + // Create a simple pipeline + let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Debug Pipeline"), + layout: None, + module: &shader, + entry_point: Some("main"), + compilation_options: Default::default(), + cache: None, + }); + + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&pipeline); + cpass.dispatch_workgroups(1, 1, 1); + } + + queue.submit(Some(encoder.finish())); + + device.poll(wgpu::PollType::wait_indefinitely()).unwrap(); + + Example {} + } + + fn update(&mut self, _event: winit::event::WindowEvent) { + //empty + } + + fn resize( + &mut self, + _config: &wgpu::SurfaceConfiguration, + _device: &wgpu::Device, + _queue: &wgpu::Queue, + ) { + } + + fn render(&mut self, _view: &wgpu::TextureView, _device: &wgpu::Device, _queue: &wgpu::Queue) {} +} + +pub fn main() { + crate::framework::run::("debug-printf"); +} diff --git a/examples/features/src/lib.rs b/examples/features/src/lib.rs index 62a547f82fd..0766117bc65 100644 --- a/examples/features/src/lib.rs +++ b/examples/features/src/lib.rs @@ -10,6 +10,7 @@ pub mod bunnymark; pub mod conservative_raster; pub mod cooperative_matrix; pub mod cube; +pub mod debug_printf; pub mod hello_synchronization; pub mod hello_triangle; pub mod hello_windows; diff --git a/examples/features/src/main.rs b/examples/features/src/main.rs index 7dd7f4698b6..c3e08d75abc 100644 --- a/examples/features/src/main.rs +++ b/examples/features/src/main.rs @@ -212,6 +212,12 @@ const EXAMPLES: &[ExampleDesc] = &[ webgl: false, webgpu: false, }, + ExampleDesc { + name: "debug_printf", + function: wgpu_examples::debug_printf::main, + webgl: false, + webgpu: false, + }, ]; fn get_example_name() -> Option { diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 9c617a3b1ed..09a0a2a1cd0 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -429,6 +429,15 @@ impl StatementGraph { "TraceRay" } }, + S::DebugPrintf { + format: _, + ref arguments, + } => { + for &expr in arguments { + self.dependencies.push((id, expr, "arg")); + } + "DebugPrintf" + } }; // Set the last node to the merge node last_node = merge_id; diff --git a/naga/src/back/glsl/writer.rs b/naga/src/back/glsl/writer.rs index dd116e03b09..21332e62be1 100644 --- a/naga/src/back/glsl/writer.rs +++ b/naga/src/back/glsl/writer.rs @@ -2264,6 +2264,7 @@ impl<'a, W: Write> Writer<'a, W> { } Statement::CooperativeStore { .. } => unimplemented!(), Statement::RayPipelineFunction(_) => unimplemented!(), + Statement::DebugPrintf { .. } => unimplemented!(), } Ok(()) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index d36541511ad..7de2fe834fd 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -3047,6 +3047,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } Statement::CooperativeStore { .. } => unimplemented!(), Statement::RayPipelineFunction(_) => unreachable!(), + Statement::DebugPrintf { .. } => unimplemented!(), } Ok(()) diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 32c7099a64e..d117d87ec50 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -929,6 +929,7 @@ pub fn supported_capabilities() -> crate::valid::Capabilities { // No DRAW_INDEX // No MEMORY_DECORATION_VOLATILE | Caps::MEMORY_DECORATION_COHERENT + | Caps::DEBUG_PRINTF } #[test] diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index c2931f7af41..c7e7bdfeb23 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -4373,6 +4373,21 @@ impl Writer { writeln!(self.out, ");")?; } crate::Statement::RayPipelineFunction(_) => unreachable!(), + crate::Statement::DebugPrintf { + ref format, + ref arguments, + } => { + write!( + self.out, + "{level}metal::os_log_default.log_info(\"{}\"", + format + )?; + for &arg in arguments { + write!(self.out, ", ")?; + self.put_expression(arg, &context.expression, true)?; + } + writeln!(self.out, ");")?; + } } } diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 6e86ee360e2..262bcab1c26 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -921,6 +921,14 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S adjust(payload); } }, + Statement::DebugPrintf { + format: _, + ref mut arguments, + } => { + for argument in arguments.iter_mut() { + adjust(argument); + } + } Statement::Break | Statement::Continue | Statement::Kill diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index f83d8920e24..62b6af1f541 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -4178,6 +4178,19 @@ impl BlockContext<'_> { }; } Statement::RayPipelineFunction(_) => unreachable!(), + Statement::DebugPrintf { + ref format, + ref arguments, + } => { + let mut format_params = Vec::with_capacity(arguments.len()); + for &arg in arguments { + let word = self.cached[arg]; + format_params.push(word); + } + + self.writer + .write_debug_printf(&mut block, format, &format_params); + } } } diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index c7b54cdb5ff..4fbc99f091a 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -1195,4 +1195,5 @@ pub fn supported_capabilities() -> crate::valid::Capabilities { | Caps::DRAW_INDEX | Caps::MEMORY_DECORATION_COHERENT | Caps::MEMORY_DECORATION_VOLATILE + | Caps::DEBUG_PRINTF } diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 993b3869c43..befed84f78d 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -283,6 +283,53 @@ impl Writer { Ok(()) } + /// Returns `true` if any `debugPrintf` statement is found in any function or entry point + /// in the module. + fn module_uses_debug_printf(module: &Module) -> bool { + let functions = module.functions.iter().map(|(_, f)| f); + let entry_points = module.entry_points.iter().map(|ep| &ep.function); + for func in functions.chain(entry_points) { + for stmt in func.body.iter() { + if Self::find_debug_printf(stmt) { + return true; + } + } + } + false + } + + /// Returns `true` if a `debugPrintf` is found anywhere within this statement, `false` otherwise. + /// + /// Does not traverse into function calls. + fn find_debug_printf(stmt: &crate::Statement) -> bool { + match *stmt { + crate::Statement::DebugPrintf { .. } => true, + crate::Statement::Block(ref b) => b.iter().any(Self::find_debug_printf), + crate::Statement::If { + ref accept, + ref reject, + .. + } => { + accept.iter().any(Self::find_debug_printf) + || reject.iter().any(Self::find_debug_printf) + } + crate::Statement::Loop { + ref body, + ref continuing, + .. + } => { + body.iter().any(Self::find_debug_printf) + || continuing.iter().any(Self::find_debug_printf) + } + crate::Statement::Switch { ref cases, .. } => cases + .iter() + .any(|c| c.body.iter().any(Self::find_debug_printf)), + // Note: does not match on function `Call`s to look inside function bodies recursively. + // This is fine because this is called on all functions and entrypoints in a module to check for `debugPrintf` usage. + _ => false, + } + } + /// Helper method which writes all the `enable` declarations /// needed for a module. fn write_enable_declarations(&mut self, module: &Module) -> BackendResult { @@ -298,6 +345,7 @@ impl Writer { ray_tracing_pipeline: bool, per_vertex: bool, binding_array: bool, + debug_printf: bool, } let mut needed = RequiredEnabled { mesh_shaders: module.uses_mesh_shaders(), @@ -427,6 +475,11 @@ impl Writer { needed.ray_tracing_pipeline = true; } + // search for debugPrintf statements in all functions and entry points + if Self::module_uses_debug_printf(module) { + needed.debug_printf = true; + } + // Write required declarations let mut any_written = false; if needed.f16 { @@ -469,6 +522,10 @@ impl Writer { writeln!(self.out, "enable wgpu_per_vertex;")?; any_written = true; } + if needed.debug_printf { + writeln!(self.out, "enable wgpu_debug_printf;")?; + any_written = true; + } if any_written { // Empty line for readability writeln!(self.out)?; @@ -1210,6 +1267,17 @@ impl Writer { writeln!(self.out, ");")? } }, + Statement::DebugPrintf { + ref format, + ref arguments, + } => { + write!(self.out, "{level}debugPrintf(\"{}\"", format)?; + for &arg in arguments { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + writeln!(self.out, ");")?; + } } Ok(()) diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 4b9dd13c899..ce2f8a2c4ae 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -168,6 +168,14 @@ impl FunctionTracer<'_> { self.expressions_used.insert(payload); } }, + St::DebugPrintf { + format: _, + ref arguments, + } => { + for &expr in arguments { + self.expressions_used.insert(expr); + } + } // Trivial statements. St::Break @@ -406,6 +414,14 @@ impl FunctionMap { adjust(payload); } }, + St::DebugPrintf { + format: _, + ref mut arguments, + } => { + for expr in arguments { + adjust(expr); + } + } // Trivial statements. St::Break diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 0ddb20b8185..13de4bd8b7d 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -1664,6 +1664,7 @@ impl> Frontend { } S::WorkGroupUniformLoad { .. } => unreachable!(), S::CooperativeStore { .. } => unreachable!(), + S::DebugPrintf { .. } => unreachable!(), } i += 1; } diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 59d2268333f..fa662aae277 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -444,6 +444,10 @@ pub(crate) enum Error<'a> { UnexpectedExprForTypeExpression(Span), MissingIncomingPayload(Span), UnterminatedBlockComment(Span), + InvalidStringLiteral { + span: Span, + description: &'static str, + }, } impl From for Error<'_> { @@ -501,6 +505,7 @@ impl<'a> Error<'a> { Token::Attribute => "@".to_string(), Token::Number(_) => "number".to_string(), Token::Word(s) => s.to_string(), + Token::String(_) => "string literal".to_string(), Token::Operation(c) => format!("operation (`{c}`)"), Token::LogicalOperation(c) => format!("logical operation (`{c}`)"), Token::ShiftOperation(c) => format!("bitshift (`{c}{c}`)"), @@ -1528,7 +1533,15 @@ impl<'a> Error<'a> { "must be closed with `*/`".into(), )], notes: vec![], - } + }, + Error::InvalidStringLiteral{span , description} =>ParseError { + message: "Invalid String literal".to_string(), + labels: vec![( + span, + description.into() + )], + notes: vec![], + } , } } } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 2572b07c122..25a0de357f3 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2545,6 +2545,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { access } + ast::Expression::String(_) => { + return Err(Box::new(Error::InvalidStringLiteral { + span, + description: "String literals are only supported in debugPrintf", + })); + } }; expr.try_map(|handle| ctx.append_expression(handle, span)) @@ -3823,6 +3829,56 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(ir::Statement::RayPipelineFunction(fun), function_span); return Ok(None); } + "debugPrintf" => { + if !ctx + .enable_extensions + .contains(crate::front::wgsl::ImplementedEnableExtension::WgpuDebugPrintf) + { + return Err(Box::new(Error::EnableExtensionNotEnabled { + span: function_span, + kind: crate::front::wgsl::ImplementedEnableExtension::WgpuDebugPrintf + .into(), + })); + } + + if arguments.is_empty() { + return Err(Box::new(Error::WrongArgumentCount { + expected: 1..u32::MAX, + found: 0, + span: function_span, + })); + } + + // extract the format string + let format_handle = arguments[0]; + let format = match ctx.ast_expressions[format_handle] { + ast::Expression::String(s) => s.to_string(), + _ => { + return Err(Box::new(Error::Internal( + "debugPrintf format must be a string literal", + ))) + } + }; + + // extract remaining arguments (if any) + let mut ir_arguments = Vec::with_capacity(arguments.len().saturating_sub(1)); + + for &ast_handle in &arguments[1..] { + let ir_handle = self.expression(ast_handle, ctx)?; + ir_arguments.push(ir_handle); + } + + let rctx = ctx.runtime_expression_ctx(function_span)?; + rctx.block.push( + ir::Statement::DebugPrintf { + format, + arguments: ir_arguments, + }, + function_span, + ); + + return Ok(None); + } _ => return Err(Box::new(Error::UnknownIdent(function_span, function_name))), } }; diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index 40230242a30..32dfc79a590 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -406,6 +406,7 @@ pub enum Expression<'a> { base: Handle>, field: Ident<'a>, }, + String(&'a str), } #[derive(Debug)] diff --git a/naga/src/front/wgsl/parse/directive/enable_extension.rs b/naga/src/front/wgsl/parse/directive/enable_extension.rs index 80e58e6078a..1c1d5bd9c12 100644 --- a/naga/src/front/wgsl/parse/directive/enable_extension.rs +++ b/naga/src/front/wgsl/parse/directive/enable_extension.rs @@ -23,6 +23,7 @@ pub(crate) struct EnableExtensions { primitive_index: bool, per_vertex: bool, wgpu_binding_array: bool, + debug_printf: bool, } impl EnableExtensions { @@ -40,6 +41,7 @@ impl EnableExtensions { primitive_index: false, per_vertex: false, wgpu_binding_array: false, + debug_printf: false, } } @@ -62,6 +64,7 @@ impl EnableExtensions { ImplementedEnableExtension::PrimitiveIndex => &mut self.primitive_index, ImplementedEnableExtension::PerVertex => &mut self.per_vertex, ImplementedEnableExtension::WgpuBindingArray => &mut self.wgpu_binding_array, + ImplementedEnableExtension::WgpuDebugPrintf => &mut self.debug_printf, }; *field = true; } @@ -83,6 +86,7 @@ impl EnableExtensions { ImplementedEnableExtension::PrimitiveIndex => self.primitive_index, ImplementedEnableExtension::PerVertex => self.per_vertex, ImplementedEnableExtension::WgpuBindingArray => self.wgpu_binding_array, + ImplementedEnableExtension::WgpuDebugPrintf => self.debug_printf, } } @@ -137,6 +141,7 @@ impl EnableExtension { const DRAW_INDEX: &'static str = "draw_index"; const PER_VERTEX: &'static str = "wgpu_per_vertex"; const BINDING_ARRAY: &'static str = "wgpu_binding_array"; + const DEBUG_PRINTF: &'static str = "wgpu_debug_printf"; /// Convert from a sentinel word in WGSL into its associated [`EnableExtension`], if possible. pub(crate) fn from_ident(word: &str, span: Span) -> Result<'_, Self> { @@ -162,6 +167,7 @@ impl EnableExtension { Self::PRIMITIVE_INDEX => Self::Implemented(ImplementedEnableExtension::PrimitiveIndex), Self::PER_VERTEX => Self::Implemented(ImplementedEnableExtension::PerVertex), Self::BINDING_ARRAY => Self::Implemented(ImplementedEnableExtension::WgpuBindingArray), + Self::DEBUG_PRINTF => Self::Implemented(ImplementedEnableExtension::WgpuDebugPrintf), _ => return Err(Box::new(Error::UnknownEnableExtension(span, word))), }) } @@ -184,6 +190,7 @@ impl EnableExtension { ImplementedEnableExtension::WgpuRayTracingPipeline => Self::RAY_TRACING_PIPELINE, ImplementedEnableExtension::PerVertex => Self::PER_VERTEX, ImplementedEnableExtension::WgpuBindingArray => Self::BINDING_ARRAY, + ImplementedEnableExtension::WgpuDebugPrintf => Self::DEBUG_PRINTF, }, Self::Unimplemented(kind) => match kind { UnimplementedEnableExtension::Subgroups => Self::SUBGROUPS, @@ -236,6 +243,8 @@ pub enum ImplementedEnableExtension { PerVertex, /// Enables the `wgpu_binding_array` extension, native only. WgpuBindingArray, + /// Enables the `wgpu_debug_printf` extension, allows using `debugPrintf`, native only. + WgpuDebugPrintf, } impl ImplementedEnableExtension { @@ -253,6 +262,7 @@ impl ImplementedEnableExtension { Self::PrimitiveIndex, Self::PerVertex, Self::WgpuBindingArray, + Self::WgpuDebugPrintf, ]; /// Returns slice of all variants of [`ImplementedEnableExtension`]. @@ -284,6 +294,7 @@ impl ImplementedEnableExtension { .union(C::TEXTURE_AND_SAMPLER_BINDING_ARRAY) .union(C::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING) .union(C::ACCELERATION_STRUCTURE_BINDING_ARRAY), + Self::WgpuDebugPrintf => C::DEBUG_PRINTF, } } } diff --git a/naga/src/front/wgsl/parse/lexer.rs b/naga/src/front/wgsl/parse/lexer.rs index cf50e3f25a7..9e7758fbd8a 100644 --- a/naga/src/front/wgsl/parse/lexer.rs +++ b/naga/src/front/wgsl/parse/lexer.rs @@ -36,6 +36,9 @@ pub enum Token<'a> { /// An identifier, possibly a reserved word. Word(&'a str), + /// A string literal, used for `debugPrintf` format strings. + String(&'a str), + /// A miscellaneous single-character operator, like an arithmetic unary or /// binary operator. This includes `=`, for assignment and initialization. Operation(char), @@ -270,6 +273,17 @@ fn consume_token( None => return (Token::End, ""), }; match cur { + '"' => { + // Find the next quote in the remaining string + match chars.as_str().find('"') { + Some(len) => { + let content = &chars.as_str()[..len]; + let rest = &chars.as_str()[len + 1..]; + (Token::String(content), rest) + } + None => (Token::Unknown('"'), chars.as_str()), + } + } ':' | ';' | ',' => (Token::Separator(cur), chars.as_str()), '.' => { let og_chars = chars.as_str(); diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 2acfe3cfa25..9a2d511e0f2 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -492,6 +492,7 @@ impl Parser { (Token::Word("RAY_QUERY_INTERSECTION_AABB"), _) => { literal_ray_intersection(crate::RayQueryIntersection::Aabb) } + (Token::String(s), _) => ast::Expression::String(s), (Token::Word(word), span) => { let ident = self.template_elaborated_ident(word, span, lexer, ctx)?; diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 7cfe810e411..d8efbb7ac6a 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -2404,6 +2404,11 @@ pub enum Statement { target: Handle, data: CooperativeData, }, + /// This corresponds to `debugPrintf` in WGSL when the `wgpu_debug_printf` extension is enabled. + DebugPrintf { + format: String, + arguments: Vec>, + }, } /// A function argument. diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index 3e3e32f5c6f..358dea5f498 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -45,7 +45,8 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::ControlBarrier(_) | S::MemoryBarrier(_) | S::CooperativeStore { .. } - | S::RayPipelineFunction(_)), + | S::RayPipelineFunction(_) + | S::DebugPrintf { .. }), ) | None => block.push(S::Return { value: None }, Default::default()), } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index ab834bdfeab..a2877016e25 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1226,6 +1226,22 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::DebugPrintf { + format: _, + ref arguments, + } => { + let mut requirements = UniformityRequirements::empty(); + for &expr in arguments { + requirements |= self.expressions[expr.index()].uniformity.requirements; + } + FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements, + }, + exit: ExitFlags::empty(), + } + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index a03564c00d2..9703a314957 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -234,6 +234,8 @@ pub enum FunctionError { InvalidPayloadAddressSpace(crate::AddressSpace), #[error("The payload type ({0:?}) passed to `traceRay` does not match the previous one {1:?}")] MismatchedPayloadType(Handle, Handle), + #[error("Argument {0:?} for `debugPrintf` must be a supported scalar type")] + InvalidDebugPrintfArgument(Handle), } bitflags::bitflags! { @@ -1759,6 +1761,35 @@ impl super::Validator { } } }, + + S::DebugPrintf { + format: _, + ref arguments, + } => { + if !self + .capabilities + .contains(super::Capabilities::DEBUG_PRINTF) + { + return Err(FunctionError::MissingCapability( + super::Capabilities::DEBUG_PRINTF, + ) + .with_span_static(span, "missing capability for this operation")); + } + for &argument in arguments { + let ty = + context.resolve_type_inner(argument, &self.valid_expression_set)?; + // Only scalar types are currently supported. Supporting other types is possible + // for example vector types would require mapping each vector component to + // a separate printf format specifier. + match *ty { + Ti::Scalar(_) => {} + _ => { + return Err(FunctionError::InvalidDebugPrintfArgument(argument) + .with_span_handle(argument, context.expressions)); + } + } + } + } } } Ok(BlockInfo { stages }) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 93af8a69df6..57ca73513c2 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -876,6 +876,15 @@ impl super::Validator { Ok(()) } }, + crate::Statement::DebugPrintf { + format: _, + ref arguments, + } => { + for &arg in arguments { + validate_expr(arg)?; + } + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 5eca1549ebb..9b70337e612 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -216,6 +216,8 @@ bitflags::bitflags! { const MEMORY_DECORATION_COHERENT = 1 << 41; /// Support for the `@volatile` memory decoration on storage buffers. const MEMORY_DECORATION_VOLATILE = 1 << 42; + /// Support for `debugPrintf`. + const DEBUG_PRINTF = 1 << 43; } } @@ -248,6 +250,7 @@ impl Capabilities { | Self::TEXTURE_AND_SAMPLER_BINDING_ARRAY_NON_UNIFORM_INDEXING => { Some(Ext::WgpuBindingArray) } + Self::DEBUG_PRINTF => Some(Ext::WgpuDebugPrintf), _ => None, } } diff --git a/naga/tests/naga/wgsl_errors.rs b/naga/tests/naga/wgsl_errors.rs index fbcd4866d65..fc1f57660cc 100644 --- a/naga/tests/naga/wgsl_errors.rs +++ b/naga/tests/naga/wgsl_errors.rs @@ -5439,3 +5439,75 @@ fn unterminated_block_comment_errors() { "unterminated block comment", ) } + +#[test] +fn debug_printf_enable_extension() { + check_extension_validation!( + Capabilities::DEBUG_PRINTF, + r#" @compute @workgroup_size(8, 1, 1) + fn main(@builtin(local_invocation_index) idx: u32) { + debugPrintf("WGSL_LOG: Thread index is %u",idx); + } +"#, + r#"error: the `wgpu_debug_printf` enable extension is not enabled + ┌─ wgsl:3:17 + │ +3 │ debugPrintf("WGSL_LOG: Thread index is %u",idx); + │ ^^^^^^^^^^^ the `wgpu_debug_printf` "Enable Extension" is needed for this functionality, but it is not currently enabled. + │ + = note: You can enable this extension by adding `enable wgpu_debug_printf;` at the top of the shader, before any other items. + +"#, + Err(naga::valid::ValidationError::EntryPoint { + stage: naga::ShaderStage::Compute, + source: naga::valid::EntryPointError::Function( + naga::valid::FunctionError::MissingCapability(Capabilities::DEBUG_PRINTF) + ), + .. + }) + ); +} + +#[test] +fn test_string_literal_usage_validation() { + check( + r#" + fn main() { + let x = "hello"; + } +"#, + r#"error: Invalid String literal + ┌─ wgsl:3:17 + │ +3 │ let x = "hello"; + │ ^^^^^^^ String literals are only supported in debugPrintf + +"#, + ); + + check( + r#" + fn log(msg: f32) {} fn main() { log("error"); } +"#, + r#"error: Invalid String literal + ┌─ wgsl:2:41 + │ +2 │ fn log(msg: f32) {} fn main() { log("error"); } + │ ^^^^^^^ String literals are only supported in debugPrintf + +"#, + ); + + check( + r#" + fn get_str() -> f32 { return "hello"; } +"#, + r#"error: Invalid String literal + ┌─ wgsl:2:34 + │ +2 │ fn get_str() -> f32 { return "hello"; } + │ ^^^^^^^ String literals are only supported in debugPrintf + +"#, + ); +} diff --git a/wgpu-core/src/limits.rs b/wgpu-core/src/limits.rs index 5448d80265a..763ea26d0ee 100644 --- a/wgpu-core/src/limits.rs +++ b/wgpu-core/src/limits.rs @@ -564,7 +564,8 @@ mod tests { .union(Features::PRIMITIVE_INDEX) //.union(Features::TEXTURE_COMPONENT_SWIZZLE) not implemented // Standard-track features not in official spec - .union(Features::IMMEDIATES), + .union(Features::IMMEDIATES) + .union(Features::DEBUG_PRINTF), ); assert!( difference.is_empty(), diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index e26cb3fccde..ccd5fd70712 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -1,14 +1,17 @@ +use block2::StackBlock; use objc2::runtime::{AnyObject, ProtocolObject, Sel}; use objc2::{available, sel}; -use objc2_foundation::{NSOperatingSystemVersion, NSProcessInfo}; +use objc2_foundation::{NSOperatingSystemVersion, NSProcessInfo, NSString}; use objc2_metal::{ - MTLArgumentBuffersTier, MTLCounterSamplingPoint, MTLDevice, MTLFeatureSet, MTLGPUFamily, - MTLIndirectAccelerationStructureInstanceDescriptor, MTLLanguageVersion, MTLPixelFormat, + MTLArgumentBuffersTier, MTLCommandQueueDescriptor, MTLCounterSamplingPoint, MTLDevice, + MTLFeatureSet, MTLGPUFamily, MTLIndirectAccelerationStructureInstanceDescriptor, + MTLLanguageVersion, MTLLogLevel, MTLLogState, MTLLogStateDescriptor, MTLPixelFormat, MTLReadWriteTextureTier, }; use wgt::{AstcBlock, AstcChannel}; use alloc::{string::ToString as _, sync::Arc, vec::Vec}; +use core::ptr::NonNull; use core::sync::atomic; use crate::metal::QueueShared; @@ -70,6 +73,31 @@ impl super::Adapter { } } +/// Sets up Metal shader logging for `debugPrintf` on the given command queue descriptor. +fn setup_debug_printf(device: &ProtocolObject, cq_desc: &MTLCommandQueueDescriptor) { + let log_desc = MTLLogStateDescriptor::new(); + log_desc.setLevel(MTLLogLevel::Debug); + + let Ok(log_state) = device.newLogStateWithDescriptor_error(&log_desc) else { + return; + }; + cq_desc.setLogState(Some(&log_state)); + + let handler = StackBlock::new( + |_subsystem: *mut NSString, + _category: *mut NSString, + _level: MTLLogLevel, + message: NonNull| { + // SAFETY: message is NonNull + let msg = unsafe { message.as_ref() }.to_string(); + println!("{msg}"); + }, + ); + + // SAFETY: handler is Send because it does not capture any state + unsafe { log_state.addLogHandler(&handler) }; +} + impl crate::Adapter for super::Adapter { type A = super::Api; @@ -79,11 +107,23 @@ impl crate::Adapter for super::Adapter { limits: &wgt::Limits, _memory_hints: &wgt::MemoryHints, ) -> Result, crate::DeviceError> { - let queue = self - .shared - .device - .newCommandQueueWithMaxCommandBufferCount(MAX_COMMAND_BUFFERS) - .unwrap(); + let device = &self.shared.device; + + let cq_desc = MTLCommandQueueDescriptor::new(); + // SAFETY: MAX_COMMAND_BUFFERS is a reasonable number of buffers. + unsafe { + cq_desc.setMaxCommandBufferCount(MAX_COMMAND_BUFFERS); + } + + let use_debug_printf = features.contains(wgt::Features::DEBUG_PRINTF) + && self.shared.private_caps.supports_debug_printf; + self.shared.use_debug_printf.replace(use_debug_printf); + + if use_debug_printf { + setup_debug_printf(device, &cq_desc); + } + + let queue = device.newCommandQueueWithDescriptor(&cq_desc).unwrap(); // Acquiring the meaning of timestamp ticks is hard with Metal! // The only thing there is a method correlating cpu & gpu timestamps (`device.sample_timestamps`). @@ -1134,6 +1174,7 @@ impl super::CapabilitiesQuery { tvos = 16.0, visionos = 1.0 ), + supports_debug_printf: msl_version >= MTLLanguageVersion::Version3_2, } } @@ -1270,6 +1311,7 @@ impl super::CapabilitiesQuery { features.set(F::EXPERIMENTAL_RAY_QUERY, self.supports_raytracing); features.set(F::MULTISAMPLE_ARRAY, self.supports_multisample_array); + features.set(F::DEBUG_PRINTF, self.supports_debug_printf); features } @@ -1466,6 +1508,7 @@ impl super::CapabilitiesQuery { timestamp_query_support: self.timestamp_query_support, supports_memoryless_storage: self.supports_memoryless_storage, mesh_shaders: self.mesh_shaders, + supports_debug_printf: self.supports_debug_printf, } } diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 02ef3bffa35..f3ef1ce0bfa 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -247,6 +247,10 @@ impl super::Device { options.setPreserveInvariance(true); } + if self.shared.use_debug_printf.get() { + options.setEnableLogging(true); + } + let library = self .shared .device diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index 13974241180..37109e894a3 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -33,7 +33,7 @@ use alloc::{ sync::Arc, vec::Vec, }; -use core::{fmt, iter, ops, ptr::NonNull, sync::atomic}; +use core::{cell::Cell, fmt, iter, ops, ptr::NonNull, sync::atomic}; use bitflags::bitflags; use hashbrown::HashMap; @@ -323,6 +323,7 @@ struct CapabilitiesQuery { supports_raytracing: bool, shader_per_vertex: bool, supports_multisample_array: bool, + supports_debug_printf: bool, } #[derive(Debug)] @@ -334,6 +335,7 @@ struct PrivateCapabilities { timestamp_query_support: TimestampQuerySupport, supports_memoryless_storage: bool, mesh_shaders: bool, + supports_debug_printf: bool, } #[derive(Debug)] @@ -390,6 +392,7 @@ struct AdapterShared { private_texture_format_caps: PrivateTextureFormatCapabilities, settings: Settings, presentation_timer: time::PresentationTimer, + use_debug_printf: Cell, } unsafe impl Send for AdapterShared {} @@ -412,6 +415,7 @@ impl AdapterShared { device, settings: Settings::default(), presentation_timer: time::PresentationTimer::new(), + use_debug_printf: Cell::new(false), } } diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index 9c287bc2f02..1bed2ccf413 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -1065,6 +1065,11 @@ impl PhysicalDeviceFeatures { || caps.supports_extension(c"VK_KHR_shader_draw_parameters"), ); + features.set( + F::DEBUG_PRINTF, + caps.supports_extension(c"VK_KHR_shader_non_semantic_info"), + ); + (features, dl_flags) } } diff --git a/wgpu-hal/src/vulkan/instance.rs b/wgpu-hal/src/vulkan/instance.rs index a80869eb250..39e70de9e6f 100644 --- a/wgpu-hal/src/vulkan/instance.rs +++ b/wgpu-hal/src/vulkan/instance.rs @@ -80,6 +80,16 @@ unsafe extern "system" fn debug_utils_messenger_callback( return vk::FALSE; } + const VUID_DEBUG_PRINTF: i32 = 0x4fe1fef9; + if cd.message_id_number == VUID_DEBUG_PRINTF { + let message = + unsafe { cd.message_as_c_str() }.map_or(Cow::Borrowed(""), CStr::to_string_lossy); + + log::warn!("[debugPrintf] {}", message); + + return vk::FALSE; + } + let level = match message_severity { // We intentionally suppress info messages down to debug // so that users are not innundated with info messages from the runtime. @@ -796,7 +806,7 @@ impl super::Instance { // Enable explicit validation features if available let mut validation_features; - let mut validation_feature_list: ArrayVec<_, 3>; + let mut validation_feature_list: ArrayVec<_, 4>; if validation_features_are_enabled { validation_feature_list = ArrayVec::new(); @@ -804,6 +814,8 @@ impl super::Instance { validation_feature_list .push(vk::ValidationFeatureEnableEXT::SYNCHRONIZATION_VALIDATION); + validation_feature_list.push(vk::ValidationFeatureEnableEXT::DEBUG_PRINTF); + // Only enable GPU assisted validation if requested. if should_enable_gpu_based_validation { validation_feature_list.push(vk::ValidationFeatureEnableEXT::GPU_ASSISTED); diff --git a/wgpu-naga-bridge/src/lib.rs b/wgpu-naga-bridge/src/lib.rs index 03c77a4ed78..591d7d08d2d 100644 --- a/wgpu-naga-bridge/src/lib.rs +++ b/wgpu-naga-bridge/src/lib.rs @@ -175,6 +175,11 @@ pub fn features_to_naga_capabilities( Caps::MEMORY_DECORATION_VOLATILE, features.contains(wgt::Features::MEMORY_DECORATION_VOLATILE), ); + + caps.set( + Caps::DEBUG_PRINTF, + features.contains(wgt::Features::DEBUG_PRINTF), + ); caps } diff --git a/wgpu-types/src/features.rs b/wgpu-types/src/features.rs index 1db5f5d0488..a80ceaed4d7 100644 --- a/wgpu-types/src/features.rs +++ b/wgpu-types/src/features.rs @@ -89,6 +89,9 @@ mod webgpu_impl { #[doc(hidden)] pub const WEBGPU_FEATURE_PRIMITIVE_INDEX: u64 = 1 << 17; + + #[doc(hidden)] + pub const WEBGPU_FEATURE_DEBUG_PRINTF: u64 = 1 << 18; } macro_rules! bitflags_array_impl { @@ -1811,6 +1814,16 @@ bitflags_array! { /// remain compatible with previous wgpu behavior. #[name("primitive-index", "shader-primitive-index")] const PRIMITIVE_INDEX = WEBGPU_FEATURE_PRIMITIVE_INDEX; + + /// Allows the user to printf inside the shader by using `debugPrintf()`. + /// + /// Supported platforms: + /// - Metal (3.2+) + /// - Vulkan (1.1+) + /// + /// This is a native only feature. + #[name("wgpu-debug-printf")] + const DEBUG_PRINTF =WEBGPU_FEATURE_DEBUG_PRINTF; } }