Skip to content

Commit 865e15a

Browse files
committed
Comments
1 parent 35a9bc9 commit 865e15a

File tree

6 files changed

+207
-159
lines changed

6 files changed

+207
-159
lines changed

src/client.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
use prost::Message;
2-
use reqwest;
3-
use reqwest::header::HeaderMap;
42
use reqwest::header::CONTENT_TYPE;
53
use reqwest::Client;
64
use std::default::Default;
75
use std::sync::Arc;
86

97
use crate::error::VssError;
8+
use crate::headers::get_headermap;
109
use crate::headers::FixedHeaders;
11-
use crate::headers::HeaderProvider;
10+
use crate::headers::VssHeaderProvider;
1211
use crate::types::{
1312
DeleteObjectRequest, DeleteObjectResponse, GetObjectRequest, GetObjectResponse, ListKeyVersionsRequest,
1413
ListKeyVersionsResponse, PutObjectRequest, PutObjectResponse,
@@ -27,7 +26,7 @@ where
2726
base_url: String,
2827
client: Client,
2928
retry_policy: R,
30-
header_provider: Arc<dyn HeaderProvider>,
29+
header_provider: Arc<dyn VssHeaderProvider>,
3130
}
3231

3332
impl<R: RetryPolicy<E = VssError>> VssClient<R> {
@@ -43,13 +42,13 @@ impl<R: RetryPolicy<E = VssError>> VssClient<R> {
4342
base_url: String::from(base_url),
4443
client,
4544
retry_policy,
46-
header_provider: Arc::new(FixedHeaders::new(HeaderMap::new())),
45+
header_provider: Arc::new(FixedHeaders::new(Vec::new())),
4746
}
4847
}
4948

5049
/// Constructs a [`VssClient`] using `base_url` as the VSS server endpoint.
5150
/// HTTP headers will be provided by the given `header_provider`.
52-
pub fn new_with_headers(base_url: &str, retry_policy: R, header_provider: Arc<dyn HeaderProvider>) -> Self {
51+
pub fn new_with_headers(base_url: &str, retry_policy: R, header_provider: Arc<dyn VssHeaderProvider>) -> Self {
5352
let client = Client::new();
5453
Self { base_url: String::from(base_url), client, retry_policy, header_provider }
5554
}
@@ -133,11 +132,12 @@ impl<R: RetryPolicy<E = VssError>> VssClient<R> {
133132
.get_headers()
134133
.await
135134
.map_err(|e| VssError::InternalError(e.to_string()))?;
135+
let headermap = get_headermap(&headers).map_err(VssError::InternalError)?;
136136
let response_raw = self
137137
.client
138138
.post(url)
139139
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
140-
.headers(headers)
140+
.headers(headermap)
141141
.body(request_body)
142142
.send()
143143
.await?;

src/headers/lnurl_auth_jwt.rs

Lines changed: 96 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
use crate::headers::HeaderProvider;
2-
use crate::headers::HeaderProviderError;
1+
use crate::headers::get_headermap;
2+
use crate::headers::VssHeaderProvider;
3+
use crate::headers::VssHeaderProviderError;
4+
use crate::util::string::UntrustedString;
35
use async_trait::async_trait;
46
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
57
use base64::Engine;
@@ -10,11 +12,9 @@ use bitcoin::hashes::{Hash, HashEngine, Hmac, HmacEngine};
1012
use bitcoin::secp256k1::{All, Message, Secp256k1};
1113
use bitcoin::Network;
1214
use bitcoin::PrivateKey;
13-
use reqwest::header::HeaderMap;
14-
use reqwest::header::AUTHORIZATION;
1515
use serde::Deserialize;
16-
use std::str::FromStr;
17-
use std::sync::Mutex;
16+
use std::ops::Deref;
17+
use std::sync::RwLock;
1818
use std::time::SystemTime;
1919
use url::Url;
2020

@@ -30,17 +30,24 @@ const K1_QUERY_PARAM: &str = "k1";
3030
const SIG_QUERY_PARAM: &str = "sig";
3131
// The key of the LNURL key query parameter.
3232
const KEY_QUERY_PARAM: &str = "key";
33+
// The authorization header name.
34+
const AUTHORIZATION: &str = "authorization";
35+
36+
#[derive(Debug, Clone)]
37+
struct JwtToken {
38+
token_str: String,
39+
expiry: Option<u64>,
40+
}
3341

3442
/// Provides a JWT token based on LNURL Auth.
3543
/// The LNURL and JWT token are exchanged over a Websocket connection.
3644
pub struct LnurlAuthJwt {
3745
engine: Secp256k1<All>,
3846
parent_key: ExtendedPrivKey,
3947
url: String,
40-
headers: HeaderMap,
48+
default_headers: Vec<(String, String)>,
4149
client: reqwest::Client,
42-
jwt_token: Mutex<Option<String>>,
43-
expiry: Mutex<Option<u64>>,
50+
jwt_token: RwLock<Option<JwtToken>>,
4451
}
4552

4653
impl LnurlAuthJwt {
@@ -51,48 +58,38 @@ impl LnurlAuthJwt {
5158
/// The JWT token will be returned in response to the signed LNURL request under a token field.
5259
/// The given set of headers will be used for LNURL requests, and will also be returned together
5360
/// with the JWT authorization header for VSS requests.
54-
pub fn new(seed: &[u8], url: String, headers: Vec<(String, String)>) -> Result<LnurlAuthJwt, HeaderProviderError> {
61+
pub fn new(
62+
seed: &[u8], url: String, default_headers: Vec<(String, String)>,
63+
) -> Result<LnurlAuthJwt, VssHeaderProviderError> {
5564
let engine = Secp256k1::new();
56-
let master = ExtendedPrivKey::new_master(Network::Testnet, seed).map_err(HeaderProviderError::from)?;
65+
let master = ExtendedPrivKey::new_master(Network::Testnet, seed).map_err(VssHeaderProviderError::from)?;
5766
let child_number =
58-
ChildNumber::from_hardened_idx(PARENT_DERIVATION_INDEX).map_err(HeaderProviderError::from)?;
67+
ChildNumber::from_hardened_idx(PARENT_DERIVATION_INDEX).map_err(VssHeaderProviderError::from)?;
5968
let parent_key = master
6069
.derive_priv(&engine, &vec![child_number])
61-
.map_err(HeaderProviderError::from)?;
62-
let mut headermap = HeaderMap::new();
63-
for (name, value) in headers {
64-
headermap.insert(
65-
reqwest::header::HeaderName::from_str(&name).map_err(HeaderProviderError::from)?,
66-
reqwest::header::HeaderValue::from_str(&value).map_err(HeaderProviderError::from)?,
67-
);
68-
}
70+
.map_err(VssHeaderProviderError::from)?;
71+
let default_headermap =
72+
get_headermap(&default_headers).map_err(|error| VssHeaderProviderError::InvalidData { error })?;
6973
let client = reqwest::Client::builder()
70-
.default_headers(headermap.clone())
74+
.default_headers(default_headermap)
7175
.build()
72-
.map_err(HeaderProviderError::from)?;
76+
.map_err(VssHeaderProviderError::from)?;
7377

74-
Ok(LnurlAuthJwt {
75-
engine,
76-
parent_key,
77-
url,
78-
headers: headermap,
79-
client,
80-
jwt_token: Mutex::new(None),
81-
expiry: Mutex::new(None),
82-
})
78+
Ok(LnurlAuthJwt { engine, parent_key, url, default_headers, client, jwt_token: RwLock::new(None) })
8379
}
8480

85-
async fn fetch_jwt_token(&self) -> Result<String, HeaderProviderError> {
81+
async fn fetch_jwt_token(&self) -> Result<JwtToken, VssHeaderProviderError> {
8682
// Fetch the LNURL.
87-
let lnurl_str = self
88-
.client
89-
.get(&self.url)
90-
.send()
91-
.await
92-
.map_err(HeaderProviderError::from)?
93-
.text()
94-
.await
95-
.map_err(HeaderProviderError::from)?;
83+
let lnurl_str = UntrustedString::new(
84+
self.client
85+
.get(&self.url)
86+
.send()
87+
.await
88+
.map_err(VssHeaderProviderError::from)?
89+
.text()
90+
.await
91+
.map_err(VssHeaderProviderError::from)?,
92+
);
9693

9794
// Sign the LNURL and perform the request.
9895
let signed_lnurl = sign_lnurl(&self.engine, &self.parent_key, &lnurl_str)?;
@@ -101,40 +98,45 @@ impl LnurlAuthJwt {
10198
.get(&signed_lnurl)
10299
.send()
103100
.await
104-
.map_err(HeaderProviderError::from)?
101+
.map_err(VssHeaderProviderError::from)?
105102
.json()
106103
.await
107-
.map_err(HeaderProviderError::from)?;
104+
.map_err(VssHeaderProviderError::from)?;
108105

109-
match lnurl_auth_response {
110-
LnurlAuthResponse { token: Some(token), .. } => Ok(token),
106+
let untrusted_token = match lnurl_auth_response {
107+
LnurlAuthResponse { token: Some(token), .. } => token,
111108
LnurlAuthResponse { reason: Some(reason), .. } => {
112-
Err(HeaderProviderError::ApplicationError(format!("LNURL Auth failed, reason is: {}", reason)))
109+
return Err(VssHeaderProviderError::ApplicationError {
110+
error: format!("LNURL Auth failed, reason is: {}", reason),
111+
});
113112
}
114-
_ => Err(HeaderProviderError::InvalidData(
115-
"LNURL Auth response did not contain a token nor an error".to_string(),
116-
)),
117-
}
113+
_ => {
114+
return Err(VssHeaderProviderError::InvalidData {
115+
error: "LNURL Auth response did not contain a token nor an error".to_string(),
116+
});
117+
}
118+
};
119+
parse_jwt_token(untrusted_token)
118120
}
119121

120-
async fn get_jwt_token(&self, force_refresh: bool) -> Result<String, HeaderProviderError> {
122+
async fn get_jwt_token(&self, force_refresh: bool) -> Result<String, VssHeaderProviderError> {
121123
if !self.is_expired() && !force_refresh {
122-
let jwt_token = self.jwt_token.lock().unwrap();
123-
if let Some(jwt_token) = jwt_token.as_deref() {
124-
return Ok(jwt_token.to_string());
124+
let jwt_token = self.jwt_token.read().unwrap();
125+
if let Some(jwt_token) = jwt_token.deref() {
126+
return Ok(jwt_token.token_str.clone());
125127
}
126128
}
127129
let jwt_token = self.fetch_jwt_token().await?;
128-
let expiry = parse_expiry(&jwt_token)?;
129-
*self.jwt_token.lock().unwrap() = Some(jwt_token.clone());
130-
*self.expiry.lock().unwrap() = expiry;
131-
Ok(jwt_token)
130+
*self.jwt_token.write().unwrap() = Some(jwt_token.clone());
131+
Ok(jwt_token.token_str)
132132
}
133133

134134
fn is_expired(&self) -> bool {
135-
self.expiry
136-
.lock()
135+
self.jwt_token
136+
.read()
137137
.unwrap()
138+
.as_ref()
139+
.and_then(|token| token.expiry)
138140
.map(|expiry| {
139141
SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + EXPIRY_BUFFER_SECS
140142
> expiry
@@ -144,41 +146,39 @@ impl LnurlAuthJwt {
144146
}
145147

146148
#[async_trait]
147-
impl HeaderProvider for LnurlAuthJwt {
148-
async fn get_headers(&self) -> Result<HeaderMap, HeaderProviderError> {
149+
impl VssHeaderProvider for LnurlAuthJwt {
150+
async fn get_headers(&self) -> Result<Vec<(String, String)>, VssHeaderProviderError> {
149151
let jwt_token = self.get_jwt_token(false).await?;
150-
let mut headers = self.headers.clone();
151-
let value = format!("Bearer {}", jwt_token).parse().map_err(HeaderProviderError::from)?;
152-
headers.insert(AUTHORIZATION, value);
152+
let mut headers = self.default_headers.clone();
153+
headers.push((AUTHORIZATION.to_string(), format!("Bearer {}", jwt_token)));
153154
Ok(headers)
154155
}
155156
}
156157

157-
fn hashing_key(engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey) -> Result<PrivateKey, HeaderProviderError> {
158+
fn hashing_key(engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey) -> Result<PrivateKey, VssHeaderProviderError> {
158159
let hashing_child_number =
159-
ChildNumber::from_normal_idx(HASHING_DERIVATION_INDEX).map_err(HeaderProviderError::from)?;
160+
ChildNumber::from_normal_idx(HASHING_DERIVATION_INDEX).map_err(VssHeaderProviderError::from)?;
160161
parent_key
161162
.derive_priv(engine, &vec![hashing_child_number])
162163
.map(|xpriv| xpriv.to_priv())
163-
.map_err(HeaderProviderError::from)
164+
.map_err(VssHeaderProviderError::from)
164165
}
165166

166-
fn linking_key_path(hashing_key: &PrivateKey, domain_name: &str) -> Result<DerivationPath, HeaderProviderError> {
167+
fn linking_key_path(hashing_key: &PrivateKey, domain_name: &str) -> Result<DerivationPath, VssHeaderProviderError> {
167168
let mut engine = HmacEngine::<sha256::Hash>::new(&hashing_key.inner[..]);
168169
engine.input(domain_name.as_bytes());
169170
let result = Hmac::<sha256::Hash>::from_engine(engine).to_byte_array();
170-
let children: Vec<ChildNumber> = (0..4)
171+
let children = (0..4)
171172
.map(|i| u32::from_be_bytes(result[(i * 4)..((i + 1) * 4)].try_into().unwrap()))
172-
.map(ChildNumber::from)
173-
.collect::<Vec<_>>();
174-
Ok(DerivationPath::from(children))
173+
.map(ChildNumber::from);
174+
Ok(DerivationPath::from_iter(children))
175175
}
176176

177177
fn sign_lnurl(
178-
engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey, lnurl_str: &str,
179-
) -> Result<String, HeaderProviderError> {
178+
engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey, lnurl_str: &UntrustedString,
179+
) -> Result<String, VssHeaderProviderError> {
180180
// Parse k1 parameter to sign.
181-
let invalid_lnurl = || HeaderProviderError::InvalidData(format!("invalid lnurl: {}", lnurl_str));
181+
let invalid_lnurl = || VssHeaderProviderError::InvalidData { error: format!("invalid lnurl: {}", lnurl_str) };
182182
let mut lnurl = Url::parse(lnurl_str).map_err(|_| invalid_lnurl())?;
183183
let domain = lnurl.domain().ok_or(invalid_lnurl())?;
184184
let k1_str = lnurl
@@ -194,11 +194,11 @@ fn sign_lnurl(
194194
let linking_key_path = linking_key_path(&hashing_key, domain)?;
195195
let private_key = parent_key
196196
.derive_priv(engine, &linking_key_path)
197-
.map_err(HeaderProviderError::from)?
197+
.map_err(VssHeaderProviderError::from)?
198198
.to_priv();
199199
let public_key = private_key.public_key(engine);
200-
let message =
201-
Message::from_slice(&k1).map_err(|_| HeaderProviderError::InvalidData(format!("invalid k1: {:?}", k1)))?;
200+
let message = Message::from_slice(&k1)
201+
.map_err(|_| VssHeaderProviderError::InvalidData { error: format!("invalid k1: {:?}", k1) })?;
202202
let sig = engine.sign_ecdsa(&message, &private_key.inner);
203203

204204
// Compose LNURL with signature and linking key.
@@ -209,55 +209,46 @@ fn sign_lnurl(
209209
Ok(lnurl.to_string())
210210
}
211211

212-
#[derive(Deserialize)]
212+
#[derive(Deserialize, Debug, Clone)]
213213
struct LnurlAuthResponse {
214-
reason: Option<String>,
215-
token: Option<String>,
214+
reason: Option<UntrustedString>,
215+
token: Option<UntrustedString>,
216216
}
217217

218-
#[derive(Deserialize)]
218+
#[derive(Deserialize, Debug, Clone)]
219219
struct ExpiryClaim {
220220
exp: Option<u64>,
221221
}
222222

223-
fn parse_expiry(jwt_token: &str) -> Result<Option<u64>, HeaderProviderError> {
223+
fn parse_jwt_token(jwt_token: UntrustedString) -> Result<JwtToken, VssHeaderProviderError> {
224224
let parts: Vec<&str> = jwt_token.split('.').collect();
225-
let invalid = || HeaderProviderError::InvalidData(format!("invalid JWT token: {}", jwt_token));
225+
let invalid = || VssHeaderProviderError::InvalidData { error: format!("invalid JWT token: {}", jwt_token) };
226226
if parts.len() != 3 {
227227
return Err(invalid());
228228
}
229+
let _ = URL_SAFE_NO_PAD.decode(parts[0]).map_err(|_| invalid())?;
229230
let bytes = URL_SAFE_NO_PAD.decode(parts[1]).map_err(|_| invalid())?;
231+
let _ = URL_SAFE_NO_PAD.decode(parts[2]).map_err(|_| invalid())?;
230232
let claim: ExpiryClaim = serde_json::from_slice(&bytes).map_err(|_| invalid())?;
231-
Ok(claim.exp)
232-
}
233-
234-
impl From<bitcoin::bip32::Error> for HeaderProviderError {
235-
fn from(e: bitcoin::bip32::Error) -> HeaderProviderError {
236-
HeaderProviderError::InvalidData(e.to_string())
237-
}
238-
}
239-
240-
impl From<reqwest::header::InvalidHeaderName> for HeaderProviderError {
241-
fn from(e: reqwest::header::InvalidHeaderName) -> HeaderProviderError {
242-
HeaderProviderError::InvalidData(e.to_string())
243-
}
233+
Ok(JwtToken { token_str: jwt_token.into_inner(), expiry: claim.exp })
244234
}
245235

246-
impl From<reqwest::header::InvalidHeaderValue> for HeaderProviderError {
247-
fn from(e: reqwest::header::InvalidHeaderValue) -> HeaderProviderError {
248-
HeaderProviderError::InvalidData(e.to_string())
236+
impl From<bitcoin::bip32::Error> for VssHeaderProviderError {
237+
fn from(e: bitcoin::bip32::Error) -> VssHeaderProviderError {
238+
VssHeaderProviderError::InvalidData { error: e.to_string() }
249239
}
250240
}
251241

252-
impl From<reqwest::Error> for HeaderProviderError {
253-
fn from(e: reqwest::Error) -> HeaderProviderError {
254-
HeaderProviderError::RequestError(e.to_string())
242+
impl From<reqwest::Error> for VssHeaderProviderError {
243+
fn from(e: reqwest::Error) -> VssHeaderProviderError {
244+
VssHeaderProviderError::RequestError { error: e.to_string() }
255245
}
256246
}
257247

258248
#[cfg(test)]
259249
mod test {
260250
use crate::headers::lnurl_auth_jwt::{linking_key_path, sign_lnurl};
251+
use crate::util::string::UntrustedString;
261252
use bitcoin::bip32::ExtendedPrivKey;
262253
use bitcoin::hashes::hex::FromHex;
263254
use bitcoin::secp256k1::Secp256k1;
@@ -288,7 +279,7 @@ mod test {
288279
let signed = sign_lnurl(
289280
&engine,
290281
&master,
291-
"https://example.com/path?tag=login&k1=e2af6254a8df433264fa23f67eb8188635d15ce883e8fc020989d5f82ae6f11e",
282+
&UntrustedString::new("https://example.com/path?tag=login&k1=e2af6254a8df433264fa23f67eb8188635d15ce883e8fc020989d5f82ae6f11e".to_string()),
292283
)
293284
.unwrap();
294285
assert_eq!(

0 commit comments

Comments
 (0)