diff --git a/examples/pure/Cargo.toml b/examples/pure/Cargo.toml index 51bf261..84fedb6 100644 --- a/examples/pure/Cargo.toml +++ b/examples/pure/Cargo.toml @@ -11,6 +11,7 @@ ahash.workspace = true env_logger.workspace = true pyo3-stub-gen = { path = "../../pyo3-stub-gen" } pyo3.workspace = true +pyo3.features = ["experimental-async"] [[bin]] name = "stub_gen" diff --git a/examples/pure/pure.pyi b/examples/pure/pure.pyi index 0e22628..6ef390d 100644 --- a/examples/pure/pure.pyi +++ b/examples/pure/pure.pyi @@ -16,6 +16,9 @@ class A: def ref_test(self, x:dict) -> dict: ... + async def async_get_x(self) -> int: + ... + class Number(Enum): FLOAT = auto() @@ -24,6 +27,9 @@ class Number(Enum): def ahash_dict() -> dict[str, int]: ... +async def async_num() -> int: + ... + def create_a(x:int) -> A: ... diff --git a/examples/pure/pyproject.toml b/examples/pure/pyproject.toml index 90a81dd..6cc8cd7 100644 --- a/examples/pure/pyproject.toml +++ b/examples/pure/pyproject.toml @@ -7,4 +7,4 @@ name = "pure" requires-python = ">=3.9" [project.optional-dependencies] -test = ["pytest", "pyright", "ruff"] +test = ["pytest", "pyright", "ruff", "pytest-asyncio"] diff --git a/examples/pure/src/lib.rs b/examples/pure/src/lib.rs index f78bbd4..7d0e86b 100644 --- a/examples/pure/src/lib.rs +++ b/examples/pure/src/lib.rs @@ -56,6 +56,10 @@ impl A { fn ref_test<'a>(&self, x: Bound<'a, PyDict>) -> Bound<'a, PyDict> { x } + + async fn async_get_x(&self) -> usize { + self.x + } } #[gen_stub_pyfunction] @@ -101,6 +105,12 @@ pub enum Number { module_variable!("pure", "MY_CONSTANT", usize); +#[gen_stub_pyfunction] +#[pyfunction] +async fn async_num() -> i32 { + 123 +} + /// Initializes the Python module #[pymodule] fn pure(m: &Bound) -> PyResult<()> { @@ -115,6 +125,7 @@ fn pure(m: &Bound) -> PyResult<()> { m.add_function(wrap_pyfunction!(str_len, m)?)?; m.add_function(wrap_pyfunction!(echo_path, m)?)?; m.add_function(wrap_pyfunction!(ahash_dict, m)?)?; + m.add_function(wrap_pyfunction!(async_num, m)?)?; Ok(()) } diff --git a/examples/pure/tests/test_python.py b/examples/pure/tests/test_python.py index 5c27e5c..2f4e301 100644 --- a/examples/pure/tests/test_python.py +++ b/examples/pure/tests/test_python.py @@ -1,4 +1,4 @@ -from pure import sum, create_dict, read_dict, echo_path, ahash_dict +from pure import sum, create_dict, read_dict, echo_path, ahash_dict, async_num, create_a import pytest import pathlib @@ -42,3 +42,10 @@ def test_path(): out = echo_path("test") assert out == "test" + + +@pytest.mark.asyncio +async def test_async(): + assert await async_num() == 123 + a = create_a(1337) + assert await a.async_get_x() == 1337 diff --git a/pyo3-stub-gen-derive/src/gen_stub/method.rs b/pyo3-stub-gen-derive/src/gen_stub/method.rs index 633d9f6..f0bbcde 100644 --- a/pyo3-stub-gen-derive/src/gen_stub/method.rs +++ b/pyo3-stub-gen-derive/src/gen_stub/method.rs @@ -18,6 +18,7 @@ pub struct MethodInfo { doc: String, is_static: bool, is_class: bool, + is_async: bool, } fn replace_inner(ty: &mut Type, self_: &Type) { @@ -83,6 +84,7 @@ impl TryFrom for MethodInfo { doc, is_static, is_class, + is_async: sig.asyncness.is_some(), }) } } @@ -97,6 +99,7 @@ impl ToTokens for MethodInfo { doc, is_class, is_static, + is_async, } = self; let sig_tt = quote_option(sig); let ret_tt = if let Some(ret) = ret { @@ -113,6 +116,7 @@ impl ToTokens for MethodInfo { doc: #doc, is_static: #is_static, is_class: #is_class, + is_async: #is_async, } }) } diff --git a/pyo3-stub-gen-derive/src/gen_stub/pyfunction.rs b/pyo3-stub-gen-derive/src/gen_stub/pyfunction.rs index 19a4a13..721d12d 100644 --- a/pyo3-stub-gen-derive/src/gen_stub/pyfunction.rs +++ b/pyo3-stub-gen-derive/src/gen_stub/pyfunction.rs @@ -17,6 +17,7 @@ pub struct PyFunctionInfo { sig: Option, doc: String, module: Option, + is_async: bool, } struct ModuleAttr { @@ -69,6 +70,7 @@ impl TryFrom for PyFunctionInfo { name, doc, module: None, + is_async: item.sig.asyncness.is_some(), }) } } @@ -82,6 +84,7 @@ impl ToTokens for PyFunctionInfo { doc, sig, module, + is_async, } = self; let ret_tt = if let Some(ret) = ret { quote! { <#ret as pyo3_stub_gen::PyStubType>::type_output } @@ -98,6 +101,7 @@ impl ToTokens for PyFunctionInfo { doc: #doc, signature: #sig_tt, module: #module_tt, + is_async: #is_async, } }) } diff --git a/pyo3-stub-gen/src/generate/function.rs b/pyo3-stub-gen/src/generate/function.rs index c8c0eb1..7b2de13 100644 --- a/pyo3-stub-gen/src/generate/function.rs +++ b/pyo3-stub-gen/src/generate/function.rs @@ -9,6 +9,7 @@ pub struct FunctionDef { pub r#return: TypeInfo, pub signature: Option<&'static str>, pub doc: &'static str, + pub is_async: bool, } impl Import for FunctionDef { @@ -29,13 +30,15 @@ impl From<&PyFunctionInfo> for FunctionDef { r#return: (info.r#return)(), doc: info.doc, signature: info.signature, + is_async: info.is_async, } } } impl fmt::Display for FunctionDef { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "def {}(", self.name)?; + let async_ = if self.is_async { "async" } else { "" }; + write!(f, "{async_}def {}(", self.name)?; if let Some(sig) = self.signature { write!(f, "{}", sig)?; } else { diff --git a/pyo3-stub-gen/src/generate/method.rs b/pyo3-stub-gen/src/generate/method.rs index 82a20c6..1ac7af7 100644 --- a/pyo3-stub-gen/src/generate/method.rs +++ b/pyo3-stub-gen/src/generate/method.rs @@ -11,6 +11,7 @@ pub struct MethodDef { pub doc: &'static str, pub is_static: bool, pub is_class: bool, + pub is_async: bool, } impl Import for MethodDef { @@ -33,6 +34,7 @@ impl From<&MethodInfo> for MethodDef { doc: info.doc, is_static: info.is_static, is_class: info.is_class, + is_async: info.is_async, } } } @@ -41,15 +43,16 @@ impl fmt::Display for MethodDef { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let indent = indent(); let mut needs_comma = false; + let async_ = if self.is_async { "async" } else { "" }; if self.is_static { writeln!(f, "{indent}@staticmethod")?; - write!(f, "{indent}def {}(", self.name)?; + write!(f, "{indent}{async_}def {}(", self.name)?; } else if self.is_class { writeln!(f, "{indent}@classmethod")?; - write!(f, "{indent}def {}(cls", self.name)?; + write!(f, "{indent}{async_}def {}(cls", self.name)?; needs_comma = true; } else { - write!(f, "{indent}def {}(self", self.name)?; + write!(f, "{indent}{async_}def {}(self", self.name)?; needs_comma = true; } if let Some(signature) = self.signature { diff --git a/pyo3-stub-gen/src/type_info.rs b/pyo3-stub-gen/src/type_info.rs index 1ae983a..9aee12e 100644 --- a/pyo3-stub-gen/src/type_info.rs +++ b/pyo3-stub-gen/src/type_info.rs @@ -51,6 +51,7 @@ pub struct MethodInfo { pub doc: &'static str, pub is_static: bool, pub is_class: bool, + pub is_async: bool, } /// Info of getter method decorated with `#[getter]` or `#[pyo3(get, set)]` appears in `#[pyclass]` @@ -125,6 +126,7 @@ pub struct PyFunctionInfo { pub doc: &'static str, pub signature: Option<&'static str>, pub module: Option<&'static str>, + pub is_async: bool, } inventory::collect!(PyFunctionInfo);