|
| 1 | +//! TODO |
| 2 | +
|
| 3 | +use std::ffi::CString; |
| 4 | +use std::mem::{size_of, MaybeUninit}; |
| 5 | +use std::os::raw::{c_char, c_int, c_void}; |
| 6 | +use std::ptr::null_mut; |
| 7 | + |
| 8 | +use ndarray::{ArrayView1, ArrayViewMut1, Axis, Dim, Ix1, ShapeBuilder, StrideShape}; |
| 9 | +use pyo3::{Bound, PyAny, PyResult, Python}; |
| 10 | + |
| 11 | +use crate::{ |
| 12 | + dtype::Element, |
| 13 | + npyffi::{flags, npy_intp, objects::PyUFuncGenericFunction, ufunc::PY_UFUNC_API}, |
| 14 | +}; |
| 15 | + |
| 16 | +/// TODO |
| 17 | +#[repr(i32)] |
| 18 | +#[derive(Debug)] |
| 19 | +pub enum Identity { |
| 20 | + /// UFunc has unit of 0, and the order of operations can be reordered |
| 21 | + /// This case allows reduction with multiple axes at once. |
| 22 | + Zero = flags::NPY_UFUNC_ZERO, |
| 23 | + /// UFunc has unit of 1, and the order of operations can be reordered |
| 24 | + /// This case allows reduction with multiple axes at once. |
| 25 | + One = flags::NPY_UFUNC_ONE, |
| 26 | + /// UFunc has unit of -1, and the order of operations can be reordered |
| 27 | + /// This case allows reduction with multiple axes at once. Intended for bitwise_and reduction. |
| 28 | + MinusOne = flags::NPY_UFUNC_MINUS_ONE, |
| 29 | + /// UFunc has no unit, and the order of operations cannot be reordered. |
| 30 | + /// This case does not allow reduction with multiple axes at once. |
| 31 | + None = flags::NPY_UFUNC_NONE, |
| 32 | + /// UFunc has no unit, and the order of operations can be reordered |
| 33 | + /// This case allows reduction with multiple axes at once. |
| 34 | + ReorderableNone = flags::NPY_UFUNC_REORDERABLE_NONE, |
| 35 | + /// UFunc unit is an identity_value, and the order of operations can be reordered |
| 36 | + /// This case allows reduction with multiple axes at once. |
| 37 | + IdentityValue = flags::NPY_UFUNC_IDENTITY_VALUE, |
| 38 | +} |
| 39 | + |
| 40 | +/// TODO |
| 41 | +/// |
| 42 | +/// ``` |
| 43 | +/// # #![allow(mixed_script_confusables)] |
| 44 | +/// # use std::ffi::CString; |
| 45 | +/// # |
| 46 | +/// use pyo3::{py_run, Python}; |
| 47 | +/// use ndarray::{azip, ArrayView1, ArrayViewMut1}; |
| 48 | +/// use numpy::ufunc::{from_func, Identity}; |
| 49 | +/// |
| 50 | +/// Python::with_gil(|py| { |
| 51 | +/// let logit = |[p]: [ArrayView1<'_, f64>; 1], [α]: [ArrayViewMut1<'_, f64>; 1]| { |
| 52 | +/// azip!((p in p, α in α) { |
| 53 | +/// let mut tmp = *p; |
| 54 | +/// tmp /= 1.0 - tmp; |
| 55 | +/// *α = tmp.ln(); |
| 56 | +/// }); |
| 57 | +/// }; |
| 58 | +/// |
| 59 | +/// let logit = |
| 60 | +/// from_func(py, CString::new("logit").unwrap(), Identity::None, logit).unwrap(); |
| 61 | +/// |
| 62 | +/// py_run!(py, logit, "assert logit(0.5) == 0.0"); |
| 63 | +/// |
| 64 | +/// let np = py.import("numpy").unwrap(); |
| 65 | +/// |
| 66 | +/// py_run!(py, logit np, "assert (logit(np.full(100, 0.5)) == np.zeros(100)).all()"); |
| 67 | +/// }); |
| 68 | +/// ``` |
| 69 | +/// |
| 70 | +/// ``` |
| 71 | +/// # #![allow(mixed_script_confusables)] |
| 72 | +/// # use std::ffi::CString; |
| 73 | +/// # |
| 74 | +/// use pyo3::{py_run, Python}; |
| 75 | +/// use ndarray::{azip, ArrayView1, ArrayViewMut1}; |
| 76 | +/// use numpy::ufunc::{from_func, Identity}; |
| 77 | +/// |
| 78 | +/// Python::with_gil(|py| { |
| 79 | +/// let cart_to_polar = |[x, y]: [ArrayView1<'_, f64>; 2], [r, φ]: [ArrayViewMut1<'_, f64>; 2]| { |
| 80 | +/// azip!((&x in x, &y in y, r in r, φ in φ) { |
| 81 | +/// *r = f64::hypot(x, y); |
| 82 | +/// *φ = f64::atan2(x, y); |
| 83 | +/// }); |
| 84 | +/// }; |
| 85 | +/// |
| 86 | +/// let cart_to_polar = from_func( |
| 87 | +/// py, |
| 88 | +/// CString::new("cart_to_polar").unwrap(), |
| 89 | +/// Identity::None, |
| 90 | +/// cart_to_polar, |
| 91 | +/// ) |
| 92 | +/// .unwrap(); |
| 93 | +/// |
| 94 | +/// let np = py.import("numpy").unwrap(); |
| 95 | +/// |
| 96 | +/// py_run!(py, cart_to_polar np, "np.testing.assert_array_almost_equal(cart_to_polar(3.0, 4.0), (5.0, 0.643501))"); |
| 97 | +/// |
| 98 | +/// py_run!(py, cart_to_polar np, "np.testing.assert_array_almost_equal(cart_to_polar(np.full((10, 10), 3.0), np.full((10, 10), 4.0))[0], np.full((10, 10), 5.0))"); |
| 99 | +/// py_run!(py, cart_to_polar np, "np.testing.assert_array_almost_equal(cart_to_polar(np.full((10, 10), 3.0), np.full((10, 10), 4.0))[1], np.full((10, 10), 0.643501))"); |
| 100 | +/// }); |
| 101 | +/// ``` |
| 102 | +pub fn from_func<'py, T, F, const NIN: usize, const NOUT: usize>( |
| 103 | + py: Python<'py>, |
| 104 | + name: CString, |
| 105 | + identity: Identity, |
| 106 | + func: F, |
| 107 | +) -> PyResult<Bound<'py, PyAny>> |
| 108 | +where |
| 109 | + T: Element, |
| 110 | + F: Fn([ArrayView1<'_, T>; NIN], [ArrayViewMut1<'_, T>; NOUT]) + 'static, |
| 111 | +{ |
| 112 | + let wrap_func = [Some(wrap_func::<T, F, NIN, NOUT> as _)]; |
| 113 | + |
| 114 | + let r#type = T::get_npy_type().expect("universal function only work for built-in types"); |
| 115 | + |
| 116 | + let inputs = [r#type as _; NIN]; |
| 117 | + let outputs = [r#type as _; NOUT]; |
| 118 | + |
| 119 | + let data = Data { |
| 120 | + func, |
| 121 | + wrap_func, |
| 122 | + name, |
| 123 | + inputs, |
| 124 | + outputs, |
| 125 | + }; |
| 126 | + |
| 127 | + let data = Box::leak(Box::new(data)); |
| 128 | + |
| 129 | + unsafe { |
| 130 | + Bound::from_owned_ptr_or_err( |
| 131 | + py, |
| 132 | + PY_UFUNC_API.PyUFunc_FromFuncAndData( |
| 133 | + py, |
| 134 | + data.wrap_func.as_mut_ptr(), |
| 135 | + data as *mut Data<F, NIN, NOUT> as *mut c_void as *mut *mut c_void, |
| 136 | + data.inputs.as_mut_ptr(), |
| 137 | + /* ntypes = */ 1, |
| 138 | + NIN as c_int, |
| 139 | + NOUT as c_int, |
| 140 | + identity as c_int, |
| 141 | + data.name.as_ptr(), |
| 142 | + /* doc = */ null_mut(), |
| 143 | + /* unused = */ 0, |
| 144 | + ), |
| 145 | + ) |
| 146 | + } |
| 147 | +} |
| 148 | + |
| 149 | +#[repr(C)] |
| 150 | +struct Data<F, const NIN: usize, const NOUT: usize> { |
| 151 | + func: F, |
| 152 | + wrap_func: [PyUFuncGenericFunction; 1], |
| 153 | + name: CString, |
| 154 | + inputs: [c_char; NIN], |
| 155 | + outputs: [c_char; NOUT], |
| 156 | +} |
| 157 | + |
| 158 | +unsafe extern "C" fn wrap_func<T, F, const NIN: usize, const NOUT: usize>( |
| 159 | + args: *mut *mut c_char, |
| 160 | + dims: *mut npy_intp, |
| 161 | + steps: *mut npy_intp, |
| 162 | + data: *mut c_void, |
| 163 | +) where |
| 164 | + F: Fn([ArrayView1<'_, T>; NIN], [ArrayViewMut1<'_, T>; NOUT]), |
| 165 | +{ |
| 166 | + // TODO: Check aliasing requirements using the `borrow` module. |
| 167 | + |
| 168 | + let mut inputs = MaybeUninit::<[ArrayView1<'_, T>; NIN]>::uninit(); |
| 169 | + let inputs_ptr = inputs.as_mut_ptr() as *mut ArrayView1<'_, T>; |
| 170 | + |
| 171 | + for i in 0..NIN { |
| 172 | + let (ptr, shape, invert) = unpack_arg(args, dims, steps, i); |
| 173 | + |
| 174 | + let mut input = ArrayView1::from_shape_ptr(shape, ptr); |
| 175 | + if invert { |
| 176 | + input.invert_axis(Axis(0)); |
| 177 | + } |
| 178 | + inputs_ptr.add(i).write(input); |
| 179 | + } |
| 180 | + |
| 181 | + let mut outputs = MaybeUninit::<[ArrayViewMut1<'_, T>; NOUT]>::uninit(); |
| 182 | + let outputs_ptr = outputs.as_mut_ptr() as *mut ArrayViewMut1<'_, T>; |
| 183 | + |
| 184 | + for i in 0..NOUT { |
| 185 | + let (ptr, shape, invert) = unpack_arg(args, dims, steps, NIN + i); |
| 186 | + |
| 187 | + let mut output = ArrayViewMut1::from_shape_ptr(shape, ptr); |
| 188 | + if invert { |
| 189 | + output.invert_axis(Axis(0)); |
| 190 | + } |
| 191 | + outputs_ptr.add(i).write(output); |
| 192 | + } |
| 193 | + |
| 194 | + let data = &*(data as *mut Data<F, NIN, NOUT>); |
| 195 | + (data.func)(inputs.assume_init(), outputs.assume_init()); |
| 196 | +} |
| 197 | + |
| 198 | +unsafe fn unpack_arg<T>( |
| 199 | + args: *mut *mut c_char, |
| 200 | + dims: *mut npy_intp, |
| 201 | + steps: *mut npy_intp, |
| 202 | + i: usize, |
| 203 | +) -> (*mut T, StrideShape<Ix1>, bool) { |
| 204 | + let dim = Dim([*dims as usize]); |
| 205 | + let itemsize = size_of::<T>(); |
| 206 | + |
| 207 | + let mut ptr = *args.add(i); |
| 208 | + let mut invert = false; |
| 209 | + |
| 210 | + let step = *steps.add(i); |
| 211 | + |
| 212 | + let step = if step >= 0 { |
| 213 | + Dim([step as usize / itemsize]) |
| 214 | + } else { |
| 215 | + ptr = ptr.offset(step * (*dims - 1)); |
| 216 | + invert = true; |
| 217 | + |
| 218 | + Dim([(-step) as usize / itemsize]) |
| 219 | + }; |
| 220 | + |
| 221 | + (ptr as *mut T, dim.strides(step), invert) |
| 222 | +} |
| 223 | + |
| 224 | +#[cfg(test)] |
| 225 | +mod tests { |
| 226 | + use super::*; |
| 227 | + |
| 228 | + use ndarray::azip; |
| 229 | + use pyo3::py_run; |
| 230 | + |
| 231 | + #[test] |
| 232 | + fn from_func_handles_negative_strides() { |
| 233 | + Python::with_gil(|py| { |
| 234 | + let negate = from_func( |
| 235 | + py, |
| 236 | + CString::new("negate").unwrap(), |
| 237 | + Identity::None, |
| 238 | + |[x]: [ArrayView1<'_, f64>; 1], [y]: [ArrayViewMut1<'_, f64>; 1]| { |
| 239 | + azip!((x in x, y in y) *y = -x); |
| 240 | + }, |
| 241 | + ) |
| 242 | + .unwrap(); |
| 243 | + |
| 244 | + let np = py.import_bound("numpy").unwrap(); |
| 245 | + |
| 246 | + py_run!(py, negate np, "assert (negate(np.linspace(1.0, 10.0, 10)[::-1]) == np.linspace(-10.0, -1.0, 10)).all()"); |
| 247 | + }); |
| 248 | + } |
| 249 | +} |
0 commit comments