Skip to content

Commit 131bb5f

Browse files
authored
Merge pull request #220 from adamreichold/pyarray-uninit
Make the PyArray::new method unsafe and document what can and cannot be done with its return value.
2 parents 4a8f79c + 11182ec commit 131bb5f

File tree

2 files changed

+74
-47
lines changed

2 files changed

+74
-47
lines changed

src/array.rs

+71-44
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ impl<T, D> PyArray<T, D> {
249249
/// ```
250250
/// use numpy::PyArray3;
251251
/// pyo3::Python::with_gil(|py| {
252-
/// let arr = PyArray3::<f64>::new(py, [4, 5, 6], false);
252+
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
253253
/// assert_eq!(arr.ndim(), 3);
254254
/// });
255255
/// ```
@@ -266,7 +266,7 @@ impl<T, D> PyArray<T, D> {
266266
/// ```
267267
/// use numpy::PyArray3;
268268
/// pyo3::Python::with_gil(|py| {
269-
/// let arr = PyArray3::<f64>::new(py, [4, 5, 6], false);
269+
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
270270
/// assert_eq!(arr.strides(), &[240, 48, 8]);
271271
/// });
272272
/// ```
@@ -287,7 +287,7 @@ impl<T, D> PyArray<T, D> {
287287
/// ```
288288
/// use numpy::PyArray3;
289289
/// pyo3::Python::with_gil(|py| {
290-
/// let arr = PyArray3::<f64>::new(py, [4, 5, 6], false);
290+
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
291291
/// assert_eq!(arr.shape(), &[4, 5, 6]);
292292
/// });
293293
/// ```
@@ -371,20 +371,46 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
371371
///
372372
/// If `is_fortran == true`, returns Fortran-order array. Else, returns C-order array.
373373
///
374+
/// # Safety
375+
///
376+
/// The returned array will always be safe to be dropped as the elements must either
377+
/// be trivially copyable or have `DATA_TYPE == DataType::Object`, i.e. be pointers
378+
/// into Python's heap, which NumPy will automatically zero-initialize.
379+
///
380+
/// However, the elements themselves will not be valid and should only be accessed
381+
/// via raw pointers obtained via [uget_raw](#method.uget_raw).
382+
///
383+
/// All methods which produce references to the elements invoke undefined behaviour.
384+
/// In particular, zero-initialized pointers are _not_ valid instances of `PyObject`.
385+
///
374386
/// # Example
375387
/// ```
376388
/// use numpy::PyArray3;
389+
///
377390
/// pyo3::Python::with_gil(|py| {
378-
/// let arr = PyArray3::<i32>::new(py, [4, 5, 6], false);
391+
/// let arr = unsafe {
392+
/// let arr = PyArray3::<i32>::new(py, [4, 5, 6], false);
393+
///
394+
/// for i in 0..4 {
395+
/// for j in 0..5 {
396+
/// for k in 0..6 {
397+
/// arr.uget_raw([i, j, k]).write((i * j * k) as i32);
398+
/// }
399+
/// }
400+
/// }
401+
///
402+
/// arr
403+
/// };
404+
///
379405
/// assert_eq!(arr.shape(), &[4, 5, 6]);
380406
/// });
381407
/// ```
382-
pub fn new<ID>(py: Python, dims: ID, is_fortran: bool) -> &Self
408+
pub unsafe fn new<ID>(py: Python, dims: ID, is_fortran: bool) -> &Self
383409
where
384410
ID: IntoDimension<Dim = D>,
385411
{
386412
let flags = if is_fortran { 1 } else { 0 };
387-
unsafe { PyArray::new_(py, dims, ptr::null_mut(), flags) }
413+
PyArray::new_(py, dims, ptr::null_mut(), flags)
388414
}
389415

390416
pub(crate) unsafe fn new_<ID>(
@@ -447,6 +473,9 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
447473
/// If `is_fortran` is true, then
448474
/// a fortran order array is created, otherwise a C-order array is created.
449475
///
476+
/// For elements with `DATA_TYPE == DataType::Object`, this will fill the array
477+
/// with valid pointers to zero-valued Python integer objects.
478+
///
450479
/// See also [PyArray_Zeros](https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Zeros)
451480
///
452481
/// # Example
@@ -593,6 +622,16 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
593622
&mut *(self.data().offset(offset) as *mut _)
594623
}
595624

625+
/// Same as [uget](#method.uget), but returns `*mut T`.
626+
#[inline(always)]
627+
pub unsafe fn uget_raw<Idx>(&self, index: Idx) -> *mut T
628+
where
629+
Idx: NpyIndex<Dim = D>,
630+
{
631+
let offset = index.get_unchecked::<T>(self.strides());
632+
self.data().offset(offset) as *mut _
633+
}
634+
596635
/// Get dynamic dimensioned array from fixed dimension array.
597636
pub fn to_dyn(&self) -> &PyArray<T, IxDyn> {
598637
let python = self.py();
@@ -730,20 +769,18 @@ impl<T: Element> PyArray<T, Ix1> {
730769
/// });
731770
/// ```
732771
pub fn from_slice<'py>(py: Python<'py>, slice: &[T]) -> &'py Self {
733-
let array = PyArray::new(py, [slice.len()], false);
734-
if T::DATA_TYPE != DataType::Object {
735-
unsafe {
772+
unsafe {
773+
let array = PyArray::new(py, [slice.len()], false);
774+
if T::DATA_TYPE != DataType::Object {
736775
array.copy_ptr(slice.as_ptr(), slice.len());
737-
}
738-
} else {
739-
unsafe {
776+
} else {
740777
let data_ptr = array.data();
741778
for (i, item) in slice.iter().enumerate() {
742779
data_ptr.add(i).write(item.clone());
743780
}
744781
}
782+
array
745783
}
746-
array
747784
}
748785

749786
/// Construct one-dimension PyArray
@@ -776,20 +813,15 @@ impl<T: Element> PyArray<T, Ix1> {
776813
/// });
777814
/// ```
778815
pub fn from_exact_iter(py: Python<'_>, iter: impl ExactSizeIterator<Item = T>) -> &Self {
779-
// Use zero-initialized pointers for object arrays
780-
// so that partially initialized arrays can be dropped safely
781-
// in case the iterator implementation panics.
782-
let array = if T::DATA_TYPE == DataType::Object {
783-
Self::zeros(py, [iter.len()], false)
784-
} else {
785-
Self::new(py, [iter.len()], false)
786-
};
816+
// NumPy will always zero-initialize object pointers,
817+
// so the array can be dropped safely if the iterator panics.
787818
unsafe {
819+
let array = Self::new(py, [iter.len()], false);
788820
for (i, item) in iter.enumerate() {
789-
*array.uget_mut([i]) = item;
821+
array.uget_raw([i]).write(item);
790822
}
823+
array
791824
}
792-
array
793825
}
794826

795827
/// Construct one-dimension PyArray from a type which implements
@@ -811,16 +843,11 @@ impl<T: Element> PyArray<T, Ix1> {
811843
let iter = iter.into_iter();
812844
let (min_len, max_len) = iter.size_hint();
813845
let mut capacity = max_len.unwrap_or_else(|| min_len.max(512 / mem::size_of::<T>()));
814-
// Use zero-initialized pointers for object arrays
815-
// so that partially initialized arrays can be dropped safely
816-
// in case the iterator implementation panics.
817-
let array = if T::DATA_TYPE == DataType::Object {
818-
Self::zeros(py, [capacity], false)
819-
} else {
820-
Self::new(py, [capacity], false)
821-
};
822-
let mut length = 0;
823846
unsafe {
847+
// NumPy will always zero-initialize object pointers,
848+
// so the array can be dropped safely if the iterator panics.
849+
let array = Self::new(py, [capacity], false);
850+
let mut length = 0;
824851
for (i, item) in iter.enumerate() {
825852
length += 1;
826853
if length > capacity {
@@ -829,13 +856,13 @@ impl<T: Element> PyArray<T, Ix1> {
829856
.resize(capacity)
830857
.expect("PyArray::from_iter: Failed to allocate memory");
831858
}
832-
*array.uget_mut([i]) = item;
859+
array.uget_raw([i]).write(item);
833860
}
861+
if capacity > length {
862+
array.resize(length).unwrap()
863+
}
864+
array
834865
}
835-
if capacity > length {
836-
array.resize(length).unwrap()
837-
}
838-
array
839866
}
840867

841868
/// Extends or trancates the length of 1 dimension PyArray.
@@ -909,15 +936,15 @@ impl<T: Element> PyArray<T, Ix2> {
909936
return Err(FromVecError::new(v.len(), last_len));
910937
}
911938
let dims = [v.len(), last_len];
912-
let array = Self::new(py, dims, false);
913939
unsafe {
940+
let array = Self::new(py, dims, false);
914941
for (y, vy) in v.iter().enumerate() {
915942
for (x, vyx) in vy.iter().enumerate() {
916-
*array.uget_mut([y, x]) = vyx.clone();
943+
array.uget_raw([y, x]).write(vyx.clone());
917944
}
918945
}
946+
Ok(array)
919947
}
920-
Ok(array)
921948
}
922949
}
923950

@@ -951,17 +978,17 @@ impl<T: Element> PyArray<T, Ix3> {
951978
return Err(FromVecError::new(v.len(), len3));
952979
}
953980
let dims = [v.len(), len2, len3];
954-
let array = Self::new(py, dims, false);
955981
unsafe {
982+
let array = Self::new(py, dims, false);
956983
for (z, vz) in v.iter().enumerate() {
957984
for (y, vzy) in vz.iter().enumerate() {
958985
for (x, vzyx) in vzy.iter().enumerate() {
959-
*array.uget_mut([z, y, x]) = vzyx.clone();
986+
array.uget_raw([z, y, x]).write(vzyx.clone());
960987
}
961988
}
962989
}
990+
Ok(array)
963991
}
964-
Ok(array)
965992
}
966993
}
967994

@@ -972,7 +999,7 @@ impl<T: Element, D> PyArray<T, D> {
972999
/// use numpy::PyArray;
9731000
/// pyo3::Python::with_gil(|py| {
9741001
/// let pyarray_f = PyArray::arange(py, 2.0, 5.0, 1.0);
975-
/// let pyarray_i = PyArray::<i64, _>::new(py, [3], false);
1002+
/// let pyarray_i = unsafe { PyArray::<i64, _>::new(py, [3], false) };
9761003
/// assert!(pyarray_f.copy_to(pyarray_i).is_ok());
9771004
/// assert_eq!(pyarray_i.readonly().as_slice().unwrap(), &[2, 3, 4]);
9781005
/// });

tests/array.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ fn not_contiguous_array<'py>(py: Python<'py>) -> &'py PyArray1<i32> {
2525
fn new_c_order() {
2626
let dim = [3, 5];
2727
pyo3::Python::with_gil(|py| {
28-
let arr = PyArray::<f64, _>::new(py, dim, false);
28+
let arr = PyArray::<f64, _>::zeros(py, dim, false);
2929
assert!(arr.ndim() == 2);
3030
assert!(arr.dims() == dim);
3131
let size = std::mem::size_of::<f64>() as isize;
@@ -37,7 +37,7 @@ fn new_c_order() {
3737
fn new_fortran_order() {
3838
let dim = [3, 5];
3939
pyo3::Python::with_gil(|py| {
40-
let arr = PyArray::<f64, _>::new(py, dim, true);
40+
let arr = PyArray::<f64, _>::zeros(py, dim, true);
4141
assert!(arr.ndim() == 2);
4242
assert!(arr.dims() == dim);
4343
let size = std::mem::size_of::<f64>() as isize;
@@ -109,7 +109,7 @@ fn as_slice() {
109109
#[test]
110110
fn is_instance() {
111111
pyo3::Python::with_gil(|py| {
112-
let arr = PyArray2::<f64>::new(py, [3, 5], false);
112+
let arr = PyArray2::<f64>::zeros(py, [3, 5], false);
113113
assert!(arr.is_instance::<PyArray2<f64>>().unwrap());
114114
assert!(!arr.is_instance::<PyList>().unwrap());
115115
})

0 commit comments

Comments
 (0)