Skip to content

Commit 08510a3

Browse files
committed
Test support for bfloat16 using ml_dtypes.
1 parent 2abafac commit 08510a3

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

.github/workflows/ci.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ jobs:
6767
shell: python
6868
- name: Test
6969
run: |
70-
pip install numpy
70+
pip install numpy ml_dtypes
7171
cargo test --all-features
7272
# Not on PyPy, because no embedding API
7373
if: ${{ !startsWith(matrix.python-version, 'pypy') }}
@@ -101,7 +101,7 @@ jobs:
101101
continue-on-error: true
102102
- uses: taiki-e/install-action@valgrind
103103
- run: |
104-
pip install numpy
104+
pip install numpy ml_dtypes
105105
cargo test --all-features --release
106106
env:
107107
CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: valgrind --leak-check=no --error-exitcode=1
@@ -115,7 +115,7 @@ jobs:
115115
- uses: Swatinem/rust-cache@v2
116116
continue-on-error: true
117117
- run: |
118-
pip install numpy
118+
pip install numpy ml_dtypes
119119
cargo install --locked cargo-careful
120120
cargo careful test --all-features
121121
@@ -201,7 +201,7 @@ jobs:
201201
python-version: 3.9
202202
architecture: x64
203203
- name: Install numpy
204-
run: pip install numpy
204+
run: pip install numpy ml_dtypes
205205
- uses: Swatinem/rust-cache@v2
206206
continue-on-error: true
207207
- uses: dtolnay/rust-toolchain@stable

tests/array.rs

+44-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::mem::size_of;
22

33
#[cfg(feature = "half")]
4-
use half::f16;
4+
use half::{bf16, f16};
55
use ndarray::{array, s, Array1, Dim};
66
use numpy::{
77
dtype, get_array_module, npyffi::NPY_ORDER, pyarray, PyArray, PyArray1, PyArray2, PyArrayDescr,
@@ -527,7 +527,7 @@ fn reshape() {
527527

528528
#[cfg(feature = "half")]
529529
#[test]
530-
fn half_works() {
530+
fn half_f16_works() {
531531
Python::with_gil(|py| {
532532
let np = py.eval("__import__('numpy')", None, None).unwrap();
533533
let locals = [("np", np)].into_py_dict(py);
@@ -558,7 +558,48 @@ fn half_works() {
558558
py_run!(
559559
py,
560560
array np,
561-
"np.testing.assert_array_almost_equal(array, np.array([[2, 4], [6, 8]], dtype='float16'))"
561+
"assert np.all(array == np.array([[2, 4], [6, 8]], dtype='float16'))"
562+
);
563+
});
564+
}
565+
566+
#[cfg(feature = "half")]
567+
#[test]
568+
fn half_bf16_works() {
569+
Python::with_gil(|py| {
570+
let np = py.eval("__import__('numpy')", None, None).unwrap();
571+
// NumPy itself does not provide a `bfloat16` dtype itself,
572+
// so we import ml_dtypes which does register such a dtype.
573+
let mldt = py.eval("__import__('ml_dtypes')", None, None).unwrap();
574+
let locals = [("np", np), ("mldt", mldt)].into_py_dict(py);
575+
576+
let array = py
577+
.eval(
578+
"np.array([[1, 2], [3, 4]], dtype='bfloat16')",
579+
None,
580+
Some(locals),
581+
)
582+
.unwrap()
583+
.downcast::<PyArray2<bf16>>()
584+
.unwrap();
585+
586+
assert_eq!(
587+
array.readonly().as_array(),
588+
array![
589+
[bf16::from_f32(1.0), bf16::from_f32(2.0)],
590+
[bf16::from_f32(3.0), bf16::from_f32(4.0)]
591+
]
592+
);
593+
594+
array
595+
.readwrite()
596+
.as_array_mut()
597+
.map_inplace(|value| *value *= bf16::from_f32(2.0));
598+
599+
py_run!(
600+
py,
601+
array np,
602+
"assert np.all(array == np.array([[2, 4], [6, 8]], dtype='bfloat16'))"
562603
);
563604
});
564605
}

0 commit comments

Comments
 (0)