Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge fallbacks with the rest of the router #3158

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions axum/src/docs/routing/route.md
Original file line number Diff line number Diff line change
@@ -36,8 +36,7 @@ documentation for more details.
It is not possible to create segments that only match some types like numbers or
regular expression. You must handle that manually in your handlers.

[`MatchedPath`](crate::extract::MatchedPath) can be used to extract the matched
path rather than the actual path.
[`MatchedPath`] can be used to extract the matched path rather than the actual path.

# Wildcards

87 changes: 53 additions & 34 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@
use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
#[cfg(feature = "matched-path")]
use crate::extract::MatchedPath;
use crate::{
body::{Body, HttpBody},
boxed::BoxedIntoRoute,
@@ -20,7 +22,8 @@ use std::{
sync::Arc,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower::service_fn;
use tower_layer::{layer_fn, Layer};
use tower_service::Service;

pub mod future;
@@ -72,8 +75,7 @@ impl<S> Clone for Router<S> {
}

struct RouterInner<S> {
path_router: PathRouter<S, false>,
fallback_router: PathRouter<S, true>,
path_router: PathRouter<S>,
default_fallback: bool,
catch_all_fallback: Fallback<S>,
}
@@ -91,7 +93,6 @@ impl<S> fmt::Debug for Router<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Router")
.field("path_router", &self.inner.path_router)
.field("fallback_router", &self.inner.fallback_router)
.field("default_fallback", &self.inner.default_fallback)
.field("catch_all_fallback", &self.inner.catch_all_fallback)
.finish()
@@ -141,7 +142,6 @@ where
Self {
inner: Arc::new(RouterInner {
path_router: Default::default(),
fallback_router: PathRouter::new_fallback(),
default_fallback: true,
catch_all_fallback: Fallback::Default(Route::new(NotFound)),
}),
@@ -153,7 +153,6 @@ where
Ok(inner) => inner,
Err(arc) => RouterInner {
path_router: arc.path_router.clone(),
fallback_router: arc.fallback_router.clone(),
default_fallback: arc.default_fallback,
catch_all_fallback: arc.catch_all_fallback.clone(),
},
@@ -207,8 +206,7 @@ where

let RouterInner {
path_router,
fallback_router,
default_fallback,
default_fallback: _,
// we don't need to inherit the catch-all fallback. It is only used for CONNECT
// requests with an empty path. If we were to inherit the catch-all fallback
// it would end up matching `/{path}/*` which doesn't match empty paths.
@@ -217,10 +215,6 @@ where

tap_inner!(self, mut this => {
panic_on_err!(this.path_router.nest(path, path_router));

if !default_fallback {
panic_on_err!(this.fallback_router.nest(path, fallback_router));
}
})
}

@@ -247,43 +241,33 @@ where
where
R: Into<Router<S>>,
{
const PANIC_MSG: &str =
"Failed to merge fallbacks. This is a bug in axum. Please file an issue";

let other: Router<S> = other.into();
let RouterInner {
path_router,
fallback_router: mut other_fallback,
default_fallback,
catch_all_fallback,
} = other.into_inner();

map_inner!(self, mut this => {
panic_on_err!(this.path_router.merge(path_router));

match (this.default_fallback, default_fallback) {
// both have the default fallback
// use the one from other
(true, true) => {
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
}
(true, true) => {}
// this has default fallback, other has a custom fallback
(true, false) => {
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
this.default_fallback = false;
}
// this has a custom fallback, other has a default
(false, true) => {
let fallback_router = std::mem::take(&mut this.fallback_router);
other_fallback.merge(fallback_router).expect(PANIC_MSG);
this.fallback_router = other_fallback;
}
// both have a custom fallback, not allowed
(false, false) => {
panic!("Cannot merge two `Router`s that both have a fallback")
}
};

panic_on_err!(this.path_router.merge(path_router));

this.catch_all_fallback = this
.catch_all_fallback
.merge(catch_all_fallback)
@@ -304,7 +288,6 @@ where
{
map_inner!(self, this => RouterInner {
path_router: this.path_router.layer(layer.clone()),
fallback_router: this.fallback_router.layer(layer.clone()),
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)),
})
@@ -322,7 +305,6 @@ where
{
map_inner!(self, this => RouterInner {
path_router: this.path_router.route_layer(layer),
fallback_router: this.fallback_router,
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback,
})
@@ -376,8 +358,51 @@ where
}

fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
// TODO make this better, get rid of the `unwrap`s.
// We need the returned `Service` to be `Clone` and the function inside `service_fn` to be
// `FnMut` so instead of just using the owned service, we do this trick with `Option`. We
// know this will be called just once so it's fine. We're doing that so that we avoid one
// clone inside `oneshot_inner` so that the `Router` and subsequently the `State` is not
// cloned too much.
tap_inner!(self, mut this => {
this.fallback_router.set_fallback(endpoint);
_ = this.path_router.route_endpoint(
"/",
endpoint.clone().layer(
layer_fn(
|service: Route| {
let mut service = Some(service);
service_fn(
#[cfg_attr(not(feature = "matched-path"), allow(unused_mut))]
move |mut request: Request| {
#[cfg(feature = "matched-path")]
request.extensions_mut().remove::<MatchedPath>();
service.take().unwrap().oneshot_inner_owned(request)
}
)
}
)
)
);

_ = this.path_router.route_endpoint(
FALLBACK_PARAM_PATH,
endpoint.layer(
layer_fn(
|service: Route| {
let mut service = Some(service);
service_fn(
#[cfg_attr(not(feature = "matched-path"), allow(unused_mut))]
move |mut request: Request| {
#[cfg(feature = "matched-path")]
request.extensions_mut().remove::<MatchedPath>();
service.take().unwrap().oneshot_inner_owned(request)
}
)
}
)
)
);

this.default_fallback = false;
})
}
@@ -386,7 +411,6 @@ where
pub fn with_state<S2>(self, state: S) -> Router<S2> {
map_inner!(self, this => RouterInner {
path_router: this.path_router.with_state(state.clone()),
fallback_router: this.fallback_router.with_state(state.clone()),
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback.with_state(state),
})
@@ -398,11 +422,6 @@ where
Err((req, state)) => (req, state),
};

let (req, state) = match self.inner.fallback_router.call_with_state(req, state) {
Ok(future) => return future,
Err((req, state)) => (req, state),
};

self.inner
.catch_all_fallback
.clone()
88 changes: 20 additions & 68 deletions axum/src/routing/path_router.rs
Original file line number Diff line number Diff line change
@@ -9,33 +9,17 @@ use tower_layer::Layer;
use tower_service::Service;

use super::{
future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route,
RouteId, NEST_TAIL_PARAM,
};

pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
pub(super) struct PathRouter<S> {
routes: HashMap<RouteId, Endpoint<S>>,
node: Arc<Node>,
prev_route_id: RouteId,
v7_checks: bool,
}

impl<S> PathRouter<S, true>
where
S: Clone + Send + Sync + 'static,
{
pub(super) fn new_fallback() -> Self {
let mut this = Self::default();
this.set_fallback(Endpoint::Route(Route::new(NotFound)));
this
}

pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
self.replace_endpoint("/", endpoint.clone());
self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
}
}

fn validate_path(v7_checks: bool, path: &str) -> Result<(), &'static str> {
if path.is_empty() {
return Err("Paths must start with a `/`. Use \"/\" for root routes");
@@ -72,7 +56,7 @@ fn validate_v07_paths(path: &str) -> Result<(), &'static str> {
.unwrap_or(Ok(()))
}

impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
impl<S> PathRouter<S>
where
S: Clone + Send + Sync + 'static,
{
@@ -159,10 +143,7 @@ where
.map_err(|err| format!("Invalid route {path:?}: {err}"))
}

pub(super) fn merge(
&mut self,
other: PathRouter<S, IS_FALLBACK>,
) -> Result<(), Cow<'static, str>> {
pub(super) fn merge(&mut self, other: PathRouter<S>) -> Result<(), Cow<'static, str>> {
let PathRouter {
routes,
node,
@@ -179,24 +160,9 @@ where
.get(&id)
.expect("no path for route id. This is a bug in axum. Please file an issue");

if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
// when merging two routers it doesn't matter if you do `a.merge(b)` or
// `b.merge(a)`. This must also be true for fallbacks.
//
// However all fallback routers will have routes for `/` and `/*` so when merging
// we have to ignore the top level fallbacks on one side otherwise we get
// conflicts.
//
// `Router::merge` makes sure that when merging fallbacks `other` always has the
// fallback we want to keep. It panics if both routers have a custom fallback. Thus
// it is always okay to ignore one fallback and `Router::merge` also makes sure the
// one we can ignore is that of `self`.
self.replace_endpoint(path, route);
} else {
match route {
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
Endpoint::Route(route) => self.route_service(path, route)?,
}
match route {
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
Endpoint::Route(route) => self.route_service(path, route)?,
}
}

@@ -206,7 +172,7 @@ where
pub(super) fn nest(
&mut self,
path_to_nest_at: &str,
router: PathRouter<S, IS_FALLBACK>,
router: PathRouter<S>,
) -> Result<(), Cow<'static, str>> {
let prefix = validate_nest_path(self.v7_checks, path_to_nest_at);

@@ -282,7 +248,7 @@ where
Ok(())
}

pub(super) fn layer<L>(self, layer: L) -> PathRouter<S, IS_FALLBACK>
pub(super) fn layer<L>(self, layer: L) -> PathRouter<S>
where
L: Layer<Route> + Clone + Send + Sync + 'static,
L::Service: Service<Request> + Clone + Send + Sync + 'static,
@@ -344,7 +310,7 @@ where
!self.routes.is_empty()
}

pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, IS_FALLBACK> {
pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2> {
let routes = self
.routes
.into_iter()
@@ -388,14 +354,12 @@ where
Ok(match_) => {
let id = *match_.value;

if !IS_FALLBACK {
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
&mut parts.extensions,
);
}
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
&mut parts.extensions,
);

url_params::insert_url_params(&mut parts.extensions, match_.params);

@@ -418,18 +382,6 @@ where
}
}

pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S>) {
match self.node.at(path) {
Ok(match_) => {
let id = *match_.value;
self.routes.insert(id, endpoint);
}
Err(_) => self
.route_endpoint(path, endpoint)
.expect("path wasn't matched so endpoint shouldn't exist"),
}
}

fn next_route_id(&mut self) -> RouteId {
let next_id = self
.prev_route_id
@@ -441,7 +393,7 @@ where
}
}

impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
impl<S> Default for PathRouter<S> {
fn default() -> Self {
Self {
routes: Default::default(),
@@ -452,7 +404,7 @@ impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
}
}

impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
impl<S> fmt::Debug for PathRouter<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PathRouter")
.field("routes", &self.routes)
@@ -461,7 +413,7 @@ impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
}
}

impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
impl<S> Clone for PathRouter<S> {
fn clone(&self) -> Self {
Self {
routes: self.routes.clone(),
81 changes: 81 additions & 0 deletions axum/src/routing/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -407,6 +407,87 @@ async fn what_matches_wildcard() {
assert_eq!(get("/x/a/b/").await, "x");
}

#[should_panic(
expected = "Invalid route \"/{*wild}\": Insertion failed due to conflict with previously registered route: /{*__private__axum_fallback}"
)]
#[test]
fn colliding_fallback_with_wildcard() {
_ = Router::<()>::new()
.fallback(|| async { "fallback" })
.route("/{*wild}", get(|| async { "wildcard" }));
}

// We might want to reject this too
#[crate::test]
async fn colliding_wildcard_with_fallback() {
let router = Router::new()
.route("/{*wild}", get(|| async { "wildcard" }))
.fallback(|| async { "fallback" });

let client = TestClient::new(router);

let res = client.get("/").await;
let body = res.text().await;
assert_eq!(body, "fallback");

let res = client.get("/x").await;
let body = res.text().await;
assert_eq!(body, "wildcard");
}

// We might want to reject this too
#[crate::test]
async fn colliding_fallback_with_fallback() {
let router = Router::new()
.fallback(|| async { "fallback1" })
.fallback(|| async { "fallback2" });

let client = TestClient::new(router);

let res = client.get("/").await;
let body = res.text().await;
assert_eq!(body, "fallback1");

let res = client.get("/x").await;
let body = res.text().await;
assert_eq!(body, "fallback1");
}

#[crate::test]
async fn colliding_root_with_fallback() {
let router = Router::new()
.route("/", get(|| async { "root" }))
.fallback(|| async { "fallback" });

let client = TestClient::new(router);

let res = client.get("/").await;
let body = res.text().await;
assert_eq!(body, "root");

let res = client.get("/x").await;
let body = res.text().await;
assert_eq!(body, "fallback");
}

#[crate::test]
async fn colliding_fallback_with_root() {
let router = Router::new()
.fallback(|| async { "fallback" })
.route("/", get(|| async { "root" }));

let client = TestClient::new(router);

// This works because fallback registers `any` so the `get` gets merged into it.
let res = client.get("/").await;
let body = res.text().await;
assert_eq!(body, "root");

let res = client.get("/x").await;
let body = res.text().await;
assert_eq!(body, "fallback");
}

#[crate::test]
async fn static_and_dynamic_paths() {
let app = Router::new()
97 changes: 97 additions & 0 deletions axum/src/routing/tests/nest.rs
Original file line number Diff line number Diff line change
@@ -387,3 +387,100 @@ async fn colon_in_route() {
async fn asterisk_in_route() {
_ = Router::<()>::new().nest("/*foo", Router::new());
}

#[crate::test]
async fn nesting_router_with_fallback() {
let nested = Router::new().fallback(|| async { "nested" });
let router = Router::new().route("/{x}/{y}", get(|| async { "two segments" }));

let client = TestClient::new(router.nest("/nest", nested));

let res = client.get("/a/b").await;
let body = res.text().await;
assert_eq!(body, "two segments");

let res = client.get("/nest/b").await;
let body = res.text().await;
assert_eq!(body, "nested");
}

#[crate::test]
async fn defining_missing_routes_in_nested_router() {
let router = Router::new()
.route("/nest/before", get(|| async { "before" }))
.nest(
"/nest",
Router::new()
.route("/mid", get(|| async { "nested mid" }))
.fallback(|| async { "nested fallback" }),
)
.route("/nest/after", get(|| async { "after" }));

let client = TestClient::new(router);

let res = client.get("/nest/before").await;
let body = res.text().await;
assert_eq!(body, "before");

let res = client.get("/nest/after").await;
let body = res.text().await;
assert_eq!(body, "after");

let res = client.get("/nest/mid").await;
let body = res.text().await;
assert_eq!(body, "nested mid");

let res = client.get("/nest/fallback").await;
let body = res.text().await;
assert_eq!(body, "nested fallback");
}

#[test]
#[should_panic(
expected = "Overlapping method route. Handler for `GET /nest/override` already exists"
)]
fn overriding_by_nested_router() {
_ = Router::<()>::new()
.route("/nest/override", get(|| async { "outer" }))
.nest(
"/nest",
Router::new().route("/override", get(|| async { "inner" })),
);
}

#[test]
#[should_panic(
expected = "Overlapping method route. Handler for `GET /nest/override` already exists"
)]
fn overriding_nested_router_() {
_ = Router::<()>::new()
.nest(
"/nest",
Router::new().route("/override", get(|| async { "inner" })),
)
.route("/nest/override", get(|| async { "outer" }));
}

// This is just documenting current state, not intended behavior.
#[crate::test]
async fn overriding_nested_service_router() {
let router = Router::new()
.route("/nest/before", get(|| async { "outer" }))
.nest_service(
"/nest",
Router::new()
.route("/before", get(|| async { "inner" }))
.route("/after", get(|| async { "inner" })),
)
.route("/nest/after", get(|| async { "outer" }));

let client = TestClient::new(router);

let res = client.get("/nest/before").await;
let body = res.text().await;
assert_eq!(body, "outer");

let res = client.get("/nest/after").await;
let body = res.text().await;
assert_eq!(body, "outer");
}