Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimal support for BFloat16 dtype. #381

Merged
merged 2 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
shell: python
- name: Test
run: |
pip install numpy
pip install numpy ml_dtypes
cargo test --all-features
# Not on PyPy, because no embedding API
if: ${{ !startsWith(matrix.python-version, 'pypy') }}
Expand Down Expand Up @@ -101,7 +101,7 @@ jobs:
continue-on-error: true
- uses: taiki-e/install-action@valgrind
- run: |
pip install numpy
pip install numpy ml_dtypes
cargo test --all-features --release
env:
CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: valgrind --leak-check=no --error-exitcode=1
Expand All @@ -115,7 +115,7 @@ jobs:
- uses: Swatinem/rust-cache@v2
continue-on-error: true
- run: |
pip install numpy
pip install numpy ml_dtypes
cargo install --locked cargo-careful
cargo careful test --all-features

Expand Down Expand Up @@ -201,7 +201,7 @@ jobs:
python-version: 3.9
architecture: x64
- name: Install numpy
run: pip install numpy
run: pip install numpy ml_dtypes
- uses: Swatinem/rust-cache@v2
continue-on-error: true
- uses: dtolnay/rust-toolchain@stable
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- Unreleased
- Increase MSRV to 1.56 released in October 2021 and available in Debain 12, RHEL 9 and Alpine 3.17 following the same change for PyO3. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
- Add support for ASCII (`PyFixedString<N>`) and Unicode (`PyFixedUnicode<N>`) string arrays, i.e. dtypes `SN` and `UN` where `N` is the number of characters. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
- Add support for the `bfloat16` dtype by extending the optional integration with the `half` crate. Note that the `bfloat16` dtype is not part of NumPy itself so that usage requires third-party packages like Tensorflow. ([#381](https://github.com/PyO3/rust-numpy/pull/381))

- v0.19.0
- Add `PyUntypedArray` as an untyped base type for `PyArray` which can be used to inspect arguments before more targeted downcasts. This is accompanied by some methods like `dtype` and `shape` moving from `PyArray` to `PyUntypedArray`. They are still accessible though, as `PyArray` dereferences to `PyUntypedArray` via the `Deref` trait. ([#369](https://github.com/PyO3/rust-numpy/pull/369))
Expand Down
20 changes: 19 additions & 1 deletion src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::os::raw::{
use std::ptr;

#[cfg(feature = "half")]
use half::f16;
use half::{bf16, f16};
use num_traits::{Bounded, Zero};
use pyo3::{
exceptions::{PyIndexError, PyValueError},
Expand All @@ -15,6 +15,8 @@ use pyo3::{
AsPyPointer, FromPyObject, FromPyPointer, IntoPyPointer, PyAny, PyNativeType, PyObject,
PyResult, PyTypeInfo, Python, ToPyObject,
};
#[cfg(feature = "half")]
use pyo3::{sync::GILOnceCell, IntoPy, Py};

use crate::npyffi::{
NpyTypes, PyArray_Descr, NPY_ALIGNED_STRUCT, NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES,
Expand Down Expand Up @@ -477,6 +479,22 @@ impl_element_scalar!(f64 => NPY_DOUBLE);
#[cfg(feature = "half")]
impl_element_scalar!(f16 => NPY_HALF);

#[cfg(feature = "half")]
unsafe impl Element for bf16 {
const IS_COPY: bool = true;

fn get_dtype(py: Python) -> &PyArrayDescr {
static DTYPE: GILOnceCell<Py<PyArrayDescr>> = GILOnceCell::new();

DTYPE
.get_or_init(py, || {
PyArrayDescr::new(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").into_py(py)
})
.clone()
.into_ref(py)
}
}

impl_element_scalar!(Complex32 => NPY_CFLOAT,
#[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]);
impl_element_scalar!(Complex64 => NPY_CDOUBLE,
Expand Down
47 changes: 44 additions & 3 deletions tests/array.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::mem::size_of;

#[cfg(feature = "half")]
use half::f16;
use half::{bf16, f16};
use ndarray::{array, s, Array1, Dim};
use numpy::{
dtype, get_array_module, npyffi::NPY_ORDER, pyarray, PyArray, PyArray1, PyArray2, PyArrayDescr,
Expand Down Expand Up @@ -527,7 +527,7 @@ fn reshape() {

#[cfg(feature = "half")]
#[test]
fn half_works() {
fn half_f16_works() {
Python::with_gil(|py| {
let np = py.eval("__import__('numpy')", None, None).unwrap();
let locals = [("np", np)].into_py_dict(py);
Expand Down Expand Up @@ -558,7 +558,48 @@ fn half_works() {
py_run!(
py,
array np,
"np.testing.assert_array_almost_equal(array, np.array([[2, 4], [6, 8]], dtype='float16'))"
"assert np.all(array == np.array([[2, 4], [6, 8]], dtype='float16'))"
);
});
}

#[cfg(feature = "half")]
#[test]
fn half_bf16_works() {
Python::with_gil(|py| {
let np = py.eval("__import__('numpy')", None, None).unwrap();
// NumPy itself does not provide a `bfloat16` dtype itself,
// so we import ml_dtypes which does register such a dtype.
let mldt = py.eval("__import__('ml_dtypes')", None, None).unwrap();
let locals = [("np", np), ("mldt", mldt)].into_py_dict(py);

let array = py
.eval(
"np.array([[1, 2], [3, 4]], dtype='bfloat16')",
None,
Some(locals),
)
.unwrap()
.downcast::<PyArray2<bf16>>()
.unwrap();

assert_eq!(
array.readonly().as_array(),
array![
[bf16::from_f32(1.0), bf16::from_f32(2.0)],
[bf16::from_f32(3.0), bf16::from_f32(4.0)]
]
);

array
.readwrite()
.as_array_mut()
.map_inplace(|value| *value *= bf16::from_f32(2.0));

py_run!(
py,
array np,
"assert np.all(array == np.array([[2, 4], [6, 8]], dtype='bfloat16'))"
);
});
}
Expand Down