diff --git a/CHANGELOG.md b/CHANGELOG.md index 9aedebfaf..f5dec9b3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,17 @@ - Support borrowing arrays that are part of other Python objects via `PyArray::borrow_from_array` ([#230](https://github.com/PyO3/rust-numpy/pull/216)) - `PyArray::new` is now `unsafe`, as it produces uninitialized arrays ([#220](https://github.com/PyO3/rust-numpy/pull/220)) - `rayon` feature is now removed, and directly specifying the feature via `ndarray` dependency is recommended ([#250](https://github.com/PyO3/rust-numpy/pull/250)) + - Descriptors rework and related changes ([#256](https://github.com/PyO3/rust-numpy/pull/256)): + - Remove `DataType` + - Add the top-level `dtype` function for easy access to registered dtypes + - Add `PyArrayDescr::of`, `PyArrayDescr::into_dtype_ptr` and `PyArrayDescr::is_equiv_to` + - `Element` trait has been simplified to just `IS_COPY` const and `get_dtype` method + - `Element` is now implemented for `isize` + - `c32` and `c64` aliases have been replaced with `Complex32` and `Complex64` + - `ShapeError` has been split into `TypeError` and `DimensionalityError` + - `i32`, `i64`, `u32` and `u64` are now guaranteed to map to + `np.int32`, `np.int64`, `np.uint32` and `np.uint64` respectively + - Remove `cfg_if` dependency - v0.15.1 - Make arrays produced via `IntoPyArray`, i.e. those owning Rust data, writeable ([#235](https://github.com/PyO3/rust-numpy/pull/235)) diff --git a/Cargo.toml b/Cargo.toml index 8981aba5c..6bcb761e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ keywords = ["python", "numpy", "ffi", "pyo3"] license = "BSD-2-Clause" [dependencies] -cfg-if = "1.0" libc = "0.2" num-complex = ">= 0.2, <= 0.4" num-traits = "0.2" diff --git a/examples/simple-extension/src/lib.rs b/examples/simple-extension/src/lib.rs index 321377136..9bfa847e9 100644 --- a/examples/simple-extension/src/lib.rs +++ b/examples/simple-extension/src/lib.rs @@ -1,5 +1,5 @@ use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD}; -use numpy::{c64, IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn}; +use numpy::{Complex64, IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn}; use pyo3::prelude::{pymodule, PyModule, PyResult, Python}; #[pymodule] @@ -15,7 +15,7 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { } // complex example - fn conj(x: ArrayViewD<'_, c64>) -> ArrayD { + fn conj(x: ArrayViewD<'_, Complex64>) -> ArrayD { x.map(|c| c.conj()) } @@ -44,7 +44,10 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { // wrapper of `conj` #[pyfn(m)] #[pyo3(name = "conj")] - fn conj_py<'py>(py: Python<'py>, x: PyReadonlyArrayDyn<'_, c64>) -> &'py PyArrayDyn { + fn conj_py<'py>( + py: Python<'py>, + x: PyReadonlyArrayDyn<'_, Complex64>, + ) -> &'py PyArrayDyn { conj(x.as_array()).into_pyarray(py) } diff --git a/src/array.rs b/src/array.rs index bb740bc96..3463b0c3d 100644 --- a/src/array.rs +++ b/src/array.rs @@ -15,8 +15,8 @@ use std::{ use std::{iter::ExactSizeIterator, marker::PhantomData}; use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; -use crate::dtype::{DataType, Element}; -use crate::error::{FromVecError, NotContiguousError, ShapeError}; +use crate::dtype::Element; +use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError}; use crate::slice_container::PySliceContainer; /// A safe, static-typed interface for @@ -136,13 +136,21 @@ impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray { } &*(ob as *const PyAny as *const PyArray) }; - let dtype = array.dtype(); - let dim = array.shape().len(); - if T::is_same_type(dtype) && D::NDIM.map(|n| n == dim).unwrap_or(true) { - Ok(array) - } else { - Err(ShapeError::new(dtype, dim, T::DATA_TYPE, D::NDIM).into()) + + let src_dtype = array.dtype(); + let dst_dtype = T::get_dtype(ob.py()); + if !src_dtype.is_equiv_to(dst_dtype) { + return Err(TypeError::new(src_dtype, dst_dtype).into()); + } + + let src_ndim = array.shape().len(); + if let Some(dst_ndim) = D::NDIM { + if src_ndim != dst_ndim { + return Err(DimensionalityError::new(src_ndim, dst_ndim).into()); + } } + + Ok(array) } } @@ -160,7 +168,7 @@ impl PyArray { /// pyo3::Python::with_gil(|py| { /// let array = numpy::PyArray::from_vec(py, vec![1, 2, 3i32]); /// let dtype = array.dtype(); - /// assert_eq!(dtype.get_datatype().unwrap(), numpy::DataType::Int32); + /// assert!(dtype.is_equiv_to(numpy::dtype::(py))); /// }); /// ``` pub fn dtype(&self) -> &crate::PyArrayDescr { @@ -428,14 +436,13 @@ impl PyArray { ID: IntoDimension, { let dims = dims.into_dimension(); - let ptr = PY_ARRAY_API.PyArray_New( + let ptr = PY_ARRAY_API.PyArray_NewFromDescr( PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type), + T::get_dtype(py).into_dtype_ptr(), dims.ndim_cint(), dims.as_dims_ptr(), - T::npy_type() as c_int, strides as *mut npy_intp, // strides ptr::null_mut(), // data - 0, // itemsize flag, // flag ptr::null_mut(), // obj ); @@ -453,16 +460,15 @@ impl PyArray { ID: IntoDimension, { let dims = dims.into_dimension(); - let ptr = PY_ARRAY_API.PyArray_New( + let ptr = PY_ARRAY_API.PyArray_NewFromDescr( PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type), + T::get_dtype(py).into_dtype_ptr(), dims.ndim_cint(), dims.as_dims_ptr(), - T::npy_type() as c_int, - strides as *mut npy_intp, // strides - data_ptr as *mut c_void, // data - mem::size_of::() as c_int, // itemsize - npyffi::NPY_ARRAY_WRITEABLE, // flag - ptr::null_mut(), // obj + strides as *mut npy_intp, // strides + data_ptr as *mut c_void, // data + npyffi::NPY_ARRAY_WRITEABLE, // flag + ptr::null_mut(), // obj ); PY_ARRAY_API.PyArray_SetBaseObject( @@ -569,11 +575,10 @@ impl PyArray { { let dims = dims.into_dimension(); unsafe { - let dtype = T::get_dtype(py); let ptr = PY_ARRAY_API.PyArray_Zeros( dims.ndim_cint(), dims.as_dims_ptr(), - dtype.into_ptr() as _, + T::get_dtype(py).into_dtype_ptr(), if is_fortran { -1 } else { 0 }, ); Self::from_owned_ptr(py, ptr) @@ -847,7 +852,7 @@ impl PyArray { pub fn from_slice<'py>(py: Python<'py>, slice: &[T]) -> &'py Self { unsafe { let array = PyArray::new(py, [slice.len()], false); - if T::DATA_TYPE != DataType::Object { + if T::IS_COPY { array.copy_ptr(slice.as_ptr(), slice.len()); } else { let data_ptr = array.data(); @@ -1104,10 +1109,9 @@ impl PyArray { /// ``` pub fn cast<'py, U: Element>(&'py self, is_fortran: bool) -> PyResult<&'py PyArray> { let ptr = unsafe { - let dtype = U::get_dtype(self.py()); PY_ARRAY_API.PyArray_CastToType( self.as_array_ptr(), - dtype.into_ptr() as _, + U::get_dtype(self.py()).into_dtype_ptr(), if is_fortran { -1 } else { 0 }, ) }; @@ -1193,7 +1197,7 @@ impl> PyArray { start.as_(), stop.as_(), step.as_(), - T::npy_type() as i32, + T::get_dtype(py).get_typenum(), ); Self::from_owned_ptr(py, ptr) } @@ -1221,4 +1225,16 @@ mod tests { array.to_dyn().to_owned_array(); }) } + + #[test] + fn test_hasobject_flag() { + use super::ToPyArray; + use pyo3::{py_run, types::PyList, Py, PyAny}; + + pyo3::Python::with_gil(|py| { + let a = ndarray::Array2::from_shape_fn((2, 3), |(_i, _j)| PyList::empty(py).into()); + let arr: &PyArray, _> = a.to_pyarray(py); + py_run!(py, arr, "assert arr.dtype.hasobject"); + }); + } } diff --git a/src/convert.rs b/src/convert.rs index 5ea649458..b7e585570 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -7,7 +7,7 @@ use std::{mem, os::raw::c_int}; use crate::{ npyffi::{self, npy_intp}, - DataType, Element, PyArray, + Element, PyArray, }; /// Conversion trait from some rust types to `PyArray`. @@ -123,7 +123,7 @@ where fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray { let len = self.len(); match self.order() { - Some(order) if A::DATA_TYPE != DataType::Object => { + Some(order) if A::IS_COPY => { // if the array is contiguous, copy it by `copy_ptr`. let strides = self.npy_strides(); unsafe { diff --git a/src/dtype.rs b/src/dtype.rs index 47f617d1c..e7b744917 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -1,10 +1,12 @@ -use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API}; -use cfg_if::cfg_if; +use std::mem::size_of; +use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort}; + +use num_traits::{Bounded, Zero}; use pyo3::{ffi, prelude::*, pyobject_native_type_core, types::PyType, AsPyPointer, PyNativeType}; -use std::os::raw::c_int; -pub use num_complex::Complex32 as c32; -pub use num_complex::Complex64 as c64; +use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API}; + +pub use num_complex::{Complex32, Complex64}; /// Binding of [`numpy.dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html). /// @@ -18,7 +20,7 @@ pub use num_complex::Complex64 as c64; /// .unwrap() /// .downcast() /// .unwrap(); -/// assert_eq!(dtype.get_datatype().unwrap(), numpy::DataType::Float64); +/// assert!(dtype.is_equiv_to(numpy::dtype::(py))); /// }); /// ``` pub struct PyArrayDescr(PyAny); @@ -37,12 +39,24 @@ unsafe fn arraydescr_check(op: *mut ffi::PyObject) -> c_int { ) } +/// Returns the type descriptor ("dtype") for a registered type. +pub fn dtype(py: Python) -> &PyArrayDescr { + T::get_dtype(py) +} + impl PyArrayDescr { /// Returns `self` as `*mut PyArray_Descr`. pub fn as_dtype_ptr(&self) -> *mut PyArray_Descr { self.as_ptr() as _ } + /// Returns `self` as `*mut PyArray_Descr` while increasing the reference count. + /// + /// Useful in cases where the descriptor is stolen by the API. + pub fn into_dtype_ptr(&self) -> *mut PyArray_Descr { + self.into_ptr() as _ + } + /// Returns the internal `PyType` that this `dtype` holds. /// /// # Example @@ -58,112 +72,33 @@ impl PyArrayDescr { unsafe { PyType::from_type_ptr(self.py(), dtype_type_ptr) } } - /// Returns the data type as `DataType` enum. - pub fn get_datatype(&self) -> Option { - DataType::from_typenum(self.get_typenum()) - } - - fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self { - unsafe { - let descr = PY_ARRAY_API.PyArray_DescrFromType(npy_type as i32); - py.from_owned_ptr(descr as _) - } + /// Shortcut for creating a descriptor of 'object' type. + pub fn object(py: Python) -> &Self { + Self::from_npy_type(py, NPY_TYPES::NPY_OBJECT) } - fn get_typenum(&self) -> std::os::raw::c_int { - unsafe { *self.as_dtype_ptr() }.type_num + /// Returns the type descriptor ("dtype") for a registered type. + pub fn of(py: Python) -> &Self { + T::get_dtype(py) } -} -/// Represents numpy data type. -/// -/// This is an incomplete counterpart of -/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types) -/// in numpy C-API. -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum DataType { - Bool, - Int8, - Int16, - Int32, - Int64, - Uint8, - Uint16, - Uint32, - Uint64, - Float32, - Float64, - Complex32, - Complex64, - Object, -} - -impl DataType { - /// Construct `DataType` from - /// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types). - pub fn from_typenum(typenum: c_int) -> Option { - Some(match typenum { - x if x == NPY_TYPES::NPY_BOOL as i32 => DataType::Bool, - x if x == NPY_TYPES::NPY_BYTE as i32 => DataType::Int8, - x if x == NPY_TYPES::NPY_SHORT as i32 => DataType::Int16, - x if x == NPY_TYPES::NPY_INT as i32 => DataType::Int32, - x if x == NPY_TYPES::NPY_LONG as i32 => return DataType::from_clong(false), - x if x == NPY_TYPES::NPY_LONGLONG as i32 => DataType::Int64, - x if x == NPY_TYPES::NPY_UBYTE as i32 => DataType::Uint8, - x if x == NPY_TYPES::NPY_USHORT as i32 => DataType::Uint16, - x if x == NPY_TYPES::NPY_UINT as i32 => DataType::Uint32, - x if x == NPY_TYPES::NPY_ULONG as i32 => return DataType::from_clong(true), - x if x == NPY_TYPES::NPY_ULONGLONG as i32 => DataType::Uint64, - x if x == NPY_TYPES::NPY_FLOAT as i32 => DataType::Float32, - x if x == NPY_TYPES::NPY_DOUBLE as i32 => DataType::Float64, - x if x == NPY_TYPES::NPY_CFLOAT as i32 => DataType::Complex32, - x if x == NPY_TYPES::NPY_CDOUBLE as i32 => DataType::Complex64, - x if x == NPY_TYPES::NPY_OBJECT as i32 => DataType::Object, - _ => return None, - }) + /// Returns true if two type descriptors are equivalent. + pub fn is_equiv_to(&self, other: &Self) -> bool { + unsafe { PY_ARRAY_API.PyArray_EquivTypes(self.as_dtype_ptr(), other.as_dtype_ptr()) != 0 } } - /// Convert `self` into - /// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types). - pub fn into_ctype(self) -> NPY_TYPES { - match self { - DataType::Bool => NPY_TYPES::NPY_BOOL, - DataType::Int8 => NPY_TYPES::NPY_BYTE, - DataType::Int16 => NPY_TYPES::NPY_SHORT, - DataType::Int32 => NPY_TYPES::NPY_INT, - #[cfg(all(target_pointer_width = "64", not(windows)))] - DataType::Int64 => NPY_TYPES::NPY_LONG, - #[cfg(any(target_pointer_width = "32", windows))] - DataType::Int64 => NPY_TYPES::NPY_LONGLONG, - DataType::Uint8 => NPY_TYPES::NPY_UBYTE, - DataType::Uint16 => NPY_TYPES::NPY_USHORT, - DataType::Uint32 => NPY_TYPES::NPY_UINT, - DataType::Uint64 => NPY_TYPES::NPY_ULONGLONG, - DataType::Float32 => NPY_TYPES::NPY_FLOAT, - DataType::Float64 => NPY_TYPES::NPY_DOUBLE, - DataType::Complex32 => NPY_TYPES::NPY_CFLOAT, - DataType::Complex64 => NPY_TYPES::NPY_CDOUBLE, - DataType::Object => NPY_TYPES::NPY_OBJECT, + fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self { + unsafe { + let descr = PY_ARRAY_API.PyArray_DescrFromType(npy_type as _); + py.from_owned_ptr(descr as _) } } - #[inline(always)] - fn from_clong(is_usize: bool) -> Option { - if cfg!(any(target_pointer_width = "32", windows)) { - Some(if is_usize { - DataType::Uint32 - } else { - DataType::Int32 - }) - } else if cfg!(all(target_pointer_width = "64", not(windows))) { - Some(if is_usize { - DataType::Uint64 - } else { - DataType::Int64 - }) - } else { - None - } + /// Retrieves the + /// [enumerated type](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types) + /// for this type descriptor. + pub fn get_typenum(&self) -> c_int { + unsafe { *self.as_dtype_ptr() }.type_num } } @@ -176,11 +111,14 @@ impl DataType { /// /// # Safety /// -/// A type `T` that implements this trait should be safe when managed in numpy array, -/// thus implementing this trait is marked unsafe. -/// This means that all data types except for `DataType::Object` are assumed to be trivially copyable. -/// Furthermore, it is assumed that for `DataType::Object` the elements are pointers into the Python heap -/// and that the corresponding `Clone` implemenation will never panic as it only increases the reference count. +/// A type `T` that implements this trait should be safe when managed in numpy +/// array, thus implementing this trait is marked unsafe. Data types that don't +/// contain Python objects (i.e., either the object type itself or record types +/// containing object-type fields) are assumed to be trivially copyable, which +/// is reflected in the `IS_COPY` flag. Furthermore, it is assumed that for +/// the object type the elements are pointers into the Python heap and that the +/// corresponding `Clone` implemenation will never panic as it only increases +/// the reference count. /// /// # Custom element types /// @@ -188,7 +126,7 @@ impl DataType { /// on Python's heap using PyO3's [Py](pyo3::Py) type. /// /// ``` -/// use numpy::{ndarray::Array2, DataType, Element, PyArray, PyArrayDescr, ToPyArray}; +/// use numpy::{ndarray::Array2, Element, PyArray, PyArrayDescr, ToPyArray}; /// use pyo3::{pyclass, Py, Python}; /// /// #[pyclass] @@ -201,10 +139,10 @@ impl DataType { /// pub struct Wrapper(pub Py); /// /// unsafe impl Element for Wrapper { -/// const DATA_TYPE: DataType = DataType::Object; +/// const IS_COPY: bool = false; /// -/// fn is_same_type(dtype: &PyArrayDescr) -> bool { -/// dtype.get_datatype() == Some(DataType::Object) +/// fn get_dtype(py: Python) -> &PyArrayDescr { +/// PyArrayDescr::object(py) /// } /// } /// @@ -217,74 +155,136 @@ impl DataType { /// }); /// ``` pub unsafe trait Element: Clone + Send { - /// `DataType` corresponding to this type. - const DATA_TYPE: DataType; + /// Flag that indicates whether this type is trivially copyable. + /// + /// It should be set to true for all trivially copyable types (like scalar types + /// and record/array types only containing trivially copyable fields and elements). + /// + /// This flag should *always* be set to `false` for object types or record types + /// that contain object-type fields. + const IS_COPY: bool; - /// Returns if the give `dtype` is convertible to `Self` in Rust. - fn is_same_type(dtype: &PyArrayDescr) -> bool; + /// Returns the associated array descriptor ("dtype") for the given type. + fn get_dtype(py: Python) -> &PyArrayDescr; +} - /// Returns the corresponding - /// [Enumerated Type](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types). - #[inline] - fn npy_type() -> NPY_TYPES { - Self::DATA_TYPE.into_ctype() +fn npy_int_type_lookup(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES { + // `npy_common.h` defines the integer aliases. In order, it checks: + // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR + // and assigns the alias to the first matching size, so we should check in this order. + match size_of::() { + x if x == size_of::() => npy_types[0], + x if x == size_of::() => npy_types[1], + x if x == size_of::() => npy_types[2], + _ => panic!("Unable to match integer type descriptor: {:?}", npy_types), } +} - /// Create `dtype`. - fn get_dtype(py: Python) -> &PyArrayDescr { - PyArrayDescr::from_npy_type(py, Self::npy_type()) +fn npy_int_type() -> NPY_TYPES { + let is_unsigned = T::min_value() == T::zero(); + let bit_width = size_of::() << 3; + + match (is_unsigned, bit_width) { + (false, 8) => NPY_TYPES::NPY_BYTE, + (false, 16) => NPY_TYPES::NPY_SHORT, + (false, 32) => npy_int_type_lookup::([ + NPY_TYPES::NPY_LONG, + NPY_TYPES::NPY_INT, + NPY_TYPES::NPY_SHORT, + ]), + (false, 64) => npy_int_type_lookup::([ + NPY_TYPES::NPY_LONG, + NPY_TYPES::NPY_LONGLONG, + NPY_TYPES::NPY_INT, + ]), + (true, 8) => NPY_TYPES::NPY_UBYTE, + (true, 16) => NPY_TYPES::NPY_USHORT, + (true, 32) => npy_int_type_lookup::([ + NPY_TYPES::NPY_ULONG, + NPY_TYPES::NPY_UINT, + NPY_TYPES::NPY_USHORT, + ]), + (true, 64) => npy_int_type_lookup::([ + NPY_TYPES::NPY_ULONG, + NPY_TYPES::NPY_ULONGLONG, + NPY_TYPES::NPY_UINT, + ]), + _ => unreachable!(), } } -macro_rules! impl_num_element { - ($t:ty, $npy_dat_t:ident $(,$npy_types: ident)+) => { - unsafe impl Element for $t { - const DATA_TYPE: DataType = DataType::$npy_dat_t; - fn is_same_type(dtype: &PyArrayDescr) -> bool { - $(dtype.get_typenum() == NPY_TYPES::$npy_types as i32 ||)+ false +macro_rules! impl_element_scalar { + (@impl: $ty:ty, $npy_type:expr $(,#[$meta:meta])*) => { + $(#[$meta])* + unsafe impl Element for $ty { + const IS_COPY: bool = true; + fn get_dtype(py: Python) -> &PyArrayDescr { + PyArrayDescr::from_npy_type(py, $npy_type) } } }; + ($ty:ty => $npy_type:ident $(,#[$meta:meta])*) => { + impl_element_scalar!(@impl: $ty, NPY_TYPES::$npy_type $(,#[$meta])*); + }; + ($($tys:ty),+) => { + $(impl_element_scalar!(@impl: $tys, npy_int_type::<$tys>());)+ + }; } -impl_num_element!(bool, Bool, NPY_BOOL); -impl_num_element!(i8, Int8, NPY_BYTE); -impl_num_element!(i16, Int16, NPY_SHORT); -impl_num_element!(u8, Uint8, NPY_UBYTE); -impl_num_element!(u16, Uint16, NPY_USHORT); -impl_num_element!(f32, Float32, NPY_FLOAT); -impl_num_element!(f64, Float64, NPY_DOUBLE); -impl_num_element!(c32, Complex32, NPY_CFLOAT); -impl_num_element!(c64, Complex64, NPY_CDOUBLE); +impl_element_scalar!(bool => NPY_BOOL); +impl_element_scalar!(i8, i16, i32, i64); +impl_element_scalar!(u8, u16, u32, u64); +impl_element_scalar!(f32 => NPY_FLOAT); +impl_element_scalar!(f64 => NPY_DOUBLE); +impl_element_scalar!(Complex32 => NPY_CFLOAT, + #[doc = "Complex type with `f32` components which maps to `np.csingle` (`np.complex64`)."]); +impl_element_scalar!(Complex64 => NPY_CDOUBLE, + #[doc = "Complex type with `f64` components which maps to `np.cdouble` (`np.complex128`)."]); -cfg_if! { - if #[cfg(all(target_pointer_width = "64", windows))] { - impl_num_element!(usize, Uint64, NPY_ULONGLONG); - } else if #[cfg(all(target_pointer_width = "64", not(windows)))] { - impl_num_element!(usize, Uint64, NPY_ULONG, NPY_ULONGLONG); - } else if #[cfg(all(target_pointer_width = "32", windows))] { - impl_num_element!(usize, Uint32, NPY_UINT, NPY_ULONG); - } else if #[cfg(all(target_pointer_width = "32", not(windows)))] { - impl_num_element!(usize, Uint32, NPY_UINT); - } -} -cfg_if! { - if #[cfg(any(target_pointer_width = "32", windows))] { - impl_num_element!(i32, Int32, NPY_INT, NPY_LONG); - impl_num_element!(u32, Uint32, NPY_UINT, NPY_ULONG); - impl_num_element!(i64, Int64, NPY_LONGLONG); - impl_num_element!(u64, Uint64, NPY_ULONGLONG); - } else if #[cfg(all(target_pointer_width = "64", not(windows)))] { - impl_num_element!(i32, Int32, NPY_INT); - impl_num_element!(u32, Uint32, NPY_UINT); - impl_num_element!(i64, Int64, NPY_LONG, NPY_LONGLONG); - impl_num_element!(u64, Uint64, NPY_ULONG, NPY_ULONGLONG); +#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] +impl_element_scalar!(usize, isize); + +unsafe impl Element for PyObject { + const IS_COPY: bool = false; + + fn get_dtype(py: Python) -> &PyArrayDescr { + PyArrayDescr::object(py) } } -unsafe impl Element for PyObject { - const DATA_TYPE: DataType = DataType::Object; - fn is_same_type(dtype: &PyArrayDescr) -> bool { - dtype.get_typenum() == NPY_TYPES::NPY_OBJECT as i32 +#[cfg(test)] +mod tests { + use super::{dtype, Complex32, Complex64, Element}; + + #[test] + fn test_dtype_names() { + fn type_name(py: pyo3::Python) -> &str { + dtype::(py).get_type().name().unwrap() + } + pyo3::Python::with_gil(|py| { + assert_eq!(type_name::(py), "bool_"); + assert_eq!(type_name::(py), "int8"); + assert_eq!(type_name::(py), "int16"); + assert_eq!(type_name::(py), "int32"); + assert_eq!(type_name::(py), "int64"); + assert_eq!(type_name::(py), "uint8"); + assert_eq!(type_name::(py), "uint16"); + assert_eq!(type_name::(py), "uint32"); + assert_eq!(type_name::(py), "uint64"); + assert_eq!(type_name::(py), "float32"); + assert_eq!(type_name::(py), "float64"); + assert_eq!(type_name::(py), "complex64"); + assert_eq!(type_name::(py), "complex128"); + #[cfg(target_pointer_width = "32")] + { + assert_eq!(type_name::(py), "uint32"); + assert_eq!(type_name::(py), "int32"); + } + #[cfg(target_pointer_width = "64")] + { + assert_eq!(type_name::(py), "uint64"); + assert_eq!(type_name::(py), "int64"); + } + }); } } diff --git a/src/error.rs b/src/error.rs index 2eaca331d..c69df25b4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,62 +1,10 @@ //! Defines error types. -use crate::DataType; -use pyo3::{exceptions as exc, PyErr, PyErrArguments, PyObject, Python, ToPyObject}; -use std::fmt; - -/// Represents a dimension and dtype of numpy array. -/// -/// Only for error formatting. -#[derive(Debug)] -pub(crate) struct ArrayDim { - dim: Option, - dtype: Option, -} -impl fmt::Display for ArrayDim { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let ArrayDim { dim, dtype } = self; - match (dim, dtype) { - (Some(dim), Some(dtype)) => write!(f, "dim={:?}, dtype={:?}", dim, dtype), - (None, Some(dtype)) => write!(f, "dim=_, dtype={:?}", dtype), - (Some(dim), None) => write!(f, "dim={:?}, dtype=Unknown", dim), - (None, None) => write!(f, "dim=_, dtype=Unknown"), - } - } -} - -/// Represents that shapes of the given arrays don't match. -#[derive(Debug)] -pub struct ShapeError { - from: ArrayDim, - to: ArrayDim, -} +use std::fmt; -impl ShapeError { - pub(crate) fn new( - from_dtype: &crate::PyArrayDescr, - from_dim: usize, - to_type: DataType, - to_dim: Option, - ) -> Self { - ShapeError { - from: ArrayDim { - dim: Some(from_dim), - dtype: from_dtype.get_datatype(), - }, - to: ArrayDim { - dim: to_dim, - dtype: Some(to_type), - }, - } - } -} +use pyo3::{exceptions as exc, PyErr, PyErrArguments, PyObject, Python, ToPyObject}; -impl fmt::Display for ShapeError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let ShapeError { from, to } = self; - write!(f, "Shape Mismatch:\n from=({}), to=({})", from, to) - } -} +use crate::dtype::PyArrayDescr; macro_rules! impl_pyerr { ($err_type: ty) => { @@ -76,7 +24,57 @@ macro_rules! impl_pyerr { }; } -impl_pyerr!(ShapeError); +/// Represents that dimensionalities of the given arrays don't match. +#[derive(Debug)] +pub struct DimensionalityError { + from: usize, + to: usize, +} + +impl DimensionalityError { + pub(crate) fn new(from: usize, to: usize) -> Self { + Self { from, to } + } +} + +impl fmt::Display for DimensionalityError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let Self { from, to } = self; + write!(f, "dimensionality mismatch:\n from={}, to={}", from, to) + } +} + +impl_pyerr!(DimensionalityError); + +/// Represents that types of the given arrays don't match. +#[derive(Debug)] +pub struct TypeError { + from: String, + to: String, +} + +impl TypeError { + pub(crate) fn new(from: &PyArrayDescr, to: &PyArrayDescr) -> Self { + let dtype_to_str = |dtype: &PyArrayDescr| { + dtype + .str() + .map_or_else(|_| "(unknown)".into(), |s| s.to_string_lossy().into_owned()) + }; + Self { + from: dtype_to_str(from), + to: dtype_to_str(to), + } + } +} + +impl fmt::Display for TypeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let Self { from, to } = self; + write!(f, "type mismatch:\n from={}, to={}", from, to) + } +} + +impl_pyerr!(TypeError); /// Represents that given vec cannot be treated as array. #[derive(Debug)] diff --git a/src/lib.rs b/src/lib.rs index 3b7629415..5b3c77a6b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,8 +46,8 @@ pub use crate::array::{ PyArray6, PyArrayDyn, }; pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; -pub use crate::dtype::{c32, c64, DataType, Element, PyArrayDescr}; -pub use crate::error::{FromVecError, NotContiguousError, ShapeError}; +pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr}; +pub use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError}; pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API}; pub use crate::npyiter::{ IterMode, NpyIterFlag, NpyMultiIter, NpyMultiIterBuilder, NpySingleIter, NpySingleIterBuilder, diff --git a/tests/array.rs b/tests/array.rs index c6fd1f831..3997b9ea1 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -249,7 +249,7 @@ fn dtype_from_py() { .downcast() .unwrap(); assert_eq!(&format!("{:?}", dtype), "dtype('uint32')"); - assert_eq!(dtype.get_datatype().unwrap(), numpy::DataType::Uint32); + assert!(dtype.is_equiv_to(numpy::dtype::(py))); }) }