Skip to content

Commit ae2d521

Browse files
authored
Add ability to specify multiple IP addresses for resolver overrides (#1622)
This change allows the `ClientBuilder::resolve_to_addrs` method to accept a slice of `SocketAddr`s for overriding resolution for a single domain. Allowing multiple IPs more accurately reflects behavior of `getaddrinfo` and allows users to rely on hyper's happy eyeballs algorithm to connect to a host that can accept traffic on IPv4 and IPv6.
1 parent 6ceb239 commit ae2d521

File tree

4 files changed

+110
-14
lines changed

4 files changed

+110
-14
lines changed

src/async_impl/client.rs

+18-4
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ struct Config {
120120
trust_dns: bool,
121121
error: Option<crate::Error>,
122122
https_only: bool,
123-
dns_overrides: HashMap<String, SocketAddr>,
123+
dns_overrides: HashMap<String, Vec<SocketAddr>>,
124124
}
125125

126126
impl Default for ClientBuilder {
@@ -1314,16 +1314,30 @@ impl ClientBuilder {
13141314
self
13151315
}
13161316

1317-
/// Override DNS resolution for specific domains to particular IP addresses.
1317+
/// Override DNS resolution for specific domains to a particular IP address.
13181318
///
13191319
/// Warning
13201320
///
13211321
/// Since the DNS protocol has no notion of ports, if you wish to send
13221322
/// traffic to a particular port you must include this port in the URL
13231323
/// itself, any port in the overridden addr will be ignored and traffic sent
13241324
/// to the conventional port for the given scheme (e.g. 80 for http).
1325-
pub fn resolve(mut self, domain: &str, addr: SocketAddr) -> ClientBuilder {
1326-
self.config.dns_overrides.insert(domain.to_string(), addr);
1325+
pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder {
1326+
self.resolve_to_addrs(domain, &[addr])
1327+
}
1328+
1329+
/// Override DNS resolution for specific domains to particular IP addresses.
1330+
///
1331+
/// Warning
1332+
///
1333+
/// Since the DNS protocol has no notion of ports, if you wish to send
1334+
/// traffic to a particular port you must include this port in the URL
1335+
/// itself, any port in the overridden addresses will be ignored and traffic sent
1336+
/// to the conventional port for the given scheme (e.g. 80 for http).
1337+
pub fn resolve_to_addrs(mut self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder {
1338+
self.config
1339+
.dns_overrides
1340+
.insert(domain.to_string(), addrs.to_vec());
13271341
self
13281342
}
13291343
}

src/blocking/client.rs

+14-2
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ impl ClientBuilder {
757757
self.with_inner(|inner| inner.https_only(enabled))
758758
}
759759

760-
/// Override DNS resolution for specific domains to particular IP addresses.
760+
/// Override DNS resolution for specific domains to a particular IP address.
761761
///
762762
/// Warning
763763
///
@@ -766,7 +766,19 @@ impl ClientBuilder {
766766
/// itself, any port in the overridden addr will be ignored and traffic sent
767767
/// to the conventional port for the given scheme (e.g. 80 for http).
768768
pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder {
769-
self.with_inner(|inner| inner.resolve(domain, addr))
769+
self.resolve_to_addrs(domain, &[addr])
770+
}
771+
772+
/// Override DNS resolution for specific domains to particular IP addresses.
773+
///
774+
/// Warning
775+
///
776+
/// Since the DNS protocol has no notion of ports, if you wish to send
777+
/// traffic to a particular port you must include this port in the URL
778+
/// itself, any port in the overridden addresses will be ignored and traffic sent
779+
/// to the conventional port for the given scheme (e.g. 80 for http).
780+
pub fn resolve_to_addrs(self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder {
781+
self.with_inner(|inner| inner.resolve_to_addrs(domain, addrs))
770782
}
771783

772784
// private

src/connect.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ impl HttpConnector {
4646
Self::Gai(hyper::client::HttpConnector::new())
4747
}
4848

49-
pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, SocketAddr>) -> Self {
49+
pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, Vec<SocketAddr>>) -> Self {
5050
let gai = hyper::client::connect::dns::GaiResolver::new();
5151
let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides);
5252
Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver(
@@ -64,7 +64,7 @@ impl HttpConnector {
6464

6565
#[cfg(feature = "trust-dns")]
6666
pub(crate) fn new_trust_dns_with_overrides(
67-
overrides: HashMap<String, SocketAddr>,
67+
overrides: HashMap<String, Vec<SocketAddr>>,
6868
) -> crate::Result<HttpConnector> {
6969
TrustDnsResolver::new()
7070
.map(|resolver| DnsResolverWithOverrides::new(resolver, overrides))
@@ -994,7 +994,7 @@ where
994994
Fut: std::future::Future<Output = Result<FutOutput, FutError>>,
995995
FutOutput: Iterator<Item = SocketAddr>,
996996
{
997-
type Output = Result<itertools::Either<FutOutput, std::iter::Once<SocketAddr>>, FutError>;
997+
type Output = Result<itertools::Either<FutOutput, std::vec::IntoIter<SocketAddr>>, FutError>;
998998

999999
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
10001000
let this = self.project();
@@ -1010,11 +1010,11 @@ where
10101010
Resolver: Clone,
10111011
{
10121012
dns_resolver: Resolver,
1013-
overrides: Arc<HashMap<String, SocketAddr>>,
1013+
overrides: Arc<HashMap<String, Vec<SocketAddr>>>,
10141014
}
10151015

10161016
impl<Resolver: Clone> DnsResolverWithOverrides<Resolver> {
1017-
fn new(dns_resolver: Resolver, overrides: HashMap<String, SocketAddr>) -> Self {
1017+
fn new(dns_resolver: Resolver, overrides: HashMap<String, Vec<SocketAddr>>) -> Self {
10181018
DnsResolverWithOverrides {
10191019
dns_resolver,
10201020
overrides: Arc::new(overrides),
@@ -1027,12 +1027,12 @@ where
10271027
Resolver: Service<Name, Response = Iter> + Clone,
10281028
Iter: Iterator<Item = SocketAddr>,
10291029
{
1030-
type Response = itertools::Either<Iter, std::iter::Once<SocketAddr>>;
1030+
type Response = itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>;
10311031
type Error = <Resolver as Service<Name>>::Error;
10321032
type Future = Either<
10331033
WrappedResolverFuture<<Resolver as Service<Name>>::Future>,
10341034
futures_util::future::Ready<
1035-
Result<itertools::Either<Iter, std::iter::Once<SocketAddr>>, Self::Error>,
1035+
Result<itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>, Self::Error>,
10361036
>,
10371037
>;
10381038

@@ -1044,7 +1044,7 @@ where
10441044
match self.overrides.get(name.as_str()) {
10451045
Some(dest) => {
10461046
let fut = futures_util::future::ready(Ok(itertools::Either::Right(
1047-
std::iter::once(dest.to_owned()),
1047+
dest.clone().into_iter(),
10481048
)));
10491049
Either::Right(fut)
10501050
}

tests/client.rs

+70
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,40 @@ async fn overridden_dns_resolution_with_gai() {
190190
assert_eq!("Hello", text);
191191
}
192192

193+
#[tokio::test]
194+
async fn overridden_dns_resolution_with_gai_multiple() {
195+
let _ = env_logger::builder().is_test(true).try_init();
196+
let server = server::http(move |_req| async { http::Response::new("Hello".into()) });
197+
198+
let overridden_domain = "rust-lang.org";
199+
let url = format!(
200+
"http://{}:{}/domain_override",
201+
overridden_domain,
202+
server.addr().port()
203+
);
204+
// the server runs on IPv4 localhost, so provide both IPv4 and IPv6 and let the happy eyeballs
205+
// algorithm decide which address to use.
206+
let client = reqwest::Client::builder()
207+
.resolve_to_addrs(
208+
overridden_domain,
209+
&[
210+
std::net::SocketAddr::new(
211+
std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
212+
server.addr().port(),
213+
),
214+
server.addr(),
215+
],
216+
)
217+
.build()
218+
.expect("client builder");
219+
let req = client.get(&url);
220+
let res = req.send().await.expect("request");
221+
222+
assert_eq!(res.status(), reqwest::StatusCode::OK);
223+
let text = res.text().await.expect("Failed to get text");
224+
assert_eq!("Hello", text);
225+
}
226+
193227
#[cfg(feature = "trust-dns")]
194228
#[tokio::test]
195229
async fn overridden_dns_resolution_with_trust_dns() {
@@ -215,6 +249,42 @@ async fn overridden_dns_resolution_with_trust_dns() {
215249
assert_eq!("Hello", text);
216250
}
217251

252+
#[cfg(feature = "trust-dns")]
253+
#[tokio::test]
254+
async fn overridden_dns_resolution_with_trust_dns_multiple() {
255+
let _ = env_logger::builder().is_test(true).try_init();
256+
let server = server::http(move |_req| async { http::Response::new("Hello".into()) });
257+
258+
let overridden_domain = "rust-lang.org";
259+
let url = format!(
260+
"http://{}:{}/domain_override",
261+
overridden_domain,
262+
server.addr().port()
263+
);
264+
// the server runs on IPv4 localhost, so provide both IPv4 and IPv6 and let the happy eyeballs
265+
// algorithm decide which address to use.
266+
let client = reqwest::Client::builder()
267+
.resolve_to_addrs(
268+
overridden_domain,
269+
&[
270+
std::net::SocketAddr::new(
271+
std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
272+
server.addr().port(),
273+
),
274+
server.addr(),
275+
],
276+
)
277+
.trust_dns(true)
278+
.build()
279+
.expect("client builder");
280+
let req = client.get(&url);
281+
let res = req.send().await.expect("request");
282+
283+
assert_eq!(res.status(), reqwest::StatusCode::OK);
284+
let text = res.text().await.expect("Failed to get text");
285+
assert_eq!("Hello", text);
286+
}
287+
218288
#[cfg(any(feature = "native-tls", feature = "__rustls",))]
219289
#[test]
220290
fn use_preconfigured_tls_with_bogus_backend() {

0 commit comments

Comments
 (0)