Skip to content

Commit 7cc945c

Browse files
authored
Merge pull request #265 from PyO3/example-downcast
Fix type confusion during downcastsing and add a test case showing how to extract an array from a dictionary.
2 parents ad49760 + c8390e3 commit 7cc945c

File tree

6 files changed

+77
-20
lines changed

6 files changed

+77
-20
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
- Unreleased
44
- Support object arrays ([#216](https://github.com/PyO3/rust-numpy/pull/216))
55
- Support borrowing arrays that are part of other Python objects via `PyArray::borrow_from_array` ([#230](https://github.com/PyO3/rust-numpy/pull/216))
6+
- Fixed downcasting ignoring element type and dimensionality ([#265](https://github.com/PyO3/rust-numpy/pull/265))
67
- `PyArray::new` is now `unsafe`, as it produces uninitialized arrays ([#220](https://github.com/PyO3/rust-numpy/pull/220))
78
- `PyArray::from_exact_iter` does not unsoundly trust `ExactSizeIterator::len` any more ([#262](https://github.com/PyO3/rust-numpy/pull/262))
89
- `PyArray::as_cell_slice` was removed as it unsoundly interacts with `PyReadonlyArray` allowing safe code to violate aliasing rules ([#260](https://github.com/PyO3/rust-numpy/pull/260))

examples/simple-extension/src/lib.rs

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
2-
use numpy::{Complex64, IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn};
3-
use pyo3::{pymodule, types::PyModule, PyResult, Python};
2+
use numpy::{Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn};
3+
use pyo3::{
4+
pymodule,
5+
types::{PyDict, PyModule},
6+
PyResult, Python,
7+
};
48

59
#[pymodule]
610
fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
@@ -52,5 +56,17 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
5256
conj(x.as_array()).into_pyarray(py)
5357
}
5458

59+
#[pyfn(m)]
60+
#[pyo3(name = "extract")]
61+
fn extract(d: &PyDict) -> f64 {
62+
let x = d
63+
.get_item("x")
64+
.unwrap()
65+
.downcast::<PyArray1<f64>>()
66+
.unwrap();
67+
68+
x.readonly().as_array().sum()
69+
}
70+
5571
Ok(())
5672
}

examples/simple-extension/tests/test_ext.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from rust_ext import axpy, conj, mult
2+
from rust_ext import axpy, conj, mult, extract
33

44

55
def test_axpy():
@@ -22,3 +22,9 @@ def test_mult():
2222
def test_conj():
2323
x = np.array([1.0 + 2j, 2.0 + 3j, 3.0 + 4j])
2424
np.testing.assert_array_almost_equal(conj(x), np.conj(x))
25+
26+
27+
def test_extract():
28+
x = np.arange(5.0)
29+
d = { "x": x }
30+
np.testing.assert_almost_equal(extract(d), 10.0)

src/array.rs

+22-14
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ use ndarray::{
1313
};
1414
use num_traits::AsPrimitive;
1515
use pyo3::{
16-
ffi, pyobject_native_type_info, pyobject_native_type_named, type_object, types::PyModule,
17-
AsPyPointer, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyErr, PyNativeType, PyObject,
18-
PyResult, Python, ToPyObject,
16+
ffi, pyobject_native_type_named, type_object, types::PyModule, AsPyPointer, FromPyObject,
17+
IntoPy, Py, PyAny, PyDowncastError, PyErr, PyNativeType, PyObject, PyResult, PyTypeInfo,
18+
Python, ToPyObject,
1919
};
2020

2121
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
@@ -110,16 +110,24 @@ pub fn get_array_module(py: Python<'_>) -> PyResult<&PyModule> {
110110
}
111111

112112
unsafe impl<T, D> type_object::PyLayout<PyArray<T, D>> for npyffi::PyArrayObject {}
113+
113114
impl<T, D> type_object::PySizedLayout<PyArray<T, D>> for npyffi::PyArrayObject {}
114115

115-
pyobject_native_type_info!(
116-
PyArray<T, D>,
117-
*npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
118-
Some("numpy"),
119-
#checkfunction=npyffi::PyArray_Check
120-
; T
121-
; D
122-
);
116+
unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
117+
type AsRefTarget = Self;
118+
119+
const NAME: &'static str = "PyArray<T, D>";
120+
const MODULE: ::std::option::Option<&'static str> = Some("numpy");
121+
122+
#[inline]
123+
fn type_object_raw(_py: Python) -> *mut ffi::PyTypeObject {
124+
unsafe { npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type) }
125+
}
126+
127+
fn is_type_of(ob: &PyAny) -> bool {
128+
<&Self>::extract(ob).is_ok()
129+
}
130+
}
123131

124132
pyobject_native_type_named!(PyArray<T, D> ; T ; D);
125133

@@ -129,12 +137,12 @@ impl<T, D> IntoPy<PyObject> for PyArray<T, D> {
129137
}
130138
}
131139

132-
impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
140+
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for &'py PyArray<T, D> {
133141
// here we do type-check three times
134142
// 1. Checks if the object is PyArray
135143
// 2. Checks if the data type of the array is T
136144
// 3. Checks if the dimension is same as D
137-
fn extract(ob: &'a PyAny) -> PyResult<Self> {
145+
fn extract(ob: &'py PyAny) -> PyResult<Self> {
138146
let array = unsafe {
139147
if npyffi::PyArray_Check(ob.as_ptr()) == 0 {
140148
return Err(PyDowncastError::new(ob, "PyArray<T, D>").into());
@@ -207,7 +215,7 @@ impl<T, D> PyArray<T, D> {
207215
/// assert!(array.is_contiguous());
208216
/// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py);
209217
/// let not_contiguous: &numpy::PyArray1<f32> = py
210-
/// .eval("np.zeros((3, 5))[::2, 4]", Some(locals), None)
218+
/// .eval("np.zeros((3, 5), dtype='float32')[::2, 4]", Some(locals), None)
211219
/// .unwrap()
212220
/// .downcast()
213221
/// .unwrap();

src/readonly.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ impl<'py, T: Element, D: Dimension> PyReadonlyArray<'py, T, D> {
6666
/// assert_eq!(readonly.as_slice().unwrap(), &[0, 1, 2, 3]);
6767
/// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py);
6868
/// let not_contiguous: &PyArray1<i32> = py
69-
/// .eval("np.arange(10)[::2]", Some(locals), None)
69+
/// .eval("np.arange(10, dtype='int32')[::2]", Some(locals), None)
7070
/// .unwrap()
7171
/// .downcast()
7272
/// .unwrap();

tests/array.rs

+28-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ use ndarray::*;
22
use numpy::*;
33
use pyo3::{
44
prelude::*,
5-
types::PyList,
6-
types::{IntoPyDict, PyDict},
5+
types::{IntoPyDict, PyDict, PyList},
76
};
87

98
fn get_np_locals(py: Python) -> &PyDict {
@@ -300,3 +299,30 @@ fn borrow_from_array() {
300299
py_run!(py, array, "assert array.shape == (10,)");
301300
});
302301
}
302+
303+
#[test]
304+
fn downcasting_works() {
305+
Python::with_gil(|py| {
306+
let ob: &PyAny = PyArray::from_slice(py, &[1_i32, 2, 3]);
307+
308+
assert!(ob.downcast::<PyArray1<i32>>().is_ok());
309+
})
310+
}
311+
312+
#[test]
313+
fn downcasting_respects_element_type() {
314+
Python::with_gil(|py| {
315+
let ob: &PyAny = PyArray::from_slice(py, &[1_i32, 2, 3]);
316+
317+
assert!(ob.downcast::<PyArray1<f64>>().is_err());
318+
})
319+
}
320+
321+
#[test]
322+
fn downcasting_respects_dimensionality() {
323+
Python::with_gil(|py| {
324+
let ob: &PyAny = PyArray::from_slice(py, &[1_i32, 2, 3]);
325+
326+
assert!(ob.downcast::<PyArray2<i32>>().is_err());
327+
})
328+
}

0 commit comments

Comments
 (0)