Skip to content

Commit de9e02c

Browse files
committed
Add minimal support for BFloat16 dtype.
1 parent 9f10b61 commit de9e02c

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

src/dtype.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::os::raw::{
55
use std::ptr;
66

77
#[cfg(feature = "half")]
8-
use half::f16;
8+
use half::{bf16, f16};
99
use num_traits::{Bounded, Zero};
1010
use pyo3::{
1111
exceptions::{PyIndexError, PyValueError},
@@ -15,6 +15,8 @@ use pyo3::{
1515
AsPyPointer, FromPyObject, FromPyPointer, IntoPyPointer, PyAny, PyNativeType, PyObject,
1616
PyResult, PyTypeInfo, Python, ToPyObject,
1717
};
18+
#[cfg(feature = "half")]
19+
use pyo3::{sync::GILOnceCell, IntoPy, Py};
1820

1921
use crate::npyffi::{
2022
NpyTypes, PyArray_Descr, NPY_ALIGNED_STRUCT, NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES,
@@ -477,6 +479,22 @@ impl_element_scalar!(f64 => NPY_DOUBLE);
477479
#[cfg(feature = "half")]
478480
impl_element_scalar!(f16 => NPY_HALF);
479481

482+
#[cfg(feature = "half")]
483+
unsafe impl Element for bf16 {
484+
const IS_COPY: bool = true;
485+
486+
fn get_dtype(py: Python) -> &PyArrayDescr {
487+
static DTYPE: GILOnceCell<Py<PyArrayDescr>> = GILOnceCell::new();
488+
489+
DTYPE
490+
.get_or_init(py, || {
491+
PyArrayDescr::new(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").into_py(py)
492+
})
493+
.clone()
494+
.into_ref(py)
495+
}
496+
}
497+
480498
impl_element_scalar!(Complex32 => NPY_CFLOAT,
481499
#[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]);
482500
impl_element_scalar!(Complex64 => NPY_CDOUBLE,

0 commit comments

Comments
 (0)