Skip to content

Commit 327ab10

Browse files
jakelishmandavidhewitt
authored andcommitted
Add From<PyReadwriteArray> for PyReadonlyArray
The ordering of drops is important for the dynamic checker, so this is a convenience for conversion.
1 parent a4b922f commit 327ab10

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

src/borrow/mod.rs

+15-2
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ use crate::array::{PyArray, PyArrayMethods};
175175
use crate::convert::NpyIndex;
176176
use crate::dtype::Element;
177177
use crate::error::{BorrowError, NotContiguousError};
178-
use crate::untyped_array::PyUntypedArrayMethods;
179178
use crate::npyffi::flags;
179+
use crate::untyped_array::PyUntypedArrayMethods;
180180

181181
use shared::{acquire, acquire_mut, release, release_mut};
182182

@@ -454,6 +454,18 @@ where
454454
unsafe { &*(self as *const Self as *const Self::Target) }
455455
}
456456
}
457+
impl<'py, T, D> From<PyReadwriteArray<'py, T, D>> for PyReadonlyArray<'py, T, D>
458+
where
459+
T: Element,
460+
D: Dimension,
461+
{
462+
fn from(value: PyReadwriteArray<'py, T, D>) -> Self {
463+
let array = value.array.clone();
464+
::std::mem::drop(value);
465+
Self::try_new(array)
466+
.expect("releasing an exclusive reference should immediately permit a shared reference")
467+
}
468+
}
457469

458470
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadwriteArray<'py, T, D> {
459471
fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
@@ -504,12 +516,13 @@ where
504516
///
505517
/// [writeable]: https://numpy.org/doc/stable/reference/c-api/array.html#c.NPY_ARRAY_WRITEABLE
506518
/// [owndata]: https://numpy.org/doc/stable/reference/c-api/array.html#c.NPY_ARRAY_OWNDATA
507-
pub fn make_nonwriteable(self) {
519+
pub fn make_nonwriteable(self) -> PyReadonlyArray<'py, T, D> {
508520
// SAFETY: consuming the only extant mutable reference guarantees we cannot invalidate an
509521
// existing reference, nor allow the caller to keep hold of one.
510522
unsafe {
511523
(*self.as_array_ptr()).flags &= !flags::NPY_ARRAY_WRITEABLE;
512524
}
525+
self.into()
513526
}
514527
}
515528

0 commit comments

Comments
 (0)