Skip to content

Commit a5e3bc8

Browse files
committed
axum: allow body types other than axum::body::Body in Services passed to serve
1 parent 0e6e96f commit a5e3bc8

File tree

1 file changed

+61
-19
lines changed

1 file changed

+61
-19
lines changed

axum/src/serve/mod.rs

+61-19
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use std::{
44
convert::Infallible,
5+
error::Error as StdError,
56
fmt::Debug,
67
future::{poll_fn, Future, IntoFuture},
78
io,
@@ -11,6 +12,7 @@ use std::{
1112

1213
use axum_core::{body::Body, extract::Request, response::Response};
1314
use futures_util::{pin_mut, FutureExt};
15+
use http_body::Body as HttpBody;
1416
use hyper::body::Incoming;
1517
use hyper_util::rt::{TokioExecutor, TokioIo};
1618
#[cfg(any(feature = "http1", feature = "http2"))]
@@ -94,12 +96,15 @@ pub use self::listener::{Listener, ListenerExt, TapIo};
9496
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
9597
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
9698
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
97-
pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
99+
pub fn serve<L, M, S, B>(listener: L, make_service: M) -> Serve<L, M, S, B>
98100
where
99101
L: Listener,
100102
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
101-
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
103+
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
102104
S::Future: Send,
105+
B: HttpBody + Send + 'static,
106+
B::Data: Send,
107+
B::Error: Into<Box<dyn StdError + Send + Sync>>,
103108
{
104109
Serve {
105110
listener,
@@ -111,14 +116,14 @@ where
111116
/// Future returned by [`serve`].
112117
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
113118
#[must_use = "futures must be awaited or polled"]
114-
pub struct Serve<L, M, S> {
119+
pub struct Serve<L, M, S, B> {
115120
listener: L,
116121
make_service: M,
117-
_marker: PhantomData<S>,
122+
_marker: PhantomData<(S, B)>,
118123
}
119124

120125
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
121-
impl<L, M, S> Serve<L, M, S>
126+
impl<L, M, S, B> Serve<L, M, S, B>
122127
where
123128
L: Listener,
124129
{
@@ -148,7 +153,7 @@ where
148153
///
149154
/// Similarly to [`serve`], although this future resolves to `io::Result<()>`, it will never
150155
/// error. It returns `Ok(())` only after the `signal` future completes.
151-
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
156+
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F, B>
152157
where
153158
F: Future<Output = ()> + Send + 'static,
154159
{
@@ -167,7 +172,7 @@ where
167172
}
168173

169174
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
170-
impl<L, M, S> Debug for Serve<L, M, S>
175+
impl<L, M, S, B> Debug for Serve<L, M, S, B>
171176
where
172177
L: Debug + 'static,
173178
M: Debug,
@@ -188,14 +193,17 @@ where
188193
}
189194

190195
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
191-
impl<L, M, S> IntoFuture for Serve<L, M, S>
196+
impl<L, M, S, B> IntoFuture for Serve<L, M, S, B>
192197
where
193198
L: Listener,
194199
L::Addr: Debug,
195200
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
196201
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
197-
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
202+
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
198203
S::Future: Send,
204+
B: HttpBody + Send + 'static,
205+
B::Data: Send,
206+
B::Error: Into<Box<dyn StdError + Send + Sync>>,
199207
{
200208
type Output = io::Result<()>;
201209
type IntoFuture = private::ServeFuture;
@@ -209,15 +217,15 @@ where
209217
/// Serve future with graceful shutdown enabled.
210218
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
211219
#[must_use = "futures must be awaited or polled"]
212-
pub struct WithGracefulShutdown<L, M, S, F> {
220+
pub struct WithGracefulShutdown<L, M, S, F, B> {
213221
listener: L,
214222
make_service: M,
215223
signal: F,
216-
_marker: PhantomData<S>,
224+
_marker: PhantomData<(S, B)>,
217225
}
218226

219227
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
220-
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
228+
impl<L, M, S, F, B> WithGracefulShutdown<L, M, S, F, B>
221229
where
222230
L: Listener,
223231
{
@@ -228,7 +236,7 @@ where
228236
}
229237

230238
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
231-
impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
239+
impl<L, M, S, F, B> Debug for WithGracefulShutdown<L, M, S, F, B>
232240
where
233241
L: Debug + 'static,
234242
M: Debug,
@@ -252,15 +260,18 @@ where
252260
}
253261

254262
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
255-
impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
263+
impl<L, M, S, F, B> IntoFuture for WithGracefulShutdown<L, M, S, F, B>
256264
where
257265
L: Listener,
258266
L::Addr: Debug,
259267
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
260268
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
261-
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
269+
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
262270
S::Future: Send,
263271
F: Future<Output = ()> + Send + 'static,
272+
B: HttpBody + Send + 'static,
273+
B::Data: Send,
274+
B::Error: Into<Box<dyn StdError + Send + Sync>>,
264275
{
265276
type Output = io::Result<()>;
266277
type IntoFuture = private::ServeFuture;
@@ -274,15 +285,18 @@ where
274285
}
275286

276287
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
277-
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
288+
impl<L, M, S, F, B> WithGracefulShutdown<L, M, S, F, B>
278289
where
279290
L: Listener,
280291
L::Addr: Debug,
281292
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
282293
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
283-
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
294+
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
284295
S::Future: Send,
285296
F: Future<Output = ()> + Send + 'static,
297+
B: HttpBody + Send + 'static,
298+
B::Data: Send,
299+
B::Error: Into<Box<dyn StdError + Send + Sync>>,
286300
{
287301
async fn run(self) {
288302
let Self {
@@ -439,14 +453,15 @@ mod tests {
439453
};
440454

441455
use axum_core::{body::Body, extract::Request};
442-
use http::StatusCode;
456+
use http::{Response, StatusCode};
443457
use hyper_util::rt::TokioIo;
444458
#[cfg(unix)]
445459
use tokio::net::UnixListener;
446460
use tokio::{
447461
io::{self, AsyncRead, AsyncWrite},
448462
net::TcpListener,
449463
};
464+
use tower::ServiceBuilder;
450465

451466
#[cfg(unix)]
452467
use super::IncomingStream;
@@ -458,7 +473,7 @@ mod tests {
458473
handler::{Handler, HandlerWithoutStateExt},
459474
routing::get,
460475
serve::ListenerExt,
461-
Router,
476+
Router, ServiceExt,
462477
};
463478

464479
#[allow(dead_code, unused_must_use)]
@@ -686,4 +701,31 @@ mod tests {
686701
let body = String::from_utf8(body.to_vec()).unwrap();
687702
assert_eq!(body, "Hello, World!");
688703
}
704+
705+
#[crate::test]
706+
async fn serving_with_custom_body_type() {
707+
struct CustomBody;
708+
impl http_body::Body for CustomBody {
709+
type Data = bytes::Bytes;
710+
type Error = std::convert::Infallible;
711+
fn poll_frame(
712+
self: std::pin::Pin<&mut Self>,
713+
_cx: &mut std::task::Context<'_>,
714+
) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>>
715+
{
716+
#![allow(clippy::unreachable)] // The implementation is not used, we just need to provide one.
717+
unreachable!();
718+
}
719+
}
720+
721+
let app = ServiceBuilder::new()
722+
.layer_fn(|_| tower::service_fn(|_| std::future::ready(Ok(Response::new(CustomBody)))))
723+
.service(Router::<()>::new());
724+
let addr = "0.0.0.0:0";
725+
726+
_ = serve(
727+
TcpListener::bind(addr).await.unwrap(),
728+
app.into_make_service(),
729+
);
730+
}
689731
}

0 commit comments

Comments
 (0)