Skip to content

Commit 836c345

Browse files
committed
Support producing nalgebra matrix slices pointing into NumPy arrays.
1 parent dbb4b30 commit 836c345

File tree

4 files changed

+275
-1
lines changed

4 files changed

+275
-1
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Changelog
22

33
- Unreleased
4-
- Add conversions from datatypes provided by the [`nalgebra` crate](https://nalgebra.org/). ([#347](https://github.com/PyO3/rust-numpy/pull/347))
4+
- Add conversions from and to datatypes provided by the [`nalgebra` crate](https://nalgebra.org/). ([#347](https://github.com/PyO3/rust-numpy/pull/347))
55
- Drop our wrapper for NumPy iterators which were deprecated in v0.16.0 as ndarray's iteration facilities are almost always preferable. ([#324](https://github.com/PyO3/rust-numpy/pull/324))
66

77
- v0.17.1

src/array.rs

+104
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,110 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
10361036
}
10371037
}
10381038

1039+
#[cfg(feature = "nalgebra")]
1040+
impl<N, D> PyArray<N, D>
1041+
where
1042+
N: nalgebra::Scalar + Element,
1043+
D: Dimension,
1044+
{
1045+
fn try_as_matrix_shape_strides<R, C, RStride, CStride>(
1046+
&self,
1047+
) -> Option<((R, C), (RStride, CStride))>
1048+
where
1049+
R: nalgebra::Dim,
1050+
C: nalgebra::Dim,
1051+
RStride: nalgebra::Dim,
1052+
CStride: nalgebra::Dim,
1053+
{
1054+
let ndim = self.ndim();
1055+
let shape = self.shape();
1056+
let strides = self.strides();
1057+
1058+
if ndim != 1 && ndim != 2 {
1059+
return None;
1060+
}
1061+
1062+
if strides.iter().any(|strides| *strides < 0) {
1063+
return None;
1064+
}
1065+
1066+
let rows = shape[0];
1067+
let cols = *shape.get(1).unwrap_or(&1);
1068+
1069+
if R::try_to_usize().map(|expected| rows == expected) == Some(false) {
1070+
return None;
1071+
}
1072+
1073+
if C::try_to_usize().map(|expected| cols == expected) == Some(false) {
1074+
return None;
1075+
}
1076+
1077+
let row_stride = strides[0] as usize / mem::size_of::<N>();
1078+
let col_stride = strides
1079+
.get(1)
1080+
.map_or(rows, |stride| *stride as usize / mem::size_of::<N>());
1081+
1082+
if RStride::try_to_usize().map(|expected| row_stride == expected) == Some(false) {
1083+
return None;
1084+
}
1085+
1086+
if CStride::try_to_usize().map(|expected| col_stride == expected) == Some(false) {
1087+
return None;
1088+
}
1089+
1090+
let shape = (R::from_usize(rows), C::from_usize(cols));
1091+
1092+
let strides = (
1093+
RStride::from_usize(row_stride),
1094+
CStride::from_usize(col_stride),
1095+
);
1096+
1097+
Some((shape, strides))
1098+
}
1099+
1100+
/// Try to convert this array into a [`nalgebra::MatrixSlice`] using the given shape and strides.
1101+
///
1102+
/// # Safety
1103+
///
1104+
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
1105+
pub unsafe fn try_as_matrix<R, C, RStride, CStride>(
1106+
&self,
1107+
) -> Option<nalgebra::MatrixSlice<N, R, C, RStride, CStride>>
1108+
where
1109+
R: nalgebra::Dim,
1110+
C: nalgebra::Dim,
1111+
RStride: nalgebra::Dim,
1112+
CStride: nalgebra::Dim,
1113+
{
1114+
let (shape, strides) = self.try_as_matrix_shape_strides()?;
1115+
1116+
let storage = nalgebra::SliceStorage::from_raw_parts(self.data(), shape, strides);
1117+
1118+
Some(nalgebra::Matrix::from_data(storage))
1119+
}
1120+
1121+
/// Try to convert this array into a [`nalgebra::MatrixSliceMut`] using the given shape and strides.
1122+
///
1123+
/// # Safety
1124+
///
1125+
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
1126+
pub unsafe fn try_as_matrix_mut<R, C, RStride, CStride>(
1127+
&self,
1128+
) -> Option<nalgebra::MatrixSliceMut<N, R, C, RStride, CStride>>
1129+
where
1130+
R: nalgebra::Dim,
1131+
C: nalgebra::Dim,
1132+
RStride: nalgebra::Dim,
1133+
CStride: nalgebra::Dim,
1134+
{
1135+
let (shape, strides) = self.try_as_matrix_shape_strides()?;
1136+
1137+
let storage = nalgebra::SliceStorageMut::from_raw_parts(self.data(), shape, strides);
1138+
1139+
Some(nalgebra::Matrix::from_data(storage))
1140+
}
1141+
}
1142+
10391143
impl<D: Dimension> PyArray<PyObject, D> {
10401144
/// Construct a NumPy array containing objects stored in a [`ndarray::Array`]
10411145
///

src/borrow.rs

+104
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,56 @@ where
474474
}
475475
}
476476

477+
#[cfg(feature = "nalgebra")]
478+
impl<'py, N, D> PyReadonlyArray<'py, N, D>
479+
where
480+
N: nalgebra::Scalar + Element,
481+
D: Dimension,
482+
{
483+
/// Try to convert this array into a [`nalgebra::MatrixSlice`] using the given shape and strides.
484+
pub fn try_as_matrix<R, C, RStride, CStride>(
485+
&self,
486+
) -> Option<nalgebra::MatrixSlice<N, R, C, RStride, CStride>>
487+
where
488+
R: nalgebra::Dim,
489+
C: nalgebra::Dim,
490+
RStride: nalgebra::Dim,
491+
CStride: nalgebra::Dim,
492+
{
493+
unsafe { self.array.try_as_matrix() }
494+
}
495+
}
496+
497+
#[cfg(feature = "nalgebra")]
498+
impl<'py, N> PyReadonlyArray<'py, N, Ix1>
499+
where
500+
N: nalgebra::Scalar + Element,
501+
{
502+
/// Convert this one-dimensional array into a [`nalgebra::DMatrixSlice`] using dynamic strides.
503+
///
504+
/// # Panics
505+
///
506+
/// Panics if the array has negative strides.
507+
pub fn as_matrix(&self) -> nalgebra::DMatrixSlice<N, nalgebra::Dynamic, nalgebra::Dynamic> {
508+
self.try_as_matrix().unwrap()
509+
}
510+
}
511+
512+
#[cfg(feature = "nalgebra")]
513+
impl<'py, N> PyReadonlyArray<'py, N, Ix2>
514+
where
515+
N: nalgebra::Scalar + Element,
516+
{
517+
/// Convert this two-dimensional array into a [`nalgebra::DMatrixSlice`] using dynamic strides.
518+
///
519+
/// # Panics
520+
///
521+
/// Panics if the array has negative strides.
522+
pub fn as_matrix(&self) -> nalgebra::DMatrixSlice<N, nalgebra::Dynamic, nalgebra::Dynamic> {
523+
self.try_as_matrix().unwrap()
524+
}
525+
}
526+
477527
impl<'a, T, D> Clone for PyReadonlyArray<'a, T, D>
478528
where
479529
T: Element,
@@ -622,6 +672,60 @@ where
622672
}
623673
}
624674

675+
#[cfg(feature = "nalgebra")]
676+
impl<'py, N, D> PyReadwriteArray<'py, N, D>
677+
where
678+
N: nalgebra::Scalar + Element,
679+
D: Dimension,
680+
{
681+
/// Try to convert this array into a [`nalgebra::MatrixSliceMut`] using the given shape and strides.
682+
pub fn try_as_matrix_mut<R, C, RStride, CStride>(
683+
&self,
684+
) -> Option<nalgebra::MatrixSliceMut<N, R, C, RStride, CStride>>
685+
where
686+
R: nalgebra::Dim,
687+
C: nalgebra::Dim,
688+
RStride: nalgebra::Dim,
689+
CStride: nalgebra::Dim,
690+
{
691+
unsafe { self.array.try_as_matrix_mut() }
692+
}
693+
}
694+
695+
#[cfg(feature = "nalgebra")]
696+
impl<'py, N> PyReadwriteArray<'py, N, Ix1>
697+
where
698+
N: nalgebra::Scalar + Element,
699+
{
700+
/// Convert this one-dimensional array into a [`nalgebra::DMatrixSliceMut`] using dynamic strides.
701+
///
702+
/// # Panics
703+
///
704+
/// Panics if the array has negative strides.
705+
pub fn as_matrix_mut(
706+
&self,
707+
) -> nalgebra::DMatrixSliceMut<N, nalgebra::Dynamic, nalgebra::Dynamic> {
708+
self.try_as_matrix_mut().unwrap()
709+
}
710+
}
711+
712+
#[cfg(feature = "nalgebra")]
713+
impl<'py, N> PyReadwriteArray<'py, N, Ix2>
714+
where
715+
N: nalgebra::Scalar + Element,
716+
{
717+
/// Convert this two-dimensional array into a [`nalgebra::DMatrixSliceMut`] using dynamic strides.
718+
///
719+
/// # Panics
720+
///
721+
/// Panics if the array has negative strides.
722+
pub fn as_matrix_mut(
723+
&self,
724+
) -> nalgebra::DMatrixSliceMut<N, nalgebra::Dynamic, nalgebra::Dynamic> {
725+
self.try_as_matrix_mut().unwrap()
726+
}
727+
}
728+
625729
impl<'py, T> PyReadwriteArray<'py, T, Ix1>
626730
where
627731
T: Element,

tests/borrow.rs

+66
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,69 @@ fn resize_using_exclusive_borrow() {
342342
assert_eq!(array.as_slice_mut().unwrap(), &[0.0; 5]);
343343
});
344344
}
345+
346+
#[cfg(feature = "nalgebra")]
347+
#[test]
348+
fn matrix_from_numpy() {
349+
Python::with_gil(|py| {
350+
let array = numpy::pyarray![py, [0, 1, 2], [3, 4, 5], [6, 7, 8]];
351+
352+
{
353+
let array = array.readonly();
354+
355+
let matrix = array.as_matrix();
356+
assert_eq!(matrix, nalgebra::Matrix3::new(0, 1, 2, 3, 4, 5, 6, 7, 8));
357+
358+
let matrix: nalgebra::MatrixSlice<
359+
i32,
360+
nalgebra::Const<3>,
361+
nalgebra::Const<3>,
362+
nalgebra::Const<3>,
363+
nalgebra::Const<1>,
364+
> = array.try_as_matrix().unwrap();
365+
assert_eq!(matrix, nalgebra::Matrix3::new(0, 1, 2, 3, 4, 5, 6, 7, 8));
366+
}
367+
368+
{
369+
let array = array.readwrite();
370+
371+
let matrix = array.as_matrix_mut();
372+
assert_eq!(matrix, nalgebra::Matrix3::new(0, 1, 2, 3, 4, 5, 6, 7, 8));
373+
374+
let matrix: nalgebra::MatrixSliceMut<
375+
i32,
376+
nalgebra::Const<3>,
377+
nalgebra::Const<3>,
378+
nalgebra::Const<3>,
379+
nalgebra::Const<1>,
380+
> = array.try_as_matrix_mut().unwrap();
381+
assert_eq!(matrix, nalgebra::Matrix3::new(0, 1, 2, 3, 4, 5, 6, 7, 8));
382+
}
383+
});
384+
385+
Python::with_gil(|py| {
386+
let array = numpy::pyarray![py, 0, 1, 2];
387+
388+
{
389+
let array = array.readonly();
390+
391+
let matrix = array.as_matrix();
392+
assert_eq!(matrix, nalgebra::Matrix3x1::new(0, 1, 2));
393+
394+
let matrix: nalgebra::MatrixSlice<i32, nalgebra::Const<3>, nalgebra::Const<1>> =
395+
array.try_as_matrix().unwrap();
396+
assert_eq!(matrix, nalgebra::Matrix3x1::new(0, 1, 2));
397+
}
398+
399+
{
400+
let array = array.readwrite();
401+
402+
let matrix = array.as_matrix_mut();
403+
assert_eq!(matrix, nalgebra::Matrix3x1::new(0, 1, 2));
404+
405+
let matrix: nalgebra::MatrixSliceMut<i32, nalgebra::Const<3>, nalgebra::Const<1>> =
406+
array.try_as_matrix_mut().unwrap();
407+
assert_eq!(matrix, nalgebra::Matrix3x1::new(0, 1, 2));
408+
}
409+
});
410+
}

0 commit comments

Comments
 (0)