Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and refinements to dynamic borrow checking #302

Merged
merged 3 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

- Unreleased
- Add dynamic borrow checking to safely construct references into the interior of NumPy arrays. ([#274](https://github.com/PyO3/rust-numpy/pull/274))
- The deprecated iterator builders `NpySingleIterBuilder::{readonly,readwrite}` and `NpyMultiIterBuilder::add_{readonly,readwrite}` now take referencces to `PyReadonlyArray` and `PyReadwriteArray` instead of consuming them.
- The destructive `PyArray::resize` method is now unsafe if used without an instance of `PyReadwriteArray`. ([#302](https://github.com/PyO3/rust-numpy/pull/302))
- The `inner`, `dot` and `einsum` functions can also return a scalar instead of a zero-dimensional array to match NumPy's types ([#285](https://github.com/PyO3/rust-numpy/pull/285))
- Deprecate `PyArray::from_exact_iter` after optimizing `PyArray::from_iter`. ([#292](https://github.com/PyO3/rust-numpy/pull/292))

Expand Down
45 changes: 38 additions & 7 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
use crate::cold;
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::dtype::{Element, PyArrayDescr};
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
use crate::error::{BorrowError, DimensionalityError, FromVecError, NotContiguousError, TypeError};
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
use crate::slice_container::PySliceContainer;

Expand Down Expand Up @@ -846,13 +846,33 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
}

/// Get an immutable borrow of the NumPy array
pub fn try_readonly(&self) -> Result<PyReadonlyArray<'_, T, D>, BorrowError> {
PyReadonlyArray::try_new(self)
}

/// Get an immutable borrow of the NumPy array
///
/// # Panics
///
/// Panics if the allocation backing the array is currently mutably borrowed.
/// For a non-panicking variant, use [`try_readonly`][Self::try_readonly].
pub fn readonly(&self) -> PyReadonlyArray<'_, T, D> {
PyReadonlyArray::try_new(self).unwrap()
self.try_readonly().unwrap()
}

/// Get a mutable borrow of the NumPy array
pub fn try_readwrite(&self) -> Result<PyReadwriteArray<'_, T, D>, BorrowError> {
PyReadwriteArray::try_new(self)
}

/// Get a mutable borrow of the NumPy array
///
/// # Panics
///
/// Panics if the allocation backing the array is currently borrowed.
/// For a non-panicking variant, use [`try_readwrite`][Self::try_readwrite].
pub fn readwrite(&self) -> PyReadwriteArray<'_, T, D> {
PyReadwriteArray::try_new(self).unwrap()
self.try_readwrite().unwrap()
}

/// Returns the internal array as [`ArrayView`].
Expand Down Expand Up @@ -1057,19 +1077,30 @@ impl<T: Element> PyArray<T, Ix1> {
data.into_pyarray(py)
}

/// Extends or trancates the length of 1 dimension PyArray.
/// Extends or truncates the length of a one-dimensional array.
///
/// # Safety
///
/// There should be no outstanding references (shared or exclusive) into the array
/// as this method might re-allocate it and thereby invalidate all pointers into it.
///
/// # Example
///
/// ```
/// use numpy::PyArray;
/// pyo3::Python::with_gil(|py| {
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let pyarray = PyArray::arange(py, 0, 10, 1);
/// assert_eq!(pyarray.len(), 10);
/// pyarray.resize(100).unwrap();
///
/// unsafe {
/// pyarray.resize(100).unwrap();
/// }
/// assert_eq!(pyarray.len(), 100);
/// });
/// ```
pub fn resize(&self, new_elems: usize) -> PyResult<()> {
pub unsafe fn resize(&self, new_elems: usize) -> PyResult<()> {
self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER)
}

Expand Down
193 changes: 129 additions & 64 deletions src/borrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,86 @@ impl BorrowFlags {
unsafe fn get(&self) -> &mut HashMap<usize, isize> {
(*self.0.get()).get_or_insert_with(HashMap::new)
}

fn acquire<T, D>(&self, array: &PyArray<T, D>) -> Result<(), BorrowError> {
let address = base_address(array);

// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
// and we are not calling into user code which might re-enter this function.
let borrow_flags = unsafe { BORROW_FLAGS.get() };

match borrow_flags.entry(address) {
Entry::Occupied(entry) => {
let readers = entry.into_mut();

let new_readers = readers.wrapping_add(1);

if new_readers <= 0 {
cold();
return Err(BorrowError::AlreadyBorrowed);
}

*readers = new_readers;
}
Entry::Vacant(entry) => {
entry.insert(1);
}
}

Ok(())
}

fn release<T, D>(&self, array: &PyArray<T, D>) {
let address = base_address(array);

// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
// and we are not calling into user code which might re-enter this function.
let borrow_flags = unsafe { BORROW_FLAGS.get() };

let readers = borrow_flags.get_mut(&address).unwrap();

*readers -= 1;

if *readers == 0 {
borrow_flags.remove(&address).unwrap();
}
}

fn acquire_mut<T, D>(&self, array: &PyArray<T, D>) -> Result<(), BorrowError> {
let address = base_address(array);

// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
// and we are not calling into user code which might re-enter this function.
let borrow_flags = unsafe { BORROW_FLAGS.get() };

match borrow_flags.entry(address) {
Entry::Occupied(entry) => {
let writers = entry.into_mut();

if *writers != 0 {
cold();
return Err(BorrowError::AlreadyBorrowed);
}

*writers = -1;
}
Entry::Vacant(entry) => {
entry.insert(-1);
}
}

Ok(())
}

fn release_mut<T, D>(&self, array: &PyArray<T, D>) {
let address = base_address(array);

// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
// and we are not calling into user code which might re-enter this function.
let borrow_flags = unsafe { self.get() };

borrow_flags.remove(&address).unwrap();
}
}

static BORROW_FLAGS: BorrowFlags = BorrowFlags::new();
Expand Down Expand Up @@ -224,29 +304,7 @@ where
D: Dimension,
{
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Result<Self, BorrowError> {
let address = base_address(array);

// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
// and we are not calling into user code which might re-enter this function.
let borrow_flags = unsafe { BORROW_FLAGS.get() };

match borrow_flags.entry(address) {
Entry::Occupied(entry) => {
let readers = entry.into_mut();

let new_readers = readers.wrapping_add(1);

if new_readers <= 0 {
cold();
return Err(BorrowError::AlreadyBorrowed);
}

*readers = new_readers;
}
Entry::Vacant(entry) => {
entry.insert(1);
}
}
BORROW_FLAGS.acquire(array)?;

Ok(Self(array))
}
Expand Down Expand Up @@ -275,21 +333,19 @@ where
}
}

impl<'a, T, D> Clone for PyReadonlyArray<'a, T, D>
where
T: Element,
D: Dimension,
{
fn clone(&self) -> Self {
Self::try_new(self.0).unwrap()
}
}

impl<'a, T, D> Drop for PyReadonlyArray<'a, T, D> {
fn drop(&mut self) {
let address = base_address(self.0);

// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
// and we are not calling into user code which might re-enter this function.
let borrow_flags = unsafe { BORROW_FLAGS.get() };

let readers = borrow_flags.get_mut(&address).unwrap();

*readers -= 1;

if *readers == 0 {
borrow_flags.remove(&address).unwrap();
}
BORROW_FLAGS.release(self.0);
}
}

Expand Down Expand Up @@ -348,27 +404,7 @@ where
return Err(BorrowError::NotWriteable);
}

let address = base_address(array);

// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
// and we are not calling into user code which might re-enter this function.
let borrow_flags = unsafe { BORROW_FLAGS.get() };

match borrow_flags.entry(address) {
Entry::Occupied(entry) => {
let writers = entry.into_mut();

if *writers != 0 {
cold();
return Err(BorrowError::AlreadyBorrowed);
}

*writers = -1;
}
Entry::Vacant(entry) => {
entry.insert(-1);
}
}
BORROW_FLAGS.acquire_mut(array)?;

Ok(Self(array))
}
Expand Down Expand Up @@ -397,15 +433,44 @@ where
}
}

impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> {
fn drop(&mut self) {
let address = base_address(self.0);
impl<'py, T> PyReadwriteArray<'py, T, Ix1>
where
T: Element,
{
/// Extends or truncates the length of a one-dimensional array.
///
/// # Example
///
/// ```
/// use numpy::PyArray;
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let pyarray = PyArray::arange(py, 0, 10, 1);
/// assert_eq!(pyarray.len(), 10);
///
/// let pyarray = pyarray.readwrite();
/// let pyarray = pyarray.resize(100).unwrap();
/// assert_eq!(pyarray.len(), 100);
/// });
/// ```
pub fn resize(self, new_elems: usize) -> PyResult<Self> {
BORROW_FLAGS.release_mut(self.0);

// SAFETY: Ownership of `self` proves exclusive access to the interior of the array.
unsafe {
self.0.resize(new_elems)?;
}

// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
// and we are not calling into user code which might re-enter this function.
let borrow_flags = unsafe { BORROW_FLAGS.get() };
BORROW_FLAGS.acquire_mut(self.0)?;

borrow_flags.remove(&address).unwrap();
Ok(self)
}
}

impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> {
fn drop(&mut self) {
BORROW_FLAGS.release_mut(self.0);
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ use crate::sealed::Sealed;
/// assert_eq!(py_array.readonly().as_slice().unwrap(), &[1, 2, 3]);
///
/// // Array cannot be resized when its data is owned by Rust.
/// assert!(py_array.resize(100).is_err());
/// unsafe {
/// assert!(py_array.resize(100).is_err());
/// }
/// });
/// ```
pub trait IntoPyArray {
Expand Down
27 changes: 27 additions & 0 deletions tests/borrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ fn borrows_span_threads() {
});
}

#[test]
fn shared_borrows_can_be_cloned() {
Python::with_gil(|py| {
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);

let shared1 = array.readonly();
let shared2 = shared1.clone();

assert_eq!(shared2.shape(), [1, 2, 3]);
assert_eq!(shared1.shape(), [1, 2, 3]);
});
}

#[test]
#[should_panic(expected = "AlreadyBorrowed")]
fn overlapping_views_conflict() {
Expand Down Expand Up @@ -235,3 +248,17 @@ fn readwrite_as_array_slice() {
assert_eq!(*array.get_mut([0, 1, 2]).unwrap(), 0.0);
});
}

#[test]
fn resize_using_exclusive_borrow() {
Python::with_gil(|py| {
let array = PyArray::<f64, _>::zeros(py, 3, false);
assert_eq!(array.shape(), [3]);

let mut array = array.readwrite();
assert_eq!(array.as_slice_mut().unwrap(), &[0.0; 3]);

let mut array = array.resize(5).unwrap();
assert_eq!(array.as_slice_mut().unwrap(), &[0.0; 5]);
});
}
4 changes: 3 additions & 1 deletion tests/to_py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ fn into_pyarray_cannot_resize() {
Python::with_gil(|py| {
let arr = vec![1, 2, 3].into_pyarray(py);

assert!(arr.resize(100).is_err())
unsafe {
assert!(arr.resize(100).is_err());
}
});
}

Expand Down