Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async fn support for free fns and methods #108

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/pure/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ahash.workspace = true
env_logger.workspace = true
pyo3-stub-gen = { path = "../../pyo3-stub-gen" }
pyo3.workspace = true
pyo3.features = ["experimental-async"]

[[bin]]
name = "stub_gen"
Expand Down
6 changes: 6 additions & 0 deletions examples/pure/pure.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class A:
def ref_test(self, x:dict) -> dict:
...

async def async_get_x(self) -> int:
...


class Number(Enum):
FLOAT = auto()
Expand All @@ -24,6 +27,9 @@ class Number(Enum):
def ahash_dict() -> dict[str, int]:
...

async def async_num() -> int:
...

def create_a(x:int) -> A:
...

Expand Down
2 changes: 1 addition & 1 deletion examples/pure/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ name = "pure"
requires-python = ">=3.9"

[project.optional-dependencies]
test = ["pytest", "pyright", "ruff"]
test = ["pytest", "pyright", "ruff", "pytest-asyncio"]
11 changes: 11 additions & 0 deletions examples/pure/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ impl A {
fn ref_test<'a>(&self, x: Bound<'a, PyDict>) -> Bound<'a, PyDict> {
x
}

async fn async_get_x(&self) -> usize {
self.x
}
}

#[gen_stub_pyfunction]
Expand Down Expand Up @@ -101,6 +105,12 @@ pub enum Number {

module_variable!("pure", "MY_CONSTANT", usize);

#[gen_stub_pyfunction]
#[pyfunction]
async fn async_num() -> i32 {
123
}

/// Initializes the Python module
#[pymodule]
fn pure(m: &Bound<PyModule>) -> PyResult<()> {
Expand All @@ -115,6 +125,7 @@ fn pure(m: &Bound<PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(str_len, m)?)?;
m.add_function(wrap_pyfunction!(echo_path, m)?)?;
m.add_function(wrap_pyfunction!(ahash_dict, m)?)?;
m.add_function(wrap_pyfunction!(async_num, m)?)?;
Ok(())
}

Expand Down
9 changes: 8 additions & 1 deletion examples/pure/tests/test_python.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pure import sum, create_dict, read_dict, echo_path, ahash_dict
from pure import sum, create_dict, read_dict, echo_path, ahash_dict, async_num, create_a
import pytest
import pathlib

Expand Down Expand Up @@ -42,3 +42,10 @@ def test_path():

out = echo_path("test")
assert out == "test"


@pytest.mark.asyncio
async def test_async():
assert await async_num() == 123
a = create_a(1337)
assert await a.async_get_x() == 1337
4 changes: 4 additions & 0 deletions pyo3-stub-gen-derive/src/gen_stub/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct MethodInfo {
doc: String,
is_static: bool,
is_class: bool,
is_async: bool,
}

fn replace_inner(ty: &mut Type, self_: &Type) {
Expand Down Expand Up @@ -83,6 +84,7 @@ impl TryFrom<ImplItemFn> for MethodInfo {
doc,
is_static,
is_class,
is_async: sig.asyncness.is_some(),
})
}
}
Expand All @@ -97,6 +99,7 @@ impl ToTokens for MethodInfo {
doc,
is_class,
is_static,
is_async,
} = self;
let sig_tt = quote_option(sig);
let ret_tt = if let Some(ret) = ret {
Expand All @@ -113,6 +116,7 @@ impl ToTokens for MethodInfo {
doc: #doc,
is_static: #is_static,
is_class: #is_class,
is_async: #is_async,
}
})
}
Expand Down
4 changes: 4 additions & 0 deletions pyo3-stub-gen-derive/src/gen_stub/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub struct PyFunctionInfo {
sig: Option<Signature>,
doc: String,
module: Option<String>,
is_async: bool,
}

struct ModuleAttr {
Expand Down Expand Up @@ -69,6 +70,7 @@ impl TryFrom<ItemFn> for PyFunctionInfo {
name,
doc,
module: None,
is_async: item.sig.asyncness.is_some(),
})
}
}
Expand All @@ -82,6 +84,7 @@ impl ToTokens for PyFunctionInfo {
doc,
sig,
module,
is_async,
} = self;
let ret_tt = if let Some(ret) = ret {
quote! { <#ret as pyo3_stub_gen::PyStubType>::type_output }
Expand All @@ -98,6 +101,7 @@ impl ToTokens for PyFunctionInfo {
doc: #doc,
signature: #sig_tt,
module: #module_tt,
is_async: #is_async,
}
})
}
Expand Down
5 changes: 4 additions & 1 deletion pyo3-stub-gen/src/generate/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub struct FunctionDef {
pub r#return: TypeInfo,
pub signature: Option<&'static str>,
pub doc: &'static str,
pub is_async: bool,
}

impl Import for FunctionDef {
Expand All @@ -29,13 +30,15 @@ impl From<&PyFunctionInfo> for FunctionDef {
r#return: (info.r#return)(),
doc: info.doc,
signature: info.signature,
is_async: info.is_async,
}
}
}

impl fmt::Display for FunctionDef {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "def {}(", self.name)?;
let async_ = if self.is_async { "async" } else { "" };
write!(f, "{async_}def {}(", self.name)?;
if let Some(sig) = self.signature {
write!(f, "{}", sig)?;
} else {
Expand Down
9 changes: 6 additions & 3 deletions pyo3-stub-gen/src/generate/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct MethodDef {
pub doc: &'static str,
pub is_static: bool,
pub is_class: bool,
pub is_async: bool,
}

impl Import for MethodDef {
Expand All @@ -33,6 +34,7 @@ impl From<&MethodInfo> for MethodDef {
doc: info.doc,
is_static: info.is_static,
is_class: info.is_class,
is_async: info.is_async,
}
}
}
Expand All @@ -41,15 +43,16 @@ impl fmt::Display for MethodDef {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let indent = indent();
let mut needs_comma = false;
let async_ = if self.is_async { "async" } else { "" };
if self.is_static {
writeln!(f, "{indent}@staticmethod")?;
write!(f, "{indent}def {}(", self.name)?;
write!(f, "{indent}{async_}def {}(", self.name)?;
} else if self.is_class {
writeln!(f, "{indent}@classmethod")?;
write!(f, "{indent}def {}(cls", self.name)?;
write!(f, "{indent}{async_}def {}(cls", self.name)?;
needs_comma = true;
} else {
write!(f, "{indent}def {}(self", self.name)?;
write!(f, "{indent}{async_}def {}(self", self.name)?;
needs_comma = true;
}
if let Some(signature) = self.signature {
Expand Down
2 changes: 2 additions & 0 deletions pyo3-stub-gen/src/type_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub struct MethodInfo {
pub doc: &'static str,
pub is_static: bool,
pub is_class: bool,
pub is_async: bool,
}

/// Info of getter method decorated with `#[getter]` or `#[pyo3(get, set)]` appears in `#[pyclass]`
Expand Down Expand Up @@ -125,6 +126,7 @@ pub struct PyFunctionInfo {
pub doc: &'static str,
pub signature: Option<&'static str>,
pub module: Option<&'static str>,
pub is_async: bool,
}

inventory::collect!(PyFunctionInfo);
Expand Down