1
- use crate :: npyffi:: { NPY_CASTING , NPY_ORDER } ;
2
- use crate :: { Element , PyArray , PY_ARRAY_API } ;
1
+ use std:: borrow:: Cow ;
2
+ use std:: ffi:: { CStr , CString } ;
3
+ use std:: ptr:: null_mut;
4
+
3
5
use ndarray:: { Dimension , IxDyn } ;
4
- use pyo3:: { AsPyPointer , FromPyPointer , PyAny , PyNativeType , PyResult } ;
5
- use std:: ffi:: CStr ;
6
+ use pyo3:: { AsPyPointer , FromPyObject , FromPyPointer , PyAny , PyNativeType , PyResult } ;
7
+
8
+ use crate :: array:: PyArray ;
9
+ use crate :: dtype:: Element ;
10
+ use crate :: npyffi:: { array:: PY_ARRAY_API , NPY_CASTING , NPY_ORDER } ;
11
+
12
+ /// Return value of a function that can yield either an array or a scalar.
13
+ pub trait ArrayOrScalar < ' py , T > : FromPyObject < ' py > { }
14
+
15
+ impl < ' py , T , D > ArrayOrScalar < ' py , T > for & ' py PyArray < T , D >
16
+ where
17
+ T : Element ,
18
+ D : Dimension ,
19
+ {
20
+ }
21
+
22
+ impl < ' py , T > ArrayOrScalar < ' py , T > for T where T : Element + FromPyObject < ' py > { }
6
23
7
24
/// Return the inner product of two arrays.
8
25
///
9
- /// # Example
26
+ /// [NumPy's documentation][inner] has the details.
27
+ ///
28
+ /// # Examples
29
+ ///
30
+ /// Note that this function can either return a scalar...
31
+ ///
10
32
/// ```
11
- /// pyo3::Python::with_gil(|py| {
12
- /// let array = numpy::pyarray![py, 1, 2, 3];
13
- /// let inner: &numpy::PyArray0::<_> = numpy::inner(array, array).unwrap();
14
- /// assert_eq!(inner.item(), 14);
33
+ /// use pyo3::Python;
34
+ /// use numpy::{inner, pyarray, PyArray0};
35
+ ///
36
+ /// Python::with_gil(|py| {
37
+ /// let vector = pyarray![py, 1.0, 2.0, 3.0];
38
+ /// let result: f64 = inner(vector, vector).unwrap();
39
+ /// assert_eq!(result, 14.0);
40
+ /// });
41
+ /// ```
42
+ ///
43
+ /// ...or an array depending on its arguments.
44
+ ///
45
+ /// ```
46
+ /// use pyo3::Python;
47
+ /// use numpy::{inner, pyarray, PyArray0};
48
+ ///
49
+ /// Python::with_gil(|py| {
50
+ /// let vector = pyarray![py, 1, 2, 3];
51
+ /// let result: &PyArray0<_> = inner(vector, vector).unwrap();
52
+ /// assert_eq!(result.item(), 14);
15
53
/// });
16
54
/// ```
17
- pub fn inner < ' py , T , DIN1 , DIN2 , DOUT > (
55
+ ///
56
+ /// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
57
+ pub fn inner < ' py , T , DIN1 , DIN2 , OUT > (
18
58
array1 : & ' py PyArray < T , DIN1 > ,
19
59
array2 : & ' py PyArray < T , DIN2 > ,
20
- ) -> PyResult < & ' py PyArray < T , DOUT > >
60
+ ) -> PyResult < OUT >
21
61
where
62
+ T : Element ,
22
63
DIN1 : Dimension ,
23
64
DIN2 : Dimension ,
24
- DOUT : Dimension ,
25
- T : Element ,
65
+ OUT : ArrayOrScalar < ' py , T > ,
26
66
{
27
67
let py = array1. py ( ) ;
28
68
let obj = unsafe {
@@ -34,27 +74,53 @@ where
34
74
35
75
/// Return the dot product of two arrays.
36
76
///
37
- /// # Example
77
+ /// [NumPy's documentation][dot] has the details.
78
+ ///
79
+ /// # Examples
80
+ ///
81
+ /// Note that this function can either return an array...
82
+ ///
38
83
/// ```
39
- /// pyo3::Python::with_gil(|py| {
40
- /// let a = numpy::pyarray![py, [1, 0], [0, 1]];
41
- /// let b = numpy::pyarray![py, [4, 1], [2, 2]];
42
- /// let dot: &numpy::PyArray2::<_> = numpy::dot(a, b).unwrap();
84
+ /// use pyo3::Python;
85
+ /// use ndarray::array;
86
+ /// use numpy::{dot, pyarray, PyArray2};
87
+ ///
88
+ /// Python::with_gil(|py| {
89
+ /// let matrix = pyarray![py, [1, 0], [0, 1]];
90
+ /// let another_matrix = pyarray![py, [4, 1], [2, 2]];
91
+ ///
92
+ /// let result: &PyArray2<_> = numpy::dot(matrix, another_matrix).unwrap();
93
+ ///
43
94
/// assert_eq!(
44
- /// dot .readonly().as_array(),
45
- /// ndarray:: array![[4, 1], [2, 2]]
95
+ /// result .readonly().as_array(),
96
+ /// array![[4, 1], [2, 2]]
46
97
/// );
47
98
/// });
48
99
/// ```
49
- pub fn dot < ' py , T , DIN1 , DIN2 , DOUT > (
100
+ ///
101
+ /// ...or a scalar depending on its arguments.
102
+ ///
103
+ /// ```
104
+ /// use pyo3::Python;
105
+ /// use numpy::{dot, pyarray, PyArray0};
106
+ ///
107
+ /// Python::with_gil(|py| {
108
+ /// let vector = pyarray![py, 1.0, 2.0, 3.0];
109
+ /// let result: f64 = dot(vector, vector).unwrap();
110
+ /// assert_eq!(result, 14.0);
111
+ /// });
112
+ /// ```
113
+ ///
114
+ /// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
115
+ pub fn dot < ' py , T , DIN1 , DIN2 , OUT > (
50
116
array1 : & ' py PyArray < T , DIN1 > ,
51
117
array2 : & ' py PyArray < T , DIN2 > ,
52
- ) -> PyResult < & ' py PyArray < T , DOUT > >
118
+ ) -> PyResult < OUT >
53
119
where
120
+ T : Element ,
54
121
DIN1 : Dimension ,
55
122
DIN2 : Dimension ,
56
- DOUT : Dimension ,
57
- T : Element ,
123
+ OUT : ArrayOrScalar < ' py , T > ,
58
124
{
59
125
let py = array1. py ( ) ;
60
126
let obj = unsafe {
66
132
67
133
/// Return the Einstein summation convention of given tensors.
68
134
///
69
- /// We also provide the [ einsum macro](./macro.einsum.html) .
135
+ /// This is usually invoked via the the [` einsum!`] macro.
70
136
pub fn einsum_impl < ' py , T , DOUT > (
71
137
subscripts : & str ,
72
138
arrays : & [ & ' py PyArray < T , IxDyn > ] ,
@@ -75,22 +141,22 @@ where
75
141
DOUT : Dimension ,
76
142
T : Element ,
77
143
{
78
- let subscripts: std:: borrow:: Cow < CStr > = match CStr :: from_bytes_with_nul ( subscripts. as_bytes ( ) )
79
- {
80
- Ok ( subscripts) => subscripts. into ( ) ,
81
- Err ( _) => std:: ffi:: CString :: new ( subscripts) . unwrap ( ) . into ( ) ,
144
+ let subscripts = match CStr :: from_bytes_with_nul ( subscripts. as_bytes ( ) ) {
145
+ Ok ( subscripts) => Cow :: Borrowed ( subscripts) ,
146
+ Err ( _) => Cow :: Owned ( CString :: new ( subscripts) . unwrap ( ) ) ,
82
147
} ;
148
+
83
149
let py = arrays[ 0 ] . py ( ) ;
84
150
let obj = unsafe {
85
151
let result = PY_ARRAY_API . PyArray_EinsteinSum (
86
152
py,
87
153
subscripts. as_ptr ( ) as _ ,
88
154
arrays. len ( ) as _ ,
89
155
arrays. as_ptr ( ) as _ ,
90
- std :: ptr :: null_mut ( ) ,
156
+ null_mut ( ) ,
91
157
NPY_ORDER :: NPY_KEEPORDER ,
92
158
NPY_CASTING :: NPY_NO_CASTING ,
93
- std :: ptr :: null_mut ( ) ,
159
+ null_mut ( ) ,
94
160
) ;
95
161
PyAny :: from_owned_ptr_or_err ( py, result) ?
96
162
} ;
@@ -99,25 +165,34 @@ where
99
165
100
166
/// Return the Einstein summation convention of given tensors.
101
167
///
102
- /// For more about the Einstein summation convention, you may reffer to
103
- /// [the numpy document](https://numpy.org/doc/stable/reference/generated/numpy. einsum.html) .
168
+ /// For more about the Einstein summation convention, please refer to
169
+ /// [NumPy's documentation][ einsum] .
104
170
///
105
171
/// # Example
172
+ ///
106
173
/// ```
107
- /// pyo3::Python::with_gil(|py| {
108
- /// let a = numpy::PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
109
- /// let b = numpy::pyarray![py, [20, 30], [40, 50], [60, 70]];
110
- /// let einsum = numpy::einsum!("ijk,ji->ik", a, b).unwrap();
174
+ /// use pyo3::Python;
175
+ /// use ndarray::array;
176
+ /// use numpy::{einsum, pyarray, PyArray};
177
+ ///
178
+ /// Python::with_gil(|py| {
179
+ /// let tensor = PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
180
+ /// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
181
+ ///
182
+ /// let result = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
183
+ ///
111
184
/// assert_eq!(
112
- /// einsum .readonly().as_array(),
113
- /// ndarray:: array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
185
+ /// result .readonly().as_array(),
186
+ /// array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
114
187
/// );
115
188
/// });
116
189
/// ```
190
+ ///
191
+ /// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
117
192
#[ macro_export]
118
193
macro_rules! einsum {
119
- ( $subscripts: literal $( , $array: ident) + $( , ) * ) => { {
194
+ ( $subscripts: literal $( , $array: ident) + $( , ) * ) => { {
120
195
let arrays = [ $( $array. to_dyn( ) , ) +] ;
121
- unsafe { $crate:: einsum_impl( concat!( $subscripts, "\0 " ) , & arrays) }
196
+ $crate:: einsum_impl( concat!( $subscripts, "\0 " ) , & arrays)
122
197
} } ;
123
198
}
0 commit comments