Skip to content

Commit 61474f4

Browse files
authored
Add HTTP Upgrade support to Response. (#1376)
1 parent e9ba0a9 commit 61474f4

File tree

5 files changed

+161
-42
lines changed

5 files changed

+161
-42
lines changed

src/async_impl/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ pub mod decoder;
1313
pub mod multipart;
1414
pub(crate) mod request;
1515
mod response;
16+
mod upgrade;

src/async_impl/response.rs

+30-42
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,10 @@ use crate::response::ResponseUrl;
2424

2525
/// A Response to a submitted `Request`.
2626
pub struct Response {
27-
status: StatusCode,
28-
headers: HeaderMap,
27+
pub(super) res: hyper::Response<Decoder>,
2928
// Boxed to save space (11 words to 1 word), and it's not accessed
3029
// frequently internally.
3130
url: Box<Url>,
32-
body: Decoder,
33-
version: Version,
34-
extensions: http::Extensions,
3531
}
3632

3733
impl Response {
@@ -41,46 +37,38 @@ impl Response {
4137
accepts: Accepts,
4238
timeout: Option<Pin<Box<Sleep>>>,
4339
) -> Response {
44-
let (parts, body) = res.into_parts();
45-
let status = parts.status;
46-
let version = parts.version;
47-
let extensions = parts.extensions;
48-
49-
let mut headers = parts.headers;
50-
let decoder = Decoder::detect(&mut headers, Body::response(body, timeout), accepts);
40+
let (mut parts, body) = res.into_parts();
41+
let decoder = Decoder::detect(&mut parts.headers, Body::response(body, timeout), accepts);
42+
let res = hyper::Response::from_parts(parts, decoder);
5143

5244
Response {
53-
status,
54-
headers,
45+
res,
5546
url: Box::new(url),
56-
body: decoder,
57-
version,
58-
extensions,
5947
}
6048
}
6149

6250
/// Get the `StatusCode` of this `Response`.
6351
#[inline]
6452
pub fn status(&self) -> StatusCode {
65-
self.status
53+
self.res.status()
6654
}
6755

6856
/// Get the HTTP `Version` of this `Response`.
6957
#[inline]
7058
pub fn version(&self) -> Version {
71-
self.version
59+
self.res.version()
7260
}
7361

7462
/// Get the `Headers` of this `Response`.
7563
#[inline]
7664
pub fn headers(&self) -> &HeaderMap {
77-
&self.headers
65+
self.res.headers()
7866
}
7967

8068
/// Get a mutable reference to the `Headers` of this `Response`.
8169
#[inline]
8270
pub fn headers_mut(&mut self) -> &mut HeaderMap {
83-
&mut self.headers
71+
self.res.headers_mut()
8472
}
8573

8674
/// Get the content-length of this response, if known.
@@ -93,7 +81,7 @@ impl Response {
9381
pub fn content_length(&self) -> Option<u64> {
9482
use hyper::body::HttpBody;
9583

96-
HttpBody::size_hint(&self.body).exact()
84+
HttpBody::size_hint(self.res.body()).exact()
9785
}
9886

9987
/// Retrieve the cookies contained in the response.
@@ -106,7 +94,7 @@ impl Response {
10694
#[cfg(feature = "cookies")]
10795
#[cfg_attr(docsrs, doc(cfg(feature = "cookies")))]
10896
pub fn cookies<'a>(&'a self) -> impl Iterator<Item = cookie::Cookie<'a>> + 'a {
109-
cookie::extract_response_cookies(&self.headers).filter_map(Result::ok)
97+
cookie::extract_response_cookies(self.res.headers()).filter_map(Result::ok)
11098
}
11199

112100
/// Get the final `Url` of this `Response`.
@@ -117,19 +105,20 @@ impl Response {
117105

118106
/// Get the remote address used to get this `Response`.
119107
pub fn remote_addr(&self) -> Option<SocketAddr> {
120-
self.extensions
108+
self.res
109+
.extensions()
121110
.get::<HttpInfo>()
122111
.map(|info| info.remote_addr())
123112
}
124113

125114
/// Returns a reference to the associated extensions.
126115
pub fn extensions(&self) -> &http::Extensions {
127-
&self.extensions
116+
self.res.extensions()
128117
}
129118

130119
/// Returns a mutable reference to the associated extensions.
131120
pub fn extensions_mut(&mut self) -> &mut http::Extensions {
132-
&mut self.extensions
121+
self.res.extensions_mut()
133122
}
134123

135124
// body methods
@@ -183,7 +172,7 @@ impl Response {
183172
/// ```
184173
pub async fn text_with_charset(self, default_encoding: &str) -> crate::Result<String> {
185174
let content_type = self
186-
.headers
175+
.headers()
187176
.get(crate::header::CONTENT_TYPE)
188177
.and_then(|value| value.to_str().ok())
189178
.and_then(|value| value.parse::<Mime>().ok());
@@ -271,7 +260,7 @@ impl Response {
271260
/// # }
272261
/// ```
273262
pub async fn bytes(self) -> crate::Result<Bytes> {
274-
hyper::body::to_bytes(self.body).await
263+
hyper::body::to_bytes(self.res.into_body()).await
275264
}
276265

277266
/// Stream a chunk of the response body.
@@ -291,7 +280,7 @@ impl Response {
291280
/// # }
292281
/// ```
293282
pub async fn chunk(&mut self) -> crate::Result<Option<Bytes>> {
294-
if let Some(item) = self.body.next().await {
283+
if let Some(item) = self.res.body_mut().next().await {
295284
Ok(Some(item?))
296285
} else {
297286
Ok(None)
@@ -323,7 +312,7 @@ impl Response {
323312
#[cfg(feature = "stream")]
324313
#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
325314
pub fn bytes_stream(self) -> impl futures_core::Stream<Item = crate::Result<Bytes>> {
326-
self.body
315+
self.res.into_body()
327316
}
328317

329318
// util methods
@@ -350,8 +339,9 @@ impl Response {
350339
/// # fn main() {}
351340
/// ```
352341
pub fn error_for_status(self) -> crate::Result<Self> {
353-
if self.status.is_client_error() || self.status.is_server_error() {
354-
Err(crate::error::status_code(*self.url, self.status))
342+
let status = self.status();
343+
if status.is_client_error() || status.is_server_error() {
344+
Err(crate::error::status_code(*self.url, status))
355345
} else {
356346
Ok(self)
357347
}
@@ -379,8 +369,9 @@ impl Response {
379369
/// # fn main() {}
380370
/// ```
381371
pub fn error_for_status_ref(&self) -> crate::Result<&Self> {
382-
if self.status.is_client_error() || self.status.is_server_error() {
383-
Err(crate::error::status_code(*self.url.clone(), self.status))
372+
let status = self.status();
373+
if status.is_client_error() || status.is_server_error() {
374+
Err(crate::error::status_code(*self.url.clone(), status))
384375
} else {
385376
Ok(self)
386377
}
@@ -395,7 +386,7 @@ impl Response {
395386
// This method is just used by the blocking API.
396387
#[cfg(feature = "blocking")]
397388
pub(crate) fn body_mut(&mut self) -> &mut Decoder {
398-
&mut self.body
389+
self.res.body_mut()
399390
}
400391
}
401392

@@ -413,27 +404,24 @@ impl<T: Into<Body>> From<http::Response<T>> for Response {
413404
fn from(r: http::Response<T>) -> Response {
414405
let (mut parts, body) = r.into_parts();
415406
let body = body.into();
416-
let body = Decoder::detect(&mut parts.headers, body, Accepts::none());
407+
let decoder = Decoder::detect(&mut parts.headers, body, Accepts::none());
417408
let url = parts
418409
.extensions
419410
.remove::<ResponseUrl>()
420411
.unwrap_or_else(|| ResponseUrl(Url::parse("http://no.url.provided.local").unwrap()));
421412
let url = url.0;
413+
let res = hyper::Response::from_parts(parts, decoder);
422414
Response {
423-
status: parts.status,
424-
headers: parts.headers,
415+
res,
425416
url: Box::new(url),
426-
body,
427-
version: parts.version,
428-
extensions: parts.extensions,
429417
}
430418
}
431419
}
432420

433421
/// A `Response` can be piped as the `Body` of another request.
434422
impl From<Response> for Body {
435423
fn from(r: Response) -> Body {
436-
Body::stream(r.body)
424+
Body::stream(r.res.into_body())
437425
}
438426
}
439427

src/async_impl/upgrade.rs

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use std::pin::Pin;
2+
use std::task::{self, Poll};
3+
use std::{fmt, io};
4+
5+
use futures_util::TryFutureExt;
6+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7+
8+
/// An upgraded HTTP connection.
9+
pub struct Upgraded {
10+
inner: hyper::upgrade::Upgraded,
11+
}
12+
13+
impl AsyncRead for Upgraded {
14+
fn poll_read(
15+
mut self: Pin<&mut Self>,
16+
cx: &mut task::Context<'_>,
17+
buf: &mut ReadBuf<'_>,
18+
) -> Poll<io::Result<()>> {
19+
Pin::new(&mut self.inner).poll_read(cx, buf)
20+
}
21+
}
22+
23+
impl AsyncWrite for Upgraded {
24+
fn poll_write(
25+
mut self: Pin<&mut Self>,
26+
cx: &mut task::Context<'_>,
27+
buf: &[u8],
28+
) -> Poll<io::Result<usize>> {
29+
Pin::new(&mut self.inner).poll_write(cx, buf)
30+
}
31+
32+
fn poll_write_vectored(
33+
mut self: Pin<&mut Self>,
34+
cx: &mut task::Context<'_>,
35+
bufs: &[io::IoSlice<'_>],
36+
) -> Poll<io::Result<usize>> {
37+
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
38+
}
39+
40+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
41+
Pin::new(&mut self.inner).poll_flush(cx)
42+
}
43+
44+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
45+
Pin::new(&mut self.inner).poll_shutdown(cx)
46+
}
47+
48+
fn is_write_vectored(&self) -> bool {
49+
self.inner.is_write_vectored()
50+
}
51+
}
52+
53+
impl fmt::Debug for Upgraded {
54+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55+
f.debug_struct("Upgraded").finish()
56+
}
57+
}
58+
59+
impl From<hyper::upgrade::Upgraded> for Upgraded {
60+
fn from(inner: hyper::upgrade::Upgraded) -> Self {
61+
Upgraded { inner }
62+
}
63+
}
64+
65+
impl super::response::Response {
66+
/// Consumes the response and returns a future for a possible HTTP upgrade.
67+
pub async fn upgrade(self) -> crate::Result<Upgraded> {
68+
hyper::upgrade::on(self.res)
69+
.map_ok(Upgraded::from)
70+
.map_err(crate::error::upgrade)
71+
.await
72+
}
73+
}

src/error.rs

+6
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ impl fmt::Display for Error {
185185
Kind::Body => f.write_str("request or response body error")?,
186186
Kind::Decode => f.write_str("error decoding response body")?,
187187
Kind::Redirect => f.write_str("error following redirect")?,
188+
Kind::Upgrade => f.write_str("error upgrading connection")?,
188189
Kind::Status(ref code) => {
189190
let prefix = if code.is_client_error() {
190191
"HTTP status client error"
@@ -236,6 +237,7 @@ pub(crate) enum Kind {
236237
Status(StatusCode),
237238
Body,
238239
Decode,
240+
Upgrade,
239241
}
240242

241243
// constructors
@@ -274,6 +276,10 @@ if_wasm! {
274276
}
275277
}
276278

279+
pub(crate) fn upgrade<E: Into<BoxError>>(e: E) -> Error {
280+
Error::new(Kind::Upgrade, Some(e))
281+
}
282+
277283
// io::Error helpers
278284

279285
#[allow(unused)]

tests/upgrade.rs

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#![cfg(not(target_arch = "wasm32"))]
2+
mod support;
3+
use support::*;
4+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
5+
6+
#[tokio::test]
7+
async fn http_upgrade() {
8+
let server = server::http(move |req| {
9+
assert_eq!(req.method(), "GET");
10+
assert_eq!(req.headers()["connection"], "upgrade");
11+
assert_eq!(req.headers()["upgrade"], "foobar");
12+
13+
tokio::spawn(async move {
14+
let mut upgraded = hyper::upgrade::on(req).await.unwrap();
15+
16+
let mut buf = vec![0; 7];
17+
upgraded.read_exact(&mut buf).await.unwrap();
18+
assert_eq!(buf, b"foo=bar");
19+
20+
upgraded.write_all(b"bar=foo").await.unwrap();
21+
});
22+
23+
async {
24+
http::Response::builder()
25+
.status(http::StatusCode::SWITCHING_PROTOCOLS)
26+
.header(http::header::CONNECTION, "upgrade")
27+
.header(http::header::UPGRADE, "foobar")
28+
.body(hyper::Body::empty())
29+
.unwrap()
30+
}
31+
});
32+
33+
let res = reqwest::Client::builder()
34+
.build()
35+
.unwrap()
36+
.get(format!("http://{}", server.addr()))
37+
.header(http::header::CONNECTION, "upgrade")
38+
.header(http::header::UPGRADE, "foobar")
39+
.send()
40+
.await
41+
.unwrap();
42+
43+
assert_eq!(res.status(), http::StatusCode::SWITCHING_PROTOCOLS);
44+
let mut upgraded = res.upgrade().await.unwrap();
45+
46+
upgraded.write_all(b"foo=bar").await.unwrap();
47+
48+
let mut buf = vec![];
49+
upgraded.read_to_end(&mut buf).await.unwrap();
50+
assert_eq!(buf, b"bar=foo");
51+
}

0 commit comments

Comments
 (0)