|
1 | 1 | use std::mem::size_of;
|
2 | 2 |
|
3 | 3 | #[cfg(feature = "half")]
|
4 |
| -use half::f16; |
| 4 | +use half::{bf16, f16}; |
5 | 5 | use ndarray::{array, s, Array1, Dim};
|
6 | 6 | use numpy::{
|
7 | 7 | dtype, get_array_module, npyffi::NPY_ORDER, pyarray, PyArray, PyArray1, PyArray2, PyArrayDescr,
|
@@ -527,7 +527,7 @@ fn reshape() {
|
527 | 527 |
|
528 | 528 | #[cfg(feature = "half")]
|
529 | 529 | #[test]
|
530 |
| -fn half_works() { |
| 530 | +fn half_f16_works() { |
531 | 531 | Python::with_gil(|py| {
|
532 | 532 | let np = py.eval("__import__('numpy')", None, None).unwrap();
|
533 | 533 | let locals = [("np", np)].into_py_dict(py);
|
@@ -558,7 +558,48 @@ fn half_works() {
|
558 | 558 | py_run!(
|
559 | 559 | py,
|
560 | 560 | array np,
|
561 |
| - "np.testing.assert_array_almost_equal(array, np.array([[2, 4], [6, 8]], dtype='float16'))" |
| 561 | + "assert np.all(array == np.array([[2, 4], [6, 8]], dtype='float16'))" |
| 562 | + ); |
| 563 | + }); |
| 564 | +} |
| 565 | + |
| 566 | +#[cfg(feature = "half")] |
| 567 | +#[test] |
| 568 | +fn half_bf16_works() { |
| 569 | + Python::with_gil(|py| { |
| 570 | + let np = py.eval("__import__('numpy')", None, None).unwrap(); |
| 571 | + // NumPy itself does not provide a `bfloat16` dtype itself, |
| 572 | + // so we import ml_dtypes which does register such a dtype. |
| 573 | + let mldt = py.eval("__import__('ml_dtypes')", None, None).unwrap(); |
| 574 | + let locals = [("np", np), ("mldt", mldt)].into_py_dict(py); |
| 575 | + |
| 576 | + let array = py |
| 577 | + .eval( |
| 578 | + "np.array([[1, 2], [3, 4]], dtype='bfloat16')", |
| 579 | + None, |
| 580 | + Some(locals), |
| 581 | + ) |
| 582 | + .unwrap() |
| 583 | + .downcast::<PyArray2<bf16>>() |
| 584 | + .unwrap(); |
| 585 | + |
| 586 | + assert_eq!( |
| 587 | + array.readonly().as_array(), |
| 588 | + array![ |
| 589 | + [bf16::from_f32(1.0), bf16::from_f32(2.0)], |
| 590 | + [bf16::from_f32(3.0), bf16::from_f32(4.0)] |
| 591 | + ] |
| 592 | + ); |
| 593 | + |
| 594 | + array |
| 595 | + .readwrite() |
| 596 | + .as_array_mut() |
| 597 | + .map_inplace(|value| *value *= bf16::from_f32(2.0)); |
| 598 | + |
| 599 | + py_run!( |
| 600 | + py, |
| 601 | + array np, |
| 602 | + "assert np.all(array == np.array([[2, 4], [6, 8]], dtype='bfloat16'))" |
562 | 603 | );
|
563 | 604 | });
|
564 | 605 | }
|
|
0 commit comments