Skip to content

Commit f5dbf5a

Browse files
committed
Add permute and transpose methods for changing the order of axes of a PyArray
1 parent 91b1c4e commit f5dbf5a

File tree

3 files changed

+159
-31
lines changed

3 files changed

+159
-31
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Changelog
22

33
- Unreleased
4+
- Add `permute` and `transpose` methods for changing the order of axes of a `PyArray`. ([#428](https://github.com/PyO3/rust-numpy/pull/428))
45

56
- v0.21.0
67
- Migrate to the new `Bound` API introduced by PyO3 0.21. ([#410](https://github.com/PyO3/rust-numpy/pull/410)) ([#411](https://github.com/PyO3/rust-numpy/pull/411)) ([#412](https://github.com/PyO3/rust-numpy/pull/412)) ([#415](https://github.com/PyO3/rust-numpy/pull/415)) ([#416](https://github.com/PyO3/rust-numpy/pull/416)) ([#418](https://github.com/PyO3/rust-numpy/pull/418)) ([#419](https://github.com/PyO3/rust-numpy/pull/419)) ([#420](https://github.com/PyO3/rust-numpy/pull/420)) ([#421](https://github.com/PyO3/rust-numpy/pull/421)) ([#422](https://github.com/PyO3/rust-numpy/pull/422))

src/array.rs

+122-26
Original file line numberDiff line numberDiff line change
@@ -1336,8 +1336,45 @@ impl<T: Element, D> PyArray<T, D> {
13361336
self.as_borrowed().cast(is_fortran).map(Bound::into_gil_ref)
13371337
}
13381338

1339+
/// A view of `self` with a different order of axes determined by `axes`.
1340+
///
1341+
/// If `axes` is `None`, the order of axes is reversed which corresponds to the standard matix transpose.
1342+
///
1343+
/// See also [`numpy.transpose`][numpy-transpose] and [`PyArray_Transpose`][PyArray_Transpose].
1344+
///
1345+
/// # Example
1346+
///
1347+
/// ```
1348+
/// use numpy::prelude::*;
1349+
/// use numpy::PyArray;
1350+
/// use pyo3::Python;
1351+
/// use ndarray::array;
1352+
///
1353+
/// Python::with_gil(|py| {
1354+
/// let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray(py);
1355+
///
1356+
/// let array = array.permute(Some([1, 0])).unwrap();
1357+
///
1358+
/// assert_eq!(array.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);
1359+
/// });
1360+
/// ```
1361+
///
1362+
/// [numpy-transpose]: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html
1363+
/// [PyArray_Transpose]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Transpose
1364+
pub fn permute<'py, ID: IntoDimension>(
1365+
&'py self,
1366+
axes: Option<ID>,
1367+
) -> PyResult<&'py PyArray<T, D>> {
1368+
self.as_borrowed().permute(axes).map(Bound::into_gil_ref)
1369+
}
1370+
1371+
/// Special case of [`permute`][Self::permute] which reverses the order the axes.
1372+
pub fn transpose<'py>(&'py self) -> PyResult<&'py PyArray<T, D>> {
1373+
self.as_borrowed().transpose().map(Bound::into_gil_ref)
1374+
}
1375+
13391376
/// Construct a new array which has same values as self,
1340-
/// but has different dimensions specified by `dims`
1377+
/// but has different dimensions specified by `shape`
13411378
/// and a possibly different memory order specified by `order`.
13421379
///
13431380
/// See also [`numpy.reshape`][numpy-reshape] and [`PyArray_Newshape`][PyArray_Newshape].
@@ -1365,21 +1402,21 @@ impl<T: Element, D> PyArray<T, D> {
13651402
/// [PyArray_Newshape]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Newshape
13661403
pub fn reshape_with_order<'py, ID: IntoDimension>(
13671404
&'py self,
1368-
dims: ID,
1405+
shape: ID,
13691406
order: NPY_ORDER,
13701407
) -> PyResult<&'py PyArray<T, ID::Dim>> {
13711408
self.as_borrowed()
1372-
.reshape_with_order(dims, order)
1409+
.reshape_with_order(shape, order)
13731410
.map(Bound::into_gil_ref)
13741411
}
13751412

13761413
/// Special case of [`reshape_with_order`][Self::reshape_with_order] which keeps the memory order the same.
13771414
#[inline(always)]
13781415
pub fn reshape<'py, ID: IntoDimension>(
13791416
&'py self,
1380-
dims: ID,
1417+
shape: ID,
13811418
) -> PyResult<&'py PyArray<T, ID::Dim>> {
1382-
self.as_borrowed().reshape(dims).map(Bound::into_gil_ref)
1419+
self.as_borrowed().reshape(shape).map(Bound::into_gil_ref)
13831420
}
13841421

13851422
/// Extends or truncates the dimensions of an array.
@@ -1414,8 +1451,8 @@ impl<T: Element, D> PyArray<T, D> {
14141451
///
14151452
/// [ndarray-resize]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.resize.html
14161453
/// [PyArray_Resize]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Resize
1417-
pub unsafe fn resize<ID: IntoDimension>(&self, dims: ID) -> PyResult<()> {
1418-
self.as_borrowed().resize(dims)
1454+
pub unsafe fn resize<ID: IntoDimension>(&self, newshape: ID) -> PyResult<()> {
1455+
self.as_borrowed().resize(newshape)
14191456
}
14201457
}
14211458

@@ -1879,8 +1916,45 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
18791916
where
18801917
T: Element;
18811918

1882-
/// Construct a new array which has same values as self,
1883-
/// but has different dimensions specified by `dims`
1919+
/// A view of `self` with a different order of axes determined by `axes`.
1920+
///
1921+
/// If `axes` is `None`, the order of axes is reversed which corresponds to the standard matix transpose.
1922+
///
1923+
/// See also [`numpy.transpose`][numpy-transpose] and [`PyArray_Transpose`][PyArray_Transpose].
1924+
///
1925+
/// # Example
1926+
///
1927+
/// ```
1928+
/// use numpy::prelude::*;
1929+
/// use numpy::PyArray;
1930+
/// use pyo3::Python;
1931+
/// use ndarray::array;
1932+
///
1933+
/// Python::with_gil(|py| {
1934+
/// let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray_bound(py);
1935+
///
1936+
/// let array = array.permute(Some([1, 0])).unwrap();
1937+
///
1938+
/// assert_eq!(array.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);
1939+
/// });
1940+
/// ```
1941+
///
1942+
/// [numpy-transpose]: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html
1943+
/// [PyArray_Transpose]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Transpose
1944+
fn permute<ID: IntoDimension>(&self, axes: Option<ID>) -> PyResult<Bound<'py, PyArray<T, D>>>
1945+
where
1946+
T: Element;
1947+
1948+
/// Special case of [`permute`][Self::permute] which reverses the order the axes.
1949+
fn transpose(&self) -> PyResult<Bound<'py, PyArray<T, D>>>
1950+
where
1951+
T: Element,
1952+
{
1953+
self.permute::<()>(None)
1954+
}
1955+
1956+
/// Construct a new array which has same values as `self`,
1957+
/// but has different dimensions specified by `shape`
18841958
/// and a possibly different memory order specified by `order`.
18851959
///
18861960
/// See also [`numpy.reshape`][numpy-reshape] and [`PyArray_Newshape`][PyArray_Newshape].
@@ -1908,19 +1982,19 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
19081982
/// [PyArray_Newshape]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Newshape
19091983
fn reshape_with_order<ID: IntoDimension>(
19101984
&self,
1911-
dims: ID,
1985+
shape: ID,
19121986
order: NPY_ORDER,
19131987
) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
19141988
where
19151989
T: Element;
19161990

19171991
/// Special case of [`reshape_with_order`][Self::reshape_with_order] which keeps the memory order the same.
19181992
#[inline(always)]
1919-
fn reshape<ID: IntoDimension>(&self, dims: ID) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
1993+
fn reshape<ID: IntoDimension>(&self, shape: ID) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
19201994
where
19211995
T: Element,
19221996
{
1923-
self.reshape_with_order(dims, NPY_ORDER::NPY_ANYORDER)
1997+
self.reshape_with_order(shape, NPY_ORDER::NPY_ANYORDER)
19241998
}
19251999

19262000
/// Extends or truncates the dimensions of an array.
@@ -1955,7 +2029,7 @@ pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
19552029
///
19562030
/// [ndarray-resize]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.resize.html
19572031
/// [PyArray_Resize]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Resize
1958-
unsafe fn resize<ID: IntoDimension>(&self, dims: ID) -> PyResult<()>
2032+
unsafe fn resize<ID: IntoDimension>(&self, newshape: ID) -> PyResult<()>
19592033
where
19602034
T: Element;
19612035

@@ -2256,48 +2330,70 @@ impl<'py, T, D> PyArrayMethods<'py, T, D> for Bound<'py, PyArray<T, D>> {
22562330
}
22572331
}
22582332

2333+
fn permute<ID: IntoDimension>(&self, axes: Option<ID>) -> PyResult<Bound<'py, PyArray<T, D>>> {
2334+
let mut axes = axes.map(|axes| axes.into_dimension());
2335+
let mut axes = axes.as_mut().map(|axes| axes.to_npy_dims());
2336+
let axes = axes
2337+
.as_mut()
2338+
.map_or_else(ptr::null_mut, |axes| axes as *mut npyffi::PyArray_Dims);
2339+
2340+
let py = self.py();
2341+
let ptr = unsafe { PY_ARRAY_API.PyArray_Transpose(py, self.as_array_ptr(), axes) };
2342+
if !ptr.is_null() {
2343+
Ok(unsafe { Bound::from_owned_ptr(py, ptr).downcast_into_unchecked() })
2344+
} else {
2345+
Err(PyErr::fetch(py))
2346+
}
2347+
}
2348+
22592349
fn reshape_with_order<ID: IntoDimension>(
22602350
&self,
2261-
dims: ID,
2351+
shape: ID,
22622352
order: NPY_ORDER,
22632353
) -> PyResult<Bound<'py, PyArray<T, ID::Dim>>>
22642354
where
22652355
T: Element,
22662356
{
2267-
let mut dims = dims.into_dimension();
2268-
let mut dims = dims.to_npy_dims();
2357+
let mut shape = shape.into_dimension();
2358+
let mut shape = shape.to_npy_dims();
2359+
2360+
let py = self.py();
22692361
let ptr = unsafe {
22702362
PY_ARRAY_API.PyArray_Newshape(
2271-
self.py(),
2363+
py,
22722364
self.as_array_ptr(),
2273-
&mut dims as *mut npyffi::PyArray_Dims,
2365+
&mut shape as *mut npyffi::PyArray_Dims,
22742366
order,
22752367
)
22762368
};
2369+
22772370
if !ptr.is_null() {
2278-
Ok(unsafe { Bound::from_owned_ptr(self.py(), ptr).downcast_into_unchecked() })
2371+
Ok(unsafe { Bound::from_owned_ptr(py, ptr).downcast_into_unchecked() })
22792372
} else {
2280-
Err(PyErr::fetch(self.py()))
2373+
Err(PyErr::fetch(py))
22812374
}
22822375
}
22832376

2284-
unsafe fn resize<ID: IntoDimension>(&self, dims: ID) -> PyResult<()>
2377+
unsafe fn resize<ID: IntoDimension>(&self, newshape: ID) -> PyResult<()>
22852378
where
22862379
T: Element,
22872380
{
2288-
let mut dims = dims.into_dimension();
2289-
let mut dims = dims.to_npy_dims();
2381+
let mut newshape = newshape.into_dimension();
2382+
let mut newshape = newshape.to_npy_dims();
2383+
2384+
let py = self.py();
22902385
let res = PY_ARRAY_API.PyArray_Resize(
2291-
self.py(),
2386+
py,
22922387
self.as_array_ptr(),
2293-
&mut dims as *mut npyffi::PyArray_Dims,
2388+
&mut newshape as *mut npyffi::PyArray_Dims,
22942389
1,
22952390
NPY_ORDER::NPY_ANYORDER,
22962391
);
2392+
22972393
if !res.is_null() {
22982394
Ok(())
22992395
} else {
2300-
Err(PyErr::fetch(self.py()))
2396+
Err(PyErr::fetch(py))
23012397
}
23022398
}
23032399

tests/array.rs

+36-5
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@ use std::mem::size_of;
33
#[cfg(feature = "half")]
44
use half::{bf16, f16};
55
use ndarray::{array, s, Array1, Dim};
6+
use numpy::prelude::*;
67
use numpy::{
7-
array::{PyArray0Methods, PyArrayMethods},
8-
dtype_bound, get_array_module,
9-
npyffi::NPY_ORDER,
10-
pyarray_bound, PyArray, PyArray1, PyArray2, PyArrayDescr, PyArrayDescrMethods, PyArrayDyn,
11-
PyFixedString, PyFixedUnicode, PyUntypedArrayMethods, ToPyArray,
8+
dtype_bound, get_array_module, npyffi::NPY_ORDER, pyarray_bound, PyArray, PyArray1, PyArray2,
9+
PyArrayDescr, PyArrayDyn, PyFixedString, PyFixedUnicode,
1210
};
1311
use pyo3::{
1412
py_run, pyclass, pymethods,
@@ -522,6 +520,39 @@ fn get_works() {
522520
});
523521
}
524522

523+
#[test]
524+
fn permute_and_transpose() {
525+
Python::with_gil(|py| {
526+
let array = array![[0, 1, 2], [3, 4, 5]].into_pyarray_bound(py);
527+
528+
let permuted = array.permute(Some([1, 0])).unwrap();
529+
assert_eq!(
530+
permuted.readonly().as_array(),
531+
array![[0, 3], [1, 4], [2, 5]]
532+
);
533+
534+
let permuted = array.permute::<()>(None).unwrap();
535+
assert_eq!(
536+
permuted.readonly().as_array(),
537+
array![[0, 3], [1, 4], [2, 5]]
538+
);
539+
540+
let transposed = array.transpose().unwrap();
541+
assert_eq!(
542+
transposed.readonly().as_array(),
543+
array![[0, 3], [1, 4], [2, 5]]
544+
);
545+
546+
let array = pyarray_bound![py, [[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]];
547+
548+
let permuted = array.permute(Some([0, 2, 1])).unwrap();
549+
assert_eq!(
550+
permuted.readonly().as_array(),
551+
array![[[1, 3], [2, 4]], [[5, 7], [6, 8]], [[9, 11], [10, 12]]]
552+
);
553+
});
554+
}
555+
525556
#[test]
526557
fn reshape() {
527558
Python::with_gil(|py| {

0 commit comments

Comments
 (0)