Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0c678d9

Browse files
committedJan 10, 2022
Add dtype() top-level function for convenience
1 parent 37b4596 commit 0c678d9

File tree

4 files changed

+12
-7
lines changed

4 files changed

+12
-7
lines changed
 

‎src/array.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ impl<T, D> PyArray<T, D> {
168168
/// pyo3::Python::with_gil(|py| {
169169
/// let array = numpy::PyArray::from_vec(py, vec![1, 2, 3i32]);
170170
/// let dtype = array.dtype();
171-
/// assert!(dtype.is_equiv_to(numpy::PyArrayDescr::of::<i32>(py)));
171+
/// assert!(dtype.is_equiv_to(numpy::dtype::<i32>(py)));
172172
/// });
173173
/// ```
174174
pub fn dtype(&self) -> &crate::PyArrayDescr {

‎src/dtype.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pub use num_complex::{Complex32, Complex64};
2020
/// .unwrap()
2121
/// .downcast()
2222
/// .unwrap();
23-
/// assert!(dtype.is_equiv_to(numpy::PyArrayDescr::of::<f64>(py)));
23+
/// assert!(dtype.is_equiv_to(numpy::dtype::<f64>(py)));
2424
/// });
2525
/// ```
2626
pub struct PyArrayDescr(PyAny);
@@ -39,6 +39,11 @@ unsafe fn arraydescr_check(op: *mut ffi::PyObject) -> c_int {
3939
)
4040
}
4141

42+
/// Returns the type descriptor ("dtype") for a registered type.
43+
pub fn dtype<T: Element>(py: Python) -> &PyArrayDescr {
44+
T::get_dtype(py)
45+
}
46+
4247
impl PyArrayDescr {
4348
/// Returns `self` as `*mut PyArray_Descr`.
4449
pub fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
@@ -72,7 +77,7 @@ impl PyArrayDescr {
7277
Self::from_npy_type(py, NPY_TYPES::NPY_OBJECT)
7378
}
7479

75-
/// Returns the type descriptor for a registered type.
80+
/// Returns the type descriptor ("dtype") for a registered type.
7681
pub fn of<T: Element>(py: Python) -> &Self {
7782
T::get_dtype(py)
7883
}
@@ -254,12 +259,12 @@ unsafe impl Element for PyObject {
254259
mod tests {
255260
use std::mem::size_of;
256261

257-
use super::{Complex32, Complex64, Element, PyArrayDescr};
262+
use super::{dtype, Complex32, Complex64, Element};
258263

259264
#[test]
260265
fn test_dtype_names() {
261266
fn type_name<T: Element>(py: pyo3::Python) -> &str {
262-
PyArrayDescr::of::<T>(py).get_type().name().unwrap()
267+
dtype::<T>(py).get_type().name().unwrap()
263268
}
264269
pyo3::Python::with_gil(|py| {
265270
assert_eq!(type_name::<bool>(py), "bool_");

‎src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ pub use crate::array::{
4646
PyArray6, PyArrayDyn,
4747
};
4848
pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
49-
pub use crate::dtype::{Complex32, Complex64, Element, PyArrayDescr};
49+
pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr};
5050
pub use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
5151
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
5252
pub use crate::npyiter::{

‎tests/array.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ fn dtype_from_py() {
249249
.downcast()
250250
.unwrap();
251251
assert_eq!(&format!("{:?}", dtype), "dtype('uint32')");
252-
assert!(dtype.is_equiv_to(numpy::PyArrayDescr::of::<u32>(py)));
252+
assert!(dtype.is_equiv_to(numpy::dtype::<u32>(py)));
253253
})
254254
}
255255

0 commit comments

Comments
 (0)
Please sign in to comment.