Skip to content

Commit 2276b54

Browse files
committed
First step towards safely creating universal functions in Rust.
1 parent 74a32b4 commit 2276b54

File tree

4 files changed

+263
-0
lines changed

4 files changed

+263
-0
lines changed

src/dtype.rs

+9
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,11 @@ pub unsafe trait Element: Clone + Send {
402402

403403
/// Returns the associated type descriptor ("dtype") for the given element type.
404404
fn get_dtype<'py>(py: Python<'py>) -> &'py PyArrayDescr;
405+
406+
/// TODO
407+
fn get_npy_type() -> Option<NPY_TYPES> {
408+
None
409+
}
405410
}
406411

407412
fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
@@ -458,6 +463,10 @@ macro_rules! impl_element_scalar {
458463
fn get_dtype<'py>(py: Python<'py>) -> &'py PyArrayDescr {
459464
PyArrayDescr::from_npy_type(py, $npy_type)
460465
}
466+
467+
fn get_npy_type() -> Option<NPY_TYPES> {
468+
Some($npy_type)
469+
}
461470
}
462471
};
463472
($ty:ty => $npy_type:ident $(,#[$meta:meta])*) => {

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ pub mod npyffi;
8383
mod slice_container;
8484
mod strings;
8585
mod sum_products;
86+
pub mod ufunc;
8687
mod untyped_array;
8788

8889
pub use ndarray;

src/npyffi/flags.rs

+7
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,10 @@ pub const NPY_OBJECT_DTYPE_FLAGS: npy_char = NPY_LIST_PICKLE
8181
| NPY_ITEM_REFCOUNT
8282
| NPY_NEEDS_INIT
8383
| NPY_NEEDS_PYAPI;
84+
85+
pub const NPY_UFUNC_ZERO: c_int = 0;
86+
pub const NPY_UFUNC_ONE: c_int = 1;
87+
pub const NPY_UFUNC_MINUS_ONE: c_int = 2;
88+
pub const NPY_UFUNC_NONE: c_int = -1;
89+
pub const NPY_UFUNC_REORDERABLE_NONE: c_int = -2;
90+
pub const NPY_UFUNC_IDENTITY_VALUE: c_int = -3;

src/ufunc.rs

+246
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
//! TODO
2+
3+
use std::ffi::CString;
4+
use std::mem::{size_of, MaybeUninit};
5+
use std::os::raw::{c_char, c_int, c_void};
6+
use std::ptr::null_mut;
7+
8+
use ndarray::{ArrayView1, ArrayViewMut1, Axis, Dim, Ix1, ShapeBuilder, StrideShape};
9+
use pyo3::{PyAny, PyResult, Python};
10+
11+
use crate::{
12+
dtype::Element,
13+
npyffi::{flags, npy_intp, objects::PyUFuncGenericFunction, ufunc::PY_UFUNC_API},
14+
};
15+
16+
/// TODO
17+
#[repr(i32)]
18+
#[derive(Debug)]
19+
pub enum Identity {
20+
/// UFunc has unit of 0, and the order of operations can be reordered
21+
/// This case allows reduction with multiple axes at once.
22+
Zero = flags::NPY_UFUNC_ZERO,
23+
/// UFunc has unit of 1, and the order of operations can be reordered
24+
/// This case allows reduction with multiple axes at once.
25+
One = flags::NPY_UFUNC_ONE,
26+
/// UFunc has unit of -1, and the order of operations can be reordered
27+
/// This case allows reduction with multiple axes at once. Intended for bitwise_and reduction.
28+
MinusOne = flags::NPY_UFUNC_MINUS_ONE,
29+
/// UFunc has no unit, and the order of operations cannot be reordered.
30+
/// This case does not allow reduction with multiple axes at once.
31+
None = flags::NPY_UFUNC_NONE,
32+
/// UFunc has no unit, and the order of operations can be reordered
33+
/// This case allows reduction with multiple axes at once.
34+
ReorderableNone = flags::NPY_UFUNC_REORDERABLE_NONE,
35+
/// UFunc unit is an identity_value, and the order of operations can be reordered
36+
/// This case allows reduction with multiple axes at once.
37+
IdentityValue = flags::NPY_UFUNC_IDENTITY_VALUE,
38+
}
39+
40+
/// TODO
41+
///
42+
/// ```
43+
/// # #![allow(mixed_script_confusables)]
44+
/// # use std::ffi::CString;
45+
/// #
46+
/// use pyo3::{py_run, Python};
47+
/// use ndarray::{azip, ArrayView1, ArrayViewMut1};
48+
/// use numpy::ufunc::{from_func, Identity};
49+
///
50+
/// Python::with_gil(|py| {
51+
/// let logit = |[p]: [ArrayView1<'_, f64>; 1], [α]: [ArrayViewMut1<'_, f64>; 1]| {
52+
/// azip!((p in p, α in α) {
53+
/// let mut tmp = *p;
54+
/// tmp /= 1.0 - tmp;
55+
/// *α = tmp.ln();
56+
/// });
57+
/// };
58+
///
59+
/// let logit =
60+
/// from_func(py, CString::new("logit").unwrap(), Identity::None, logit).unwrap();
61+
///
62+
/// py_run!(py, logit, "assert logit(0.5) == 0.0");
63+
///
64+
/// let np = py.import("numpy").unwrap();
65+
///
66+
/// py_run!(py, logit np, "assert (logit(np.full(100, 0.5)) == np.zeros(100)).all()");
67+
/// });
68+
/// ```
69+
///
70+
/// ```
71+
/// # #![allow(mixed_script_confusables)]
72+
/// # use std::ffi::CString;
73+
/// #
74+
/// use pyo3::{py_run, Python};
75+
/// use ndarray::{azip, ArrayView1, ArrayViewMut1};
76+
/// use numpy::ufunc::{from_func, Identity};
77+
///
78+
/// Python::with_gil(|py| {
79+
/// let cart_to_polar = |[x, y]: [ArrayView1<'_, f64>; 2], [r, φ]: [ArrayViewMut1<'_, f64>; 2]| {
80+
/// azip!((&x in x, &y in y, r in r, φ in φ) {
81+
/// *r = f64::hypot(x, y);
82+
/// *φ = f64::atan2(x, y);
83+
/// });
84+
/// };
85+
///
86+
/// let cart_to_polar = from_func(
87+
/// py,
88+
/// CString::new("cart_to_polar").unwrap(),
89+
/// Identity::None,
90+
/// cart_to_polar,
91+
/// )
92+
/// .unwrap();
93+
///
94+
/// let np = py.import("numpy").unwrap();
95+
///
96+
/// py_run!(py, cart_to_polar np, "np.testing.assert_array_almost_equal(cart_to_polar(3.0, 4.0), (5.0, 0.643501))");
97+
///
98+
/// py_run!(py, cart_to_polar np, "np.testing.assert_array_almost_equal(cart_to_polar(np.full((10, 10), 3.0), np.full((10, 10), 4.0))[0], np.full((10, 10), 5.0))");
99+
/// py_run!(py, cart_to_polar np, "np.testing.assert_array_almost_equal(cart_to_polar(np.full((10, 10), 3.0), np.full((10, 10), 4.0))[1], np.full((10, 10), 0.643501))");
100+
/// });
101+
/// ```
102+
pub fn from_func<'py, T, F, const NIN: usize, const NOUT: usize>(
103+
py: Python<'py>,
104+
name: CString,
105+
identity: Identity,
106+
func: F,
107+
) -> PyResult<&'py PyAny>
108+
where
109+
T: Element,
110+
F: Fn([ArrayView1<'_, T>; NIN], [ArrayViewMut1<'_, T>; NOUT]) + 'static,
111+
{
112+
let wrap_func = [Some(wrap_func::<T, F, NIN, NOUT> as _)];
113+
114+
let r#type = T::get_npy_type().expect("universal function only work for built-in types");
115+
116+
let inputs = [r#type as _; NIN];
117+
let outputs = [r#type as _; NOUT];
118+
119+
let data = Data {
120+
func,
121+
wrap_func,
122+
name,
123+
inputs,
124+
outputs,
125+
};
126+
127+
let data = Box::leak(Box::new(data));
128+
129+
unsafe {
130+
py.from_owned_ptr_or_err(PY_UFUNC_API.PyUFunc_FromFuncAndData(
131+
py,
132+
data.wrap_func.as_mut_ptr(),
133+
data as *mut Data<F, NIN, NOUT> as *mut c_void as *mut *mut c_void,
134+
data.inputs.as_mut_ptr(),
135+
/* ntypes = */ 1,
136+
NIN as c_int,
137+
NOUT as c_int,
138+
identity as c_int,
139+
data.name.as_ptr(),
140+
/* doc = */ null_mut(),
141+
/* unused = */ 0,
142+
))
143+
}
144+
}
145+
146+
#[repr(C)]
147+
struct Data<F, const NIN: usize, const NOUT: usize> {
148+
func: F,
149+
wrap_func: [PyUFuncGenericFunction; 1],
150+
name: CString,
151+
inputs: [c_char; NIN],
152+
outputs: [c_char; NOUT],
153+
}
154+
155+
unsafe extern "C" fn wrap_func<T, F, const NIN: usize, const NOUT: usize>(
156+
args: *mut *mut c_char,
157+
dims: *mut npy_intp,
158+
steps: *mut npy_intp,
159+
data: *mut c_void,
160+
) where
161+
F: Fn([ArrayView1<'_, T>; NIN], [ArrayViewMut1<'_, T>; NOUT]),
162+
{
163+
// TODO: Check aliasing requirements using the `borrow` module.
164+
165+
let mut inputs = MaybeUninit::<[ArrayView1<'_, T>; NIN]>::uninit();
166+
let inputs_ptr = inputs.as_mut_ptr() as *mut ArrayView1<'_, T>;
167+
168+
for i in 0..NIN {
169+
let (ptr, shape, invert) = unpack_arg(args, dims, steps, i);
170+
171+
let mut input = ArrayView1::from_shape_ptr(shape, ptr);
172+
if invert {
173+
input.invert_axis(Axis(0));
174+
}
175+
inputs_ptr.add(i).write(input);
176+
}
177+
178+
let mut outputs = MaybeUninit::<[ArrayViewMut1<'_, T>; NOUT]>::uninit();
179+
let outputs_ptr = outputs.as_mut_ptr() as *mut ArrayViewMut1<'_, T>;
180+
181+
for i in 0..NOUT {
182+
let (ptr, shape, invert) = unpack_arg(args, dims, steps, NIN + i);
183+
184+
let mut output = ArrayViewMut1::from_shape_ptr(shape, ptr);
185+
if invert {
186+
output.invert_axis(Axis(0));
187+
}
188+
outputs_ptr.add(i).write(output);
189+
}
190+
191+
let data = &*(data as *mut Data<F, NIN, NOUT>);
192+
(data.func)(inputs.assume_init(), outputs.assume_init());
193+
}
194+
195+
unsafe fn unpack_arg<T>(
196+
args: *mut *mut c_char,
197+
dims: *mut npy_intp,
198+
steps: *mut npy_intp,
199+
i: usize,
200+
) -> (*mut T, StrideShape<Ix1>, bool) {
201+
let dim = Dim([*dims as usize]);
202+
let itemsize = size_of::<T>();
203+
204+
let mut ptr = *args.add(i);
205+
let mut invert = false;
206+
207+
let step = *steps.add(i);
208+
209+
let step = if step >= 0 {
210+
Dim([step as usize / itemsize])
211+
} else {
212+
ptr = ptr.offset(step * (*dims - 1));
213+
invert = true;
214+
215+
Dim([(-step) as usize / itemsize])
216+
};
217+
218+
(ptr as *mut T, dim.strides(step), invert)
219+
}
220+
221+
#[cfg(test)]
222+
mod tests {
223+
use super::*;
224+
225+
use ndarray::azip;
226+
use pyo3::py_run;
227+
228+
#[test]
229+
fn from_func_handles_negative_strides() {
230+
Python::with_gil(|py| {
231+
let negate = from_func(
232+
py,
233+
CString::new("negate").unwrap(),
234+
Identity::None,
235+
|[x]: [ArrayView1<'_, f64>; 1], [y]: [ArrayViewMut1<'_, f64>; 1]| {
236+
azip!((x in x, y in y) *y = -x);
237+
},
238+
)
239+
.unwrap();
240+
241+
let np = py.import("numpy").unwrap();
242+
243+
py_run!(py, negate np, "assert (negate(np.linspace(1.0, 10.0, 10)[::-1]) == np.linspace(-10.0, -1.0, 10)).all()");
244+
});
245+
}
246+
}

0 commit comments

Comments
 (0)