|
| 1 | +//! DNS Resolution used by the `HttpConnector`. |
| 2 | +//! |
| 3 | +//! This module contains: |
| 4 | +//! |
| 5 | +//! - A [`GaiResolver`](GaiResolver) that is the default resolver for the |
| 6 | +//! `HttpConnector`. |
| 7 | +//! - The `Name` type used as an argument to custom resolvers. |
| 8 | +//! |
| 9 | +//! # Resolvers are `Service`s |
| 10 | +//! |
| 11 | +//! A resolver is just a |
| 12 | +//! `Service<Name, Response = impl Iterator<Item = SocketAddr>>`. |
| 13 | +//! |
| 14 | +//! A simple resolver that ignores the name and always returns a specific |
| 15 | +//! address: |
| 16 | +//! |
| 17 | +//! ```rust,ignore |
| 18 | +//! use std::{convert::Infallible, iter, net::SocketAddr}; |
| 19 | +//! |
| 20 | +//! let resolver = tower::service_fn(|_name| async { |
| 21 | +//! Ok::<_, Infallible>(iter::once(SocketAddr::from(([127, 0, 0, 1], 8080)))) |
| 22 | +//! }); |
| 23 | +//! ``` |
| 24 | +use std::error::Error; |
| 25 | +use std::future::Future; |
| 26 | +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; |
| 27 | +use std::pin::Pin; |
| 28 | +use std::str::FromStr; |
| 29 | +use std::task::{self, Poll}; |
| 30 | +use std::{fmt, io, vec}; |
| 31 | + |
| 32 | +use tokio::task::JoinHandle; |
| 33 | +use tower_service::Service; |
| 34 | +use tracing::debug; |
| 35 | + |
| 36 | +pub(super) use self::sealed::Resolve; |
| 37 | + |
| 38 | +/// A domain name to resolve into IP addresses. |
| 39 | +#[derive(Clone, Hash, Eq, PartialEq)] |
| 40 | +pub struct Name { |
| 41 | + host: Box<str>, |
| 42 | +} |
| 43 | + |
| 44 | +/// A resolver using blocking `getaddrinfo` calls in a threadpool. |
| 45 | +#[derive(Clone)] |
| 46 | +pub struct GaiResolver { |
| 47 | + _priv: (), |
| 48 | +} |
| 49 | + |
| 50 | +/// An iterator of IP addresses returned from `getaddrinfo`. |
| 51 | +pub struct GaiAddrs { |
| 52 | + inner: SocketAddrs, |
| 53 | +} |
| 54 | + |
| 55 | +/// A future to resolve a name returned by `GaiResolver`. |
| 56 | +pub struct GaiFuture { |
| 57 | + inner: JoinHandle<Result<SocketAddrs, io::Error>>, |
| 58 | +} |
| 59 | + |
| 60 | +impl Name { |
| 61 | + pub(super) fn new(host: Box<str>) -> Name { |
| 62 | + Name { host } |
| 63 | + } |
| 64 | + |
| 65 | + /// View the hostname as a string slice. |
| 66 | + pub fn as_str(&self) -> &str { |
| 67 | + &self.host |
| 68 | + } |
| 69 | +} |
| 70 | + |
| 71 | +impl fmt::Debug for Name { |
| 72 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 73 | + fmt::Debug::fmt(&self.host, f) |
| 74 | + } |
| 75 | +} |
| 76 | + |
| 77 | +impl fmt::Display for Name { |
| 78 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 79 | + fmt::Display::fmt(&self.host, f) |
| 80 | + } |
| 81 | +} |
| 82 | + |
| 83 | +impl FromStr for Name { |
| 84 | + type Err = InvalidNameError; |
| 85 | + |
| 86 | + fn from_str(host: &str) -> Result<Self, Self::Err> { |
| 87 | + // Possibly add validation later |
| 88 | + Ok(Name::new(host.into())) |
| 89 | + } |
| 90 | +} |
| 91 | + |
| 92 | +/// Error indicating a given string was not a valid domain name. |
| 93 | +#[derive(Debug)] |
| 94 | +pub struct InvalidNameError(()); |
| 95 | + |
| 96 | +impl fmt::Display for InvalidNameError { |
| 97 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 98 | + f.write_str("Not a valid domain name") |
| 99 | + } |
| 100 | +} |
| 101 | + |
| 102 | +impl Error for InvalidNameError {} |
| 103 | + |
| 104 | +impl GaiResolver { |
| 105 | + /// Construct a new `GaiResolver`. |
| 106 | + pub fn new() -> Self { |
| 107 | + GaiResolver { _priv: () } |
| 108 | + } |
| 109 | +} |
| 110 | + |
| 111 | +impl Service<Name> for GaiResolver { |
| 112 | + type Response = GaiAddrs; |
| 113 | + type Error = io::Error; |
| 114 | + type Future = GaiFuture; |
| 115 | + |
| 116 | + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> { |
| 117 | + Poll::Ready(Ok(())) |
| 118 | + } |
| 119 | + |
| 120 | + fn call(&mut self, name: Name) -> Self::Future { |
| 121 | + let blocking = tokio::task::spawn_blocking(move || { |
| 122 | + debug!("resolving host={:?}", name.host); |
| 123 | + (&*name.host, 0) |
| 124 | + .to_socket_addrs() |
| 125 | + .map(|i| SocketAddrs { iter: i }) |
| 126 | + }); |
| 127 | + |
| 128 | + GaiFuture { inner: blocking } |
| 129 | + } |
| 130 | +} |
| 131 | + |
| 132 | +impl fmt::Debug for GaiResolver { |
| 133 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 134 | + f.pad("GaiResolver") |
| 135 | + } |
| 136 | +} |
| 137 | + |
| 138 | +impl Future for GaiFuture { |
| 139 | + type Output = Result<GaiAddrs, io::Error>; |
| 140 | + |
| 141 | + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { |
| 142 | + Pin::new(&mut self.inner).poll(cx).map(|res| match res { |
| 143 | + Ok(Ok(addrs)) => Ok(GaiAddrs { inner: addrs }), |
| 144 | + Ok(Err(err)) => Err(err), |
| 145 | + Err(join_err) => { |
| 146 | + if join_err.is_cancelled() { |
| 147 | + Err(io::Error::new(io::ErrorKind::Interrupted, join_err)) |
| 148 | + } else { |
| 149 | + panic!("gai background task failed: {:?}", join_err) |
| 150 | + } |
| 151 | + } |
| 152 | + }) |
| 153 | + } |
| 154 | +} |
| 155 | + |
| 156 | +impl fmt::Debug for GaiFuture { |
| 157 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 158 | + f.pad("GaiFuture") |
| 159 | + } |
| 160 | +} |
| 161 | + |
| 162 | +impl Drop for GaiFuture { |
| 163 | + fn drop(&mut self) { |
| 164 | + self.inner.abort(); |
| 165 | + } |
| 166 | +} |
| 167 | + |
| 168 | +impl Iterator for GaiAddrs { |
| 169 | + type Item = SocketAddr; |
| 170 | + |
| 171 | + fn next(&mut self) -> Option<Self::Item> { |
| 172 | + self.inner.next() |
| 173 | + } |
| 174 | +} |
| 175 | + |
| 176 | +impl fmt::Debug for GaiAddrs { |
| 177 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 178 | + f.pad("GaiAddrs") |
| 179 | + } |
| 180 | +} |
| 181 | + |
| 182 | +pub(super) struct SocketAddrs { |
| 183 | + iter: vec::IntoIter<SocketAddr>, |
| 184 | +} |
| 185 | + |
| 186 | +impl SocketAddrs { |
| 187 | + pub(super) fn new(addrs: Vec<SocketAddr>) -> Self { |
| 188 | + SocketAddrs { |
| 189 | + iter: addrs.into_iter(), |
| 190 | + } |
| 191 | + } |
| 192 | + |
| 193 | + pub(super) fn try_parse(host: &str, port: u16) -> Option<SocketAddrs> { |
| 194 | + if let Ok(addr) = host.parse::<Ipv4Addr>() { |
| 195 | + let addr = SocketAddrV4::new(addr, port); |
| 196 | + return Some(SocketAddrs { |
| 197 | + iter: vec![SocketAddr::V4(addr)].into_iter(), |
| 198 | + }); |
| 199 | + } |
| 200 | + if let Ok(addr) = host.parse::<Ipv6Addr>() { |
| 201 | + let addr = SocketAddrV6::new(addr, port, 0, 0); |
| 202 | + return Some(SocketAddrs { |
| 203 | + iter: vec![SocketAddr::V6(addr)].into_iter(), |
| 204 | + }); |
| 205 | + } |
| 206 | + None |
| 207 | + } |
| 208 | + |
| 209 | + #[inline] |
| 210 | + fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> SocketAddrs { |
| 211 | + SocketAddrs::new(self.iter.filter(predicate).collect()) |
| 212 | + } |
| 213 | + |
| 214 | + pub(super) fn split_by_preference( |
| 215 | + self, |
| 216 | + local_addr_ipv4: Option<Ipv4Addr>, |
| 217 | + local_addr_ipv6: Option<Ipv6Addr>, |
| 218 | + ) -> (SocketAddrs, SocketAddrs) { |
| 219 | + match (local_addr_ipv4, local_addr_ipv6) { |
| 220 | + (Some(_), None) => (self.filter(SocketAddr::is_ipv4), SocketAddrs::new(vec![])), |
| 221 | + (None, Some(_)) => (self.filter(SocketAddr::is_ipv6), SocketAddrs::new(vec![])), |
| 222 | + _ => { |
| 223 | + let preferring_v6 = self |
| 224 | + .iter |
| 225 | + .as_slice() |
| 226 | + .first() |
| 227 | + .map(SocketAddr::is_ipv6) |
| 228 | + .unwrap_or(false); |
| 229 | + |
| 230 | + let (preferred, fallback) = self |
| 231 | + .iter |
| 232 | + .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6); |
| 233 | + |
| 234 | + (SocketAddrs::new(preferred), SocketAddrs::new(fallback)) |
| 235 | + } |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + pub(super) fn is_empty(&self) -> bool { |
| 240 | + self.iter.as_slice().is_empty() |
| 241 | + } |
| 242 | + |
| 243 | + pub(super) fn len(&self) -> usize { |
| 244 | + self.iter.as_slice().len() |
| 245 | + } |
| 246 | +} |
| 247 | + |
| 248 | +impl Iterator for SocketAddrs { |
| 249 | + type Item = SocketAddr; |
| 250 | + #[inline] |
| 251 | + fn next(&mut self) -> Option<SocketAddr> { |
| 252 | + self.iter.next() |
| 253 | + } |
| 254 | +} |
| 255 | + |
| 256 | +mod sealed { |
| 257 | + use std::future::Future; |
| 258 | + use std::task::{self, Poll}; |
| 259 | + |
| 260 | + use super::{Name, SocketAddr}; |
| 261 | + use tower_service::Service; |
| 262 | + |
| 263 | + // "Trait alias" for `Service<Name, Response = Addrs>` |
| 264 | + pub trait Resolve { |
| 265 | + type Addrs: Iterator<Item = SocketAddr>; |
| 266 | + type Error: Into<Box<dyn std::error::Error + Send + Sync>>; |
| 267 | + type Future: Future<Output = Result<Self::Addrs, Self::Error>>; |
| 268 | + |
| 269 | + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>; |
| 270 | + fn resolve(&mut self, name: Name) -> Self::Future; |
| 271 | + } |
| 272 | + |
| 273 | + impl<S> Resolve for S |
| 274 | + where |
| 275 | + S: Service<Name>, |
| 276 | + S::Response: Iterator<Item = SocketAddr>, |
| 277 | + S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, |
| 278 | + { |
| 279 | + type Addrs = S::Response; |
| 280 | + type Error = S::Error; |
| 281 | + type Future = S::Future; |
| 282 | + |
| 283 | + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 284 | + Service::poll_ready(self, cx) |
| 285 | + } |
| 286 | + |
| 287 | + fn resolve(&mut self, name: Name) -> Self::Future { |
| 288 | + Service::call(self, name) |
| 289 | + } |
| 290 | + } |
| 291 | +} |
| 292 | + |
| 293 | +pub(super) async fn resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error> |
| 294 | +where |
| 295 | + R: Resolve, |
| 296 | +{ |
| 297 | + futures_util::future::poll_fn(|cx| resolver.poll_ready(cx)).await?; |
| 298 | + resolver.resolve(name).await |
| 299 | +} |
| 300 | + |
| 301 | +#[cfg(test)] |
| 302 | +mod tests { |
| 303 | + use super::*; |
| 304 | + use std::net::{Ipv4Addr, Ipv6Addr}; |
| 305 | + |
| 306 | + #[test] |
| 307 | + fn test_ip_addrs_split_by_preference() { |
| 308 | + let ip_v4 = Ipv4Addr::new(127, 0, 0, 1); |
| 309 | + let ip_v6 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); |
| 310 | + let v4_addr = (ip_v4, 80).into(); |
| 311 | + let v6_addr = (ip_v6, 80).into(); |
| 312 | + |
| 313 | + let (mut preferred, mut fallback) = SocketAddrs { |
| 314 | + iter: vec![v4_addr, v6_addr].into_iter(), |
| 315 | + } |
| 316 | + .split_by_preference(None, None); |
| 317 | + assert!(preferred.next().unwrap().is_ipv4()); |
| 318 | + assert!(fallback.next().unwrap().is_ipv6()); |
| 319 | + |
| 320 | + let (mut preferred, mut fallback) = SocketAddrs { |
| 321 | + iter: vec![v6_addr, v4_addr].into_iter(), |
| 322 | + } |
| 323 | + .split_by_preference(None, None); |
| 324 | + assert!(preferred.next().unwrap().is_ipv6()); |
| 325 | + assert!(fallback.next().unwrap().is_ipv4()); |
| 326 | + |
| 327 | + let (mut preferred, mut fallback) = SocketAddrs { |
| 328 | + iter: vec![v4_addr, v6_addr].into_iter(), |
| 329 | + } |
| 330 | + .split_by_preference(Some(ip_v4), Some(ip_v6)); |
| 331 | + assert!(preferred.next().unwrap().is_ipv4()); |
| 332 | + assert!(fallback.next().unwrap().is_ipv6()); |
| 333 | + |
| 334 | + let (mut preferred, mut fallback) = SocketAddrs { |
| 335 | + iter: vec![v6_addr, v4_addr].into_iter(), |
| 336 | + } |
| 337 | + .split_by_preference(Some(ip_v4), Some(ip_v6)); |
| 338 | + assert!(preferred.next().unwrap().is_ipv6()); |
| 339 | + assert!(fallback.next().unwrap().is_ipv4()); |
| 340 | + |
| 341 | + let (mut preferred, fallback) = SocketAddrs { |
| 342 | + iter: vec![v4_addr, v6_addr].into_iter(), |
| 343 | + } |
| 344 | + .split_by_preference(Some(ip_v4), None); |
| 345 | + assert!(preferred.next().unwrap().is_ipv4()); |
| 346 | + assert!(fallback.is_empty()); |
| 347 | + |
| 348 | + let (mut preferred, fallback) = SocketAddrs { |
| 349 | + iter: vec![v4_addr, v6_addr].into_iter(), |
| 350 | + } |
| 351 | + .split_by_preference(None, Some(ip_v6)); |
| 352 | + assert!(preferred.next().unwrap().is_ipv6()); |
| 353 | + assert!(fallback.is_empty()); |
| 354 | + } |
| 355 | + |
| 356 | + #[test] |
| 357 | + fn test_name_from_str() { |
| 358 | + const DOMAIN: &str = "test.example.com"; |
| 359 | + let name = Name::from_str(DOMAIN).expect("Should be a valid domain"); |
| 360 | + assert_eq!(name.as_str(), DOMAIN); |
| 361 | + assert_eq!(name.to_string(), DOMAIN); |
| 362 | + } |
| 363 | +} |
0 commit comments