forked from PyO3/rust-numpy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlib.rs
55 lines (48 loc) · 1.39 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
use numpy::{Complex64, IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn};
use pyo3::prelude::{pymodule, PyModule, PyResult, Python};
#[pymodule]
fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// immutable example
fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD<f64> {
a * &x + &y
}
// mutable example (no return)
fn mult(a: f64, mut x: ArrayViewMutD<'_, f64>) {
x *= a;
}
// complex example
fn conj(x: ArrayViewD<'_, Complex64>) -> ArrayD<Complex64> {
x.map(|c| c.conj())
}
// wrapper of `axpy`
#[pyfn(m)]
#[pyo3(name = "axpy")]
fn axpy_py<'py>(
py: Python<'py>,
a: f64,
x: PyReadonlyArrayDyn<'_, f64>,
y: PyReadonlyArrayDyn<'_, f64>,
) -> &'py PyArrayDyn<f64> {
let x = x.as_array();
let y = y.as_array();
axpy(a, x, y).into_pyarray(py)
}
// wrapper of `mult`
#[pyfn(m)]
#[pyo3(name = "mult")]
fn mult_py(a: f64, x: &PyArrayDyn<f64>) {
let x = unsafe { x.as_array_mut() };
mult(a, x);
}
// wrapper of `conj`
#[pyfn(m)]
#[pyo3(name = "conj")]
fn conj_py<'py>(
py: Python<'py>,
x: PyReadonlyArrayDyn<'_, Complex64>,
) -> &'py PyArrayDyn<Complex64> {
conj(x.as_array()).into_pyarray(py)
}
Ok(())
}