diff --git a/newsfragments/5957.added.md b/newsfragments/5957.added.md new file mode 100644 index 00000000000..c030d226fa9 --- /dev/null +++ b/newsfragments/5957.added.md @@ -0,0 +1 @@ +Add `#[pyo3(overload(...))]` attribute for declaring `@typing.overload` variants in generated stubs. diff --git a/pyo3-introspection/src/introspection.rs b/pyo3-introspection/src/introspection.rs index 88931bb38bd..f853b77ed1d 100644 --- a/pyo3-introspection/src/introspection.rs +++ b/pyo3-introspection/src/introspection.rs @@ -1,6 +1,6 @@ use crate::model::{ Argument, Arguments, Attribute, Class, Constant, Expr, Function, Module, Operator, - VariableLengthArgument, + OverloadSignature, VariableLengthArgument, }; use anyhow::{anyhow, bail, ensure, Context, Result}; use goblin::elf::section_header::SHN_XINDEX; @@ -177,6 +177,7 @@ fn convert_members<'a>( is_async, returns, doc, + overloads, } => functions.push(convert_function( name, arguments, @@ -184,6 +185,7 @@ fn convert_members<'a>( returns, *is_async, doc.as_deref(), + overloads, type_hint_for_annotation_id, )), Chunk::Attribute { @@ -268,6 +270,7 @@ fn convert_class( }) } +#[expect(clippy::too_many_arguments)] fn convert_function( name: &str, arguments: &ChunkArguments, @@ -275,6 +278,7 @@ fn convert_function( returns: &Option, is_async: bool, docstring: Option<&str>, + overloads: &[ChunkOverload], type_hint_for_annotation_id: &HashMap, ) -> Function { Function { @@ -283,36 +287,53 @@ fn convert_function( .iter() .map(|e| convert_expr(e, type_hint_for_annotation_id)) .collect(), - arguments: Arguments { - positional_only_arguments: arguments - .posonlyargs - .iter() - .map(|a| convert_argument(a, type_hint_for_annotation_id)) - .collect(), - arguments: arguments - .args - .iter() - .map(|a| convert_argument(a, type_hint_for_annotation_id)) - .collect(), - vararg: arguments - .vararg - .as_ref() - .map(|a| convert_variable_length_argument(a, type_hint_for_annotation_id)), - keyword_only_arguments: arguments - .kwonlyargs - .iter() - .map(|e| convert_argument(e, type_hint_for_annotation_id)) - .collect(), - kwarg: arguments - .kwarg - .as_ref() - .map(|a| convert_variable_length_argument(a, type_hint_for_annotation_id)), - }, + arguments: convert_arguments(arguments, type_hint_for_annotation_id), returns: returns .as_ref() .map(|a| convert_expr(a, type_hint_for_annotation_id)), is_async, docstring: docstring.map(Into::into), + overloads: overloads + .iter() + .map(|o| OverloadSignature { + arguments: convert_arguments(&o.arguments, type_hint_for_annotation_id), + returns: o + .returns + .as_ref() + .map(|r| convert_expr(r, type_hint_for_annotation_id)), + }) + .collect(), + } +} + +fn convert_arguments( + arguments: &ChunkArguments, + type_hint_for_annotation_id: &HashMap, +) -> Arguments { + Arguments { + positional_only_arguments: arguments + .posonlyargs + .iter() + .map(|a| convert_argument(a, type_hint_for_annotation_id)) + .collect(), + arguments: arguments + .args + .iter() + .map(|a| convert_argument(a, type_hint_for_annotation_id)) + .collect(), + vararg: arguments + .vararg + .as_ref() + .map(|a| convert_variable_length_argument(a, type_hint_for_annotation_id)), + keyword_only_arguments: arguments + .kwonlyargs + .iter() + .map(|e| convert_argument(e, type_hint_for_annotation_id)) + .collect(), + kwarg: arguments + .kwarg + .as_ref() + .map(|a| convert_variable_length_argument(a, type_hint_for_annotation_id)), } } @@ -700,6 +721,8 @@ enum Chunk { is_async: bool, #[serde(default)] doc: Option, + #[serde(default)] + overloads: Vec, }, Attribute { #[serde(default)] @@ -739,6 +762,13 @@ struct ChunkArgument { annotation: Option, } +#[derive(Deserialize)] +struct ChunkOverload { + arguments: Box, + #[serde(default)] + returns: Option, +} + #[derive(Deserialize)] #[serde(tag = "type", rename_all = "lowercase")] enum ChunkExpr { diff --git a/pyo3-introspection/src/model.rs b/pyo3-introspection/src/model.rs index b6ca8d28a8d..50d0f0652fc 100644 --- a/pyo3-introspection/src/model.rs +++ b/pyo3-introspection/src/model.rs @@ -31,6 +31,14 @@ pub struct Function { pub returns: Option, pub is_async: bool, pub docstring: Option, + /// `@overload` variants for this function + pub overloads: Vec, +} + +#[derive(Debug, Eq, PartialEq, Clone, Hash)] +pub struct OverloadSignature { + pub arguments: Arguments, + pub returns: Option, } #[derive(Debug, Eq, PartialEq, Clone, Hash)] diff --git a/pyo3-introspection/src/stubs.rs b/pyo3-introspection/src/stubs.rs index 9877652cc40..59ff404a183 100644 --- a/pyo3-introspection/src/stubs.rs +++ b/pyo3-introspection/src/stubs.rs @@ -85,6 +85,7 @@ fn module_stubs(module: &Module, parents: &[&str]) -> String { }), is_async: false, docstring: None, + overloads: Vec::new(), }, &imports, None, @@ -175,36 +176,89 @@ fn class_stubs(class: &Class, imports: &Imports) -> String { } fn function_stubs(function: &Function, imports: &Imports, class_name: Option<&str>) -> String { + if function.overloads.is_empty() { + return single_function_stub( + &function.name, + &function.decorators, + &function.arguments, + function.returns.as_ref(), + function.is_async, + function.docstring.as_deref(), + imports, + class_name, + ); + } + + let mut buffer = String::new(); + + for overload in &function.overloads { + let mut overload_decorators = vec![Expr::Attribute { + value: Box::new(Expr::Name { + id: "typing".into(), + }), + attr: "overload".into(), + }]; + overload_decorators.extend(function.decorators.iter().cloned()); + + buffer.push_str(&single_function_stub( + &function.name, + &overload_decorators, + &overload.arguments, + overload.returns.as_ref(), + function.is_async, + None, + imports, + class_name, + )); + buffer.push('\n'); + } + + buffer.truncate(buffer.trim_end_matches('\n').len()); + + buffer +} + +#[expect(clippy::too_many_arguments)] +fn single_function_stub( + name: &str, + decorators: &[Expr], + arguments: &Arguments, + returns: Option<&Expr>, + is_async: bool, + docstring: Option<&str>, + imports: &Imports, + class_name: Option<&str>, +) -> String { // Signature let mut parameters = Vec::new(); - for argument in &function.arguments.positional_only_arguments { + for argument in &arguments.positional_only_arguments { parameters.push(argument_stub(argument, imports)); } - if !function.arguments.positional_only_arguments.is_empty() { + if !arguments.positional_only_arguments.is_empty() { parameters.push("/".into()); } - for argument in &function.arguments.arguments { + for argument in &arguments.arguments { parameters.push(argument_stub(argument, imports)); } - if let Some(argument) = &function.arguments.vararg { + if let Some(argument) = &arguments.vararg { parameters.push(format!( "*{}", variable_length_argument_stub(argument, imports) )); - } else if !function.arguments.keyword_only_arguments.is_empty() { + } else if !arguments.keyword_only_arguments.is_empty() { parameters.push("*".into()); } - for argument in &function.arguments.keyword_only_arguments { + for argument in &arguments.keyword_only_arguments { parameters.push(argument_stub(argument, imports)); } - if let Some(argument) = &function.arguments.kwarg { + if let Some(argument) = &arguments.kwarg { parameters.push(format!( "**{}", variable_length_argument_stub(argument, imports) )); } let mut buffer = String::new(); - for decorator in &function.decorators { + for decorator in decorators { buffer.push('@'); // We remove the class name if it's a prefix to get nicer decorators let mut decorator_buffer = String::new(); @@ -217,20 +271,20 @@ fn function_stubs(function: &Function, imports: &Imports, class_name: Option<&st buffer.push_str(&decorator_buffer); buffer.push('\n'); } - if function.is_async { + if is_async { buffer.push_str("async "); } buffer.push_str("def "); - buffer.push_str(&function.name); + buffer.push_str(name); buffer.push('('); buffer.push_str(¶meters.join(", ")); buffer.push(')'); - if let Some(returns) = &function.returns { + if let Some(returns) = returns { buffer.push_str(" -> "); imports.serialize_expr(returns, &mut buffer); } - if let Some(docstring) = &function.docstring { + if let Some(docstring) = docstring { buffer.push_str(":\n \"\"\""); for line in docstring.lines() { buffer.push_str("\n "); @@ -562,29 +616,42 @@ impl ElementsUsedInAnnotations { for decorator in &function.decorators { self.walk_expr(decorator); } - for arg in function - .arguments + if function.overloads.is_empty() { + self.walk_arguments(&function.arguments, function.returns.as_ref()); + } + if !function.overloads.is_empty() { + self.module_to_name + .entry("typing".into()) + .or_default() + .insert("overload".into()); + } + for overload in &function.overloads { + self.walk_arguments(&overload.arguments, overload.returns.as_ref()); + } + } + + fn walk_arguments(&mut self, arguments: &Arguments, returns: Option<&Expr>) { + for arg in arguments .positional_only_arguments .iter() - .chain(&function.arguments.arguments) - .chain(&function.arguments.keyword_only_arguments) + .chain(&arguments.arguments) + .chain(&arguments.keyword_only_arguments) { if let Some(type_hint) = &arg.annotation { self.walk_expr(type_hint); } } - for arg in function - .arguments + for arg in arguments .vararg .as_ref() .iter() - .chain(&function.arguments.kwarg.as_ref()) + .chain(&arguments.kwarg.as_ref()) { if let Some(type_hint) = &arg.annotation { self.walk_expr(type_hint); } } - if let Some(type_hint) = &function.returns { + if let Some(type_hint) = returns { self.walk_expr(type_hint); } } @@ -669,6 +736,7 @@ mod tests { }), is_async: false, docstring: None, + overloads: Vec::new(), }; assert_eq!( "def func(posonly, /, arg, *varargs, karg: \"str\", **kwarg: \"str\") -> \"list[str]\": ...", @@ -711,6 +779,7 @@ mod tests { returns: None, is_async: false, docstring: None, + overloads: Vec::new(), }; assert_eq!( "def afunc(posonly=1, /, arg=True, *, karg: \"str\" = \"foo\"): ...", @@ -733,6 +802,7 @@ mod tests { returns: None, is_async: true, docstring: None, + overloads: Vec::new(), }; assert_eq!( "async def foo(): ...", @@ -828,6 +898,7 @@ mod tests { returns: Some(big_type.clone()), is_async: false, docstring: None, + overloads: Vec::new(), }], attributes: Vec::new(), incomplete: true, @@ -849,4 +920,346 @@ mod tests { imports.serialize_expr(&big_type, &mut output); assert_eq!(output, "dict[A, (A3.C, A3.D, B, A2, int, int2, float)]"); } + + #[test] + fn function_stubs_with_overloads() { + use crate::model::OverloadSignature; + + let str_int = Expr::Constant { + value: Constant::Str("int".into()), + }; + let str_str = Expr::Constant { + value: Constant::Str("str".into()), + }; + let str_int_or_str = Expr::Constant { + value: Constant::Str("int | str".into()), + }; + + let module = Module { + name: "test".into(), + modules: Vec::new(), + classes: Vec::new(), + functions: vec![Function { + name: "process".into(), + decorators: Vec::new(), + arguments: Arguments { + positional_only_arguments: Vec::new(), + arguments: vec![Argument { + name: "x".into(), + default_value: None, + annotation: Some(str_int_or_str.clone()), + }], + vararg: None, + keyword_only_arguments: Vec::new(), + kwarg: None, + }, + returns: Some(str_int_or_str), + is_async: false, + docstring: None, + overloads: vec![ + OverloadSignature { + arguments: Arguments { + positional_only_arguments: Vec::new(), + arguments: vec![Argument { + name: "x".into(), + default_value: None, + annotation: Some(str_int.clone()), + }], + vararg: None, + keyword_only_arguments: Vec::new(), + kwarg: None, + }, + returns: Some(str_int), + }, + OverloadSignature { + arguments: Arguments { + positional_only_arguments: Vec::new(), + arguments: vec![Argument { + name: "x".into(), + default_value: None, + annotation: Some(str_str.clone()), + }], + vararg: None, + keyword_only_arguments: Vec::new(), + kwarg: None, + }, + returns: Some(str_str), + }, + ], + }], + attributes: Vec::new(), + incomplete: false, + docstring: None, + }; + let stubs = module_stubs(&module, &[]); + assert_eq!( + stubs, + "from typing import overload\n\n@overload\ndef process(x: \"int\") -> \"int\": ...\n@overload\ndef process(x: \"str\") -> \"str\": ...\n" + ); + } + + #[test] + fn overloaded_method_stubs_on_class() { + use crate::model::OverloadSignature; + + let str_int = Expr::Constant { + value: Constant::Str("int".into()), + }; + let str_str = Expr::Constant { + value: Constant::Str("str".into()), + }; + + let module = Module { + name: "test".into(), + modules: Vec::new(), + classes: vec![Class { + name: "MyClass".into(), + bases: Vec::new(), + methods: vec![Function { + name: "process".into(), + decorators: Vec::new(), + arguments: Arguments { + positional_only_arguments: vec![Argument { + name: "self".into(), + default_value: None, + annotation: None, + }], + arguments: vec![Argument { + name: "x".into(), + default_value: None, + annotation: None, + }], + vararg: None, + keyword_only_arguments: Vec::new(), + kwarg: None, + }, + returns: None, + is_async: false, + docstring: None, + overloads: vec![ + OverloadSignature { + arguments: Arguments { + positional_only_arguments: vec![Argument { + name: "self".into(), + default_value: None, + annotation: None, + }], + arguments: vec![Argument { + name: "x".into(), + default_value: None, + annotation: Some(str_int.clone()), + }], + vararg: None, + keyword_only_arguments: Vec::new(), + kwarg: None, + }, + returns: Some(str_int), + }, + OverloadSignature { + arguments: Arguments { + positional_only_arguments: vec![Argument { + name: "self".into(), + default_value: None, + annotation: None, + }], + arguments: vec![Argument { + name: "x".into(), + default_value: None, + annotation: Some(str_str.clone()), + }], + vararg: None, + keyword_only_arguments: Vec::new(), + kwarg: None, + }, + returns: Some(str_str), + }, + ], + }], + attributes: Vec::new(), + decorators: Vec::new(), + inner_classes: Vec::new(), + docstring: None, + }], + functions: Vec::new(), + attributes: Vec::new(), + incomplete: false, + docstring: None, + }; + let stubs = module_stubs(&module, &[]); + assert_eq!( + stubs, + concat!( + "from typing import overload\n\n", + "class MyClass:\n", + " @overload\n", + " def process(self, /, x: \"int\") -> \"int\": ...\n", + " @overload\n", + " def process(self, /, x: \"str\") -> \"str\": ...\n", + ) + ); + } + + #[test] + fn overloaded_function_with_keyword_only_args() { + use crate::model::OverloadSignature; + + let str_int = Expr::Constant { + value: Constant::Str("int".into()), + }; + let str_str = Expr::Constant { + value: Constant::Str("str".into()), + }; + + let module = Module { + name: "test".into(), + modules: Vec::new(), + classes: Vec::new(), + functions: vec![Function { + name: "fetch".into(), + decorators: Vec::new(), + arguments: Arguments { + positional_only_arguments: Vec::new(), + arguments: vec![Argument { + name: "url".into(), + default_value: None, + annotation: Some(str_str.clone()), + }], + vararg: None, + keyword_only_arguments: vec![Argument { + name: "timeout".into(), + default_value: Some(Expr::Constant { + value: Constant::Int("30".into()), + }), + annotation: Some(str_int.clone()), + }], + kwarg: None, + }, + returns: Some(str_str.clone()), + is_async: false, + docstring: None, + overloads: vec![ + OverloadSignature { + arguments: Arguments { + positional_only_arguments: Vec::new(), + arguments: vec![Argument { + name: "url".into(), + default_value: None, + annotation: Some(str_str.clone()), + }], + vararg: None, + keyword_only_arguments: vec![Argument { + name: "timeout".into(), + default_value: Some(Expr::Constant { + value: Constant::Int("30".into()), + }), + annotation: Some(str_int.clone()), + }], + kwarg: None, + }, + returns: Some(str_str.clone()), + }, + OverloadSignature { + arguments: Arguments { + positional_only_arguments: Vec::new(), + arguments: vec![Argument { + name: "url".into(), + default_value: None, + annotation: Some(str_str.clone()), + }], + vararg: None, + keyword_only_arguments: vec![Argument { + name: "timeout".into(), + default_value: None, + annotation: Some(str_str.clone()), + }], + kwarg: None, + }, + returns: Some(str_int), + }, + ], + }], + attributes: Vec::new(), + incomplete: false, + docstring: None, + }; + let stubs = module_stubs(&module, &[]); + assert_eq!( + stubs, + concat!( + "from typing import overload\n\n", + "@overload\n", + "def fetch(url: \"str\", *, timeout: \"int\" = 30) -> \"str\": ...\n", + "@overload\n", + "def fetch(url: \"str\", *, timeout: \"str\") -> \"int\": ...\n", + ) + ); + } + + #[test] + fn overloaded_function_preserves_existing_decorators() { + use crate::model::OverloadSignature; + + let str_int = Expr::Constant { + value: Constant::Str("int".into()), + }; + + let module = Module { + name: "test".into(), + modules: Vec::new(), + classes: vec![Class { + name: "MyClass".into(), + bases: Vec::new(), + methods: vec![Function { + name: "my_method".into(), + decorators: vec![Expr::Name { + id: "staticmethod".into(), + }], + arguments: Arguments { + positional_only_arguments: Vec::new(), + arguments: Vec::new(), + vararg: None, + keyword_only_arguments: Vec::new(), + kwarg: None, + }, + returns: None, + is_async: false, + docstring: None, + overloads: vec![OverloadSignature { + arguments: Arguments { + positional_only_arguments: Vec::new(), + arguments: vec![Argument { + name: "x".into(), + default_value: None, + annotation: Some(str_int.clone()), + }], + vararg: None, + keyword_only_arguments: Vec::new(), + kwarg: None, + }, + returns: Some(str_int), + }], + }], + attributes: Vec::new(), + decorators: Vec::new(), + inner_classes: Vec::new(), + docstring: None, + }], + functions: Vec::new(), + attributes: Vec::new(), + incomplete: false, + docstring: None, + }; + let stubs = module_stubs(&module, &[]); + // @overload should come before @staticmethod + assert_eq!( + stubs, + concat!( + "from typing import overload\n\n", + "class MyClass:\n", + " @overload\n", + " @staticmethod\n", + " def my_method(x: \"int\") -> \"int\": ...\n", + ) + ); + } } diff --git a/pyo3-macros-backend/src/attributes.rs b/pyo3-macros-backend/src/attributes.rs index 9894c463628..d9dacfa6eda 100644 --- a/pyo3-macros-backend/src/attributes.rs +++ b/pyo3-macros-backend/src/attributes.rs @@ -56,6 +56,7 @@ pub mod kw { syn::custom_keyword!(category); syn::custom_keyword!(from_py_object); syn::custom_keyword!(skip_from_py_object); + syn::custom_keyword!(overload); } fn take_int(read: &mut &str, tracker: &mut usize) -> String { diff --git a/pyo3-macros-backend/src/introspection.rs b/pyo3-macros-backend/src/introspection.rs index aad7eb1f267..34f656c848c 100644 --- a/pyo3-macros-backend/src/introspection.rs +++ b/pyo3-macros-backend/src/introspection.rs @@ -10,6 +10,7 @@ use crate::method::{FnArg, RegularArg}; use crate::py_expr::PyExpr; +use crate::pyfunction::signature::{Signature, SignatureItem}; use crate::pyfunction::FunctionSignature; use crate::utils::{PyO3CratePath, PythonDoc, StrOrExpr}; use proc_macro2::{Span, TokenStream}; @@ -109,6 +110,7 @@ pub fn function_introspection_code( is_returning_not_implemented_on_extraction_error: bool, doc: Option<&PythonDoc>, parent: Option<&Type>, + overloads: &[Signature], ) -> TokenStream { let mut desc = HashMap::from([ ("type", IntrospectionNode::String("function".into())), @@ -161,6 +163,26 @@ pub fn function_introspection_code( IntrospectionNode::IntrospectionId(Some(Cow::Borrowed(parent))), ); } + if !overloads.is_empty() { + let overload_nodes: Vec> = overloads + .iter() + .map(|overload| { + let mut overload_map = HashMap::new(); + overload_map.insert( + "arguments", + overload_arguments_from_signature(overload, first_argument), + ); + if let Some((_, returns)) = &overload.returns { + overload_map.insert( + "returns", + IntrospectionNode::TypeHint(Cow::Owned(returns.as_type_hint())), + ); + } + IntrospectionNode::Map(overload_map).into() + }) + .collect(); + desc.insert("overloads", IntrospectionNode::List(overload_nodes)); + } IntrospectionNode::Map(desc).emit(pyo3_crate_path) } @@ -302,23 +324,7 @@ fn arguments_introspection_data<'a>( kwarg = Some(IntrospectionNode::Map(params)); } - let mut map = HashMap::new(); - if !posonlyargs.is_empty() { - map.insert("posonlyargs", IntrospectionNode::List(posonlyargs)); - } - if !args.is_empty() { - map.insert("args", IntrospectionNode::List(args)); - } - if let Some(vararg) = vararg { - map.insert("vararg", vararg); - } - if !kwonlyargs.is_empty() { - map.insert("kwonlyargs", IntrospectionNode::List(kwonlyargs)); - } - if let Some(kwarg) = kwarg { - map.insert("kwarg", kwarg); - } - IntrospectionNode::Map(map) + build_arguments_map(posonlyargs, args, vararg, kwonlyargs, kwarg) } fn argument_introspection_data<'a>( @@ -347,6 +353,137 @@ fn argument_introspection_data<'a>( IntrospectionNode::Map(params).into() } +fn overload_arguments_from_signature<'a>( + signature: &'a Signature, + first_argument: Option<&'a str>, +) -> IntrospectionNode<'a> { + let mut posonlyargs = Vec::new(); + let mut args = Vec::new(); + let mut vararg = None; + let mut kwonlyargs = Vec::new(); + let mut kwarg = None; + + let mut seen_posargs_sep = false; + + if let Some(first_argument) = first_argument { + posonlyargs.push( + IntrospectionNode::Map( + [("name", IntrospectionNode::String(first_argument.into()))].into(), + ) + .into(), + ); + seen_posargs_sep = true; + } + let mut seen_varargs_sep = false; + + for item in &signature.items { + match item { + SignatureItem::Argument(arg) => { + let mut params: HashMap<&'static str, IntrospectionNode<'a>> = [( + "name", + IntrospectionNode::String(arg.ident.to_string().into()), + )] + .into(); + if let Some((_, annotation)) = &arg.colon_and_annotation { + params.insert( + "annotation", + IntrospectionNode::TypeHint(Cow::Owned(annotation.as_type_hint())), + ); + } + if let Some((_, default)) = &arg.eq_and_default { + params.insert( + "default", + IntrospectionNode::TypeHint(Cow::Owned(PyExpr::constant_from_expression( + default, + ))), + ); + } + let node: AttributedIntrospectionNode<'a> = IntrospectionNode::Map(params).into(); + if seen_varargs_sep { + kwonlyargs.push(node); + } else if !seen_posargs_sep { + posonlyargs.push(node); + } else { + args.push(node); + } + } + SignatureItem::PosargsSep(_) => { + seen_posargs_sep = true; + } + SignatureItem::VarargsSep(_) => { + seen_varargs_sep = true; + if !seen_posargs_sep { + args.append(&mut posonlyargs); + } + } + SignatureItem::Varargs(v) => { + seen_varargs_sep = true; + if !seen_posargs_sep { + args.append(&mut posonlyargs); + } + let mut params: HashMap<&'static str, IntrospectionNode<'a>> = [( + "name", + IntrospectionNode::String(v.ident.to_string().into()), + )] + .into(); + if let Some((_, annotation)) = &v.colon_and_annotation { + params.insert( + "annotation", + IntrospectionNode::TypeHint(Cow::Owned(annotation.as_type_hint())), + ); + } + vararg = Some(IntrospectionNode::Map(params)); + } + SignatureItem::Kwargs(kw) => { + let mut params: HashMap<&'static str, IntrospectionNode<'a>> = [( + "name", + IntrospectionNode::String(kw.ident.to_string().into()), + )] + .into(); + if let Some((_, annotation)) = &kw.colon_and_annotation { + params.insert( + "annotation", + IntrospectionNode::TypeHint(Cow::Owned(annotation.as_type_hint())), + ); + } + kwarg = Some(IntrospectionNode::Map(params)); + } + } + } + + if !seen_posargs_sep && !seen_varargs_sep { + args.append(&mut posonlyargs); + } + + build_arguments_map(posonlyargs, args, vararg, kwonlyargs, kwarg) +} + +fn build_arguments_map<'a>( + posonlyargs: Vec>, + args: Vec>, + vararg: Option>, + kwonlyargs: Vec>, + kwarg: Option>, +) -> IntrospectionNode<'a> { + let mut map = HashMap::new(); + if !posonlyargs.is_empty() { + map.insert("posonlyargs", IntrospectionNode::List(posonlyargs)); + } + if !args.is_empty() { + map.insert("args", IntrospectionNode::List(args)); + } + if let Some(vararg) = vararg { + map.insert("vararg", vararg); + } + if !kwonlyargs.is_empty() { + map.insert("kwonlyargs", IntrospectionNode::List(kwonlyargs)); + } + if let Some(kwarg) = kwarg { + map.insert("kwarg", kwarg); + } + IntrospectionNode::Map(map) +} + enum IntrospectionNode<'a> { String(Cow<'a, str>), Bool(bool), diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 3ec89dc08ca..0f53b40138c 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -10,7 +10,7 @@ use syn::{ext::IdentExt, spanned::Spanned, Ident, Result}; use crate::params::is_forwarded_args; #[cfg(feature = "experimental-inspect")] use crate::py_expr::PyExpr; -use crate::pyfunction::{PyFunctionWarning, WarningFactory}; +use crate::pyfunction::{OverloadAttribute, PyFunctionWarning, WarningFactory}; use crate::utils::Ctx; use crate::{ attributes::{FromPyWithAttribute, TextSignatureAttribute, TextSignatureAttributeValue}, @@ -449,6 +449,8 @@ pub struct FnSpec<'a> { pub asyncness: Option, pub unsafety: Option, pub warnings: Vec, + #[cfg_attr(not(feature = "experimental-inspect"), allow(dead_code))] + pub overloads: Vec, pub output: syn::ReturnType, } @@ -491,6 +493,7 @@ impl<'a> FnSpec<'a> { name, signature, warnings, + overloads, .. } = options; @@ -528,6 +531,7 @@ impl<'a> FnSpec<'a> { asyncness: sig.asyncness, unsafety: sig.unsafety, warnings, + overloads, output: sig.output.clone(), }) } diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index f309c7702a8..a4cf27f3932 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -1813,6 +1813,7 @@ impl FunctionIntrospectionData<'_> { self.is_returning_not_implemented_on_extraction_error, None, Some(cls), + &[], ) }) .collect() @@ -2044,6 +2045,7 @@ fn complex_enum_struct_variant_new<'a>( asyncness: None, unsafety: None, warnings: vec![], + overloads: vec![], output: syn::ReturnType::Default, }; @@ -2109,6 +2111,7 @@ fn complex_enum_tuple_variant_new<'a>( asyncness: None, unsafety: None, warnings: vec![], + overloads: vec![], output: syn::ReturnType::Default, }; @@ -2152,6 +2155,7 @@ fn complex_enum_variant_field_getter( asyncness: None, unsafety: None, warnings: vec![], + overloads: vec![], output: parse_quote!(-> #field_type), }; @@ -2222,6 +2226,7 @@ fn descriptors_to_items( false, utils::get_doc(&field.attrs, None).as_ref(), Some(&parse_quote!(#cls)), + &[], )); } items.push(getter); @@ -2277,6 +2282,7 @@ fn descriptors_to_items( false, get_doc(&field.attrs, None).as_ref(), Some(&parse_quote!(#cls)), + &[], )); } items.push(setter); diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs index 3fa4b9b5317..a3db3b43c2b 100644 --- a/pyo3-macros-backend/src/pyfunction.rs +++ b/pyo3-macros-backend/src/pyfunction.rs @@ -24,9 +24,9 @@ use syn::punctuated::Punctuated; use syn::LitCStr; use syn::{ext::IdentExt, spanned::Spanned, LitStr, Path, Result, Token}; -mod signature; +pub mod signature; -pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute}; +pub use self::signature::{ConstructorAttribute, FunctionSignature, Signature, SignatureAttribute}; #[derive(Clone, Debug)] pub struct PyFunctionArgPyO3Attributes { @@ -235,6 +235,20 @@ impl ToTokens for PyFunctionWarningAttribute { } } +#[derive(Clone)] +pub struct OverloadAttribute { + pub kw: attributes::kw::overload, + pub value: Signature, +} + +impl Parse for OverloadAttribute { + fn parse(input: ParseStream<'_>) -> Result { + let kw = input.parse()?; + let value = input.parse()?; + Ok(Self { kw, value }) + } +} + #[derive(Default)] pub struct PyFunctionOptions { pub pass_module: Option, @@ -243,6 +257,7 @@ pub struct PyFunctionOptions { pub text_signature: Option, pub krate: Option, pub warnings: Vec, + pub overloads: Vec, } impl Parse for PyFunctionOptions { @@ -263,6 +278,7 @@ pub enum PyFunctionOption { TextSignature(TextSignatureAttribute), Crate(CrateAttribute), Warning(PyFunctionWarningAttribute), + Overload(OverloadAttribute), } impl Parse for PyFunctionOption { @@ -280,6 +296,8 @@ impl Parse for PyFunctionOption { input.parse().map(PyFunctionOption::Crate) } else if lookahead.peek(attributes::kw::warn) { input.parse().map(PyFunctionOption::Warning) + } else if lookahead.peek(attributes::kw::overload) { + input.parse().map(PyFunctionOption::Overload) } else { Err(lookahead.error()) } @@ -318,6 +336,17 @@ impl PyFunctionOptions { PyFunctionOption::Warning(warning) => { self.warnings.push(warning.into()); } + PyFunctionOption::Overload(overload) => { + ensure_spanned!( + cfg!(feature = "experimental-inspect"), + overload.kw.span() => "`overload` is only supported with the `experimental-inspect` feature" + ); + ensure_spanned!( + overload.value.returns.is_some(), + overload.kw.span() => "`overload` must include a return type annotation (e.g. `overload(x: \"int\") -> \"int\"`)" + ); + self.overloads.push(overload); + } } } Ok(()) @@ -346,6 +375,7 @@ pub fn impl_wrap_pyfunction( text_signature, krate, warnings, + overloads, } = options; let ctx = &Ctx::new(&krate, Some(&func.sig)); @@ -395,12 +425,15 @@ pub fn impl_wrap_pyfunction( asyncness: func.sig.asyncness, unsafety: func.sig.unsafety, warnings, + overloads, output: func.sig.output.clone(), }; let vis = &func.vis; let name = &func.sig.ident; + #[cfg(feature = "experimental-inspect")] + let overload_sigs: Vec<_> = spec.overloads.iter().map(|o| o.value.clone()).collect(); #[cfg(feature = "experimental-inspect")] let introspection = function_introspection_code( pyo3_path, @@ -414,6 +447,7 @@ pub fn impl_wrap_pyfunction( false, get_doc(&func.attrs, None).as_ref(), None, + &overload_sigs, ); #[cfg(not(feature = "experimental-inspect"))] let introspection = quote! {}; diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index 95d8a6c45c6..b5919d0e004 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -492,6 +492,7 @@ pub fn method_introspection_code( } else { spec.output.clone() }; + let overload_sigs: Vec<_> = spec.overloads.iter().map(|o| o.value.clone()).collect(); function_introspection_code( pyo3_path, None, @@ -504,5 +505,6 @@ pub fn method_introspection_code( is_returning_not_implemented_on_extraction_error, get_doc(attrs, None).as_ref(), Some(parent), + &overload_sigs, ) } diff --git a/pytests/src/pyclasses.rs b/pytests/src/pyclasses.rs index e6f5fd78b4d..74f4afa044b 100644 --- a/pytests/src/pyclasses.rs +++ b/pytests/src/pyclasses.rs @@ -330,11 +330,34 @@ impl Number { } } +#[cfg(feature = "experimental-inspect")] +#[pyclass] +struct ClassWithOverloads; + +#[cfg(feature = "experimental-inspect")] +#[pymethods] +impl ClassWithOverloads { + #[new] + fn new() -> Self { + ClassWithOverloads + } + + #[pyo3(overload(x: "int") -> "int")] + #[pyo3(overload(x: "str") -> "str")] + #[pyo3(signature = (x))] + fn process<'py>(&self, x: Bound<'py, PyAny>) -> Bound<'py, PyAny> { + x + } +} + #[pymodule] pub mod pyclasses { #[cfg(any(Py_3_10, not(Py_LIMITED_API)))] #[pymodule_export] use super::ClassWithDict; + #[cfg(feature = "experimental-inspect")] + #[pymodule_export] + use super::ClassWithOverloads; #[cfg(not(any(Py_LIMITED_API, GraalPy)))] #[pymodule_export] use super::SubClassWithInit; diff --git a/pytests/src/pyfunctions.rs b/pytests/src/pyfunctions.rs index 6e1015e7627..c0ca3904831 100644 --- a/pytests/src/pyfunctions.rs +++ b/pytests/src/pyfunctions.rs @@ -88,6 +88,15 @@ fn with_custom_type_annotations<'py>( a } +#[cfg(feature = "experimental-inspect")] +#[pyfunction] +#[pyo3(overload(x: "int") -> "int")] +#[pyo3(overload(x: "str") -> "str")] +#[pyo3(signature = (x))] +fn with_overloads<'py>(x: Any<'py>) -> Any<'py> { + x +} + #[cfg(feature = "experimental-async")] #[pyfunction] async fn with_async() {} @@ -143,12 +152,12 @@ pub mod pyfunctions { #[cfg(feature = "experimental-async")] #[pymodule_export] use super::with_async; - #[cfg(feature = "experimental-inspect")] - #[pymodule_export] - use super::with_custom_type_annotations; #[pymodule_export] use super::{ args_kwargs, many_keyword_arguments, none, positional_only, simple, simple_args, simple_args_kwargs, simple_kwargs, with_typed_args, }; + #[cfg(feature = "experimental-inspect")] + #[pymodule_export] + use super::{with_custom_type_annotations, with_overloads}; } diff --git a/pytests/stubs/pyclasses.pyi b/pytests/stubs/pyclasses.pyi index 9dd201c2d03..2cc80bda4d7 100644 --- a/pytests/stubs/pyclasses.pyi +++ b/pytests/stubs/pyclasses.pyi @@ -1,5 +1,5 @@ from _typeshed import Incomplete -from typing import Final, final +from typing import Final, final, overload class AssertingBaseClass: """ @@ -44,6 +44,14 @@ class ClassWithDecorators: class ClassWithDict: def __new__(cls, /) -> ClassWithDict: ... +@final +class ClassWithOverloads: + def __new__(cls, /) -> ClassWithOverloads: ... + @overload + def process(self, /, x: "int") -> "int": ... + @overload + def process(self, /, x: "str") -> "str": ... + @final class ClassWithoutConstructor: ... diff --git a/pytests/stubs/pyfunctions.pyi b/pytests/stubs/pyfunctions.pyi index 1d5cca9a33d..c2a4919afc2 100644 --- a/pytests/stubs/pyfunctions.pyi +++ b/pytests/stubs/pyfunctions.pyi @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, overload def args_kwargs(*args, **kwargs) -> tuple[tuple, dict | None]: ... def many_keyword_arguments( @@ -38,6 +38,10 @@ async def with_async() -> None: ... def with_custom_type_annotations( a: "int", *_args: "str", _b: "int | None" = None, **_kwargs: "bool" ) -> "int": ... +@overload +def with_overloads(x: "int") -> "int": ... +@overload +def with_overloads(x: "str") -> "str": ... def with_typed_args( a: bool = False, b: int = 0, c: float = 0.0, d: str = "" ) -> tuple[bool, int, float, str]: ... diff --git a/tests/ui/invalid_pyfunction_signatures.stderr b/tests/ui/invalid_pyfunction_signatures.stderr index c7ddb11dd37..610a8c076b7 100644 --- a/tests/ui/invalid_pyfunction_signatures.stderr +++ b/tests/ui/invalid_pyfunction_signatures.stderr @@ -16,7 +16,7 @@ error: expected argument from function definition `y` but got argument `x` 13 | #[pyo3(signature = (x))] | ^ -error: expected one of: `name`, `pass_module`, `signature`, `text_signature`, `crate`, `warn` +error: expected one of: `name`, `pass_module`, `signature`, `text_signature`, `crate`, `warn`, `overload` --> tests/ui/invalid_pyfunction_signatures.rs:18:14 | 18 | #[pyfunction(x)]