Skip to content

Commit ec75ee3

Browse files
authored
Add a separate trait for optional extractors (#2475)
1 parent fd11d8e commit ec75ee3

File tree

17 files changed

+306
-84
lines changed

17 files changed

+306
-84
lines changed

axum-core/CHANGELOG.md

+8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
# Unreleased
9+
10+
- **breaking:**: `Option<T>` as an extractor now requires `T` to implement the
11+
new trait `OptionalFromRequest` (if used as the last extractor) or
12+
`OptionalFromRequestParts` (other extractors) ([#2475])
13+
14+
[#2475]: https://github.com/tokio-rs/axum/pull/2475
15+
816
# 0.5.0
917

1018
## alpha.1

axum-core/src/extract/mod.rs

+6-28
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@ pub mod rejection;
1313

1414
mod default_body_limit;
1515
mod from_ref;
16+
mod option;
1617
mod request_parts;
1718
mod tuple;
1819

1920
pub(crate) use self::default_body_limit::DefaultBodyLimitKind;
20-
pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef};
21+
pub use self::{
22+
default_body_limit::DefaultBodyLimit,
23+
from_ref::FromRef,
24+
option::{OptionalFromRequest, OptionalFromRequestParts},
25+
};
2126

2227
/// Type alias for [`http::Request`] whose body type defaults to [`Body`], the most common body
2328
/// type used with axum.
@@ -102,33 +107,6 @@ where
102107
}
103108
}
104109

105-
impl<S, T> FromRequestParts<S> for Option<T>
106-
where
107-
T: FromRequestParts<S>,
108-
S: Send + Sync,
109-
{
110-
type Rejection = Infallible;
111-
112-
async fn from_request_parts(
113-
parts: &mut Parts,
114-
state: &S,
115-
) -> Result<Option<T>, Self::Rejection> {
116-
Ok(T::from_request_parts(parts, state).await.ok())
117-
}
118-
}
119-
120-
impl<S, T> FromRequest<S> for Option<T>
121-
where
122-
T: FromRequest<S>,
123-
S: Send + Sync,
124-
{
125-
type Rejection = Infallible;
126-
127-
async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> {
128-
Ok(T::from_request(req, state).await.ok())
129-
}
130-
}
131-
132110
impl<S, T> FromRequestParts<S> for Result<T, T::Rejection>
133111
where
134112
T: FromRequestParts<S>,

axum-core/src/extract/option.rs

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
use std::future::Future;
2+
3+
use http::request::Parts;
4+
5+
use crate::response::IntoResponse;
6+
7+
use super::{private, FromRequest, FromRequestParts, Request};
8+
9+
/// Customize the behavior of `Option<Self>` as a [`FromRequestParts`]
10+
/// extractor.
11+
pub trait OptionalFromRequestParts<S>: Sized {
12+
/// If the extractor fails, it will use this "rejection" type.
13+
///
14+
/// A rejection is a kind of error that can be converted into a response.
15+
type Rejection: IntoResponse;
16+
17+
/// Perform the extraction.
18+
fn from_request_parts(
19+
parts: &mut Parts,
20+
state: &S,
21+
) -> impl Future<Output = Result<Option<Self>, Self::Rejection>> + Send;
22+
}
23+
24+
/// Customize the behavior of `Option<Self>` as a [`FromRequest`] extractor.
25+
pub trait OptionalFromRequest<S, M = private::ViaRequest>: Sized {
26+
/// If the extractor fails, it will use this "rejection" type.
27+
///
28+
/// A rejection is a kind of error that can be converted into a response.
29+
type Rejection: IntoResponse;
30+
31+
/// Perform the extraction.
32+
fn from_request(
33+
req: Request,
34+
state: &S,
35+
) -> impl Future<Output = Result<Option<Self>, Self::Rejection>> + Send;
36+
}
37+
38+
impl<S, T> FromRequestParts<S> for Option<T>
39+
where
40+
T: OptionalFromRequestParts<S>,
41+
S: Send + Sync,
42+
{
43+
type Rejection = T::Rejection;
44+
45+
fn from_request_parts(
46+
parts: &mut Parts,
47+
state: &S,
48+
) -> impl Future<Output = Result<Option<T>, Self::Rejection>> {
49+
T::from_request_parts(parts, state)
50+
}
51+
}
52+
53+
impl<S, T> FromRequest<S> for Option<T>
54+
where
55+
T: OptionalFromRequest<S>,
56+
S: Send + Sync,
57+
{
58+
type Rejection = T::Rejection;
59+
60+
async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> {
61+
T::from_request(req, state).await
62+
}
63+
}

axum-extra/CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning].
77

88
# Unreleased
99

10+
- **breaking:** `Option<Query<T>>` no longer swallows all error conditions, instead rejecting the
11+
request in many cases; see its documentation for details ([#2475])
12+
- **changed:** Deprecated `OptionalPath<T>` and `OptionalQuery<T>` ([#2475])
1013
- **fixed:** `Host` extractor includes port number when parsing authority ([#2242])
1114
- **changed:** The `multipart` feature is no longer on by default ([#3058])
1215
- **added:** Add `RouterExt::typed_connect` ([#2961])
@@ -16,6 +19,7 @@ and this project adheres to [Semantic Versioning].
1619
- **added:** Add `FileStream` for easy construction of file stream responses ([#3047])
1720

1821
[#2242]: https://github.com/tokio-rs/axum/pull/2242
22+
[#2475]: https://github.com/tokio-rs/axum/pull/2475
1923
[#3058]: https://github.com/tokio-rs/axum/pull/3058
2024
[#2961]: https://github.com/tokio-rs/axum/pull/2961
2125
[#2962]: https://github.com/tokio-rs/axum/pull/2962

axum-extra/src/extract/mod.rs

+7-4
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ pub mod multipart;
2424
#[cfg(feature = "scheme")]
2525
mod scheme;
2626

27-
pub use self::{
28-
cached::Cached, host::Host, optional_path::OptionalPath, with_rejection::WithRejection,
29-
};
27+
#[allow(deprecated)]
28+
pub use self::optional_path::OptionalPath;
29+
pub use self::{cached::Cached, host::Host, with_rejection::WithRejection};
3030

3131
#[cfg(feature = "cookie")]
3232
pub use self::cookie::CookieJar;
@@ -41,7 +41,10 @@ pub use self::cookie::SignedCookieJar;
4141
pub use self::form::{Form, FormRejection};
4242

4343
#[cfg(feature = "query")]
44-
pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejection};
44+
#[allow(deprecated)]
45+
pub use self::query::OptionalQuery;
46+
#[cfg(feature = "query")]
47+
pub use self::query::{OptionalQueryRejection, Query, QueryRejection};
4548

4649
#[cfg(feature = "multipart")]
4750
pub use self::multipart::Multipart;

axum-extra/src/extract/optional_path.rs

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use axum::{
2-
extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts, Path},
2+
extract::{rejection::PathRejection, FromRequestParts, Path},
33
RequestPartsExt,
44
};
55
use serde::de::DeserializeOwned;
@@ -31,9 +31,11 @@ use serde::de::DeserializeOwned;
3131
/// .route("/blog/{page}", get(render_blog));
3232
/// # let app: Router = app;
3333
/// ```
34+
#[deprecated = "Use Option<Path<_>> instead"]
3435
#[derive(Debug)]
3536
pub struct OptionalPath<T>(pub Option<T>);
3637

38+
#[allow(deprecated)]
3739
impl<T, S> FromRequestParts<S> for OptionalPath<T>
3840
where
3941
T: DeserializeOwned + Send + 'static,
@@ -45,19 +47,15 @@ where
4547
parts: &mut http::request::Parts,
4648
_: &S,
4749
) -> Result<Self, Self::Rejection> {
48-
match parts.extract::<Path<T>>().await {
49-
Ok(Path(params)) => Ok(Self(Some(params))),
50-
Err(PathRejection::FailedToDeserializePathParams(e))
51-
if matches!(e.kind(), ErrorKind::WrongNumberOfParameters { got: 0, .. }) =>
52-
{
53-
Ok(Self(None))
54-
}
55-
Err(e) => Err(e),
56-
}
50+
parts
51+
.extract::<Option<Path<T>>>()
52+
.await
53+
.map(|opt| Self(opt.map(|Path(x)| x)))
5754
}
5855
}
5956

6057
#[cfg(test)]
58+
#[allow(deprecated)]
6159
mod tests {
6260
use std::num::NonZeroU32;
6361

axum-extra/src/extract/query.rs

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use axum::{
2-
extract::FromRequestParts,
2+
extract::{FromRequestParts, OptionalFromRequestParts},
33
response::{IntoResponse, Response},
44
Error,
55
};
@@ -18,6 +18,19 @@ use std::fmt;
1818
/// with the `multiple` attribute. Those values can be collected into a `Vec` or other sequential
1919
/// container.
2020
///
21+
/// # `Option<Query<T>>` behavior
22+
///
23+
/// If `Query<T>` itself is used as an extractor and there is no query string in
24+
/// the request URL, `T`'s `Deserialize` implementation is called on an empty
25+
/// string instead.
26+
///
27+
/// You can avoid this by using `Option<Query<T>>`, which gives you `None` in
28+
/// the case that there is no query string in the request URL.
29+
///
30+
/// Note that an empty query string is not the same as no query string, that is
31+
/// `https://example.org/` and `https://example.org/?` are not treated the same
32+
/// in this case.
33+
///
2134
/// # Example
2235
///
2336
/// ```rust,no_run
@@ -96,6 +109,27 @@ where
96109
}
97110
}
98111

112+
impl<T, S> OptionalFromRequestParts<S> for Query<T>
113+
where
114+
T: DeserializeOwned,
115+
S: Send + Sync,
116+
{
117+
type Rejection = QueryRejection;
118+
119+
async fn from_request_parts(
120+
parts: &mut Parts,
121+
_state: &S,
122+
) -> Result<Option<Self>, Self::Rejection> {
123+
if let Some(query) = parts.uri.query() {
124+
let value = serde_html_form::from_str(query)
125+
.map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?;
126+
Ok(Some(Self(value)))
127+
} else {
128+
Ok(None)
129+
}
130+
}
131+
}
132+
99133
axum_core::__impl_deref!(Query);
100134

101135
/// Rejection used for [`Query`].
@@ -182,9 +216,11 @@ impl std::error::Error for QueryRejection {
182216
///
183217
/// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs
184218
#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
219+
#[deprecated = "Use Option<Query<_>> instead"]
185220
#[derive(Debug, Clone, Copy, Default)]
186221
pub struct OptionalQuery<T>(pub Option<T>);
187222

223+
#[allow(deprecated)]
188224
impl<T, S> FromRequestParts<S> for OptionalQuery<T>
189225
where
190226
T: DeserializeOwned,
@@ -204,6 +240,7 @@ where
204240
}
205241
}
206242

243+
#[allow(deprecated)]
207244
impl<T> std::ops::Deref for OptionalQuery<T> {
208245
type Target = Option<T>;
209246

@@ -213,6 +250,7 @@ impl<T> std::ops::Deref for OptionalQuery<T> {
213250
}
214251
}
215252

253+
#[allow(deprecated)]
216254
impl<T> std::ops::DerefMut for OptionalQuery<T> {
217255
#[inline]
218256
fn deref_mut(&mut self) -> &mut Self::Target {
@@ -260,6 +298,7 @@ impl std::error::Error for OptionalQueryRejection {
260298
}
261299

262300
#[cfg(test)]
301+
#[allow(deprecated)]
263302
mod tests {
264303
use super::*;
265304
use crate::test_helpers::*;

axum-extra/src/typed_header.rs

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Extractor and response for typed headers.
22
33
use axum::{
4-
extract::FromRequestParts,
4+
extract::{FromRequestParts, OptionalFromRequestParts},
55
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
66
};
77
use headers::{Header, HeaderMapExt};
@@ -78,6 +78,30 @@ where
7878
}
7979
}
8080

81+
impl<T, S> OptionalFromRequestParts<S> for TypedHeader<T>
82+
where
83+
T: Header,
84+
S: Send + Sync,
85+
{
86+
type Rejection = TypedHeaderRejection;
87+
88+
async fn from_request_parts(
89+
parts: &mut Parts,
90+
_state: &S,
91+
) -> Result<Option<Self>, Self::Rejection> {
92+
let mut values = parts.headers.get_all(T::name()).iter();
93+
let is_missing = values.size_hint() == (0, Some(0));
94+
match T::decode(&mut values) {
95+
Ok(res) => Ok(Some(Self(res))),
96+
Err(_) if is_missing => Ok(None),
97+
Err(err) => Err(TypedHeaderRejection {
98+
name: T::name(),
99+
reason: TypedHeaderRejectionReason::Error(err),
100+
}),
101+
}
102+
}
103+
}
104+
81105
axum_core::__impl_deref!(TypedHeader);
82106

83107
impl<T> IntoResponseParts for TypedHeader<T>

axum-macros/tests/typed_path/pass/option_result.rs axum-macros/tests/typed_path/pass/result_handler.rs

-3
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ struct UsersShow {
88
id: String,
99
}
1010

11-
async fn option_handler(_: Option<UsersShow>) {}
12-
1311
async fn result_handler(_: Result<UsersShow, PathRejection>) {}
1412

1513
#[derive(TypedPath, Deserialize)]
@@ -20,7 +18,6 @@ async fn result_handler_unit_struct(_: Result<UsersIndex, StatusCode>) {}
2018

2119
fn main() {
2220
_ = axum::Router::<()>::new()
23-
.typed_get(option_handler)
2421
.typed_post(result_handler)
2522
.typed_post(result_handler_unit_struct);
2623
}

axum/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2020
This allows middleware to add bodies to requests without needing to manually set `content-length` ([#2897])
2121
- **breaking:** Remove `WebSocket::close`.
2222
Users should explicitly send close messages themselves. ([#2974])
23+
- **breaking:** `Option<Path<T>>` and `Option<Query<T>>` no longer swallow all error conditions,
24+
instead rejecting the request in many cases; see their documentation for details ([#2475])
2325
- **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`)
2426
This new variant captures both `key`, `value`, and `message` from named path parameters parse errors,
2527
instead of only deserialization error message in `ErrorKind::Message`. ([#2720])
2628
- **breaking:** Make `serve` generic over the listener and IO types ([#2941])
2729

30+
[#2475]: https://github.com/tokio-rs/axum/pull/2475
2831
[#2897]: https://github.com/tokio-rs/axum/pull/2897
2932
[#2903]: https://github.com/tokio-rs/axum/pull/2903
3033
[#2894]: https://github.com/tokio-rs/axum/pull/2894

axum/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ features = [
116116

117117
[dev-dependencies]
118118
anyhow = "1.0"
119+
axum-extra = { path = "../axum-extra", features = ["typed-header"] }
119120
axum-macros = { path = "../axum-macros", features = ["__private"] }
120121
hyper = { version = "1.1.0", features = ["client"] }
121122
quickcheck = "1.0"

0 commit comments

Comments
 (0)