Skip to content

Commit 8cf5e84

Browse files
committed
Extend the usage of the ArrayOrScalar trait to the einsum macro and function.
1 parent 4566c8a commit 8cf5e84

File tree

3 files changed

+23
-26
lines changed

3 files changed

+23
-26
lines changed

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ pub use crate::readonly::{
6464
PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4,
6565
PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn,
6666
};
67-
pub use crate::sum_products::{dot, einsum_impl, inner};
67+
pub use crate::sum_products::{dot, einsum, inner};
6868
pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
6969

7070
#[cfg(doctest)]

src/sum_products.rs

+5-8
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,10 @@ where
133133
/// Return the Einstein summation convention of given tensors.
134134
///
135135
/// This is usually invoked via the the [`einsum!`] macro.
136-
pub fn einsum_impl<'py, T, DOUT>(
137-
subscripts: &str,
138-
arrays: &[&'py PyArray<T, IxDyn>],
139-
) -> PyResult<&'py PyArray<T, DOUT>>
136+
pub fn einsum<'py, T, OUT>(subscripts: &str, arrays: &[&'py PyArray<T, IxDyn>]) -> PyResult<OUT>
140137
where
141-
DOUT: Dimension,
142138
T: Element,
139+
OUT: ArrayOrScalar<'py, T>,
143140
{
144141
let subscripts = match CStr::from_bytes_with_nul(subscripts.as_bytes()) {
145142
Ok(subscripts) => Cow::Borrowed(subscripts),
@@ -173,13 +170,13 @@ where
173170
/// ```
174171
/// use pyo3::Python;
175172
/// use ndarray::array;
176-
/// use numpy::{einsum, pyarray, PyArray};
173+
/// use numpy::{einsum, pyarray, PyArray, PyArray2};
177174
///
178175
/// Python::with_gil(|py| {
179176
/// let tensor = PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
180177
/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
181178
///
182-
/// let result = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
179+
/// let result: &PyArray2<_> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
183180
///
184181
/// assert_eq!(
185182
/// result.readonly().as_array(),
@@ -193,6 +190,6 @@ where
193190
macro_rules! einsum {
194191
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
195192
let arrays = [$($array.to_dyn(),)+];
196-
$crate::einsum_impl(concat!($subscripts, "\0"), &arrays)
193+
$crate::einsum(concat!($subscripts, "\0"), &arrays)
197194
}};
198195
}

tests/sum_products.rs

+17-17
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,22 @@ fn test_einsum() {
6363
let b = pyarray![py, 0, 1, 2, 3, 4];
6464
let c = pyarray![py, [0, 1, 2], [3, 4, 5]];
6565

66-
assert_eq!(einsum!("ii", a).unwrap().readonly().as_array(), arr0(60));
67-
assert_eq!(
68-
einsum!("ii->i", a).unwrap().readonly().as_array(),
69-
array![0, 6, 12, 18, 24],
70-
);
71-
assert_eq!(
72-
einsum!("ij->i", a).unwrap().readonly().as_array(),
73-
array![10, 35, 60, 85, 110],
74-
);
75-
assert_eq!(
76-
einsum!("ji", c).unwrap().readonly().as_array(),
77-
array![[0, 3], [1, 4], [2, 5]],
78-
);
79-
assert_eq!(
80-
einsum!("ij,j", a, b).unwrap().readonly().as_array(),
81-
array![30, 80, 130, 180, 230],
82-
);
66+
let d: &PyArray0<_> = einsum!("ii", a).unwrap();
67+
assert_eq!(d.readonly().as_array(), arr0(60));
68+
69+
let d: i32 = einsum!("ii", a).unwrap();
70+
assert_eq!(d, 60);
71+
72+
let d: &PyArray1<_> = einsum!("ii->i", a).unwrap();
73+
assert_eq!(d.readonly().as_array(), array![0, 6, 12, 18, 24]);
74+
75+
let d: &PyArray1<_> = einsum!("ij->i", a).unwrap();
76+
assert_eq!(d.readonly().as_array(), array![10, 35, 60, 85, 110]);
77+
78+
let d: &PyArray2<_> = einsum!("ji", c).unwrap();
79+
assert_eq!(d.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);
80+
81+
let d: &PyArray1<_> = einsum!("ij,j", a, b).unwrap();
82+
assert_eq!(d.readonly().as_array(), array![30, 80, 130, 180, 230]);
8383
});
8484
}

0 commit comments

Comments
 (0)