Skip to content

Commit 59769c9

Browse files
authored
Merge pull request #31 from wvanlint/header_provider
Introduce header provider trait
2 parents 63387fa + b06b37c commit 59769c9

File tree

6 files changed

+170
-15
lines changed

6 files changed

+170
-15
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
/target
22
/src/proto/
3+
/Cargo.lock

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ prost = "0.11.6"
1616
reqwest = { version = "0.11.13", default-features = false, features = ["rustls-tls"] }
1717
tokio = { version = "1", default-features = false, features = ["time"] }
1818
rand = "0.8.5"
19+
async-trait = "0.1.77"
1920

2021
[target.'cfg(genproto)'.build-dependencies]
2122
prost-build = { version = "0.11.3" }

src/client.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use prost::Message;
2-
use reqwest;
32
use reqwest::header::CONTENT_TYPE;
43
use reqwest::Client;
4+
use std::collections::HashMap;
55
use std::default::Default;
6+
use std::sync::Arc;
67

78
use crate::error::VssError;
9+
use crate::headers::{get_headermap, FixedHeaders, VssHeaderProvider};
810
use crate::types::{
911
DeleteObjectRequest, DeleteObjectResponse, GetObjectRequest, GetObjectResponse, ListKeyVersionsRequest,
1012
ListKeyVersionsResponse, PutObjectRequest, PutObjectResponse,
@@ -23,18 +25,27 @@ where
2325
base_url: String,
2426
client: Client,
2527
retry_policy: R,
28+
header_provider: Arc<dyn VssHeaderProvider>,
2629
}
2730

2831
impl<R: RetryPolicy<E = VssError>> VssClient<R> {
2932
/// Constructs a [`VssClient`] using `base_url` as the VSS server endpoint.
30-
pub fn new(base_url: &str, retry_policy: R) -> Self {
33+
pub fn new(base_url: String, retry_policy: R) -> Self {
3134
let client = Client::new();
3235
Self::from_client(base_url, client, retry_policy)
3336
}
3437

3538
/// Constructs a [`VssClient`] from a given [`reqwest::Client`], using `base_url` as the VSS server endpoint.
36-
pub fn from_client(base_url: &str, client: Client, retry_policy: R) -> Self {
37-
Self { base_url: String::from(base_url), client, retry_policy }
39+
pub fn from_client(base_url: String, client: Client, retry_policy: R) -> Self {
40+
Self { base_url, client, retry_policy, header_provider: Arc::new(FixedHeaders::new(HashMap::new())) }
41+
}
42+
43+
/// Constructs a [`VssClient`] using `base_url` as the VSS server endpoint.
44+
///
45+
/// HTTP headers will be provided by the given `header_provider`.
46+
pub fn new_with_headers(base_url: String, retry_policy: R, header_provider: Arc<dyn VssHeaderProvider>) -> Self {
47+
let client = Client::new();
48+
Self { base_url, client, retry_policy, header_provider }
3849
}
3950

4051
/// Returns the underlying base URL.
@@ -111,10 +122,17 @@ impl<R: RetryPolicy<E = VssError>> VssClient<R> {
111122

112123
async fn post_request<Rq: Message, Rs: Message + Default>(&self, request: &Rq, url: &str) -> Result<Rs, VssError> {
113124
let request_body = request.encode_to_vec();
125+
let headermap = self
126+
.header_provider
127+
.get_headers(&request_body)
128+
.await
129+
.and_then(get_headermap)
130+
.map_err(|e| VssError::AuthError(e.to_string()))?;
114131
let response_raw = self
115132
.client
116133
.post(url)
117134
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
135+
.headers(headermap)
118136
.body(request_body)
119137
.send()
120138
.await?;

src/headers/mod.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use async_trait::async_trait;
2+
use reqwest::header::HeaderMap;
3+
use std::collections::HashMap;
4+
use std::error::Error;
5+
use std::fmt::Display;
6+
use std::fmt::Formatter;
7+
use std::str::FromStr;
8+
9+
/// Defines a trait around how headers are provided for each VSS request.
10+
#[async_trait]
11+
pub trait VssHeaderProvider {
12+
/// Returns the HTTP headers to be used for a VSS request.
13+
/// This method is called on each request, and should likely perform some form of caching.
14+
///
15+
/// A reference to the serialized request body is given as `request`.
16+
/// It can be used to perform operations such as request signing.
17+
async fn get_headers(&self, request: &[u8]) -> Result<HashMap<String, String>, VssHeaderProviderError>;
18+
}
19+
20+
/// Errors around providing headers for each VSS request.
21+
#[derive(Debug)]
22+
pub enum VssHeaderProviderError {
23+
/// Invalid data was encountered.
24+
InvalidData {
25+
/// The error message.
26+
error: String,
27+
},
28+
}
29+
30+
impl Display for VssHeaderProviderError {
31+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32+
match self {
33+
Self::InvalidData { error } => {
34+
write!(f, "invalid data: {}", error)
35+
}
36+
}
37+
}
38+
}
39+
40+
impl Error for VssHeaderProviderError {}
41+
42+
/// A header provider returning an given, fixed set of headers.
43+
pub struct FixedHeaders {
44+
headers: HashMap<String, String>,
45+
}
46+
47+
impl FixedHeaders {
48+
/// Creates a new header provider returning the given, fixed set of headers.
49+
pub fn new(headers: HashMap<String, String>) -> FixedHeaders {
50+
FixedHeaders { headers }
51+
}
52+
}
53+
54+
#[async_trait]
55+
impl VssHeaderProvider for FixedHeaders {
56+
async fn get_headers(&self, _request: &[u8]) -> Result<HashMap<String, String>, VssHeaderProviderError> {
57+
Ok(self.headers.clone())
58+
}
59+
}
60+
61+
pub(crate) fn get_headermap(headers: HashMap<String, String>) -> Result<HeaderMap, VssHeaderProviderError> {
62+
let mut headermap = HeaderMap::new();
63+
for (name, value) in headers {
64+
headermap.insert(
65+
reqwest::header::HeaderName::from_str(&name)
66+
.map_err(|e| VssHeaderProviderError::InvalidData { error: e.to_string() })?,
67+
reqwest::header::HeaderValue::from_str(&value)
68+
.map_err(|e| VssHeaderProviderError::InvalidData { error: e.to_string() })?,
69+
);
70+
}
71+
Ok(headermap)
72+
}

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ pub mod util;
2525

2626
// Encryption-Decryption related crate-only helpers.
2727
pub(crate) mod crypto;
28+
29+
/// A collection of header providers.
30+
pub mod headers;

tests/tests.rs

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
#[cfg(test)]
22
mod tests {
3+
use async_trait::async_trait;
34
use mockito::{self, Matcher};
45
use prost::Message;
56
use reqwest::header::CONTENT_TYPE;
7+
use std::collections::HashMap;
8+
use std::sync::Arc;
69
use std::time::Duration;
710
use vss_client::client::VssClient;
811
use vss_client::error::VssError;
12+
use vss_client::headers::FixedHeaders;
13+
use vss_client::headers::VssHeaderProvider;
14+
use vss_client::headers::VssHeaderProviderError;
915

1016
use vss_client::types::{
1117
DeleteObjectRequest, DeleteObjectResponse, ErrorCode, ErrorResponse, GetObjectRequest, GetObjectResponse,
@@ -41,7 +47,42 @@ mod tests {
4147
.create();
4248

4349
// Create a new VssClient with the mock server URL.
44-
let client = VssClient::new(&base_url, retry_policy());
50+
let client = VssClient::new(base_url, retry_policy());
51+
52+
let actual_result = client.get_object(&get_request).await.unwrap();
53+
54+
let expected_result = &mock_response;
55+
assert_eq!(actual_result, *expected_result);
56+
57+
// Verify server endpoint was called exactly once.
58+
mock_server.expect(1).assert();
59+
}
60+
61+
#[tokio::test]
62+
async fn test_get_with_headers() {
63+
// Spin-up mock server with mock response for given request.
64+
let base_url = mockito::server_url().to_string();
65+
66+
// Set up the mock request/response.
67+
let get_request = GetObjectRequest { store_id: "store".to_string(), key: "k1".to_string() };
68+
let mock_response = GetObjectResponse {
69+
value: Some(KeyValue { key: "k1".to_string(), version: 2, value: b"k1v2".to_vec() }),
70+
..Default::default()
71+
};
72+
73+
// Register the mock endpoint with the mockito server and provide expected headers.
74+
let mock_server = mockito::mock("POST", GET_OBJECT_ENDPOINT)
75+
.match_header(CONTENT_TYPE.as_str(), APPLICATION_OCTET_STREAM)
76+
.match_header("headerkey", "headervalue")
77+
.match_body(get_request.encode_to_vec())
78+
.with_status(200)
79+
.with_body(mock_response.encode_to_vec())
80+
.create();
81+
82+
// Create a new VssClient with the mock server URL and fixed headers.
83+
let header_provider =
84+
Arc::new(FixedHeaders::new(HashMap::from([("headerkey".to_string(), "headervalue".to_string())])));
85+
let client = VssClient::new_with_headers(base_url, retry_policy(), header_provider);
4586

4687
let actual_result = client.get_object(&get_request).await.unwrap();
4788

@@ -75,7 +116,7 @@ mod tests {
75116
.create();
76117

77118
// Create a new VssClient with the mock server URL.
78-
let vss_client = VssClient::new(&base_url, retry_policy());
119+
let vss_client = VssClient::new(base_url, retry_policy());
79120
let actual_result = vss_client.put_object(&request).await.unwrap();
80121

81122
let expected_result = &mock_response;
@@ -106,7 +147,7 @@ mod tests {
106147
.create();
107148

108149
// Create a new VssClient with the mock server URL.
109-
let vss_client = VssClient::new(&base_url, retry_policy());
150+
let vss_client = VssClient::new(base_url, retry_policy());
110151
let actual_result = vss_client.delete_object(&request).await.unwrap();
111152

112153
let expected_result = &mock_response;
@@ -147,7 +188,7 @@ mod tests {
147188
.create();
148189

149190
// Create a new VssClient with the mock server URL.
150-
let client = VssClient::new(&base_url, retry_policy());
191+
let client = VssClient::new(base_url, retry_policy());
151192

152193
let actual_result = client.list_key_versions(&request).await.unwrap();
153194

@@ -161,7 +202,7 @@ mod tests {
161202
#[tokio::test]
162203
async fn test_no_such_key_err_handling() {
163204
let base_url = mockito::server_url();
164-
let vss_client = VssClient::new(&base_url, retry_policy());
205+
let vss_client = VssClient::new(base_url, retry_policy());
165206

166207
// NoSuchKeyError
167208
let error_response = ErrorResponse {
@@ -185,7 +226,7 @@ mod tests {
185226
#[tokio::test]
186227
async fn test_get_response_without_value() {
187228
let base_url = mockito::server_url();
188-
let vss_client = VssClient::new(&base_url, retry_policy());
229+
let vss_client = VssClient::new(base_url, retry_policy());
189230

190231
// GetObjectResponse with None value
191232
let mock_response = GetObjectResponse { value: None, ..Default::default() };
@@ -206,7 +247,7 @@ mod tests {
206247
#[tokio::test]
207248
async fn test_invalid_request_err_handling() {
208249
let base_url = mockito::server_url();
209-
let vss_client = VssClient::new(&base_url, retry_policy());
250+
let vss_client = VssClient::new(base_url, retry_policy());
210251

211252
// Invalid Request Error
212253
let error_response = ErrorResponse {
@@ -258,7 +299,7 @@ mod tests {
258299
#[tokio::test]
259300
async fn test_auth_err_handling() {
260301
let base_url = mockito::server_url();
261-
let vss_client = VssClient::new(&base_url, retry_policy());
302+
let vss_client = VssClient::new(base_url, retry_policy());
262303

263304
// Invalid Request Error
264305
let error_response =
@@ -305,10 +346,29 @@ mod tests {
305346
mock_server.expect(4).assert();
306347
}
307348

349+
struct FailingHeaderProvider {}
350+
351+
#[async_trait]
352+
impl VssHeaderProvider for FailingHeaderProvider {
353+
async fn get_headers(&self, _request: &[u8]) -> Result<HashMap<String, String>, VssHeaderProviderError> {
354+
Err(VssHeaderProviderError::InvalidData { error: "test".to_string() })
355+
}
356+
}
357+
358+
#[tokio::test]
359+
async fn test_header_provider_error() {
360+
let get_request = GetObjectRequest { store_id: "store".to_string(), key: "k1".to_string() };
361+
let header_provider = Arc::new(FailingHeaderProvider {});
362+
let client = VssClient::new_with_headers("notused".to_string(), retry_policy(), header_provider);
363+
let result = client.get_object(&get_request).await;
364+
365+
assert!(matches!(result, Err(VssError::AuthError { .. })));
366+
}
367+
308368
#[tokio::test]
309369
async fn test_conflict_err_handling() {
310370
let base_url = mockito::server_url();
311-
let vss_client = VssClient::new(&base_url, retry_policy());
371+
let vss_client = VssClient::new(base_url, retry_policy());
312372

313373
// Conflict Error
314374
let error_response =
@@ -335,7 +395,7 @@ mod tests {
335395
#[tokio::test]
336396
async fn test_internal_server_err_handling() {
337397
let base_url = mockito::server_url();
338-
let vss_client = VssClient::new(&base_url, retry_policy());
398+
let vss_client = VssClient::new(base_url, retry_policy());
339399

340400
// Internal Server Error
341401
let error_response = ErrorResponse {
@@ -387,7 +447,7 @@ mod tests {
387447
#[tokio::test]
388448
async fn test_internal_err_handling() {
389449
let base_url = mockito::server_url();
390-
let vss_client = VssClient::new(&base_url, retry_policy());
450+
let vss_client = VssClient::new(base_url, retry_policy());
391451

392452
let error_response = ErrorResponse { error_code: 999, message: "UnknownException".to_string() };
393453
let mut _mock_server = mockito::mock("POST", Matcher::Any)

0 commit comments

Comments
 (0)