Skip to content

Commit b81fbe5

Browse files
committed
Make PyArray::resize unsafe as it can invalidate existing pointers into the array.
1 parent b63f11c commit b81fbe5

File tree

6 files changed

+156
-70
lines changed

6 files changed

+156
-70
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
- Unreleased
44
- Add dynamic borrow checking to safely construct references into the interior of NumPy arrays. ([#274](https://github.com/PyO3/rust-numpy/pull/274))
5+
- The deprecated iterator builders `NpySingleIterBuilder::{readonly,readwrite}` and `NpyMultiIterBuilder::add_{readonly,readwrite}` now take referencces to `PyReadonlyArray` and `PyReadwriteArray` instead of consuming them.
6+
- The destructive `PyArray::resize` method is now unsafe if used without an instance of `PyReadwriteArray`. ([#302](https://github.com/PyO3/rust-numpy/pull/302))
57
- Deprecate `PyArray::from_exact_iter` after optimizing `PyArray::from_iter`. ([#292](https://github.com/PyO3/rust-numpy/pull/292))
68

79
- v0.16.2

src/array.rs

+15-4
Original file line numberDiff line numberDiff line change
@@ -1077,19 +1077,30 @@ impl<T: Element> PyArray<T, Ix1> {
10771077
data.into_pyarray(py)
10781078
}
10791079

1080-
/// Extends or trancates the length of 1 dimension PyArray.
1080+
/// Extends or truncates the length of a one-dimensional array.
1081+
///
1082+
/// # Safety
1083+
///
1084+
/// There should be no outstanding references (shared or exclusive) into the array
1085+
/// as this method might re-allocate it and thereby invalidate all pointers into it.
10811086
///
10821087
/// # Example
1088+
///
10831089
/// ```
10841090
/// use numpy::PyArray;
1085-
/// pyo3::Python::with_gil(|py| {
1091+
/// use pyo3::Python;
1092+
///
1093+
/// Python::with_gil(|py| {
10861094
/// let pyarray = PyArray::arange(py, 0, 10, 1);
10871095
/// assert_eq!(pyarray.len(), 10);
1088-
/// pyarray.resize(100).unwrap();
1096+
///
1097+
/// unsafe {
1098+
/// pyarray.resize(100).unwrap();
1099+
/// }
10891100
/// assert_eq!(pyarray.len(), 100);
10901101
/// });
10911102
/// ```
1092-
pub fn resize(&self, new_elems: usize) -> PyResult<()> {
1103+
pub unsafe fn resize(&self, new_elems: usize) -> PyResult<()> {
10931104
self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER)
10941105
}
10951106

src/borrow.rs

+119-64
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,86 @@ impl BorrowFlags {
170170
unsafe fn get(&self) -> &mut HashMap<usize, isize> {
171171
(*self.0.get()).get_or_insert_with(HashMap::new)
172172
}
173+
174+
fn acquire<T, D>(&self, array: &PyArray<T, D>) -> Result<(), BorrowError> {
175+
let address = base_address(array);
176+
177+
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
178+
// and we are not calling into user code which might re-enter this function.
179+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
180+
181+
match borrow_flags.entry(address) {
182+
Entry::Occupied(entry) => {
183+
let readers = entry.into_mut();
184+
185+
let new_readers = readers.wrapping_add(1);
186+
187+
if new_readers <= 0 {
188+
cold();
189+
return Err(BorrowError::AlreadyBorrowed);
190+
}
191+
192+
*readers = new_readers;
193+
}
194+
Entry::Vacant(entry) => {
195+
entry.insert(1);
196+
}
197+
}
198+
199+
Ok(())
200+
}
201+
202+
fn release<T, D>(&self, array: &PyArray<T, D>) {
203+
let address = base_address(array);
204+
205+
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
206+
// and we are not calling into user code which might re-enter this function.
207+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
208+
209+
let readers = borrow_flags.get_mut(&address).unwrap();
210+
211+
*readers -= 1;
212+
213+
if *readers == 0 {
214+
borrow_flags.remove(&address).unwrap();
215+
}
216+
}
217+
218+
fn acquire_mut<T, D>(&self, array: &PyArray<T, D>) -> Result<(), BorrowError> {
219+
let address = base_address(array);
220+
221+
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
222+
// and we are not calling into user code which might re-enter this function.
223+
let borrow_flags = unsafe { BORROW_FLAGS.get() };
224+
225+
match borrow_flags.entry(address) {
226+
Entry::Occupied(entry) => {
227+
let writers = entry.into_mut();
228+
229+
if *writers != 0 {
230+
cold();
231+
return Err(BorrowError::AlreadyBorrowed);
232+
}
233+
234+
*writers = -1;
235+
}
236+
Entry::Vacant(entry) => {
237+
entry.insert(-1);
238+
}
239+
}
240+
241+
Ok(())
242+
}
243+
244+
fn release_mut<T, D>(&self, array: &PyArray<T, D>) {
245+
let address = base_address(array);
246+
247+
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
248+
// and we are not calling into user code which might re-enter this function.
249+
let borrow_flags = unsafe { self.get() };
250+
251+
borrow_flags.remove(&address).unwrap();
252+
}
173253
}
174254

175255
static BORROW_FLAGS: BorrowFlags = BorrowFlags::new();
@@ -224,29 +304,7 @@ where
224304
D: Dimension,
225305
{
226306
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Result<Self, BorrowError> {
227-
let address = base_address(array);
228-
229-
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
230-
// and we are not calling into user code which might re-enter this function.
231-
let borrow_flags = unsafe { BORROW_FLAGS.get() };
232-
233-
match borrow_flags.entry(address) {
234-
Entry::Occupied(entry) => {
235-
let readers = entry.into_mut();
236-
237-
let new_readers = readers.wrapping_add(1);
238-
239-
if new_readers <= 0 {
240-
cold();
241-
return Err(BorrowError::AlreadyBorrowed);
242-
}
243-
244-
*readers = new_readers;
245-
}
246-
Entry::Vacant(entry) => {
247-
entry.insert(1);
248-
}
249-
}
307+
BORROW_FLAGS.acquire(array)?;
250308

251309
Ok(Self(array))
252310
}
@@ -287,19 +345,7 @@ where
287345

288346
impl<'a, T, D> Drop for PyReadonlyArray<'a, T, D> {
289347
fn drop(&mut self) {
290-
let address = base_address(self.0);
291-
292-
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
293-
// and we are not calling into user code which might re-enter this function.
294-
let borrow_flags = unsafe { BORROW_FLAGS.get() };
295-
296-
let readers = borrow_flags.get_mut(&address).unwrap();
297-
298-
*readers -= 1;
299-
300-
if *readers == 0 {
301-
borrow_flags.remove(&address).unwrap();
302-
}
348+
BORROW_FLAGS.release(self.0);
303349
}
304350
}
305351

@@ -358,27 +404,7 @@ where
358404
return Err(BorrowError::NotWriteable);
359405
}
360406

361-
let address = base_address(array);
362-
363-
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
364-
// and we are not calling into user code which might re-enter this function.
365-
let borrow_flags = unsafe { BORROW_FLAGS.get() };
366-
367-
match borrow_flags.entry(address) {
368-
Entry::Occupied(entry) => {
369-
let writers = entry.into_mut();
370-
371-
if *writers != 0 {
372-
cold();
373-
return Err(BorrowError::AlreadyBorrowed);
374-
}
375-
376-
*writers = -1;
377-
}
378-
Entry::Vacant(entry) => {
379-
entry.insert(-1);
380-
}
381-
}
407+
BORROW_FLAGS.acquire_mut(array)?;
382408

383409
Ok(Self(array))
384410
}
@@ -407,15 +433,44 @@ where
407433
}
408434
}
409435

410-
impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> {
411-
fn drop(&mut self) {
412-
let address = base_address(self.0);
436+
impl<'py, T> PyReadwriteArray<'py, T, Ix1>
437+
where
438+
T: Element,
439+
{
440+
/// Extends or truncates the length of a one-dimensional array.
441+
///
442+
/// # Example
443+
///
444+
/// ```
445+
/// use numpy::PyArray;
446+
/// use pyo3::Python;
447+
///
448+
/// Python::with_gil(|py| {
449+
/// let pyarray = PyArray::arange(py, 0, 10, 1);
450+
/// assert_eq!(pyarray.len(), 10);
451+
///
452+
/// let pyarray = pyarray.readwrite();
453+
/// let pyarray = pyarray.resize(100).unwrap();
454+
/// assert_eq!(pyarray.len(), 100);
455+
/// });
456+
/// ```
457+
pub fn resize(self, new_elems: usize) -> PyResult<Self> {
458+
BORROW_FLAGS.release_mut(self.0);
459+
460+
// SAFETY: Ownership of `self` proves exclusive access to the interior of the array.
461+
unsafe {
462+
self.0.resize(new_elems)?;
463+
}
413464

414-
// SAFETY: Access to a `&'py PyArray<T, D>` implies holding the GIL
415-
// and we are not calling into user code which might re-enter this function.
416-
let borrow_flags = unsafe { BORROW_FLAGS.get() };
465+
BORROW_FLAGS.acquire_mut(self.0)?;
417466

418-
borrow_flags.remove(&address).unwrap();
467+
Ok(self)
468+
}
469+
}
470+
471+
impl<'a, T, D> Drop for PyReadwriteArray<'a, T, D> {
472+
fn drop(&mut self) {
473+
BORROW_FLAGS.release_mut(self.0);
419474
}
420475
}
421476

src/convert.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ use crate::sealed::Sealed;
2929
/// assert_eq!(py_array.readonly().as_slice().unwrap(), &[1, 2, 3]);
3030
///
3131
/// // Array cannot be resized when its data is owned by Rust.
32-
/// assert!(py_array.resize(100).is_err());
32+
/// unsafe {
33+
/// assert!(py_array.resize(100).is_err());
34+
/// }
3335
/// });
3436
/// ```
3537
pub trait IntoPyArray {

tests/borrow.rs

+14
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,17 @@ fn readwrite_as_array_slice() {
248248
assert_eq!(*array.get_mut([0, 1, 2]).unwrap(), 0.0);
249249
});
250250
}
251+
252+
#[test]
253+
fn resize_using_exclusive_borrow() {
254+
Python::with_gil(|py| {
255+
let array = PyArray::<f64, _>::zeros(py, 3, false);
256+
assert_eq!(array.shape(), [3]);
257+
258+
let mut array = array.readwrite();
259+
assert_eq!(array.as_slice_mut().unwrap(), &[0.0; 3]);
260+
261+
let mut array = array.resize(5).unwrap();
262+
assert_eq!(array.as_slice_mut().unwrap(), &[0.0; 5]);
263+
});
264+
}

tests/to_py.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ fn into_pyarray_cannot_resize() {
161161
Python::with_gil(|py| {
162162
let arr = vec![1, 2, 3].into_pyarray(py);
163163

164-
assert!(arr.resize(100).is_err())
164+
unsafe {
165+
assert!(arr.resize(100).is_err());
166+
}
165167
});
166168
}
167169

0 commit comments

Comments
 (0)