Skip to content

Commit 05b1b32

Browse files
Icxoluadamreichold
authored andcommitted
deprecate inner, dot and einsum
1 parent 0b39d09 commit 05b1b32

File tree

3 files changed

+148
-69
lines changed

3 files changed

+148
-69
lines changed

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ pub use crate::dtype::{
114114
pub use crate::error::{BorrowError, FromVecError, NotContiguousError};
115115
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
116116
pub use crate::strings::{PyFixedString, PyFixedUnicode};
117+
#[allow(deprecated)]
117118
pub use crate::sum_products::{dot, einsum, inner};
119+
pub use crate::sum_products::{dot_bound, einsum_bound, inner_bound};
118120
pub use crate::untyped_array::{PyUntypedArray, PyUntypedArrayMethods};
119121

120122
pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};

src/sum_products.rs

+110-33
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::ptr::null_mut;
44

55
use ndarray::{Dimension, IxDyn};
66
use pyo3::types::PyAnyMethods;
7-
use pyo3::{AsPyPointer, Bound, FromPyObject, PyNativeType, PyResult};
7+
use pyo3::{Borrowed, Bound, FromPyObject, PyNativeType, PyResult};
88

99
use crate::array::PyArray;
1010
use crate::dtype::Element;
@@ -20,8 +20,33 @@ where
2020
{
2121
}
2222

23+
impl<'py, T, D> ArrayOrScalar<'py, T> for Bound<'py, PyArray<T, D>>
24+
where
25+
T: Element,
26+
D: Dimension,
27+
{
28+
}
29+
2330
impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}
2431

32+
/// Deprecated form of [`inner_bound`]
33+
#[deprecated(
34+
since = "0.21.0",
35+
note = "will be replaced by `inner_bound` in the future"
36+
)]
37+
pub fn inner<'py, T, DIN1, DIN2, OUT>(
38+
array1: &'py PyArray<T, DIN1>,
39+
array2: &'py PyArray<T, DIN2>,
40+
) -> PyResult<OUT>
41+
where
42+
T: Element,
43+
DIN1: Dimension,
44+
DIN2: Dimension,
45+
OUT: ArrayOrScalar<'py, T>,
46+
{
47+
inner_bound(&array1.as_borrowed(), &array2.as_borrowed())
48+
}
49+
2550
/// Return the inner product of two arrays.
2651
///
2752
/// [NumPy's documentation][inner] has the details.
@@ -31,33 +56,33 @@ impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}
3156
/// Note that this function can either return a scalar...
3257
///
3358
/// ```
34-
/// use pyo3::Python;
35-
/// use numpy::{inner, pyarray, PyArray0};
59+
/// use pyo3::{Python, PyNativeType};
60+
/// use numpy::{inner_bound, pyarray, PyArray0};
3661
///
3762
/// Python::with_gil(|py| {
38-
/// let vector = pyarray![py, 1.0, 2.0, 3.0];
39-
/// let result: f64 = inner(vector, vector).unwrap();
63+
/// let vector = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
64+
/// let result: f64 = inner_bound(&vector, &vector).unwrap();
4065
/// assert_eq!(result, 14.0);
4166
/// });
4267
/// ```
4368
///
4469
/// ...or an array depending on its arguments.
4570
///
4671
/// ```
47-
/// use pyo3::Python;
48-
/// use numpy::{inner, pyarray, PyArray0};
72+
/// use pyo3::{Python, Bound, PyNativeType};
73+
/// use numpy::{inner_bound, pyarray, PyArray0, PyArray0Methods};
4974
///
5075
/// Python::with_gil(|py| {
51-
/// let vector = pyarray![py, 1, 2, 3];
52-
/// let result: &PyArray0<_> = inner(vector, vector).unwrap();
76+
/// let vector = pyarray![py, 1, 2, 3].as_borrowed();
77+
/// let result: Bound<'_, PyArray0<_>> = inner_bound(&vector, &vector).unwrap();
5378
/// assert_eq!(result.item(), 14);
5479
/// });
5580
/// ```
5681
///
5782
/// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
58-
pub fn inner<'py, T, DIN1, DIN2, OUT>(
59-
array1: &'py PyArray<T, DIN1>,
60-
array2: &'py PyArray<T, DIN2>,
83+
pub fn inner_bound<'py, T, DIN1, DIN2, OUT>(
84+
array1: &Bound<'py, PyArray<T, DIN1>>,
85+
array2: &Bound<'py, PyArray<T, DIN2>>,
6186
) -> PyResult<OUT>
6287
where
6388
T: Element,
@@ -73,6 +98,24 @@ where
7398
obj.extract()
7499
}
75100

101+
/// Deprecated form of [`dot_bound`]
102+
#[deprecated(
103+
since = "0.21.0",
104+
note = "will be replaced by `dot_bound` in the future"
105+
)]
106+
pub fn dot<'py, T, DIN1, DIN2, OUT>(
107+
array1: &'py PyArray<T, DIN1>,
108+
array2: &'py PyArray<T, DIN2>,
109+
) -> PyResult<OUT>
110+
where
111+
T: Element,
112+
DIN1: Dimension,
113+
DIN2: Dimension,
114+
OUT: ArrayOrScalar<'py, T>,
115+
{
116+
dot_bound(&array1.as_borrowed(), &array2.as_borrowed())
117+
}
118+
76119
/// Return the dot product of two arrays.
77120
///
78121
/// [NumPy's documentation][dot] has the details.
@@ -82,15 +125,15 @@ where
82125
/// Note that this function can either return an array...
83126
///
84127
/// ```
85-
/// use pyo3::Python;
128+
/// use pyo3::{Python, Bound, PyNativeType};
86129
/// use ndarray::array;
87-
/// use numpy::{dot, pyarray, PyArray2};
130+
/// use numpy::{dot_bound, pyarray, PyArray2, PyArrayMethods};
88131
///
89132
/// Python::with_gil(|py| {
90-
/// let matrix = pyarray![py, [1, 0], [0, 1]];
91-
/// let another_matrix = pyarray![py, [4, 1], [2, 2]];
133+
/// let matrix = pyarray![py, [1, 0], [0, 1]].as_borrowed();
134+
/// let another_matrix = pyarray![py, [4, 1], [2, 2]].as_borrowed();
92135
///
93-
/// let result: &PyArray2<_> = numpy::dot(matrix, another_matrix).unwrap();
136+
/// let result: Bound<'_, PyArray2<_>> = dot_bound(&matrix, &another_matrix).unwrap();
94137
///
95138
/// assert_eq!(
96139
/// result.readonly().as_array(),
@@ -102,20 +145,20 @@ where
102145
/// ...or a scalar depending on its arguments.
103146
///
104147
/// ```
105-
/// use pyo3::Python;
106-
/// use numpy::{dot, pyarray, PyArray0};
148+
/// use pyo3::{Python, PyNativeType};
149+
/// use numpy::{dot_bound, pyarray, PyArray0};
107150
///
108151
/// Python::with_gil(|py| {
109-
/// let vector = pyarray![py, 1.0, 2.0, 3.0];
110-
/// let result: f64 = dot(vector, vector).unwrap();
152+
/// let vector = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
153+
/// let result: f64 = dot_bound(&vector, &vector).unwrap();
111154
/// assert_eq!(result, 14.0);
112155
/// });
113156
/// ```
114157
///
115158
/// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
116-
pub fn dot<'py, T, DIN1, DIN2, OUT>(
117-
array1: &'py PyArray<T, DIN1>,
118-
array2: &'py PyArray<T, DIN2>,
159+
pub fn dot_bound<'py, T, DIN1, DIN2, OUT>(
160+
array1: &Bound<'py, PyArray<T, DIN1>>,
161+
array2: &Bound<'py, PyArray<T, DIN2>>,
119162
) -> PyResult<OUT>
120163
where
121164
T: Element,
@@ -131,10 +174,30 @@ where
131174
obj.extract()
132175
}
133176

177+
/// Deprecated form of [`einsum_bound`]
178+
#[deprecated(
179+
since = "0.21.0",
180+
note = "will be replaced by `einsum_bound` in the future"
181+
)]
182+
pub fn einsum<'py, T, OUT>(subscripts: &str, arrays: &[&'py PyArray<T, IxDyn>]) -> PyResult<OUT>
183+
where
184+
T: Element,
185+
OUT: ArrayOrScalar<'py, T>,
186+
{
187+
// Safety: &PyArray<T, IxDyn> has the same size and layout in memory as
188+
// Borrowed<'_, '_, PyArray<T, IxDyn>>
189+
einsum_bound(subscripts, unsafe {
190+
std::slice::from_raw_parts(arrays.as_ptr().cast(), arrays.len())
191+
})
192+
}
193+
134194
/// Return the Einstein summation convention of given tensors.
135195
///
136196
/// This is usually invoked via the the [`einsum!`][crate::einsum!] macro.
137-
pub fn einsum<'py, T, OUT>(subscripts: &str, arrays: &[&'py PyArray<T, IxDyn>]) -> PyResult<OUT>
197+
pub fn einsum_bound<'py, T, OUT>(
198+
subscripts: &str,
199+
arrays: &[Borrowed<'_, 'py, PyArray<T, IxDyn>>],
200+
) -> PyResult<OUT>
138201
where
139202
T: Element,
140203
OUT: ArrayOrScalar<'py, T>,
@@ -161,6 +224,20 @@ where
161224
obj.extract()
162225
}
163226

227+
/// Deprecated form of [`einsum_bound!`]
228+
#[deprecated(
229+
since = "0.21.0",
230+
note = "will be replaced by `einsum_bound!` in the future"
231+
)]
232+
#[macro_export]
233+
macro_rules! einsum {
234+
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
235+
use pyo3::PyNativeType;
236+
let arrays = [$($array.to_dyn().as_borrowed(),)+];
237+
$crate::einsum_bound(concat!($subscripts, "\0"), &arrays)
238+
}};
239+
}
240+
164241
/// Return the Einstein summation convention of given tensors.
165242
///
166243
/// For more about the Einstein summation convention, please refer to
@@ -169,15 +246,15 @@ where
169246
/// # Example
170247
///
171248
/// ```
172-
/// use pyo3::Python;
249+
/// use pyo3::{Python, Bound, PyNativeType};
173250
/// use ndarray::array;
174-
/// use numpy::{einsum, pyarray, PyArray, PyArray2, PyArrayMethods};
251+
/// use numpy::{einsum_bound, pyarray, PyArray, PyArray2, PyArrayMethods};
175252
///
176253
/// Python::with_gil(|py| {
177-
/// let tensor = PyArray::arange_bound(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap().into_gil_ref();
178-
/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
254+
/// let tensor = PyArray::arange_bound(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
255+
/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]].as_borrowed();
179256
///
180-
/// let result: &PyArray2<_> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
257+
/// let result: Bound<'_, PyArray2<_>> = einsum_bound!("ijk,ji->ik", tensor, another_tensor).unwrap();
181258
///
182259
/// assert_eq!(
183260
/// result.readonly().as_array(),
@@ -188,9 +265,9 @@ where
188265
///
189266
/// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
190267
#[macro_export]
191-
macro_rules! einsum {
268+
macro_rules! einsum_bound {
192269
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
193-
let arrays = [$($array.to_dyn(),)+];
194-
$crate::einsum(concat!($subscripts, "\0"), &arrays)
270+
let arrays = [$($array.to_dyn().as_borrowed(),)+];
271+
$crate::einsum_bound(concat!($subscripts, "\0"), &arrays)
195272
}};
196273
}

tests/sum_products.rs

+36-36
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,55 @@
1-
use numpy::{array, dot, einsum, inner, pyarray, PyArray0, PyArray1, PyArray2, PyArrayMethods};
2-
use pyo3::Python;
1+
use numpy::prelude::*;
2+
use numpy::{array, dot_bound, einsum_bound, inner_bound, pyarray, PyArray0, PyArray1, PyArray2};
3+
use pyo3::{Bound, PyNativeType, Python};
34

45
#[test]
56
fn test_dot() {
67
Python::with_gil(|py| {
7-
let a = pyarray![py, [1, 0], [0, 1]];
8-
let b = pyarray![py, [4, 1], [2, 2]];
9-
let c: &PyArray2<_> = dot(a, b).unwrap();
8+
let a = pyarray![py, [1, 0], [0, 1]].as_borrowed();
9+
let b = pyarray![py, [4, 1], [2, 2]].as_borrowed();
10+
let c: Bound<'_, PyArray2<_>> = dot_bound(&a, &b).unwrap();
1011
assert_eq!(c.readonly().as_array(), array![[4, 1], [2, 2]]);
1112

12-
let a = pyarray![py, 1, 2, 3];
13-
let err = dot::<_, _, _, &PyArray2<_>>(a, b).unwrap_err();
13+
let a = pyarray![py, 1, 2, 3].as_borrowed();
14+
let err = dot_bound::<_, _, _, Bound<'_, PyArray2<_>>>(&a, &b).unwrap_err();
1415
assert!(err.to_string().contains("not aligned"), "{}", err);
1516

16-
let a = pyarray![py, 1, 2, 3];
17-
let b = pyarray![py, 0, 1, 0];
18-
let c: &PyArray0<_> = dot(a, b).unwrap();
17+
let a = pyarray![py, 1, 2, 3].as_borrowed();
18+
let b = pyarray![py, 0, 1, 0].as_borrowed();
19+
let c: Bound<'_, PyArray0<_>> = dot_bound(&a, &b).unwrap();
1920
assert_eq!(c.item(), 2);
20-
let c: i32 = dot(a, b).unwrap();
21+
let c: i32 = dot_bound(&a, &b).unwrap();
2122
assert_eq!(c, 2);
2223

23-
let a = pyarray![py, 1.0, 2.0, 3.0];
24-
let b = pyarray![py, 0.0, 0.0, 0.0];
25-
let c: f64 = dot(a, b).unwrap();
24+
let a = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
25+
let b = pyarray![py, 0.0, 0.0, 0.0].as_borrowed();
26+
let c: f64 = dot_bound(&a, &b).unwrap();
2627
assert_eq!(c, 0.0);
2728
});
2829
}
2930

3031
#[test]
3132
fn test_inner() {
3233
Python::with_gil(|py| {
33-
let a = pyarray![py, 1, 2, 3];
34-
let b = pyarray![py, 0, 1, 0];
35-
let c: &PyArray0<_> = inner(a, b).unwrap();
34+
let a = pyarray![py, 1, 2, 3].as_borrowed();
35+
let b = pyarray![py, 0, 1, 0].as_borrowed();
36+
let c: Bound<'_, PyArray0<_>> = inner_bound(&a, &b).unwrap();
3637
assert_eq!(c.item(), 2);
37-
let c: i32 = inner(a, b).unwrap();
38+
let c: i32 = inner_bound(&a, &b).unwrap();
3839
assert_eq!(c, 2);
3940

40-
let a = pyarray![py, 1.0, 2.0, 3.0];
41-
let b = pyarray![py, 0.0, 0.0, 0.0];
42-
let c: f64 = inner(a, b).unwrap();
41+
let a = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
42+
let b = pyarray![py, 0.0, 0.0, 0.0].as_borrowed();
43+
let c: f64 = inner_bound(&a, &b).unwrap();
4344
assert_eq!(c, 0.0);
4445

45-
let a = pyarray![py, [1, 0], [0, 1]];
46-
let b = pyarray![py, [4, 1], [2, 2]];
47-
let c: &PyArray2<_> = inner(a, b).unwrap();
46+
let a = pyarray![py, [1, 0], [0, 1]].as_borrowed();
47+
let b = pyarray![py, [4, 1], [2, 2]].as_borrowed();
48+
let c: Bound<'_, PyArray2<_>> = inner_bound(&a, &b).unwrap();
4849
assert_eq!(c.readonly().as_array(), array![[4, 2], [1, 2]]);
4950

50-
let a = pyarray![py, 1, 2, 3];
51-
let err = inner::<_, _, _, &PyArray2<_>>(a, b).unwrap_err();
51+
let a = pyarray![py, 1, 2, 3].as_borrowed();
52+
let err = inner_bound::<_, _, _, Bound<'_, PyArray2<_>>>(&a, &b).unwrap_err();
5253
assert!(err.to_string().contains("not aligned"), "{}", err);
5354
});
5455
}
@@ -58,27 +59,26 @@ fn test_einsum() {
5859
Python::with_gil(|py| {
5960
let a = PyArray1::<i32>::arange_bound(py, 0, 25, 1)
6061
.reshape([5, 5])
61-
.unwrap()
62-
.into_gil_ref();
63-
let b = pyarray![py, 0, 1, 2, 3, 4];
64-
let c = pyarray![py, [0, 1, 2], [3, 4, 5]];
62+
.unwrap();
63+
let b = pyarray![py, 0, 1, 2, 3, 4].as_borrowed();
64+
let c = pyarray![py, [0, 1, 2], [3, 4, 5]].as_borrowed();
6565

66-
let d: &PyArray0<_> = einsum!("ii", a).unwrap();
66+
let d: Bound<'_, PyArray0<_>> = einsum_bound!("ii", a).unwrap();
6767
assert_eq!(d.item(), 60);
6868

69-
let d: i32 = einsum!("ii", a).unwrap();
69+
let d: i32 = einsum_bound!("ii", a).unwrap();
7070
assert_eq!(d, 60);
7171

72-
let d: &PyArray1<_> = einsum!("ii->i", a).unwrap();
72+
let d: Bound<'_, PyArray1<_>> = einsum_bound!("ii->i", a).unwrap();
7373
assert_eq!(d.readonly().as_array(), array![0, 6, 12, 18, 24]);
7474

75-
let d: &PyArray1<_> = einsum!("ij->i", a).unwrap();
75+
let d: Bound<'_, PyArray1<_>> = einsum_bound!("ij->i", a).unwrap();
7676
assert_eq!(d.readonly().as_array(), array![10, 35, 60, 85, 110]);
7777

78-
let d: &PyArray2<_> = einsum!("ji", c).unwrap();
78+
let d: Bound<'_, PyArray2<_>> = einsum_bound!("ji", c).unwrap();
7979
assert_eq!(d.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);
8080

81-
let d: &PyArray1<_> = einsum!("ij,j", a, b).unwrap();
81+
let d: Bound<'_, PyArray1<_>> = einsum_bound!("ij,j", a, b).unwrap();
8282
assert_eq!(d.readonly().as_array(), array![30, 80, 130, 180, 230]);
8383
});
8484
}

0 commit comments

Comments
 (0)