Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 95 additions & 12 deletions dragonfly-client-backend/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,19 @@ impl HTTP {
return;
}

debug!(
"caching temporary redirect {} -> {}",
original_url, target_url
);
// Only cache absolute URLs. A relative path (e.g. "/new/path") cannot be used
// as a standalone redirect target in subsequent requests, so skip caching it.
match Url::parse(target_url) {
Ok(parsed) if parsed.has_host() => {}
_ => {
debug!(
"skipping cache for relative redirect {} -> {}",
original_url, target_url
);

return;
}
}

let mut temporary_redirects = self.temporary_redirects.lock().await;
temporary_redirects.put(
Expand All @@ -351,6 +360,11 @@ impl HTTP {
created_at: Instant::now(),
},
);

debug!(
"caching temporary redirect {} -> {}",
original_url, target_url
);
}
}

Expand Down Expand Up @@ -411,26 +425,28 @@ impl Backend for HTTP {
{
Ok(response) if response.status() == reqwest::StatusCode::TEMPORARY_REDIRECT => {
if let Some(location) = response.headers().get(LOCATION) {
let redirect_url = location.to_str().or_err(ErrorType::ParseError)?;
let location = location.to_str().or_err(ErrorType::ParseError)?;
let base_url = Url::parse(&request.url).or_err(ErrorType::ParseError)?;
let redirect_url = base_url.join(location).or_err(ErrorType::ParseError)?;
debug!(
"stat request got 307 Temporary Redirect, following redirect {} -> {}",
request.url, redirect_url
);

self.store_temporary_redirect_url(&request.url, redirect_url)
self.store_temporary_redirect_url(&request.url, location.as_ref())
.await;

// Strips sensitive headers when following a cross-origin redirect.
let mut redirect_headers = request_header.clone();
remove_sensitive_headers(
&mut redirect_headers,
&redirect_url.parse()?,
&redirect_url,
&request.url.parse()?,
);

match self
.client(request.client_cert.clone(), self.enable_hickory_dns)?
.get(redirect_url)
.get(redirect_url.clone())
.headers(redirect_headers)
.timeout(request.timeout)
.send()
Expand Down Expand Up @@ -613,26 +629,28 @@ impl Backend for HTTP {
// If the response is a 307 Temporary Redirect, follow the redirect manually.
if response.status() == reqwest::StatusCode::TEMPORARY_REDIRECT {
if let Some(location) = response.headers().get(LOCATION) {
let redirect_url = location.to_str().or_err(ErrorType::ParseError)?;
let location = location.to_str().or_err(ErrorType::ParseError)?;
let base_url = Url::parse(&request.url).or_err(ErrorType::ParseError)?;
let redirect_url = base_url.join(location).or_err(ErrorType::ParseError)?;
debug!(
"get request got 307 Temporary Redirect, following redirect {} -> {}",
request.url, redirect_url
);

self.store_temporary_redirect_url(&request.url, redirect_url)
self.store_temporary_redirect_url(&request.url, location.as_ref())
.await;

// Strips sensitive headers when following a cross-origin redirect.
let mut redirect_headers = request_header.clone();
remove_sensitive_headers(
&mut redirect_headers,
&redirect_url.parse()?,
&redirect_url,
&request.url.parse()?,
);

response = match self
.client(request.client_cert.clone(), self.enable_hickory_dns)?
.get(redirect_url)
.get(redirect_url.clone())
.headers(redirect_headers)
.timeout(request.timeout)
.send()
Expand Down Expand Up @@ -1462,6 +1480,71 @@ LJ8gCHKBOJy9dW62DcRWw6zzlTtt9y18/Btx0Hpawg==
assert_eq!(response.text().await.unwrap(), "target content");
}

#[tokio::test]
async fn should_not_cache_relative_307_redirect_location() {
let server = wiremock::MockServer::start().await;
Mock::given(method("GET"))
.and(path("/target"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("target content")
.insert_header("Content-Type", "text/plain"),
)
.mount(&server)
.await;

// Return a 307 with a relative Location path each time it is called.
Mock::given(method("GET"))
.and(path("/redirect"))
.respond_with(ResponseTemplate::new(307).insert_header("Location", "/target"))
.expect(2)
.mount(&server)
.await;

let backend = HTTP::new(HTTP_SCHEME, None, true, Duration::from_secs(600), true).unwrap();

// First request - relative Location should NOT be cached.
let mut response = backend
.get(GetRequest {
task_id: "test".to_string(),
piece_id: "1".to_string(),
url: format!("{}/redirect", server.uri()),
range: None,
http_header: Some(HeaderMap::new()),
timeout: Duration::from_secs(5),
client_cert: None,
object_storage: None,
hdfs: None,
hugging_face: None,
model_scope: None,
})
.await
.unwrap();
assert_eq!(response.http_status_code, Some(StatusCode::OK));
assert_eq!(response.text().await.unwrap(), "target content");

// Second request - because the relative URL was not cached, the origin server
// must be contacted again (wiremock expects exactly 2 calls to /redirect).
let mut response = backend
.get(GetRequest {
task_id: "test".to_string(),
piece_id: "1".to_string(),
url: format!("{}/redirect", server.uri()),
range: None,
http_header: Some(HeaderMap::new()),
timeout: Duration::from_secs(5),
client_cert: None,
object_storage: None,
hdfs: None,
hugging_face: None,
model_scope: None,
})
.await
.unwrap();
assert_eq!(response.http_status_code, Some(StatusCode::OK));
assert_eq!(response.text().await.unwrap(), "target content");
}

#[tokio::test]
async fn should_expire_cached_redirect_after_ttl() {
let server = wiremock::MockServer::start().await;
Expand Down
Loading