Skip to content

Commit b02c6df

Browse files
authored
Merge pull request #216 from adamreichold/object-element
Give PyArray<PyObject> another try.
2 parents 29f2737 + b6e58e2 commit b02c6df

File tree

4 files changed

+113
-25
lines changed

4 files changed

+113
-25
lines changed

src/array.rs

+28-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use std::{cell::Cell, mem, os::raw::c_int, ptr, slice};
1010
use std::{iter::ExactSizeIterator, marker::PhantomData};
1111

1212
use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
13-
use crate::dtype::Element;
13+
use crate::dtype::{DataType, Element};
1414
use crate::error::{FromVecError, NotContiguousError, ShapeError};
1515
use crate::slice_box::SliceBox;
1616

@@ -731,8 +731,17 @@ impl<T: Element> PyArray<T, Ix1> {
731731
/// ```
732732
pub fn from_slice<'py>(py: Python<'py>, slice: &[T]) -> &'py Self {
733733
let array = PyArray::new(py, [slice.len()], false);
734-
unsafe {
735-
array.copy_ptr(slice.as_ptr(), slice.len());
734+
if T::DATA_TYPE != DataType::Object {
735+
unsafe {
736+
array.copy_ptr(slice.as_ptr(), slice.len());
737+
}
738+
} else {
739+
unsafe {
740+
let data_ptr = array.data();
741+
for (i, item) in slice.iter().enumerate() {
742+
data_ptr.add(i).write(item.clone());
743+
}
744+
}
736745
}
737746
array
738747
}
@@ -767,7 +776,14 @@ impl<T: Element> PyArray<T, Ix1> {
767776
/// });
768777
/// ```
769778
pub fn from_exact_iter(py: Python<'_>, iter: impl ExactSizeIterator<Item = T>) -> &Self {
770-
let array = Self::new(py, [iter.len()], false);
779+
// Use zero-initialized pointers for object arrays
780+
// so that partially initialized arrays can be dropped safely
781+
// in case the iterator implementation panics.
782+
let array = if T::DATA_TYPE == DataType::Object {
783+
Self::zeros(py, [iter.len()], false)
784+
} else {
785+
Self::new(py, [iter.len()], false)
786+
};
771787
unsafe {
772788
for (i, item) in iter.enumerate() {
773789
*array.uget_mut([i]) = item;
@@ -795,7 +811,14 @@ impl<T: Element> PyArray<T, Ix1> {
795811
let iter = iter.into_iter();
796812
let (min_len, max_len) = iter.size_hint();
797813
let mut capacity = max_len.unwrap_or_else(|| min_len.max(512 / mem::size_of::<T>()));
798-
let array = Self::new(py, [capacity], false);
814+
// Use zero-initialized pointers for object arrays
815+
// so that partially initialized arrays can be dropped safely
816+
// in case the iterator implementation panics.
817+
let array = if T::DATA_TYPE == DataType::Object {
818+
Self::zeros(py, [capacity], false)
819+
} else {
820+
Self::new(py, [capacity], false)
821+
};
799822
let mut length = 0;
800823
unsafe {
801824
for (i, item) in iter.enumerate() {

src/convert.rs

+22-18
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{mem, os::raw::c_int};
77

88
use crate::{
99
npyffi::{self, npy_intp},
10-
Element, PyArray,
10+
DataType, Element, PyArray,
1111
};
1212

1313
/// Covnersion trait from some rust types to `PyArray`.
@@ -130,25 +130,29 @@ where
130130
type Dim = D;
131131
fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
132132
let len = self.len();
133-
if let Some(order) = self.order() {
134-
// if the array is contiguous, copy it by `copy_ptr`.
135-
let strides = self.npy_strides();
136-
unsafe {
137-
let array = PyArray::new_(py, self.raw_dim(), strides.as_ptr(), order.to_flag());
138-
array.copy_ptr(self.as_ptr(), len);
139-
array
133+
match self.order() {
134+
Some(order) if A::DATA_TYPE != DataType::Object => {
135+
// if the array is contiguous, copy it by `copy_ptr`.
136+
let strides = self.npy_strides();
137+
unsafe {
138+
let array =
139+
PyArray::new_(py, self.raw_dim(), strides.as_ptr(), order.to_flag());
140+
array.copy_ptr(self.as_ptr(), len);
141+
array
142+
}
140143
}
141-
} else {
142-
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
143-
let dim = self.raw_dim();
144-
let strides = NpyStrides::from_dim(&dim, mem::size_of::<A>());
145-
unsafe {
146-
let array = PyArray::<A, _>::new_(py, dim, strides.as_ptr(), 0);
147-
let data_ptr = array.data();
148-
for (i, item) in self.iter().enumerate() {
149-
data_ptr.add(i).write(item.clone());
144+
_ => {
145+
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
146+
let dim = self.raw_dim();
147+
let strides = NpyStrides::from_dim(&dim, mem::size_of::<A>());
148+
unsafe {
149+
let array = PyArray::<A, _>::new_(py, dim, strides.as_ptr(), 0);
150+
let data_ptr = array.data();
151+
for (i, item) in self.iter().enumerate() {
152+
data_ptr.add(i).write(item.clone());
153+
}
154+
array
150155
}
151-
array
152156
}
153157
}
154158
}

src/dtype.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,9 @@ impl DataType {
179179
///
180180
/// A type `T` that implements this trait should be safe when managed in numpy array,
181181
/// thus implementing this trait is marked unsafe.
182-
/// For example, we don't support `PyObject` because of [an odd segfault](https://github.com/PyO3/rust-numpy/pull/143),
183-
/// although numpy itself supports it.
182+
/// This means that all data types except for `DataType::Object` are assumed to be trivially copyable.
183+
/// Furthermore, it is assumed that for `DataType::Object` the elements are pointers into the Python heap
184+
/// and that the corresponding `Clone` implemenation will never panic as it only increases the reference count.
184185
pub unsafe trait Element: Clone + Send {
185186
/// `DataType` corresponding to this type.
186187
const DATA_TYPE: DataType;
@@ -246,3 +247,10 @@ cfg_if! {
246247
impl_num_element!(u64, Uint64, NPY_ULONG, NPY_ULONGLONG);
247248
}
248249
}
250+
251+
unsafe impl Element for PyObject {
252+
const DATA_TYPE: DataType = DataType::Object;
253+
fn is_same_type(dtype: &PyArrayDescr) -> bool {
254+
dtype.get_typenum() == NPY_TYPES::NPY_OBJECT as i32
255+
}
256+
}

tests/to_py.rs

+53
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,56 @@ fn forder_into_pyarray() {
169169
pyo3::py_run!(py, fmat_py, "assert fmat_py.flags['F_CONTIGUOUS']")
170170
})
171171
}
172+
173+
#[test]
174+
fn to_pyarray_object_vec() {
175+
use pyo3::{
176+
types::{PyDict, PyString},
177+
ToPyObject,
178+
};
179+
use std::cmp::Ordering;
180+
181+
pyo3::Python::with_gil(|py| {
182+
let dict = PyDict::new(py);
183+
let string = PyString::new(py, "Hello:)");
184+
let vec = vec![dict.to_object(py), string.to_object(py)];
185+
let arr = vec.to_pyarray(py).readonly();
186+
187+
for (a, b) in vec.iter().zip(arr.as_slice().unwrap().iter()) {
188+
assert_eq!(
189+
a.as_ref(py).compare(b).map_err(|e| e.print(py)).unwrap(),
190+
Ordering::Equal
191+
);
192+
}
193+
})
194+
}
195+
196+
#[test]
197+
fn to_pyarray_object_array() {
198+
use ndarray::Array2;
199+
use pyo3::{
200+
types::{PyDict, PyString},
201+
ToPyObject,
202+
};
203+
use std::cmp::Ordering;
204+
205+
pyo3::Python::with_gil(|py| {
206+
let mut nd_arr = Array2::from_shape_fn((2, 3), |(_, _)| py.None());
207+
nd_arr[(0, 2)] = PyDict::new(py).to_object(py);
208+
nd_arr[(1, 0)] = PyString::new(py, "Hello:)").to_object(py);
209+
210+
let py_arr = nd_arr.to_pyarray(py).readonly();
211+
212+
for (a, b) in nd_arr
213+
.as_slice()
214+
.unwrap()
215+
.iter()
216+
.zip(py_arr.as_slice().unwrap().iter())
217+
{
218+
assert_eq!(
219+
a.as_ref(py).compare(b).map_err(|e| e.print(py)).unwrap(),
220+
Ordering::Equal
221+
);
222+
}
223+
})
224+
}

0 commit comments

Comments
 (0)