diff --git a/bindings/rust/extended/s2n-tls/src/callbacks.rs b/bindings/rust/extended/s2n-tls/src/callbacks.rs index 5d871c34844..c1d476ec1b8 100644 --- a/bindings/rust/extended/s2n-tls/src/callbacks.rs +++ b/bindings/rust/extended/s2n-tls/src/callbacks.rs @@ -51,15 +51,15 @@ where F: FnOnce(&mut Connection, &mut Context) -> T, { let raw = NonNull::new(conn_ptr).expect("connection should not be null"); - let mut conn = Connection::from_raw(raw); - let mut config = conn.config().expect("config should not be null"); - let context = config.context_mut(); - let r = action(&mut conn, context); // Since this is a callback, it receives a pointer to the connection // but doesn't own that connection or control its lifecycle. // Do not drop / free the connection. - let _ = ManuallyDrop::new(conn); - r + // We must make the connection `ManuallyDrop` before `action`, otherwise a panic + // in `action` would cause the unwind mechanism to drop the connection. + let mut conn = ManuallyDrop::new(Connection::from_raw(raw)); + let mut config = conn.config().expect("config should not be null"); + let context = config.context_mut(); + action(&mut conn, context) } /// A trait for the callback used to verify host name(s) during X509 @@ -100,3 +100,39 @@ pub(crate) unsafe fn verify_host( Err(_) => 0, // If the host name can't be parsed, fail closed. } } + +#[cfg(test)] +mod tests { + use crate::{callbacks::with_context, config::Config, connection::Builder, enums::Mode}; + + // The temporary connection created in `with_context` should never be freed, + // even if customer code panics. + #[test] + fn panic_does_not_free_connection() -> Result<(), crate::error::Error> { + let config = Config::new(); + let mut connection = config.build_connection(Mode::Server)?; + + // 1 connection + 1 self reference + assert_eq!(config.test_get_refcount()?, 2); + + let conn_ptr = connection.as_ptr(); + let unwind_result = std::panic::catch_unwind(|| { + unsafe { + with_context(conn_ptr, |_conn, _context| { + panic!("force unwind"); + }) + }; + }); + + // a panic happened + assert!(unwind_result.is_err()); + + // the connection hasn't been freed yet + assert_eq!(config.test_get_refcount()?, 2); + + drop(connection); + assert_eq!(config.test_get_refcount()?, 1); + + Ok(()) + } +}