Skip to content

Commit ba2a144

Browse files
authored
fix(client): strip path from Uri before calling Connector (#2109)
1 parent a5720fa commit ba2a144

File tree

4 files changed

+59
-34
lines changed

4 files changed

+59
-34
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ serde_derive = "1.0"
5151
serde_json = "1.0"
5252
tokio = { version = "0.2.2", features = ["fs", "macros", "io-std", "rt-util", "sync", "time", "test-util"] }
5353
tokio-test = "0.2"
54+
tower-util = "0.3"
5455
url = "1.0"
5556

5657
[features]

src/client/mod.rs

+21-15
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
5151
use std::fmt;
5252
use std::mem;
53-
use std::sync::Arc;
5453
use std::time::Duration;
5554

5655
use futures_channel::oneshot;
@@ -230,14 +229,13 @@ where
230229
other => return ResponseFuture::error_version(other),
231230
};
232231

233-
let domain = match extract_domain(req.uri_mut(), is_http_connect) {
232+
let pool_key = match extract_domain(req.uri_mut(), is_http_connect) {
234233
Ok(s) => s,
235234
Err(err) => {
236235
return ResponseFuture::new(Box::new(future::err(err)));
237236
}
238237
};
239238

240-
let pool_key = Arc::new(domain);
241239
ResponseFuture::new(Box::new(self.retryably_send_request(req, pool_key)))
242240
}
243241

@@ -281,7 +279,7 @@ where
281279
mut req: Request<B>,
282280
pool_key: PoolKey,
283281
) -> impl Future<Output = Result<Response<Body>, ClientError<B>>> + Unpin {
284-
let conn = self.connection_for(req.uri().clone(), pool_key);
282+
let conn = self.connection_for(pool_key);
285283

286284
let set_host = self.config.set_host;
287285
let executor = self.conn_builder.exec.clone();
@@ -377,7 +375,6 @@ where
377375

378376
fn connection_for(
379377
&self,
380-
uri: Uri,
381378
pool_key: PoolKey,
382379
) -> impl Future<Output = Result<Pooled<PoolClient<B>>, ClientError<B>>> {
383380
// This actually races 2 different futures to try to get a ready
@@ -394,7 +391,7 @@ where
394391
// connection future is spawned into the runtime to complete,
395392
// and then be inserted into the pool as an idle connection.
396393
let checkout = self.pool.checkout(pool_key.clone());
397-
let connect = self.connect_to(uri, pool_key);
394+
let connect = self.connect_to(pool_key);
398395

399396
let executor = self.conn_builder.exec.clone();
400397
// The order of the `select` is depended on below...
@@ -455,7 +452,6 @@ where
455452

456453
fn connect_to(
457454
&self,
458-
uri: Uri,
459455
pool_key: PoolKey,
460456
) -> impl Lazy<Output = crate::Result<Pooled<PoolClient<B>>>> + Unpin {
461457
let executor = self.conn_builder.exec.clone();
@@ -464,7 +460,7 @@ where
464460
let ver = self.config.ver;
465461
let is_ver_h2 = ver == Ver::Http2;
466462
let connector = self.connector.clone();
467-
let dst = uri;
463+
let dst = domain_as_uri(pool_key.clone());
468464
hyper_lazy(move || {
469465
// Try to take a "connecting lock".
470466
//
@@ -794,22 +790,22 @@ fn authority_form(uri: &mut Uri) {
794790
};
795791
}
796792

797-
fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result<String> {
793+
fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result<PoolKey> {
798794
let uri_clone = uri.clone();
799795
match (uri_clone.scheme(), uri_clone.authority()) {
800-
(Some(scheme), Some(auth)) => Ok(format!("{}://{}", scheme, auth)),
796+
(Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())),
801797
(None, Some(auth)) if is_http_connect => {
802798
let scheme = match auth.port_u16() {
803799
Some(443) => {
804800
set_scheme(uri, Scheme::HTTPS);
805-
"https"
801+
Scheme::HTTPS
806802
}
807803
_ => {
808804
set_scheme(uri, Scheme::HTTP);
809-
"http"
805+
Scheme::HTTP
810806
}
811807
};
812-
Ok(format!("{}://{}", scheme, auth))
808+
Ok((scheme, auth.clone()))
813809
}
814810
_ => {
815811
debug!("Client requires absolute-form URIs, received: {:?}", uri);
@@ -818,6 +814,15 @@ fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> crate::Result<String>
818814
}
819815
}
820816

817+
fn domain_as_uri((scheme, auth): PoolKey) -> Uri {
818+
http::uri::Builder::new()
819+
.scheme(scheme)
820+
.authority(auth)
821+
.path_and_query("/")
822+
.build()
823+
.expect("domain is valid Uri")
824+
}
825+
821826
fn set_scheme(uri: &mut Uri, scheme: Scheme) {
822827
debug_assert!(
823828
uri.scheme().is_none(),
@@ -1126,7 +1131,8 @@ mod unit_tests {
11261131
#[test]
11271132
fn test_extract_domain_connect_no_port() {
11281133
let mut uri = "hyper.rs".parse().unwrap();
1129-
let domain = extract_domain(&mut uri, true).expect("extract domain");
1130-
assert_eq!(domain, "http://hyper.rs");
1134+
let (scheme, host) = extract_domain(&mut uri, true).expect("extract domain");
1135+
assert_eq!(scheme, *"http");
1136+
assert_eq!(host, "hyper.rs");
11311137
}
11321138
}

src/client/pool.rs

+13-10
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pub(super) enum Reservation<T> {
5252
}
5353

5454
/// Simple type alias in case the key type needs to be adjusted.
55-
pub(super) type Key = Arc<String>;
55+
pub(super) type Key = (http::uri::Scheme, http::uri::Authority); //Arc<String>;
5656

5757
struct PoolInner<T> {
5858
// A flag that a connection is being established, and the connection
@@ -755,7 +755,6 @@ impl<T> WeakOpt<T> {
755755

756756
#[cfg(test)]
757757
mod tests {
758-
use std::sync::Arc;
759758
use std::task::Poll;
760759
use std::time::Duration;
761760

@@ -787,6 +786,10 @@ mod tests {
787786
}
788787
}
789788

789+
fn host_key(s: &str) -> Key {
790+
(http::uri::Scheme::HTTP, s.parse().expect("host key"))
791+
}
792+
790793
fn pool_no_timer<T>() -> Pool<T> {
791794
pool_max_idle_no_timer(::std::usize::MAX)
792795
}
@@ -807,7 +810,7 @@ mod tests {
807810
#[tokio::test]
808811
async fn test_pool_checkout_smoke() {
809812
let pool = pool_no_timer();
810-
let key = Arc::new("foo".to_string());
813+
let key = host_key("foo");
811814
let pooled = pool.pooled(c(key.clone()), Uniq(41));
812815

813816
drop(pooled);
@@ -839,7 +842,7 @@ mod tests {
839842
#[tokio::test]
840843
async fn test_pool_checkout_returns_none_if_expired() {
841844
let pool = pool_no_timer();
842-
let key = Arc::new("foo".to_string());
845+
let key = host_key("foo");
843846
let pooled = pool.pooled(c(key.clone()), Uniq(41));
844847

845848
drop(pooled);
@@ -854,7 +857,7 @@ mod tests {
854857
#[tokio::test]
855858
async fn test_pool_checkout_removes_expired() {
856859
let pool = pool_no_timer();
857-
let key = Arc::new("foo".to_string());
860+
let key = host_key("foo");
858861

859862
pool.pooled(c(key.clone()), Uniq(41));
860863
pool.pooled(c(key.clone()), Uniq(5));
@@ -876,7 +879,7 @@ mod tests {
876879
#[test]
877880
fn test_pool_max_idle_per_host() {
878881
let pool = pool_max_idle_no_timer(2);
879-
let key = Arc::new("foo".to_string());
882+
let key = host_key("foo");
880883

881884
pool.pooled(c(key.clone()), Uniq(41));
882885
pool.pooled(c(key.clone()), Uniq(5));
@@ -904,7 +907,7 @@ mod tests {
904907
&Exec::Default,
905908
);
906909

907-
let key = Arc::new("foo".to_string());
910+
let key = host_key("foo");
908911

909912
pool.pooled(c(key.clone()), Uniq(41));
910913
pool.pooled(c(key.clone()), Uniq(5));
@@ -929,7 +932,7 @@ mod tests {
929932
use futures_util::FutureExt;
930933

931934
let pool = pool_no_timer();
932-
let key = Arc::new("foo".to_string());
935+
let key = host_key("foo");
933936
let pooled = pool.pooled(c(key.clone()), Uniq(41));
934937

935938
let checkout = join(pool.checkout(key), async {
@@ -948,7 +951,7 @@ mod tests {
948951
#[tokio::test]
949952
async fn test_pool_checkout_drop_cleans_up_waiters() {
950953
let pool = pool_no_timer::<Uniq<i32>>();
951-
let key = Arc::new("localhost:12345".to_string());
954+
let key = host_key("foo");
952955

953956
let mut checkout1 = pool.checkout(key.clone());
954957
let mut checkout2 = pool.checkout(key.clone());
@@ -993,7 +996,7 @@ mod tests {
993996
#[test]
994997
fn pooled_drop_if_closed_doesnt_reinsert() {
995998
let pool = pool_no_timer();
996-
let key = Arc::new("localhost:12345".to_string());
999+
let key = host_key("foo");
9971000
pool.pooled(
9981001
c(key.clone()),
9991002
CanClose {

src/client/tests.rs

+24-9
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
1-
// FIXME: re-implement tests with `async/await`
2-
/*
3-
#![cfg(feature = "runtime")]
1+
use std::io;
2+
3+
use futures_util::future;
4+
use tokio::net::TcpStream;
45

5-
use futures::{Async, Future, Stream};
6-
use futures::future::poll_fn;
7-
use futures::sync::oneshot;
8-
use tokio::runtime::current_thread::Runtime;
6+
use super::Client;
97

10-
use crate::mock::MockConnector;
11-
use super::*;
8+
#[tokio::test]
9+
async fn client_connect_uri_argument() {
10+
let connector = tower_util::service_fn(|dst: http::Uri| {
11+
assert_eq!(dst.scheme(), Some(&http::uri::Scheme::HTTP));
12+
assert_eq!(dst.host(), Some("example.local"));
13+
assert_eq!(dst.port(), None);
14+
assert_eq!(dst.path(), "/", "path should be removed");
15+
16+
future::err::<TcpStream, _>(io::Error::new(io::ErrorKind::Other, "expect me"))
17+
});
1218

19+
let client = Client::builder().build::<_, crate::Body>(connector);
20+
let _ = client
21+
.get("http://example.local/and/a/path".parse().unwrap())
22+
.await
23+
.expect_err("response should fail");
24+
}
25+
26+
/*
27+
// FIXME: re-implement tests with `async/await`
1328
#[test]
1429
fn retryable_request() {
1530
let _ = pretty_env_logger::try_init();

0 commit comments

Comments
 (0)