Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(extract): prefixed signed&private cookies #3251

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions axum-extra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ version = "0.11.0"
default = ["tracing"]

async-read-body = ["dep:tokio-util", "tokio-util?/io", "dep:tokio"]
file-stream = ["dep:tokio-util", "tokio-util?/io", "dep:tokio", "tokio?/fs", "tokio?/io-util"]
file-stream = [
"dep:tokio-util",
"tokio-util?/io",
"dep:tokio",
"tokio?/fs",
"tokio?/io-util",
]
attachment = ["dep:tracing"]
error-response = ["dep:tracing", "tracing/std"]
cookie = ["dep:cookie"]
Expand All @@ -36,10 +42,19 @@ json-lines = [
multipart = ["dep:multer", "dep:fastrand"]
protobuf = ["dep:prost"]
scheme = []
query = ["dep:form_urlencoded", "dep:serde_html_form", "dep:serde_path_to_error"]
query = [
"dep:form_urlencoded",
"dep:serde_html_form",
"dep:serde_path_to_error",
]
tracing = ["axum-core/tracing", "axum/tracing"]
typed-header = ["dep:headers"]
typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"]
typed-routing = [
"dep:axum-macros",
"dep:percent-encoding",
"dep:serde_html_form",
"dep:form_urlencoded",
]

# Enabled by docs.rs because it uses all-features
__private_docs = [
Expand All @@ -48,10 +63,14 @@ __private_docs = [
]

[dependencies]
axum = { path = "../axum", version = "0.8.2", default-features = false, features = ["original-uri"] }
axum = { path = "../axum", version = "0.8.2", default-features = false, features = [
"original-uri",
] }
axum-core = { path = "../axum-core", version = "0.5.1" }
bytes = "1.1.0"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
futures-util = { version = "0.3", default-features = false, features = [
"alloc",
] }
http = "1.0.0"
http-body = "1.0.0"
http-body-util = "0.1.0"
Expand All @@ -64,7 +83,9 @@ tower-service = "0.3"

# optional dependencies
axum-macros = { path = "../axum-macros", version = "0.5.0", optional = true }
cookie = { package = "cookie", version = "0.18.0", features = ["percent-encode"], optional = true }
cookie = { package = "cookie", version = "0.18.0", features = [
"percent-encode",
], optional = true }
fastrand = { version = "2.1.0", optional = true }
form_urlencoded = { version = "1.1.0", optional = true }
headers = { version = "0.4.0", optional = true }
Expand All @@ -84,7 +105,11 @@ typed-json = { version = "0.1.1", optional = true }
axum = { path = "../axum", features = ["macros", "__private"] }
axum-macros = { path = "../axum-macros", features = ["__private"] }
hyper = "1.0.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }
reqwest = { version = "0.12", default-features = false, features = [
"json",
"stream",
"multipart",
] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.71"
tokio = { version = "1.14", features = ["full"] }
Expand Down
188 changes: 188 additions & 0 deletions axum-extra/src/extract/cookie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,82 @@ impl CookieJar {
pub fn iter(&self) -> impl Iterator<Item = &'_ Cookie<'static>> {
self.jar.iter()
}

/// Add a cookie with the specified prefix to the jar.
///
/// # Example
/// ```rust
/// use axum_extra::extract::cookie::{CookieJar, Cookie};
/// use cookie::prefix::{Host, Secure};
///
/// async fn handler(jar: CookieJar) -> CookieJar {
/// // Add a cookie with the "__Host-" prefix
/// let with_host = jar.clone().add_prefixed(Host, Cookie::new("session_id", "value"));
///
/// // Add a cookie with the "__Secure-" prefix
/// let _with_secure = jar.add_prefixed(Secure, Cookie::new("auth", "token"));
///
/// with_host
/// }
/// ```
#[must_use]
pub fn add_prefixed<P: cookie::prefix::Prefix>(
mut self,
prefix: P,
cookie: Cookie<'static>,
) -> Self {
let mut prefixed_jar = self.jar.prefixed_mut(prefix);
prefixed_jar.add(cookie);
self
}

/// Get a signed cookie with the specified prefix from the jar.
///
/// If the cookie exists and its signature is valid, it is returned with its original name
/// (without the prefix) and plaintext value.
///
/// # Example
/// ```rust
/// use axum_extra::extract::cookie::{CookieJar, Cookie};
/// use cookie::prefix::{Host, Secure};
///
/// async fn handler(jar: CookieJar) {
/// if let Some(cookie) = jar.get_prefixed(cookie::prefix::Host, "session_id") {
/// let value = cookie.value();
/// }
/// }
/// ```
pub fn get_prefixed<P: cookie::prefix::Prefix>(
&self,
prefix: P,
name: &str,
) -> Option<Cookie<'static>> {
let prefixed_jar = self.jar.prefixed(prefix);
prefixed_jar.get(name)
}

/// Remove a cookie with the specified prefix from the jar.
///
/// # Example
/// ```rust
/// use axum_extra::extract::cookie::CookieJar;
/// use cookie::prefix::{Host, Secure};
///
/// async fn handler(jar: CookieJar) -> CookieJar {
/// // Remove a cookie with the "__Host-" prefix
/// jar.remove_prefixed(Host, "session_id")
/// }
/// ```
#[must_use]
pub fn remove_prefixed<P, S>(mut self, prefix: P, name: S) -> Self
where
P: cookie::prefix::Prefix,
S: Into<String>,
{
let mut prefixed_jar = self.jar.prefixed_mut(prefix);
prefixed_jar.remove(name.into());
self
}
}

impl IntoResponseParts for CookieJar {
Expand Down Expand Up @@ -232,6 +308,7 @@ fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) {
mod tests {
use super::*;
use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router};
use cookie::prefix::Host;
use http_body_util::BodyExt;
use tower::ServiceExt;

Expand Down Expand Up @@ -268,6 +345,19 @@ mod tests {
.await
.unwrap();
let cookie_value = res.headers()["set-cookie"].to_str().unwrap();
println!("Set Cookie value: {}", cookie_value);

assert!(cookie_value.starts_with("key="));

// For signed/private cookies, verify that the plaintext value is not directly visible
// (only for signed and private jars, not for the regular CookieJar)
if std::any::type_name::<$jar>().contains("Private")
|| std::any::type_name::<$jar>().contains("Signed")
{
assert!(!cookie_value.contains("key=value"));
} else {
assert!(cookie_value.contains("key=value"));
}

let res = app
.clone()
Expand Down Expand Up @@ -302,17 +392,115 @@ mod tests {
};
}

macro_rules! cookie_prefixed_test {
($name:ident, $jar:ty) => {
#[tokio::test]
async fn $name() {
async fn set_cookie_prefixed(jar: $jar) -> impl IntoResponse {
jar.add_prefixed(Host, Cookie::new("key", "value"))
}

async fn get_cookie_prefixed(jar: $jar) -> impl IntoResponse {
jar.get_prefixed(Host, "key").unwrap().value().to_owned()
}

async fn remove_cookie_prefixed(jar: $jar) -> impl IntoResponse {
jar.remove_prefixed(Host, "key")
}

let state = AppState {
key: Key::generate(),
custom_key: CustomKey(Key::generate()),
};

let app = Router::new()
.route("/set", get(set_cookie_prefixed))
.route("/get", get(get_cookie_prefixed))
.route("/remove", get(remove_cookie_prefixed))
.with_state(state);

let res = app
.clone()
.oneshot(Request::builder().uri("/set").body(Body::empty()).unwrap())
.await
.unwrap();
let cookie_value = res.headers()["set-cookie"].to_str().unwrap();
println!("Set Cookie value: {}", cookie_value);
assert!(cookie_value.contains("__Host-key"));

// For signed/private cookies, verify that the plaintext value is not directly visible
// (only for signed and private jars, not for the regular CookieJar)
if std::any::type_name::<$jar>().contains("Private")
|| std::any::type_name::<$jar>().contains("Signed")
{
assert!(!cookie_value.contains("key=value"));
} else {
assert!(cookie_value.contains("key=value"));
}

// Extract just the cookie part (before the first semicolon)
// Set-Cookie: __Host-key=value; Secure; Path=/ -> __Host-key=value
let cookie_header_value = cookie_value.split(';').next().unwrap().trim();
println!("Using Cookie header value: {}", cookie_header_value);

let res = app
.clone()
.oneshot(
Request::builder()
.uri("/get")
.header("cookie", cookie_header_value)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let body = body_text(res).await;
assert_eq!(body, "value");

let res = app
.clone()
.oneshot(
Request::builder()
.uri("/remove")
.header("cookie", cookie_value)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert!(res.headers()["set-cookie"]
.to_str()
.unwrap()
.contains("__Host-key=;"));
}
};
}

cookie_test!(plaintext_cookies, CookieJar);

#[cfg(feature = "cookie-signed")]
cookie_test!(signed_cookies, SignedCookieJar);
#[cfg(feature = "cookie-signed")]
cookie_prefixed_test!(signed_cookies_prefixed, SignedCookieJar);
#[cfg(feature = "cookie-signed")]
cookie_test!(signed_cookies_with_custom_key, SignedCookieJar<CustomKey>);
#[cfg(feature = "cookie-signed")]
cookie_prefixed_test!(
signed_cookies_prefixed_with_custom_key,
SignedCookieJar<CustomKey>
);

#[cfg(feature = "cookie-private")]
cookie_test!(private_cookies, PrivateCookieJar);
#[cfg(feature = "cookie-private")]
cookie_prefixed_test!(private_cookies_prefixed, PrivateCookieJar);
#[cfg(feature = "cookie-private")]
cookie_test!(private_cookies_with_custom_key, PrivateCookieJar<CustomKey>);
#[cfg(feature = "cookie-private")]
cookie_prefixed_test!(
private_cookies_prefixed_with_custom_key,
PrivateCookieJar<CustomKey>
);

#[derive(Clone)]
struct AppState {
Expand Down
Loading
Loading