Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional support for nalgebra types #347

Merged
merged 10 commits into from
Sep 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
- name: Test
run: |
pip install numpy
cargo test
cargo test --all-features
# Not on PyPy, because no embedding API
if: ${{ !startsWith(matrix.python-version, 'pypy') }}
- name: Test example
Expand Down Expand Up @@ -215,7 +215,7 @@ jobs:
- name: Install cargo-llvm-cov
uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
run: cargo llvm-cov --lcov --output-path coverage.lcov
run: cargo llvm-cov --all-features --lcov --output-path coverage.lcov
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
with:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

- Unreleased
- Add conversions from and to datatypes provided by the [`nalgebra` crate](https://nalgebra.org/). ([#347](https://github.com/PyO3/rust-numpy/pull/347))
- Drop our wrapper for NumPy iterators which were deprecated in v0.16.0 as ndarray's iteration facilities are almost always preferable. ([#324](https://github.com/PyO3/rust-numpy/pull/324))

- v0.17.2
Expand Down
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ license = "BSD-2-Clause"
ahash = "0.7"
half = { version = "1.8", default-features = false, optional = true }
libc = "0.2"
nalgebra = { version = "0.31", default-features = false, optional = true }
num-complex = ">= 0.2, < 0.5"
num-integer = "0.1"
num-traits = "0.2"
Expand All @@ -26,6 +27,7 @@ pyo3 = { version = "0.17", default-features = false, features = ["macros"] }

[dev-dependencies]
pyo3 = { version = "0.17", default-features = false, features = ["auto-initialize"] }
nalgebra = { version = "0.31", default-features = false, features = ["std"] }

[package.metadata.docs.rs]
all-features = true
110 changes: 108 additions & 2 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
///
/// # Safety
///
/// The existence of an exclusive reference to the internal data, e.g. `&mut [T]` or `ArrayViewMut`, implies undefined behavior.
/// Calling this method invalidates all exclusive references to the internal data, e.g. `&mut [T]` or `ArrayViewMut`.
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
self.as_view(|shape, ptr| ArrayView::from_shape_ptr(shape, ptr))
}
Expand All @@ -1002,7 +1002,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
///
/// # Safety
///
/// The existence of another reference to the internal data, e.g. `&[T]` or `ArrayView`, implies undefined behavior.
/// Calling this method invalidates all other references to the internal data, e.g. `ArrayView` or `ArrayViewMut`.
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
self.as_view(|shape, ptr| ArrayViewMut::from_shape_ptr(shape, ptr))
}
Expand Down Expand Up @@ -1040,6 +1040,112 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
}
}

#[cfg(feature = "nalgebra")]
impl<N, D> PyArray<N, D>
where
N: nalgebra::Scalar + Element,
D: Dimension,
{
fn try_as_matrix_shape_strides<R, C, RStride, CStride>(
&self,
) -> Option<((R, C), (RStride, CStride))>
where
R: nalgebra::Dim,
C: nalgebra::Dim,
RStride: nalgebra::Dim,
CStride: nalgebra::Dim,
{
let ndim = self.ndim();
let shape = self.shape();
let strides = self.strides();

if ndim != 1 && ndim != 2 {
return None;
}

if strides.iter().any(|strides| *strides < 0) {
return None;
}

let rows = shape[0];
let cols = *shape.get(1).unwrap_or(&1);

if R::try_to_usize().map(|expected| rows == expected) == Some(false) {
return None;
}

if C::try_to_usize().map(|expected| cols == expected) == Some(false) {
return None;
}

let row_stride = strides[0] as usize / mem::size_of::<N>();
let col_stride = strides
.get(1)
.map_or(rows, |stride| *stride as usize / mem::size_of::<N>());

if RStride::try_to_usize().map(|expected| row_stride == expected) == Some(false) {
return None;
}

if CStride::try_to_usize().map(|expected| col_stride == expected) == Some(false) {
return None;
}

let shape = (R::from_usize(rows), C::from_usize(cols));

let strides = (
RStride::from_usize(row_stride),
CStride::from_usize(col_stride),
);

Some((shape, strides))
}

/// Try to convert this array into a [`nalgebra::MatrixSlice`] using the given shape and strides.
///
/// # Safety
///
/// Calling this method invalidates all exclusive references to the internal data, e.g. `ArrayViewMut` or `MatrixSliceMut`.
#[doc(alias = "nalgebra")]
pub unsafe fn try_as_matrix<R, C, RStride, CStride>(
&self,
) -> Option<nalgebra::MatrixSlice<N, R, C, RStride, CStride>>
where
R: nalgebra::Dim,
C: nalgebra::Dim,
RStride: nalgebra::Dim,
CStride: nalgebra::Dim,
{
let (shape, strides) = self.try_as_matrix_shape_strides()?;

let storage = nalgebra::SliceStorage::from_raw_parts(self.data(), shape, strides);

Some(nalgebra::Matrix::from_data(storage))
}

/// Try to convert this array into a [`nalgebra::MatrixSliceMut`] using the given shape and strides.
///
/// # Safety
///
/// Calling this method invalidates all other references to the internal data, e.g. `ArrayView`, `MatrixSlice`, `ArrayViewMut` or `MatrixSliceMut`.
#[doc(alias = "nalgebra")]
pub unsafe fn try_as_matrix_mut<R, C, RStride, CStride>(
&self,
) -> Option<nalgebra::MatrixSliceMut<N, R, C, RStride, CStride>>
where
R: nalgebra::Dim,
C: nalgebra::Dim,
RStride: nalgebra::Dim,
CStride: nalgebra::Dim,
{
let (shape, strides) = self.try_as_matrix_shape_strides()?;

let storage = nalgebra::SliceStorageMut::from_raw_parts(self.data(), shape, strides);

Some(nalgebra::Matrix::from_data(storage))
}
}

impl<D: Dimension> PyArray<PyObject, D> {
/// Construct a NumPy array containing objects stored in a [`ndarray::Array`]
///
Expand Down
110 changes: 110 additions & 0 deletions src/borrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,59 @@ where
}
}

#[cfg(feature = "nalgebra")]
impl<'py, N, D> PyReadonlyArray<'py, N, D>
where
N: nalgebra::Scalar + Element,
D: Dimension,
{
/// Try to convert this array into a [`nalgebra::MatrixSlice`] using the given shape and strides.
#[doc(alias = "nalgebra")]
pub fn try_as_matrix<R, C, RStride, CStride>(
&self,
) -> Option<nalgebra::MatrixSlice<N, R, C, RStride, CStride>>
where
R: nalgebra::Dim,
C: nalgebra::Dim,
RStride: nalgebra::Dim,
CStride: nalgebra::Dim,
{
unsafe { self.array.try_as_matrix() }
}
}

#[cfg(feature = "nalgebra")]
impl<'py, N> PyReadonlyArray<'py, N, Ix1>
where
N: nalgebra::Scalar + Element,
{
/// Convert this one-dimensional array into a [`nalgebra::DMatrixSlice`] using dynamic strides.
///
/// # Panics
///
/// Panics if the array has negative strides.
#[doc(alias = "nalgebra")]
pub fn as_matrix(&self) -> nalgebra::DMatrixSlice<N, nalgebra::Dynamic, nalgebra::Dynamic> {
self.try_as_matrix().unwrap()
}
}

#[cfg(feature = "nalgebra")]
impl<'py, N> PyReadonlyArray<'py, N, Ix2>
where
N: nalgebra::Scalar + Element,
{
/// Convert this two-dimensional array into a [`nalgebra::DMatrixSlice`] using dynamic strides.
///
/// # Panics
///
/// Panics if the array has negative strides.
#[doc(alias = "nalgebra")]
pub fn as_matrix(&self) -> nalgebra::DMatrixSlice<N, nalgebra::Dynamic, nalgebra::Dynamic> {
self.try_as_matrix().unwrap()
}
}

impl<'a, T, D> Clone for PyReadonlyArray<'a, T, D>
where
T: Element,
Expand Down Expand Up @@ -622,6 +675,63 @@ where
}
}

#[cfg(feature = "nalgebra")]
impl<'py, N, D> PyReadwriteArray<'py, N, D>
where
N: nalgebra::Scalar + Element,
D: Dimension,
{
/// Try to convert this array into a [`nalgebra::MatrixSliceMut`] using the given shape and strides.
#[doc(alias = "nalgebra")]
pub fn try_as_matrix_mut<R, C, RStride, CStride>(
&self,
) -> Option<nalgebra::MatrixSliceMut<N, R, C, RStride, CStride>>
where
R: nalgebra::Dim,
C: nalgebra::Dim,
RStride: nalgebra::Dim,
CStride: nalgebra::Dim,
{
unsafe { self.array.try_as_matrix_mut() }
}
}

#[cfg(feature = "nalgebra")]
impl<'py, N> PyReadwriteArray<'py, N, Ix1>
where
N: nalgebra::Scalar + Element,
{
/// Convert this one-dimensional array into a [`nalgebra::DMatrixSliceMut`] using dynamic strides.
///
/// # Panics
///
/// Panics if the array has negative strides.
#[doc(alias = "nalgebra")]
pub fn as_matrix_mut(
&self,
) -> nalgebra::DMatrixSliceMut<N, nalgebra::Dynamic, nalgebra::Dynamic> {
self.try_as_matrix_mut().unwrap()
}
}

#[cfg(feature = "nalgebra")]
impl<'py, N> PyReadwriteArray<'py, N, Ix2>
where
N: nalgebra::Scalar + Element,
{
/// Convert this two-dimensional array into a [`nalgebra::DMatrixSliceMut`] using dynamic strides.
///
/// # Panics
///
/// Panics if the array has negative strides.
#[doc(alias = "nalgebra")]
pub fn as_matrix_mut(
&self,
) -> nalgebra::DMatrixSliceMut<N, nalgebra::Dynamic, nalgebra::Dynamic> {
self.try_as_matrix_mut().unwrap()
}
}

impl<'py, T> PyReadwriteArray<'py, T, Ix1>
where
T: Element,
Expand Down
32 changes: 32 additions & 0 deletions src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,38 @@ where
}
}

#[cfg(feature = "nalgebra")]
impl<N, R, C, S> ToPyArray for nalgebra::Matrix<N, R, C, S>
where
N: nalgebra::Scalar + Element,
R: nalgebra::Dim,
C: nalgebra::Dim,
S: nalgebra::Storage<N, R, C>,
{
type Item = N;
type Dim = crate::Ix2;

/// Note that the NumPy array always has Fortran memory layout
/// matching the [memory layout][memory-layout] used by [`nalgebra`].
///
/// [memory-layout]: https://nalgebra.org/docs/faq/#what-is-the-memory-layout-of-matrices
fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
unsafe {
let array = PyArray::<N, _>::new(py, (self.nrows(), self.ncols()), true);
let mut data_ptr = array.data();
if self.data.is_contiguous() {
ptr::copy_nonoverlapping(self.data.ptr(), data_ptr, self.len());
} else {
for item in self.iter() {
data_ptr.write(item.clone());
data_ptr = data_ptr.add(1);
}
}
array
}
}
}

pub(crate) trait ArrayExt {
fn npy_strides(&self) -> [npyffi::npy_intp; 32];
fn order(&self) -> Option<c_int>;
Expand Down
43 changes: 42 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
//! Loading NumPy is done automatically and on demand. So if it is not installed, the functions
//! provided by this crate will panic instead of returning a result.
//!
#![cfg_attr(
feature = "nalgebra",
doc = "Integration with [`nalgebra`] is rovided via an implementation of [`ToPyArray`] for [`nalgebra::Matrix`] to convert nalgebra matrices into NumPy arrays
as well as the [`PyReadonlyArray::try_as_matrix`] and [`PyReadwriteArray::try_as_matrix_mut`] methods to treat NumPy array as nalgebra matrix slices.
"
)]
//! # Example
//!
//! ```
Expand All @@ -26,7 +32,39 @@
//! py_array.readonly().as_array(),
//! array![[1i64, 2], [3, 4]]
//! );
//! })
//! });
//! ```
//!
#![cfg_attr(feature = "nalgebra", doc = "```")]
#![cfg_attr(not(feature = "nalgebra"), doc = "```rust,ignore")]
//! use numpy::pyo3::Python;
//! use numpy::nalgebra::Matrix3;
//! use numpy::{pyarray, ToPyArray};
//!
//! Python::with_gil(|py| {
//! let py_array = pyarray![py, [0, 1, 2], [3, 4, 5], [6, 7, 8]];
//!
//! let py_array_square;
//!
//! {
//! let py_array = py_array.readwrite();
//! let mut na_matrix = py_array.as_matrix_mut();
//!
//! na_matrix.add_scalar_mut(1);
//!
//! py_array_square = na_matrix.pow(2).to_pyarray(py);
//! }
//!
//! assert_eq!(
//! py_array.readonly().as_matrix(),
//! Matrix3::new(1, 2, 3, 4, 5, 6, 7, 8, 9)
//! );
//!
//! assert_eq!(
//! py_array_square.readonly().as_matrix(),
//! Matrix3::new(30, 36, 42, 66, 81, 96, 102, 126, 150)
//! );
//! });
//! ```
//!
//! [c-api]: https://numpy.org/doc/stable/reference/c-api
Expand All @@ -49,6 +87,9 @@ mod sum_products;
pub use ndarray;
pub use pyo3;

#[cfg(feature = "nalgebra")]
pub use nalgebra;

pub use crate::array::{
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
PyArray6, PyArrayDyn,
Expand Down
Loading