Skip to content

Commit 1e0fbef

Browse files
committed
Extend the usage of the ArrayOrScalar trait to the einsum macro and function.
1 parent ee084db commit 1e0fbef

File tree

4 files changed

+24
-27
lines changed

4 files changed

+24
-27
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Changelog
22

33
- Unreleased
4-
- The `inner` and `dot` functions can also return a scalar instead of a zero-dimensional array to match NumPy's types ([#285](https://github.com/PyO3/rust-numpy/pull/285))
4+
- The `inner`, `dot` and `einsum` functions can also return a scalar instead of a zero-dimensional array to match NumPy's types ([#285](https://github.com/PyO3/rust-numpy/pull/285))
55

66
- v0.16.1
77
- Fix build when PyO3's `multiple-pymethods` feature is used. ([#288](https://github.com/PyO3/rust-numpy/pull/288))

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ pub use crate::readonly::{
6464
PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4,
6565
PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn,
6666
};
67-
pub use crate::sum_products::{dot, einsum_impl, inner};
67+
pub use crate::sum_products::{dot, einsum, inner};
6868
pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
6969

7070
#[cfg(doctest)]

src/sum_products.rs

+5-8
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,10 @@ where
133133
/// Return the Einstein summation convention of given tensors.
134134
///
135135
/// This is usually invoked via the the [`einsum!`] macro.
136-
pub fn einsum_impl<'py, T, DOUT>(
137-
subscripts: &str,
138-
arrays: &[&'py PyArray<T, IxDyn>],
139-
) -> PyResult<&'py PyArray<T, DOUT>>
136+
pub fn einsum<'py, T, OUT>(subscripts: &str, arrays: &[&'py PyArray<T, IxDyn>]) -> PyResult<OUT>
140137
where
141-
DOUT: Dimension,
142138
T: Element,
139+
OUT: ArrayOrScalar<'py, T>,
143140
{
144141
let subscripts = match CStr::from_bytes_with_nul(subscripts.as_bytes()) {
145142
Ok(subscripts) => Cow::Borrowed(subscripts),
@@ -173,13 +170,13 @@ where
173170
/// ```
174171
/// use pyo3::Python;
175172
/// use ndarray::array;
176-
/// use numpy::{einsum, pyarray, PyArray};
173+
/// use numpy::{einsum, pyarray, PyArray, PyArray2};
177174
///
178175
/// Python::with_gil(|py| {
179176
/// let tensor = PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
180177
/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
181178
///
182-
/// let result = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
179+
/// let result: &PyArray2<_> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
183180
///
184181
/// assert_eq!(
185182
/// result.readonly().as_array(),
@@ -193,6 +190,6 @@ where
193190
macro_rules! einsum {
194191
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
195192
let arrays = [$($array.to_dyn(),)+];
196-
$crate::einsum_impl(concat!($subscripts, "\0"), &arrays)
193+
$crate::einsum(concat!($subscripts, "\0"), &arrays)
197194
}};
198195
}

tests/sum_products.rs

+17-17
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,22 @@ fn test_einsum() {
6262
let b = pyarray![py, 0, 1, 2, 3, 4];
6363
let c = pyarray![py, [0, 1, 2], [3, 4, 5]];
6464

65-
assert_eq!(einsum!("ii", a).unwrap().item(), 60);
66-
assert_eq!(
67-
einsum!("ii->i", a).unwrap().readonly().as_array(),
68-
array![0, 6, 12, 18, 24],
69-
);
70-
assert_eq!(
71-
einsum!("ij->i", a).unwrap().readonly().as_array(),
72-
array![10, 35, 60, 85, 110],
73-
);
74-
assert_eq!(
75-
einsum!("ji", c).unwrap().readonly().as_array(),
76-
array![[0, 3], [1, 4], [2, 5]],
77-
);
78-
assert_eq!(
79-
einsum!("ij,j", a, b).unwrap().readonly().as_array(),
80-
array![30, 80, 130, 180, 230],
81-
);
65+
let d: &PyArray0<_> = einsum!("ii", a).unwrap();
66+
assert_eq!(d.item(), 60);
67+
68+
let d: i32 = einsum!("ii", a).unwrap();
69+
assert_eq!(d, 60);
70+
71+
let d: &PyArray1<_> = einsum!("ii->i", a).unwrap();
72+
assert_eq!(d.readonly().as_array(), array![0, 6, 12, 18, 24]);
73+
74+
let d: &PyArray1<_> = einsum!("ij->i", a).unwrap();
75+
assert_eq!(d.readonly().as_array(), array![10, 35, 60, 85, 110]);
76+
77+
let d: &PyArray2<_> = einsum!("ji", c).unwrap();
78+
assert_eq!(d.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);
79+
80+
let d: &PyArray1<_> = einsum!("ij,j", a, b).unwrap();
81+
assert_eq!(d.readonly().as_array(), array![30, 80, 130, 180, 230]);
8282
});
8383
}

0 commit comments

Comments
 (0)