Skip to content

Commit aec633a

Browse files
committed
Extend our type signatures of inner and dot to match NumPy's types.
1 parent b9f01fa commit aec633a

File tree

3 files changed

+121
-45
lines changed

3 files changed

+121
-45
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Changelog
22

33
- Unreleased
4+
- The `inner` and `dot` functions can also return a scalar instead of a zero-dimensional array to match NumPy's types ([#284](https://github.com/PyO3/rust-numpy/pull/284))
45

56
- v0.16.0
67
- Bump PyO3 version to 0.16 ([#259](https://github.com/PyO3/rust-numpy/pull/212))

src/sum_products.rs

+116-41
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,68 @@
1-
use crate::npyffi::{NPY_CASTING, NPY_ORDER};
2-
use crate::{Element, PyArray, PY_ARRAY_API};
1+
use std::borrow::Cow;
2+
use std::ffi::{CStr, CString};
3+
use std::ptr::null_mut;
4+
35
use ndarray::{Dimension, IxDyn};
4-
use pyo3::{AsPyPointer, FromPyPointer, PyAny, PyNativeType, PyResult};
5-
use std::ffi::CStr;
6+
use pyo3::{AsPyPointer, FromPyObject, FromPyPointer, PyAny, PyNativeType, PyResult};
7+
8+
use crate::array::PyArray;
9+
use crate::dtype::Element;
10+
use crate::npyffi::{array::PY_ARRAY_API, NPY_CASTING, NPY_ORDER};
11+
12+
/// Return value of a function that can yield either an array or a scalar.
13+
pub trait ArrayOrScalar<'py, T>: FromPyObject<'py> {}
14+
15+
impl<'py, T, D> ArrayOrScalar<'py, T> for &'py PyArray<T, D>
16+
where
17+
T: Element,
18+
D: Dimension,
19+
{
20+
}
21+
22+
impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}
623

724
/// Return the inner product of two arrays.
825
///
9-
/// # Example
26+
/// [NumPy's documentation][inner] has the details.
27+
///
28+
/// # Examples
29+
///
30+
/// Note that this function can either return a scalar...
31+
///
1032
/// ```
11-
/// pyo3::Python::with_gil(|py| {
12-
/// let array = numpy::pyarray![py, 1, 2, 3];
13-
/// let inner: &numpy::PyArray0::<_> = numpy::inner(array, array).unwrap();
14-
/// assert_eq!(inner.item(), 14);
33+
/// use pyo3::Python;
34+
/// use numpy::{inner, pyarray, PyArray0};
35+
///
36+
/// Python::with_gil(|py| {
37+
/// let vector = pyarray![py, 1.0, 2.0, 3.0];
38+
/// let result: f64 = inner(vector, vector).unwrap();
39+
/// assert_eq!(result, 14.0);
40+
/// });
41+
/// ```
42+
///
43+
/// ...or an array depending on its arguments.
44+
///
45+
/// ```
46+
/// use pyo3::Python;
47+
/// use numpy::{inner, pyarray, PyArray0};
48+
///
49+
/// Python::with_gil(|py| {
50+
/// let vector = pyarray![py, 1, 2, 3];
51+
/// let result: &PyArray0<_> = inner(vector, vector).unwrap();
52+
/// assert_eq!(result.item(), 14);
1553
/// });
1654
/// ```
17-
pub fn inner<'py, T, DIN1, DIN2, DOUT>(
55+
///
56+
/// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
57+
pub fn inner<'py, T, DIN1, DIN2, OUT>(
1858
array1: &'py PyArray<T, DIN1>,
1959
array2: &'py PyArray<T, DIN2>,
20-
) -> PyResult<&'py PyArray<T, DOUT>>
60+
) -> PyResult<OUT>
2161
where
62+
T: Element,
2263
DIN1: Dimension,
2364
DIN2: Dimension,
24-
DOUT: Dimension,
25-
T: Element,
65+
OUT: ArrayOrScalar<'py, T>,
2666
{
2767
let py = array1.py();
2868
let obj = unsafe {
@@ -34,27 +74,53 @@ where
3474

3575
/// Return the dot product of two arrays.
3676
///
37-
/// # Example
77+
/// [NumPy's documentation][dot] has the details.
78+
///
79+
/// # Examples
80+
///
81+
/// Note that this function can either return an array...
82+
///
3883
/// ```
39-
/// pyo3::Python::with_gil(|py| {
40-
/// let a = numpy::pyarray![py, [1, 0], [0, 1]];
41-
/// let b = numpy::pyarray![py, [4, 1], [2, 2]];
42-
/// let dot: &numpy::PyArray2::<_> = numpy::dot(a, b).unwrap();
84+
/// use pyo3::Python;
85+
/// use ndarray::array;
86+
/// use numpy::{dot, pyarray, PyArray2};
87+
///
88+
/// Python::with_gil(|py| {
89+
/// let matrix = pyarray![py, [1, 0], [0, 1]];
90+
/// let another_matrix = pyarray![py, [4, 1], [2, 2]];
91+
///
92+
/// let result: &PyArray2<_> = numpy::dot(matrix, another_matrix).unwrap();
93+
///
4394
/// assert_eq!(
44-
/// dot.readonly().as_array(),
45-
/// ndarray::array![[4, 1], [2, 2]]
95+
/// result.readonly().as_array(),
96+
/// array![[4, 1], [2, 2]]
4697
/// );
4798
/// });
4899
/// ```
49-
pub fn dot<'py, T, DIN1, DIN2, DOUT>(
100+
///
101+
/// ...or a scalar depending on its arguments.
102+
///
103+
/// ```
104+
/// use pyo3::Python;
105+
/// use numpy::{dot, pyarray, PyArray0};
106+
///
107+
/// Python::with_gil(|py| {
108+
/// let vector = pyarray![py, 1.0, 2.0, 3.0];
109+
/// let result: f64 = dot(vector, vector).unwrap();
110+
/// assert_eq!(result, 14.0);
111+
/// });
112+
/// ```
113+
///
114+
/// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
115+
pub fn dot<'py, T, DIN1, DIN2, OUT>(
50116
array1: &'py PyArray<T, DIN1>,
51117
array2: &'py PyArray<T, DIN2>,
52-
) -> PyResult<&'py PyArray<T, DOUT>>
118+
) -> PyResult<OUT>
53119
where
120+
T: Element,
54121
DIN1: Dimension,
55122
DIN2: Dimension,
56-
DOUT: Dimension,
57-
T: Element,
123+
OUT: ArrayOrScalar<'py, T>,
58124
{
59125
let py = array1.py();
60126
let obj = unsafe {
@@ -66,7 +132,7 @@ where
66132

67133
/// Return the Einstein summation convention of given tensors.
68134
///
69-
/// We also provide the [einsum macro](./macro.einsum.html).
135+
/// This is usually invoked via the the [`einsum!`] macro.
70136
pub fn einsum_impl<'py, T, DOUT>(
71137
subscripts: &str,
72138
arrays: &[&'py PyArray<T, IxDyn>],
@@ -75,22 +141,22 @@ where
75141
DOUT: Dimension,
76142
T: Element,
77143
{
78-
let subscripts: std::borrow::Cow<CStr> = match CStr::from_bytes_with_nul(subscripts.as_bytes())
79-
{
80-
Ok(subscripts) => subscripts.into(),
81-
Err(_) => std::ffi::CString::new(subscripts).unwrap().into(),
144+
let subscripts = match CStr::from_bytes_with_nul(subscripts.as_bytes()) {
145+
Ok(subscripts) => Cow::Borrowed(subscripts),
146+
Err(_) => Cow::Owned(CString::new(subscripts).unwrap()),
82147
};
148+
83149
let py = arrays[0].py();
84150
let obj = unsafe {
85151
let result = PY_ARRAY_API.PyArray_EinsteinSum(
86152
py,
87153
subscripts.as_ptr() as _,
88154
arrays.len() as _,
89155
arrays.as_ptr() as _,
90-
std::ptr::null_mut(),
156+
null_mut(),
91157
NPY_ORDER::NPY_KEEPORDER,
92158
NPY_CASTING::NPY_NO_CASTING,
93-
std::ptr::null_mut(),
159+
null_mut(),
94160
);
95161
PyAny::from_owned_ptr_or_err(py, result)?
96162
};
@@ -99,25 +165,34 @@ where
99165

100166
/// Return the Einstein summation convention of given tensors.
101167
///
102-
/// For more about the Einstein summation convention, you may reffer to
103-
/// [the numpy document](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
168+
/// For more about the Einstein summation convention, please refer to
169+
/// [NumPy's documentation][einsum].
104170
///
105171
/// # Example
172+
///
106173
/// ```
107-
/// pyo3::Python::with_gil(|py| {
108-
/// let a = numpy::PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
109-
/// let b = numpy::pyarray![py, [20, 30], [40, 50], [60, 70]];
110-
/// let einsum = numpy::einsum!("ijk,ji->ik", a, b).unwrap();
174+
/// use pyo3::Python;
175+
/// use ndarray::array;
176+
/// use numpy::{einsum, pyarray, PyArray};
177+
///
178+
/// Python::with_gil(|py| {
179+
/// let tensor = PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
180+
/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
181+
///
182+
/// let result = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
183+
///
111184
/// assert_eq!(
112-
/// einsum.readonly().as_array(),
113-
/// ndarray::array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
185+
/// result.readonly().as_array(),
186+
/// array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
114187
/// );
115188
/// });
116189
/// ```
190+
///
191+
/// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
117192
#[macro_export]
118193
macro_rules! einsum {
119-
($subscripts: literal $(,$array: ident)+ $(,)*) => {{
194+
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
120195
let arrays = [$($array.to_dyn(),)+];
121-
unsafe { $crate::einsum_impl(concat!($subscripts, "\0"), &arrays) }
196+
$crate::einsum_impl(concat!($subscripts, "\0"), &arrays)
122197
}};
123198
}

tests/sum_products.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use numpy::{array, dot, einsum, inner, pyarray, PyArray1, PyArray2};
1+
use numpy::{array, dot, einsum, inner, pyarray, PyArray0, PyArray1, PyArray2};
22

33
#[test]
44
fn test_dot() {
55
pyo3::Python::with_gil(|py| {
66
let a = pyarray![py, [1, 0], [0, 1]];
77
let b = pyarray![py, [4, 1], [2, 2]];
8-
let c = dot(a, b).unwrap();
8+
let c: &PyArray2<_> = dot(a, b).unwrap();
99
assert_eq!(c.readonly().as_array(), array![[4, 1], [2, 2]]);
1010
let a = pyarray![py, 1, 2, 3];
1111
let err: pyo3::PyResult<&PyArray2<_>> = dot(a, b);
@@ -19,11 +19,11 @@ fn test_inner() {
1919
pyo3::Python::with_gil(|py| {
2020
let a = pyarray![py, 1, 2, 3];
2121
let b = pyarray![py, 0, 1, 0];
22-
let c = inner(a, b).unwrap();
22+
let c: &PyArray0<_> = inner(a, b).unwrap();
2323
assert_eq!(c.readonly().as_array(), ndarray::arr0(2));
2424
let a = pyarray![py, [1, 0], [0, 1]];
2525
let b = pyarray![py, [4, 1], [2, 2]];
26-
let c = inner(a, b).unwrap();
26+
let c: &PyArray2<_> = inner(a, b).unwrap();
2727
assert_eq!(c.readonly().as_array(), array![[4, 2], [1, 2]]);
2828
let a = pyarray![py, 1, 2, 3];
2929
let err: pyo3::PyResult<&PyArray2<_>> = inner(a, b);

0 commit comments

Comments
 (0)