Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyArray::borrow_from_array to expose array data with ownership being tied to anyother Python object #230

Merged
merged 1 commit into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@ use pyo3::{
ffi, prelude::*, type_object, types::PyAny, AsPyPointer, PyDowncastError, PyNativeType,
PyResult,
};
use std::{cell::Cell, mem, os::raw::c_int, ptr, slice};
use std::{
cell::Cell,
mem,
os::raw::{c_int, c_void},
ptr, slice,
};
use std::{iter::ExactSizeIterator, marker::PhantomData};

use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::dtype::{DataType, Element};
use crate::error::{FromVecError, NotContiguousError, ShapeError};
use crate::slice_box::SliceBox;
Expand Down Expand Up @@ -468,6 +473,65 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
Self::from_owned_ptr(py, ptr)
}

/// Creates a NumPy array backed by `array` and ties its ownership to the Python object `owner`.
///
/// # Safety
///
/// `owner` is set as a base object of the returned array which must not be dropped until `owner` is dropped.
/// Furthermore, `array` must not be reallocated from the time this method is called and until `owner` is dropped.
///
/// # Example
///
/// ```rust
/// # use pyo3::prelude::*;
/// # use numpy::{ndarray::Array1, PyArray1};
/// #
/// #[pyclass]
/// struct Owner {
/// array: Array1<f64>,
/// }
///
/// #[pymethods]
/// impl Owner {
/// #[getter]
/// fn array<'py>(this: &'py PyCell<Self>) -> &'py PyArray1<f64> {
/// let array = &this.borrow().array;
///
/// // SAFETY: The memory backing `array` will stay valid as long as this object is alive
/// // as we do not modify `array` in any way which would cause it to be reallocated.
/// unsafe { PyArray1::borrow_from_array(array, this) }
/// }
/// }
/// ```
pub unsafe fn borrow_from_array<'py, S>(array: &ArrayBase<S, D>, owner: &'py PyAny) -> &'py Self
where
S: Data<Elem = T>,
{
let (strides, dims) = (array.npy_strides(), array.raw_dim());
let data_ptr = array.as_ptr();

let ptr = PY_ARRAY_API.PyArray_New(
PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
dims.ndim_cint(),
dims.as_dims_ptr(),
T::npy_type() as c_int,
strides.as_ptr() as *mut npy_intp, // strides
data_ptr as *mut c_void, // data
mem::size_of::<T>() as c_int, // itemsize
0, // flag
ptr::null_mut(), // obj
);

mem::forget(owner.to_object(owner.py()));

PY_ARRAY_API.PyArray_SetBaseObject(
ptr as *mut npyffi::PyArrayObject,
owner as *const PyAny as *mut PyAny as *mut ffi::PyObject,
);

Self::from_owned_ptr(owner.py(), ptr)
}

/// Construct a new nd-dimensional array filled with 0.
///
/// If `is_fortran` is true, then
Expand Down
8 changes: 4 additions & 4 deletions src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ where
}
}

enum Order {
pub(crate) enum Order {
Standard,
Fortran,
}
Expand All @@ -172,7 +172,7 @@ impl Order {
}
}

trait ArrayExt {
pub(crate) trait ArrayExt {
fn npy_strides(&self) -> NpyStrides;
fn order(&self) -> Option<Order>;
}
Expand Down Expand Up @@ -201,13 +201,13 @@ where
}

/// Numpy strides with short array optimization
enum NpyStrides {
pub(crate) enum NpyStrides {
Short([npyffi::npy_intp; 8]),
Long(Vec<npyffi::npy_intp>),
}

impl NpyStrides {
fn as_ptr(&self) -> *const npy_intp {
pub(crate) fn as_ptr(&self) -> *const npy_intp {
match self {
NpyStrides::Short(inner) => inner.as_ptr(),
NpyStrides::Long(inner) => inner.as_ptr(),
Expand Down
37 changes: 37 additions & 0 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,40 @@ fn dtype_from_py() {
assert_eq!(dtype.get_datatype().unwrap(), numpy::DataType::Uint32);
})
}

#[test]
fn borrow_from_array() {
use numpy::ndarray::Array1;
use pyo3::py_run;

#[pyclass]
struct Owner {
array: Array1<f64>,
}

#[pymethods]
impl Owner {
#[getter]
fn array<'py>(this: &'py PyCell<Self>) -> &'py PyArray1<f64> {
let array = &this.borrow().array;

unsafe { PyArray1::borrow_from_array(array, this) }
}
}

let array = Python::with_gil(|py| {
let owner = Py::new(
py,
Owner {
array: Array1::linspace(0., 1., 10),
},
)
.unwrap();

owner.getattr(py, "array").unwrap()
});

Python::with_gil(|py| {
py_run!(py, array, "assert array.shape == (10,)");
});
}