Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

axum: allow body types other than axum::body::Body in Services passed to serve #3205

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

# Unreleased

- **changed:** `serve` has an additional generic argument and can now work with any response body
type, not just `axum::body::Body` ([3205])

# 0.8.2

- **added:** Implement `OptionalFromRequest` for `Json` ([#3142])
Expand Down
80 changes: 61 additions & 19 deletions axum/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::{
convert::Infallible,
error::Error as StdError,
fmt::Debug,
future::{poll_fn, Future, IntoFuture},
io,
Expand All @@ -11,6 +12,7 @@ use std::{

use axum_core::{body::Body, extract::Request, response::Response};
use futures_util::{pin_mut, FutureExt};
use http_body::Body as HttpBody;
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
#[cfg(any(feature = "http1", feature = "http2"))]
Expand Down Expand Up @@ -94,12 +96,15 @@ pub use self::listener::{Listener, ListenerExt, TapIo};
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
pub fn serve<L, M, S, B>(listener: L, make_service: M) -> Serve<L, M, S, B>
where
L: Listener,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
Serve {
listener,
Expand All @@ -111,14 +116,14 @@ where
/// Future returned by [`serve`].
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct Serve<L, M, S> {
pub struct Serve<L, M, S, B> {
listener: L,
make_service: M,
_marker: PhantomData<S>,
_marker: PhantomData<(S, B)>,
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S> Serve<L, M, S>
impl<L, M, S, B> Serve<L, M, S, B>
where
L: Listener,
{
Expand Down Expand Up @@ -148,7 +153,7 @@ where
///
/// Similarly to [`serve`], although this future resolves to `io::Result<()>`, it will never
/// error. It returns `Ok(())` only after the `signal` future completes.
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F, B>
where
F: Future<Output = ()> + Send + 'static,
{
Expand All @@ -167,7 +172,7 @@ where
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S> Debug for Serve<L, M, S>
impl<L, M, S, B> Debug for Serve<L, M, S, B>
where
L: Debug + 'static,
M: Debug,
Expand All @@ -188,14 +193,17 @@ where
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S> IntoFuture for Serve<L, M, S>
impl<L, M, S, B> IntoFuture for Serve<L, M, S, B>
where
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;
Expand All @@ -209,15 +217,15 @@ where
/// Serve future with graceful shutdown enabled.
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct WithGracefulShutdown<L, M, S, F> {
pub struct WithGracefulShutdown<L, M, S, F, B> {
listener: L,
make_service: M,
signal: F,
_marker: PhantomData<S>,
_marker: PhantomData<(S, B)>,
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
impl<L, M, S, F, B> WithGracefulShutdown<L, M, S, F, B>
where
L: Listener,
{
Expand All @@ -228,7 +236,7 @@ where
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
impl<L, M, S, F, B> Debug for WithGracefulShutdown<L, M, S, F, B>
where
L: Debug + 'static,
M: Debug,
Expand All @@ -252,15 +260,18 @@ where
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
impl<L, M, S, F, B> IntoFuture for WithGracefulShutdown<L, M, S, F, B>
where
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;
Expand All @@ -274,15 +285,18 @@ where
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
impl<L, M, S, F, B> WithGracefulShutdown<L, M, S, F, B>
where
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S: Service<Request, Response = Response<B>, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
async fn run(self) {
let Self {
Expand Down Expand Up @@ -439,14 +453,15 @@ mod tests {
};

use axum_core::{body::Body, extract::Request};
use http::StatusCode;
use http::{Response, StatusCode};
use hyper_util::rt::TokioIo;
#[cfg(unix)]
use tokio::net::UnixListener;
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::TcpListener,
};
use tower::ServiceBuilder;

#[cfg(unix)]
use super::IncomingStream;
Expand All @@ -458,7 +473,7 @@ mod tests {
handler::{Handler, HandlerWithoutStateExt},
routing::get,
serve::ListenerExt,
Router,
Router, ServiceExt,
};

#[allow(dead_code, unused_must_use)]
Expand Down Expand Up @@ -686,4 +701,31 @@ mod tests {
let body = String::from_utf8(body.to_vec()).unwrap();
assert_eq!(body, "Hello, World!");
}

#[crate::test]
async fn serving_with_custom_body_type() {
struct CustomBody;
impl http_body::Body for CustomBody {
type Data = bytes::Bytes;
type Error = std::convert::Infallible;
fn poll_frame(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>>
{
#![allow(clippy::unreachable)] // The implementation is not used, we just need to provide one.
unreachable!();
}
}

let app = ServiceBuilder::new()
.layer_fn(|_| tower::service_fn(|_| std::future::ready(Ok(Response::new(CustomBody)))))
.service(Router::<()>::new().route("/hello", get(|| async {})));
let addr = "0.0.0.0:0";

_ = serve(
TcpListener::bind(addr).await.unwrap(),
app.into_make_service(),
);
}
}