Skip to content

Commit e75190b

Browse files
committed
Fix returning invalid strides and dimensions for rank zero arrays.
1 parent 9ec102d commit e75190b

File tree

4 files changed

+24
-0
lines changed

4 files changed

+24
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Changelog
22

33
- Unreleased
4+
- Fix returning invalid slices from `PyArray::{strides,shape}` for rank zero arrays. ([#???](https://github.com/PyO3/rust-numpy/pull/???))
45

56
- v0.16.2
67
- Fix build on platforms where `c_char` is `u8` like Linux/AArch64. ([#296](https://github.com/PyO3/rust-numpy/pull/296))

src/array.rs

+9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use pyo3::{
1919
Python, ToPyObject,
2020
};
2121

22+
use crate::cold;
2223
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
2324
use crate::dtype::{Element, PyArrayDescr};
2425
use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
@@ -314,6 +315,10 @@ impl<T, D> PyArray<T, D> {
314315
// C API: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_STRIDES
315316
pub fn strides(&self) -> &[isize] {
316317
let n = self.ndim();
318+
if n == 0 {
319+
cold();
320+
return &[];
321+
}
317322
let ptr = self.as_array_ptr();
318323
unsafe {
319324
let p = (*ptr).strides;
@@ -335,6 +340,10 @@ impl<T, D> PyArray<T, D> {
335340
// C API: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DIMS
336341
pub fn shape(&self) -> &[usize] {
337342
let n = self.ndim();
343+
if n == 0 {
344+
cold();
345+
return &[];
346+
}
338347
let ptr = self.as_array_ptr();
339348
unsafe {
340349
let p = (*ptr).dimensions as *mut usize;

src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ mod doctest {
7878
doc_comment!(include_str!("../README.md"), readme);
7979
}
8080

81+
#[cold]
82+
fn cold() {}
83+
8184
/// Create a [`PyArray`] with one, two or three dimensions.
8285
///
8386
/// This macro is backed by [`ndarray::array`].

tests/array.rs

+11
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ fn tuple_as_dim() {
6868
});
6969
}
7070

71+
#[test]
72+
fn rank_zero_array_has_invalid_strides_dimensions() {
73+
Python::with_gil(|py| {
74+
let arr = PyArray::<f64, _>::zeros(py, (), false);
75+
76+
assert_eq!(arr.ndim(), 0);
77+
assert_eq!(arr.strides(), &[]);
78+
assert_eq!(arr.shape(), &[]);
79+
})
80+
}
81+
7182
#[test]
7283
fn zeros() {
7384
Python::with_gil(|py| {

0 commit comments

Comments
 (0)