Skip to content

Commit 1e40dd7

Browse files
committed
First step towards safely creating universal functions in Rust.
1 parent 0832b28 commit 1e40dd7

File tree

4 files changed

+266
-0
lines changed

4 files changed

+266
-0
lines changed

src/dtype.rs

+9
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,11 @@ pub unsafe trait Element: Clone + Send {
691691

692692
/// Returns the associated type descriptor ("dtype") for the given element type.
693693
fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr>;
694+
695+
/// TODO
696+
fn get_npy_type() -> Option<NPY_TYPES> {
697+
None
698+
}
694699
}
695700

696701
fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
@@ -747,6 +752,10 @@ macro_rules! impl_element_scalar {
747752
fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
748753
PyArrayDescr::from_npy_type(py, $npy_type)
749754
}
755+
756+
fn get_npy_type() -> Option<NPY_TYPES> {
757+
Some($npy_type)
758+
}
750759
}
751760
};
752761
($ty:ty => $npy_type:ident $(,#[$meta:meta])*) => {

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ pub mod npyffi;
8383
mod slice_container;
8484
mod strings;
8585
mod sum_products;
86+
pub mod ufunc;
8687
mod untyped_array;
8788

8889
pub use ndarray;

src/npyffi/flags.rs

+7
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,10 @@ pub const NPY_OBJECT_DTYPE_FLAGS: npy_char = NPY_LIST_PICKLE
8181
| NPY_ITEM_REFCOUNT
8282
| NPY_NEEDS_INIT
8383
| NPY_NEEDS_PYAPI;
84+
85+
pub const NPY_UFUNC_ZERO: c_int = 0;
86+
pub const NPY_UFUNC_ONE: c_int = 1;
87+
pub const NPY_UFUNC_MINUS_ONE: c_int = 2;
88+
pub const NPY_UFUNC_NONE: c_int = -1;
89+
pub const NPY_UFUNC_REORDERABLE_NONE: c_int = -2;
90+
pub const NPY_UFUNC_IDENTITY_VALUE: c_int = -3;

src/ufunc.rs

+249
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
//! TODO
2+
3+
use std::ffi::CString;
4+
use std::mem::{size_of, MaybeUninit};
5+
use std::os::raw::{c_char, c_int, c_void};
6+
use std::ptr::null_mut;
7+
8+
use ndarray::{ArrayView1, ArrayViewMut1, Axis, Dim, Ix1, ShapeBuilder, StrideShape};
9+
use pyo3::{Bound, PyAny, PyResult, Python};
10+
11+
use crate::{
12+
dtype::Element,
13+
npyffi::{flags, npy_intp, objects::PyUFuncGenericFunction, ufunc::PY_UFUNC_API},
14+
};
15+
16+
/// TODO
17+
#[repr(i32)]
18+
#[derive(Debug)]
19+
pub enum Identity {
20+
/// UFunc has unit of 0, and the order of operations can be reordered
21+
/// This case allows reduction with multiple axes at once.
22+
Zero = flags::NPY_UFUNC_ZERO,
23+
/// UFunc has unit of 1, and the order of operations can be reordered
24+
/// This case allows reduction with multiple axes at once.
25+
One = flags::NPY_UFUNC_ONE,
26+
/// UFunc has unit of -1, and the order of operations can be reordered
27+
/// This case allows reduction with multiple axes at once. Intended for bitwise_and reduction.
28+
MinusOne = flags::NPY_UFUNC_MINUS_ONE,
29+
/// UFunc has no unit, and the order of operations cannot be reordered.
30+
/// This case does not allow reduction with multiple axes at once.
31+
None = flags::NPY_UFUNC_NONE,
32+
/// UFunc has no unit, and the order of operations can be reordered
33+
/// This case allows reduction with multiple axes at once.
34+
ReorderableNone = flags::NPY_UFUNC_REORDERABLE_NONE,
35+
/// UFunc unit is an identity_value, and the order of operations can be reordered
36+
/// This case allows reduction with multiple axes at once.
37+
IdentityValue = flags::NPY_UFUNC_IDENTITY_VALUE,
38+
}
39+
40+
/// TODO
41+
///
42+
/// ```
43+
/// # #![allow(mixed_script_confusables)]
44+
/// # use std::ffi::CString;
45+
/// #
46+
/// use pyo3::{py_run, Python};
47+
/// use ndarray::{azip, ArrayView1, ArrayViewMut1};
48+
/// use numpy::ufunc::{from_func, Identity};
49+
///
50+
/// Python::with_gil(|py| {
51+
/// let logit = |[p]: [ArrayView1<'_, f64>; 1], [α]: [ArrayViewMut1<'_, f64>; 1]| {
52+
/// azip!((p in p, α in α) {
53+
/// let mut tmp = *p;
54+
/// tmp /= 1.0 - tmp;
55+
/// *α = tmp.ln();
56+
/// });
57+
/// };
58+
///
59+
/// let logit =
60+
/// from_func(py, CString::new("logit").unwrap(), Identity::None, logit).unwrap();
61+
///
62+
/// py_run!(py, logit, "assert logit(0.5) == 0.0");
63+
///
64+
/// let np = py.import("numpy").unwrap();
65+
///
66+
/// py_run!(py, logit np, "assert (logit(np.full(100, 0.5)) == np.zeros(100)).all()");
67+
/// });
68+
/// ```
69+
///
70+
/// ```
71+
/// # #![allow(mixed_script_confusables)]
72+
/// # use std::ffi::CString;
73+
/// #
74+
/// use pyo3::{py_run, Python};
75+
/// use ndarray::{azip, ArrayView1, ArrayViewMut1};
76+
/// use numpy::ufunc::{from_func, Identity};
77+
///
78+
/// Python::with_gil(|py| {
79+
/// let cart_to_polar = |[x, y]: [ArrayView1<'_, f64>; 2], [r, φ]: [ArrayViewMut1<'_, f64>; 2]| {
80+
/// azip!((&x in x, &y in y, r in r, φ in φ) {
81+
/// *r = f64::hypot(x, y);
82+
/// *φ = f64::atan2(x, y);
83+
/// });
84+
/// };
85+
///
86+
/// let cart_to_polar = from_func(
87+
/// py,
88+
/// CString::new("cart_to_polar").unwrap(),
89+
/// Identity::None,
90+
/// cart_to_polar,
91+
/// )
92+
/// .unwrap();
93+
///
94+
/// let np = py.import("numpy").unwrap();
95+
///
96+
/// py_run!(py, cart_to_polar np, "np.testing.assert_array_almost_equal(cart_to_polar(3.0, 4.0), (5.0, 0.643501))");
97+
///
98+
/// py_run!(py, cart_to_polar np, "np.testing.assert_array_almost_equal(cart_to_polar(np.full((10, 10), 3.0), np.full((10, 10), 4.0))[0], np.full((10, 10), 5.0))");
99+
/// py_run!(py, cart_to_polar np, "np.testing.assert_array_almost_equal(cart_to_polar(np.full((10, 10), 3.0), np.full((10, 10), 4.0))[1], np.full((10, 10), 0.643501))");
100+
/// });
101+
/// ```
102+
pub fn from_func<'py, T, F, const NIN: usize, const NOUT: usize>(
103+
py: Python<'py>,
104+
name: CString,
105+
identity: Identity,
106+
func: F,
107+
) -> PyResult<Bound<'py, PyAny>>
108+
where
109+
T: Element,
110+
F: Fn([ArrayView1<'_, T>; NIN], [ArrayViewMut1<'_, T>; NOUT]) + 'static,
111+
{
112+
let wrap_func = [Some(wrap_func::<T, F, NIN, NOUT> as _)];
113+
114+
let r#type = T::get_npy_type().expect("universal function only work for built-in types");
115+
116+
let inputs = [r#type as _; NIN];
117+
let outputs = [r#type as _; NOUT];
118+
119+
let data = Data {
120+
func,
121+
wrap_func,
122+
name,
123+
inputs,
124+
outputs,
125+
};
126+
127+
let data = Box::leak(Box::new(data));
128+
129+
unsafe {
130+
Bound::from_owned_ptr_or_err(
131+
py,
132+
PY_UFUNC_API.PyUFunc_FromFuncAndData(
133+
py,
134+
data.wrap_func.as_mut_ptr(),
135+
data as *mut Data<F, NIN, NOUT> as *mut c_void as *mut *mut c_void,
136+
data.inputs.as_mut_ptr(),
137+
/* ntypes = */ 1,
138+
NIN as c_int,
139+
NOUT as c_int,
140+
identity as c_int,
141+
data.name.as_ptr(),
142+
/* doc = */ null_mut(),
143+
/* unused = */ 0,
144+
),
145+
)
146+
}
147+
}
148+
149+
#[repr(C)]
150+
struct Data<F, const NIN: usize, const NOUT: usize> {
151+
func: F,
152+
wrap_func: [PyUFuncGenericFunction; 1],
153+
name: CString,
154+
inputs: [c_char; NIN],
155+
outputs: [c_char; NOUT],
156+
}
157+
158+
unsafe extern "C" fn wrap_func<T, F, const NIN: usize, const NOUT: usize>(
159+
args: *mut *mut c_char,
160+
dims: *mut npy_intp,
161+
steps: *mut npy_intp,
162+
data: *mut c_void,
163+
) where
164+
F: Fn([ArrayView1<'_, T>; NIN], [ArrayViewMut1<'_, T>; NOUT]),
165+
{
166+
// TODO: Check aliasing requirements using the `borrow` module.
167+
168+
let mut inputs = MaybeUninit::<[ArrayView1<'_, T>; NIN]>::uninit();
169+
let inputs_ptr = inputs.as_mut_ptr() as *mut ArrayView1<'_, T>;
170+
171+
for i in 0..NIN {
172+
let (ptr, shape, invert) = unpack_arg(args, dims, steps, i);
173+
174+
let mut input = ArrayView1::from_shape_ptr(shape, ptr);
175+
if invert {
176+
input.invert_axis(Axis(0));
177+
}
178+
inputs_ptr.add(i).write(input);
179+
}
180+
181+
let mut outputs = MaybeUninit::<[ArrayViewMut1<'_, T>; NOUT]>::uninit();
182+
let outputs_ptr = outputs.as_mut_ptr() as *mut ArrayViewMut1<'_, T>;
183+
184+
for i in 0..NOUT {
185+
let (ptr, shape, invert) = unpack_arg(args, dims, steps, NIN + i);
186+
187+
let mut output = ArrayViewMut1::from_shape_ptr(shape, ptr);
188+
if invert {
189+
output.invert_axis(Axis(0));
190+
}
191+
outputs_ptr.add(i).write(output);
192+
}
193+
194+
let data = &*(data as *mut Data<F, NIN, NOUT>);
195+
(data.func)(inputs.assume_init(), outputs.assume_init());
196+
}
197+
198+
unsafe fn unpack_arg<T>(
199+
args: *mut *mut c_char,
200+
dims: *mut npy_intp,
201+
steps: *mut npy_intp,
202+
i: usize,
203+
) -> (*mut T, StrideShape<Ix1>, bool) {
204+
let dim = Dim([*dims as usize]);
205+
let itemsize = size_of::<T>();
206+
207+
let mut ptr = *args.add(i);
208+
let mut invert = false;
209+
210+
let step = *steps.add(i);
211+
212+
let step = if step >= 0 {
213+
Dim([step as usize / itemsize])
214+
} else {
215+
ptr = ptr.offset(step * (*dims - 1));
216+
invert = true;
217+
218+
Dim([(-step) as usize / itemsize])
219+
};
220+
221+
(ptr as *mut T, dim.strides(step), invert)
222+
}
223+
224+
#[cfg(test)]
225+
mod tests {
226+
use super::*;
227+
228+
use ndarray::azip;
229+
use pyo3::py_run;
230+
231+
#[test]
232+
fn from_func_handles_negative_strides() {
233+
Python::with_gil(|py| {
234+
let negate = from_func(
235+
py,
236+
CString::new("negate").unwrap(),
237+
Identity::None,
238+
|[x]: [ArrayView1<'_, f64>; 1], [y]: [ArrayViewMut1<'_, f64>; 1]| {
239+
azip!((x in x, y in y) *y = -x);
240+
},
241+
)
242+
.unwrap();
243+
244+
let np = py.import_bound("numpy").unwrap();
245+
246+
py_run!(py, negate np, "assert (negate(np.linspace(1.0, 10.0, 10)[::-1]) == np.linspace(-10.0, -1.0, 10)).all()");
247+
});
248+
}
249+
}

0 commit comments

Comments
 (0)