Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give PyArray<PyObject> another try. #216

Merged
merged 4 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{cell::Cell, mem, os::raw::c_int, ptr, slice};
use std::{iter::ExactSizeIterator, marker::PhantomData};

use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::dtype::Element;
use crate::dtype::{DataType, Element};
use crate::error::{FromVecError, NotContiguousError, ShapeError};
use crate::slice_box::SliceBox;

Expand Down Expand Up @@ -731,8 +731,17 @@ impl<T: Element> PyArray<T, Ix1> {
/// ```
pub fn from_slice<'py>(py: Python<'py>, slice: &[T]) -> &'py Self {
let array = PyArray::new(py, [slice.len()], false);
unsafe {
array.copy_ptr(slice.as_ptr(), slice.len());
if T::DATA_TYPE != DataType::Object {
unsafe {
array.copy_ptr(slice.as_ptr(), slice.len());
}
} else {
unsafe {
let data_ptr = array.data();
for (i, item) in slice.iter().enumerate() {
data_ptr.add(i).write(item.clone());
}
}
}
array
}
Expand Down Expand Up @@ -767,7 +776,14 @@ impl<T: Element> PyArray<T, Ix1> {
/// });
/// ```
pub fn from_exact_iter(py: Python<'_>, iter: impl ExactSizeIterator<Item = T>) -> &Self {
let array = Self::new(py, [iter.len()], false);
// Use zero-initialized pointers for object arrays
// so that partially initialized arrays can be dropped safely
// in case the iterator implementation panics.
let array = if T::DATA_TYPE == DataType::Object {
Self::zeros(py, [iter.len()], false)
} else {
Self::new(py, [iter.len()], false)
};
unsafe {
for (i, item) in iter.enumerate() {
*array.uget_mut([i]) = item;
Expand Down Expand Up @@ -795,7 +811,14 @@ impl<T: Element> PyArray<T, Ix1> {
let iter = iter.into_iter();
let (min_len, max_len) = iter.size_hint();
let mut capacity = max_len.unwrap_or_else(|| min_len.max(512 / mem::size_of::<T>()));
let array = Self::new(py, [capacity], false);
// Use zero-initialized pointers for object arrays
// so that partially initialized arrays can be dropped safely
// in case the iterator implementation panics.
let array = if T::DATA_TYPE == DataType::Object {
Self::zeros(py, [capacity], false)
} else {
Self::new(py, [capacity], false)
};
let mut length = 0;
unsafe {
for (i, item) in iter.enumerate() {
Expand Down
40 changes: 22 additions & 18 deletions src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{mem, os::raw::c_int};

use crate::{
npyffi::{self, npy_intp},
Element, PyArray,
DataType, Element, PyArray,
};

/// Covnersion trait from some rust types to `PyArray`.
Expand Down Expand Up @@ -130,25 +130,29 @@ where
type Dim = D;
fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
let len = self.len();
if let Some(order) = self.order() {
// if the array is contiguous, copy it by `copy_ptr`.
let strides = self.npy_strides();
unsafe {
let array = PyArray::new_(py, self.raw_dim(), strides.as_ptr(), order.to_flag());
array.copy_ptr(self.as_ptr(), len);
array
match self.order() {
Some(order) if A::DATA_TYPE != DataType::Object => {
// if the array is contiguous, copy it by `copy_ptr`.
let strides = self.npy_strides();
unsafe {
let array =
PyArray::new_(py, self.raw_dim(), strides.as_ptr(), order.to_flag());
array.copy_ptr(self.as_ptr(), len);
array
}
}
} else {
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
let dim = self.raw_dim();
let strides = NpyStrides::from_dim(&dim, mem::size_of::<A>());
unsafe {
let array = PyArray::<A, _>::new_(py, dim, strides.as_ptr(), 0);
let data_ptr = array.data();
for (i, item) in self.iter().enumerate() {
data_ptr.add(i).write(item.clone());
_ => {
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
let dim = self.raw_dim();
let strides = NpyStrides::from_dim(&dim, mem::size_of::<A>());
unsafe {
let array = PyArray::<A, _>::new_(py, dim, strides.as_ptr(), 0);
let data_ptr = array.data();
for (i, item) in self.iter().enumerate() {
data_ptr.add(i).write(item.clone());
}
array
}
array
}
}
}
Expand Down
12 changes: 10 additions & 2 deletions src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ impl DataType {
///
/// A type `T` that implements this trait should be safe when managed in numpy array,
/// thus implementing this trait is marked unsafe.
/// For example, we don't support `PyObject` because of [an odd segfault](https://github.com/PyO3/rust-numpy/pull/143),
/// although numpy itself supports it.
/// This means that all data types except for `DataType::Object` are assumed to be trivially copyable.
/// Furthermore, it is assumed that for `DataType::Object` the elements are pointers into the Python heap
/// and that the corresponding `Clone` implemenation will never panic as it only increases the reference count.
pub unsafe trait Element: Clone + Send {
/// `DataType` corresponding to this type.
const DATA_TYPE: DataType;
Expand Down Expand Up @@ -246,3 +247,10 @@ cfg_if! {
impl_num_element!(u64, Uint64, NPY_ULONG, NPY_ULONGLONG);
}
}

unsafe impl Element for PyObject {
const DATA_TYPE: DataType = DataType::Object;
fn is_same_type(dtype: &PyArrayDescr) -> bool {
dtype.get_typenum() == NPY_TYPES::NPY_OBJECT as i32
}
}
23 changes: 23 additions & 0 deletions tests/to_py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,26 @@ fn forder_into_pyarray() {
pyo3::py_run!(py, fmat_py, "assert fmat_py.flags['F_CONTIGUOUS']")
})
}

#[test]
fn to_pyarray_object_vec() {
use pyo3::{
types::{PyDict, PyString},
ToPyObject,
};
use std::cmp::Ordering;

pyo3::Python::with_gil(|py| {
let dict = PyDict::new(py);
let string = PyString::new(py, "Hello:)");
let vec = vec![dict.to_object(py), string.to_object(py)];
let arr = vec.to_pyarray(py).readonly();

for (a, b) in vec.iter().zip(arr.as_slice().unwrap().iter()) {
assert_eq!(
a.as_ref(py).compare(b).map_err(|e| e.print(py)).unwrap(),
Ordering::Equal
);
}
})
}