Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend our type signatures of inner and dot to match NumPy's types. #285

Merged
merged 3 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- Unreleased
- Add dynamic borrow checking to safely construct references into the interior of NumPy arrays. ([#274](https://github.com/PyO3/rust-numpy/pull/274))
- The `inner`, `dot` and `einsum` functions can also return a scalar instead of a zero-dimensional array to match NumPy's types ([#285](https://github.com/PyO3/rust-numpy/pull/285))
- Deprecate `PyArray::from_exact_iter` after optimizing `PyArray::from_iter`. ([#292](https://github.com/PyO3/rust-numpy/pull/292))

- v0.16.2
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
pub use crate::npyiter::{
IterMode, NpyIterFlag, NpyMultiIter, NpyMultiIterBuilder, NpySingleIter, NpySingleIterBuilder,
};
pub use crate::sum_products::{dot, einsum_impl, inner};
pub use crate::sum_products::{dot, einsum, inner};

pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};

#[cfg(doctest)]
Expand Down
164 changes: 118 additions & 46 deletions src/sum_products.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,68 @@
use crate::npyffi::{NPY_CASTING, NPY_ORDER};
use crate::{Element, PyArray, PY_ARRAY_API};
use std::borrow::Cow;
use std::ffi::{CStr, CString};
use std::ptr::null_mut;

use ndarray::{Dimension, IxDyn};
use pyo3::{AsPyPointer, FromPyPointer, PyAny, PyNativeType, PyResult};
use std::ffi::CStr;
use pyo3::{AsPyPointer, FromPyObject, FromPyPointer, PyAny, PyNativeType, PyResult};

use crate::array::PyArray;
use crate::dtype::Element;
use crate::npyffi::{array::PY_ARRAY_API, NPY_CASTING, NPY_ORDER};

/// Return value of a function that can yield either an array or a scalar.
pub trait ArrayOrScalar<'py, T>: FromPyObject<'py> {}

impl<'py, T, D> ArrayOrScalar<'py, T> for &'py PyArray<T, D>
where
T: Element,
D: Dimension,
{
}

impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}

/// Return the inner product of two arrays.
///
/// # Example
/// [NumPy's documentation][inner] has the details.
///
/// # Examples
///
/// Note that this function can either return a scalar...
///
/// ```
/// pyo3::Python::with_gil(|py| {
/// let array = numpy::pyarray![py, 1, 2, 3];
/// let inner: &numpy::PyArray0::<_> = numpy::inner(array, array).unwrap();
/// assert_eq!(inner.item(), 14);
/// use pyo3::Python;
/// use numpy::{inner, pyarray, PyArray0};
///
/// Python::with_gil(|py| {
/// let vector = pyarray![py, 1.0, 2.0, 3.0];
/// let result: f64 = inner(vector, vector).unwrap();
/// assert_eq!(result, 14.0);
/// });
/// ```
pub fn inner<'py, T, DIN1, DIN2, DOUT>(
///
/// ...or an array depending on its arguments.
///
/// ```
/// use pyo3::Python;
/// use numpy::{inner, pyarray, PyArray0};
///
/// Python::with_gil(|py| {
/// let vector = pyarray![py, 1, 2, 3];
/// let result: &PyArray0<_> = inner(vector, vector).unwrap();
/// assert_eq!(result.item(), 14);
/// });
/// ```
///
/// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
pub fn inner<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
) -> PyResult<&'py PyArray<T, DOUT>>
) -> PyResult<OUT>
where
T: Element,
DIN1: Dimension,
DIN2: Dimension,
DOUT: Dimension,
T: Element,
OUT: ArrayOrScalar<'py, T>,
{
let py = array1.py();
let obj = unsafe {
Expand All @@ -34,27 +74,53 @@ where

/// Return the dot product of two arrays.
///
/// # Example
/// [NumPy's documentation][dot] has the details.
///
/// # Examples
///
/// Note that this function can either return an array...
///
/// ```
/// pyo3::Python::with_gil(|py| {
/// let a = numpy::pyarray![py, [1, 0], [0, 1]];
/// let b = numpy::pyarray![py, [4, 1], [2, 2]];
/// let dot: &numpy::PyArray2::<_> = numpy::dot(a, b).unwrap();
/// use pyo3::Python;
/// use ndarray::array;
/// use numpy::{dot, pyarray, PyArray2};
///
/// Python::with_gil(|py| {
/// let matrix = pyarray![py, [1, 0], [0, 1]];
/// let another_matrix = pyarray![py, [4, 1], [2, 2]];
///
/// let result: &PyArray2<_> = numpy::dot(matrix, another_matrix).unwrap();
///
/// assert_eq!(
/// dot.readonly().as_array(),
/// ndarray::array![[4, 1], [2, 2]]
/// result.readonly().as_array(),
/// array![[4, 1], [2, 2]]
/// );
/// });
/// ```
pub fn dot<'py, T, DIN1, DIN2, DOUT>(
///
/// ...or a scalar depending on its arguments.
///
/// ```
/// use pyo3::Python;
/// use numpy::{dot, pyarray, PyArray0};
///
/// Python::with_gil(|py| {
/// let vector = pyarray![py, 1.0, 2.0, 3.0];
/// let result: f64 = dot(vector, vector).unwrap();
/// assert_eq!(result, 14.0);
/// });
/// ```
///
/// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
pub fn dot<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
) -> PyResult<&'py PyArray<T, DOUT>>
) -> PyResult<OUT>
where
T: Element,
DIN1: Dimension,
DIN2: Dimension,
DOUT: Dimension,
T: Element,
OUT: ArrayOrScalar<'py, T>,
{
let py = array1.py();
let obj = unsafe {
Expand All @@ -66,31 +132,28 @@ where

/// Return the Einstein summation convention of given tensors.
///
/// We also provide the [einsum macro](./macro.einsum.html).
pub fn einsum_impl<'py, T, DOUT>(
subscripts: &str,
arrays: &[&'py PyArray<T, IxDyn>],
) -> PyResult<&'py PyArray<T, DOUT>>
/// This is usually invoked via the the [`einsum!`] macro.
pub fn einsum<'py, T, OUT>(subscripts: &str, arrays: &[&'py PyArray<T, IxDyn>]) -> PyResult<OUT>
where
DOUT: Dimension,
T: Element,
OUT: ArrayOrScalar<'py, T>,
{
let subscripts: std::borrow::Cow<CStr> = match CStr::from_bytes_with_nul(subscripts.as_bytes())
{
Ok(subscripts) => subscripts.into(),
Err(_) => std::ffi::CString::new(subscripts).unwrap().into(),
let subscripts = match CStr::from_bytes_with_nul(subscripts.as_bytes()) {
Ok(subscripts) => Cow::Borrowed(subscripts),
Err(_) => Cow::Owned(CString::new(subscripts).unwrap()),
};

let py = arrays[0].py();
let obj = unsafe {
let result = PY_ARRAY_API.PyArray_EinsteinSum(
py,
subscripts.as_ptr() as _,
arrays.len() as _,
arrays.as_ptr() as _,
std::ptr::null_mut(),
null_mut(),
NPY_ORDER::NPY_KEEPORDER,
NPY_CASTING::NPY_NO_CASTING,
std::ptr::null_mut(),
null_mut(),
);
PyAny::from_owned_ptr_or_err(py, result)?
};
Expand All @@ -99,25 +162,34 @@ where

/// Return the Einstein summation convention of given tensors.
///
/// For more about the Einstein summation convention, you may reffer to
/// [the numpy document](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
/// For more about the Einstein summation convention, please refer to
/// [NumPy's documentation][einsum].
///
/// # Example
///
/// ```
/// pyo3::Python::with_gil(|py| {
/// let a = numpy::PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
/// let b = numpy::pyarray![py, [20, 30], [40, 50], [60, 70]];
/// let einsum = numpy::einsum!("ijk,ji->ik", a, b).unwrap();
/// use pyo3::Python;
/// use ndarray::array;
/// use numpy::{einsum, pyarray, PyArray, PyArray2};
///
/// Python::with_gil(|py| {
/// let tensor = PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
///
/// let result: &PyArray2<_> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
///
/// assert_eq!(
/// einsum.readonly().as_array(),
/// ndarray::array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
/// result.readonly().as_array(),
/// array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
/// );
/// });
/// ```
///
/// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
#[macro_export]
macro_rules! einsum {
($subscripts: literal $(,$array: ident)+ $(,)*) => {{
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
let arrays = [$($array.to_dyn(),)+];
unsafe { $crate::einsum_impl(concat!($subscripts, "\0"), &arrays) }
$crate::einsum(concat!($subscripts, "\0"), &arrays)
}};
}
70 changes: 43 additions & 27 deletions tests/sum_products.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
use numpy::{array, dot, einsum, inner, pyarray, Ix2, PyArray1};
use numpy::{array, dot, einsum, inner, pyarray, PyArray0, PyArray1, PyArray2};
use pyo3::Python;

#[test]
fn test_dot() {
Python::with_gil(|py| {
let a = pyarray![py, [1, 0], [0, 1]];
let b = pyarray![py, [4, 1], [2, 2]];
let c = dot(a, b).unwrap();
let c: &PyArray2<_> = dot(a, b).unwrap();
assert_eq!(c.readonly().as_array(), array![[4, 1], [2, 2]]);

let a = pyarray![py, 1, 2, 3];
let err = dot::<_, _, _, Ix2>(a, b).unwrap_err();
let err = dot::<_, _, _, &PyArray2<_>>(a, b).unwrap_err();
assert!(err.to_string().contains("not aligned"), "{}", err);

let a = pyarray![py, 1, 2, 3];
let b = pyarray![py, 0, 1, 0];
let c: &PyArray0<_> = dot(a, b).unwrap();
assert_eq!(c.item(), 2);
let c: i32 = dot(a, b).unwrap();
assert_eq!(c, 2);

let a = pyarray![py, 1.0, 2.0, 3.0];
let b = pyarray![py, 0.0, 0.0, 0.0];
let c: f64 = dot(a, b).unwrap();
assert_eq!(c, 0.0);
});
}

Expand All @@ -20,16 +32,23 @@ fn test_inner() {
Python::with_gil(|py| {
let a = pyarray![py, 1, 2, 3];
let b = pyarray![py, 0, 1, 0];
let c = inner(a, b).unwrap();
assert_eq!(c.readonly().as_array(), ndarray::arr0(2));
let c: &PyArray0<_> = inner(a, b).unwrap();
assert_eq!(c.item(), 2);
let c: i32 = inner(a, b).unwrap();
assert_eq!(c, 2);

let a = pyarray![py, 1.0, 2.0, 3.0];
let b = pyarray![py, 0.0, 0.0, 0.0];
let c: f64 = inner(a, b).unwrap();
assert_eq!(c, 0.0);

let a = pyarray![py, [1, 0], [0, 1]];
let b = pyarray![py, [4, 1], [2, 2]];
let c = inner(a, b).unwrap();
let c: &PyArray2<_> = inner(a, b).unwrap();
assert_eq!(c.readonly().as_array(), array![[4, 2], [1, 2]]);

let a = pyarray![py, 1, 2, 3];
let err = inner::<_, _, _, Ix2>(a, b).unwrap_err();
let err = inner::<_, _, _, &PyArray2<_>>(a, b).unwrap_err();
assert!(err.to_string().contains("not aligned"), "{}", err);
});
}
Expand All @@ -43,25 +62,22 @@ fn test_einsum() {
let b = pyarray![py, 0, 1, 2, 3, 4];
let c = pyarray![py, [0, 1, 2], [3, 4, 5]];

assert_eq!(
einsum!("ii", a).unwrap().readonly().as_array(),
ndarray::arr0(60)
);
assert_eq!(
einsum!("ii->i", a).unwrap().readonly().as_array(),
array![0, 6, 12, 18, 24],
);
assert_eq!(
einsum!("ij->i", a).unwrap().readonly().as_array(),
array![10, 35, 60, 85, 110],
);
assert_eq!(
einsum!("ji", c).unwrap().readonly().as_array(),
array![[0, 3], [1, 4], [2, 5]],
);
assert_eq!(
einsum!("ij,j", a, b).unwrap().readonly().as_array(),
array![30, 80, 130, 180, 230],
);
let d: &PyArray0<_> = einsum!("ii", a).unwrap();
assert_eq!(d.item(), 60);

let d: i32 = einsum!("ii", a).unwrap();
assert_eq!(d, 60);

let d: &PyArray1<_> = einsum!("ii->i", a).unwrap();
assert_eq!(d.readonly().as_array(), array![0, 6, 12, 18, 24]);

let d: &PyArray1<_> = einsum!("ij->i", a).unwrap();
assert_eq!(d.readonly().as_array(), array![10, 35, 60, 85, 110]);

let d: &PyArray2<_> = einsum!("ji", c).unwrap();
assert_eq!(d.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);

let d: &PyArray1<_> = einsum!("ij,j", a, b).unwrap();
assert_eq!(d.readonly().as_array(), array![30, 80, 130, 180, 230]);
});
}