1
- use crate :: npyffi:: { NPY_CASTING , NPY_ORDER } ;
2
- use crate :: { Element , PyArray , PY_ARRAY_API } ;
1
+ use std:: borrow:: Cow ;
2
+ use std:: ffi:: { CStr , CString } ;
3
+ use std:: ptr:: null_mut;
4
+
3
5
use ndarray:: { Dimension , IxDyn } ;
4
- use pyo3:: { AsPyPointer , FromPyPointer , PyAny , PyNativeType , PyResult } ;
5
- use std:: ffi:: CStr ;
6
+ use pyo3:: { AsPyPointer , FromPyObject , FromPyPointer , PyAny , PyNativeType , PyResult } ;
7
+
8
+ use crate :: array:: PyArray ;
9
+ use crate :: dtype:: Element ;
10
+ use crate :: npyffi:: { array:: PY_ARRAY_API , NPY_CASTING , NPY_ORDER } ;
11
+
12
+ /// Return value of a function that can yield either an array or a scalar.
13
+ pub trait ArrayOrScalar < ' py , T > : FromPyObject < ' py > { }
14
+
15
+ impl < ' py , T , D > ArrayOrScalar < ' py , T > for & ' py PyArray < T , D >
16
+ where
17
+ T : Element ,
18
+ D : Dimension ,
19
+ {
20
+ }
21
+
22
+ impl < ' py , T > ArrayOrScalar < ' py , T > for T where T : Element + FromPyObject < ' py > { }
6
23
7
24
/// Return the inner product of two arrays.
8
25
///
9
- /// # Example
26
+ /// [NumPy's documentation][inner] has the details.
27
+ ///
28
+ /// # Examples
29
+ ///
30
+ /// Note that this function can either return a scalar...
31
+ ///
10
32
/// ```
11
- /// pyo3::Python::with_gil(|py| {
12
- /// let array = numpy::pyarray![py, 1, 2, 3];
13
- /// let inner: &numpy::PyArray0::<_> = numpy::inner(array, array).unwrap();
14
- /// assert_eq!(inner.item(), 14);
33
+ /// use pyo3::Python;
34
+ /// use numpy::{inner, pyarray, PyArray0};
35
+ ///
36
+ /// Python::with_gil(|py| {
37
+ /// let vector = pyarray![py, 1.0, 2.0, 3.0];
38
+ /// let result: f64 = inner(vector, vector).unwrap();
39
+ /// assert_eq!(result, 14.0);
15
40
/// });
16
41
/// ```
17
- pub fn inner < ' py , T , DIN1 , DIN2 , DOUT > (
42
+ ///
43
+ /// ...or an array depending on its arguments.
44
+ ///
45
+ /// ```
46
+ /// use pyo3::Python;
47
+ /// use numpy::{inner, pyarray, PyArray0};
48
+ ///
49
+ /// Python::with_gil(|py| {
50
+ /// let vector = pyarray![py, 1, 2, 3];
51
+ /// let result: &PyArray0<_> = inner(vector, vector).unwrap();
52
+ /// assert_eq!(result.item(), 14);
53
+ /// });
54
+ /// ```
55
+ ///
56
+ /// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
57
+ pub fn inner < ' py , T , DIN1 , DIN2 , OUT > (
18
58
array1 : & ' py PyArray < T , DIN1 > ,
19
59
array2 : & ' py PyArray < T , DIN2 > ,
20
- ) -> PyResult < & ' py PyArray < T , DOUT > >
60
+ ) -> PyResult < OUT >
21
61
where
62
+ T : Element ,
22
63
DIN1 : Dimension ,
23
64
DIN2 : Dimension ,
24
- DOUT : Dimension ,
25
- T : Element ,
65
+ OUT : ArrayOrScalar < ' py , T > ,
26
66
{
27
67
let py = array1. py ( ) ;
28
68
let obj = unsafe {
@@ -34,27 +74,53 @@ where
34
74
35
75
/// Return the dot product of two arrays.
36
76
///
37
- /// # Example
77
+ /// [NumPy's documentation][dot] has the details.
78
+ ///
79
+ /// # Examples
80
+ ///
81
+ /// Note that this function can either return an array...
82
+ ///
38
83
/// ```
39
- /// pyo3::Python::with_gil(|py| {
40
- /// let a = numpy::pyarray![py, [1, 0], [0, 1]];
41
- /// let b = numpy::pyarray![py, [4, 1], [2, 2]];
42
- /// let dot: &numpy::PyArray2::<_> = numpy::dot(a, b).unwrap();
84
+ /// use pyo3::Python;
85
+ /// use ndarray::array;
86
+ /// use numpy::{dot, pyarray, PyArray2};
87
+ ///
88
+ /// Python::with_gil(|py| {
89
+ /// let matrix = pyarray![py, [1, 0], [0, 1]];
90
+ /// let another_matrix = pyarray![py, [4, 1], [2, 2]];
91
+ ///
92
+ /// let result: &PyArray2<_> = numpy::dot(matrix, another_matrix).unwrap();
93
+ ///
43
94
/// assert_eq!(
44
- /// dot .readonly().as_array(),
45
- /// ndarray:: array![[4, 1], [2, 2]]
95
+ /// result .readonly().as_array(),
96
+ /// array![[4, 1], [2, 2]]
46
97
/// );
47
98
/// });
48
99
/// ```
49
- pub fn dot < ' py , T , DIN1 , DIN2 , DOUT > (
100
+ ///
101
+ /// ...or a scalar depending on its arguments.
102
+ ///
103
+ /// ```
104
+ /// use pyo3::Python;
105
+ /// use numpy::{dot, pyarray, PyArray0};
106
+ ///
107
+ /// Python::with_gil(|py| {
108
+ /// let vector = pyarray![py, 1.0, 2.0, 3.0];
109
+ /// let result: f64 = dot(vector, vector).unwrap();
110
+ /// assert_eq!(result, 14.0);
111
+ /// });
112
+ /// ```
113
+ ///
114
+ /// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
115
+ pub fn dot < ' py , T , DIN1 , DIN2 , OUT > (
50
116
array1 : & ' py PyArray < T , DIN1 > ,
51
117
array2 : & ' py PyArray < T , DIN2 > ,
52
- ) -> PyResult < & ' py PyArray < T , DOUT > >
118
+ ) -> PyResult < OUT >
53
119
where
120
+ T : Element ,
54
121
DIN1 : Dimension ,
55
122
DIN2 : Dimension ,
56
- DOUT : Dimension ,
57
- T : Element ,
123
+ OUT : ArrayOrScalar < ' py , T > ,
58
124
{
59
125
let py = array1. py ( ) ;
60
126
let obj = unsafe {
@@ -66,31 +132,28 @@ where
66
132
67
133
/// Return the Einstein summation convention of given tensors.
68
134
///
69
- /// We also provide the [einsum macro](./macro.einsum.html).
70
- pub fn einsum_impl < ' py , T , DOUT > (
71
- subscripts : & str ,
72
- arrays : & [ & ' py PyArray < T , IxDyn > ] ,
73
- ) -> PyResult < & ' py PyArray < T , DOUT > >
135
+ /// This is usually invoked via the the [`einsum!`] macro.
136
+ pub fn einsum < ' py , T , OUT > ( subscripts : & str , arrays : & [ & ' py PyArray < T , IxDyn > ] ) -> PyResult < OUT >
74
137
where
75
- DOUT : Dimension ,
76
138
T : Element ,
139
+ OUT : ArrayOrScalar < ' py , T > ,
77
140
{
78
- let subscripts: std:: borrow:: Cow < CStr > = match CStr :: from_bytes_with_nul ( subscripts. as_bytes ( ) )
79
- {
80
- Ok ( subscripts) => subscripts. into ( ) ,
81
- Err ( _) => std:: ffi:: CString :: new ( subscripts) . unwrap ( ) . into ( ) ,
141
+ let subscripts = match CStr :: from_bytes_with_nul ( subscripts. as_bytes ( ) ) {
142
+ Ok ( subscripts) => Cow :: Borrowed ( subscripts) ,
143
+ Err ( _) => Cow :: Owned ( CString :: new ( subscripts) . unwrap ( ) ) ,
82
144
} ;
145
+
83
146
let py = arrays[ 0 ] . py ( ) ;
84
147
let obj = unsafe {
85
148
let result = PY_ARRAY_API . PyArray_EinsteinSum (
86
149
py,
87
150
subscripts. as_ptr ( ) as _ ,
88
151
arrays. len ( ) as _ ,
89
152
arrays. as_ptr ( ) as _ ,
90
- std :: ptr :: null_mut ( ) ,
153
+ null_mut ( ) ,
91
154
NPY_ORDER :: NPY_KEEPORDER ,
92
155
NPY_CASTING :: NPY_NO_CASTING ,
93
- std :: ptr :: null_mut ( ) ,
156
+ null_mut ( ) ,
94
157
) ;
95
158
PyAny :: from_owned_ptr_or_err ( py, result) ?
96
159
} ;
@@ -99,25 +162,34 @@ where
99
162
100
163
/// Return the Einstein summation convention of given tensors.
101
164
///
102
- /// For more about the Einstein summation convention, you may reffer to
103
- /// [the numpy document](https://numpy.org/doc/stable/reference/generated/numpy. einsum.html) .
165
+ /// For more about the Einstein summation convention, please refer to
166
+ /// [NumPy's documentation][ einsum] .
104
167
///
105
168
/// # Example
169
+ ///
106
170
/// ```
107
- /// pyo3::Python::with_gil(|py| {
108
- /// let a = numpy::PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
109
- /// let b = numpy::pyarray![py, [20, 30], [40, 50], [60, 70]];
110
- /// let einsum = numpy::einsum!("ijk,ji->ik", a, b).unwrap();
171
+ /// use pyo3::Python;
172
+ /// use ndarray::array;
173
+ /// use numpy::{einsum, pyarray, PyArray, PyArray2};
174
+ ///
175
+ /// Python::with_gil(|py| {
176
+ /// let tensor = PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
177
+ /// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
178
+ ///
179
+ /// let result: &PyArray2<_> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
180
+ ///
111
181
/// assert_eq!(
112
- /// einsum .readonly().as_array(),
113
- /// ndarray:: array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
182
+ /// result .readonly().as_array(),
183
+ /// array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
114
184
/// );
115
185
/// });
116
186
/// ```
187
+ ///
188
+ /// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
117
189
#[ macro_export]
118
190
macro_rules! einsum {
119
- ( $subscripts: literal $( , $array: ident) + $( , ) * ) => { {
191
+ ( $subscripts: literal $( , $array: ident) + $( , ) * ) => { {
120
192
let arrays = [ $( $array. to_dyn( ) , ) +] ;
121
- unsafe { $crate:: einsum_impl ( concat!( $subscripts, "\0 " ) , & arrays) }
193
+ $crate:: einsum ( concat!( $subscripts, "\0 " ) , & arrays)
122
194
} } ;
123
195
}
0 commit comments