Skip to content

Commit 54c4118

Browse files
Merge pull request #86 from theseus-rs/add-configurable-hashers
feat!: add configurable hashers
2 parents 9630a68 + a2c891d commit 54c4118

File tree

7 files changed

+212
-19
lines changed

7 files changed

+212
-19
lines changed

postgresql_archive/src/hasher/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub mod registry;
2+
pub mod sha2_256;
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
use crate::hasher::sha2_256;
2+
use crate::Result;
3+
use lazy_static::lazy_static;
4+
use std::collections::HashMap;
5+
use std::sync::{Arc, Mutex, RwLock};
6+
7+
lazy_static! {
8+
static ref REGISTRY: Arc<Mutex<HasherRegistry>> =
9+
Arc::new(Mutex::new(HasherRegistry::default()));
10+
}
11+
12+
pub type HasherFn = fn(&Vec<u8>) -> Result<String>;
13+
14+
/// Singleton struct to store hashers
15+
struct HasherRegistry {
16+
hashers: HashMap<String, Arc<RwLock<HasherFn>>>,
17+
}
18+
19+
impl HasherRegistry {
20+
/// Creates a new hasher registry.
21+
///
22+
/// # Returns
23+
/// * The hasher registry.
24+
fn new() -> Self {
25+
Self {
26+
hashers: HashMap::new(),
27+
}
28+
}
29+
30+
/// Registers a hasher for an extension. Newly registered hashers with the same extension will
31+
/// override existing ones.
32+
///
33+
/// # Arguments
34+
/// * `extension` - The extension to register the hasher for.
35+
/// * `hasher_fn` - The hasher function to register.
36+
fn register<S: AsRef<str>>(&mut self, extension: S, hasher_fn: HasherFn) {
37+
let extension = extension.as_ref().to_string();
38+
self.hashers
39+
.insert(extension, Arc::new(RwLock::new(hasher_fn)));
40+
}
41+
42+
/// Get a hasher for the specified extension.
43+
///
44+
/// # Arguments
45+
/// * `extension` - The extension to locate a hasher for.
46+
///
47+
/// # Returns
48+
/// * The hasher for the extension or [None] if not found.
49+
fn get<S: AsRef<str>>(&self, extension: S) -> Option<HasherFn> {
50+
let extension = extension.as_ref().to_string();
51+
if let Some(hasher) = self.hashers.get(&extension) {
52+
return Some(*hasher.read().unwrap());
53+
}
54+
55+
None
56+
}
57+
}
58+
59+
impl Default for HasherRegistry {
60+
fn default() -> Self {
61+
let mut registry = Self::new();
62+
registry.register("sha256", sha2_256::hash);
63+
registry
64+
}
65+
}
66+
67+
/// Registers a hasher for an extension. Newly registered hashers with the same extension will
68+
/// override existing ones.
69+
///
70+
/// # Arguments
71+
/// * `extension` - The extension to register the hasher for.
72+
/// * `hasher_fn` - The hasher function to register.
73+
///
74+
/// # Panics
75+
/// * If the registry is poisoned.
76+
#[allow(dead_code)]
77+
pub fn register<S: AsRef<str>>(extension: S, hasher_fn: HasherFn) {
78+
let mut registry = REGISTRY.lock().unwrap();
79+
registry.register(extension, hasher_fn);
80+
}
81+
82+
/// Get a hasher for the specified extension.
83+
///
84+
/// # Arguments
85+
/// * `extension` - The extension to locate a hasher for.
86+
///
87+
/// # Returns
88+
/// * The hasher for the extension or [None] if not found.
89+
///
90+
/// # Panics
91+
/// * If the registry is poisoned.
92+
pub fn get<S: AsRef<str>>(extension: S) -> Option<HasherFn> {
93+
let registry = REGISTRY.lock().unwrap();
94+
registry.get(extension)
95+
}
96+
97+
#[cfg(test)]
98+
mod tests {
99+
use super::*;
100+
101+
#[test]
102+
fn test_register() -> Result<()> {
103+
let extension = "sha256";
104+
let hashers = REGISTRY.lock().unwrap().hashers.len();
105+
assert!(!REGISTRY.lock().unwrap().hashers.is_empty());
106+
REGISTRY.lock().unwrap().hashers.remove(extension);
107+
assert_ne!(hashers, REGISTRY.lock().unwrap().hashers.len());
108+
register(extension, sha2_256::hash);
109+
assert_eq!(hashers, REGISTRY.lock().unwrap().hashers.len());
110+
111+
let hasher = get(extension).unwrap();
112+
let data = vec![1, 2, 3];
113+
let hash = hasher(&data)?;
114+
115+
assert_eq!(
116+
"039058c6f2c0cb492c533b0a4d14ef77cc0f78abccced5287d84a1a2011cfb81",
117+
hash
118+
);
119+
Ok(())
120+
}
121+
122+
#[test]
123+
fn test_sha2_256() -> Result<()> {
124+
let hasher = get("sha256").unwrap();
125+
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0];
126+
let hash = hasher(&data)?;
127+
128+
assert_eq!(
129+
"9a89c68c4c5e28b8c4a5567673d462fff515db46116f9900624d09c474f593fb",
130+
hash
131+
);
132+
Ok(())
133+
}
134+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
use crate::Result;
2+
use sha2::{Digest, Sha256};
3+
4+
/// Hashes the data using SHA2-256.
5+
///
6+
/// # Arguments
7+
/// * `data` - The data to hash.
8+
///
9+
/// # Returns
10+
/// * The hash of the data.
11+
///
12+
/// # Errors
13+
/// * If the data cannot be hashed.
14+
pub fn hash(data: &Vec<u8>) -> Result<String> {
15+
let mut hasher = Sha256::new();
16+
hasher.update(data);
17+
let hash = hex::encode(hasher.finalize());
18+
Ok(hash)
19+
}
20+
21+
#[cfg(test)]
22+
mod tests {
23+
use super::*;
24+
25+
#[test]
26+
fn test_hash() -> Result<()> {
27+
let data = vec![4, 2];
28+
let hash = hash(&data)?;
29+
assert_eq!(
30+
"b7586d310e5efb1b7d10a917ba5af403adbf54f4f77fe7fdcb4880a95dac7e7e",
31+
hash
32+
);
33+
Ok(())
34+
}
35+
}

postgresql_archive/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ mod archive;
113113
#[cfg(feature = "blocking")]
114114
pub mod blocking;
115115
mod error;
116+
pub mod hasher;
116117
pub mod matcher;
117118
pub mod repository;
118119
mod version;

postgresql_archive/src/matcher/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
mod default;
2-
mod postgresql_binaries;
3-
pub(crate) mod registry;
1+
pub mod default;
2+
pub mod postgresql_binaries;
3+
pub mod registry;

postgresql_archive/src/matcher/registry.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ lazy_static! {
1010
Arc::new(Mutex::new(MatchersRegistry::default()));
1111
}
1212

13-
type MatcherFn = fn(&str, &Version) -> Result<bool>;
13+
pub type MatcherFn = fn(&str, &Version) -> Result<bool>;
1414

1515
/// Singleton struct to store matchers
1616
struct MatchersRegistry {
@@ -75,6 +75,9 @@ impl Default for MatchersRegistry {
7575
/// # Arguments
7676
/// * `url` - The URL to register the matcher for; [None] to register the default.
7777
/// * `matcher_fn` - The matcher function to register.
78+
///
79+
/// # Panics
80+
/// * If the registry is poisoned.
7881
#[allow(dead_code)]
7982
pub fn register<S: AsRef<str>>(url: Option<S>, matcher_fn: MatcherFn) {
8083
let mut registry = REGISTRY.lock().unwrap();
@@ -89,6 +92,9 @@ pub fn register<S: AsRef<str>>(url: Option<S>, matcher_fn: MatcherFn) {
8992
///
9093
/// # Returns
9194
/// * The matcher for the URL, or the default matcher.
95+
///
96+
/// # Panics
97+
/// * If the registry is poisoned.
9298
pub fn get<S: AsRef<str>>(url: S) -> MatcherFn {
9399
let registry = REGISTRY.lock().unwrap();
94100
registry.get(url)
@@ -99,8 +105,8 @@ mod tests {
99105
use super::*;
100106
use std::env;
101107

102-
#[tokio::test]
103-
async fn test_register() -> Result<()> {
108+
#[test]
109+
fn test_register() -> Result<()> {
104110
let matchers = REGISTRY.lock().unwrap().matchers.len();
105111
assert!(!REGISTRY.lock().unwrap().matchers.is_empty());
106112
REGISTRY.lock().unwrap().matchers.remove(&None::<String>);
@@ -117,8 +123,8 @@ mod tests {
117123
Ok(())
118124
}
119125

120-
#[tokio::test]
121-
async fn test_default_matcher() -> Result<()> {
126+
#[test]
127+
fn test_default_matcher() -> Result<()> {
122128
let matcher = get("https://foo.com");
123129
let version = Version::new(16, 3, 0);
124130
let os = env::consts::OS;

postgresql_archive/src/repository/github/repository.rs

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
use crate::hasher::registry::HasherFn;
12
use crate::repository::github::models::{Asset, Release};
23
use crate::repository::model::Repository;
34
use crate::repository::Archive;
45
use crate::Error::{
56
ArchiveHashMismatch, AssetHashNotFound, AssetNotFound, RepositoryFailure, VersionNotFound,
67
};
7-
use crate::{matcher, Result};
8+
use crate::{hasher, matcher, Result};
89
use async_trait::async_trait;
910
use bytes::Bytes;
1011
use http::{header, Extensions};
@@ -16,7 +17,6 @@ use reqwest_retry::policies::ExponentialBackoff;
1617
use reqwest_retry::RetryTransientMiddleware;
1718
use reqwest_tracing::TracingMiddleware;
1819
use semver::{Version, VersionReq};
19-
use sha2::{Digest, Sha256};
2020
use std::env;
2121
use std::str::FromStr;
2222
use tracing::{debug, instrument, warn};
@@ -26,7 +26,7 @@ const GITHUB_API_VERSION_HEADER: &str = "X-GitHub-Api-Version";
2626
const GITHUB_API_VERSION: &str = "2022-11-28";
2727

2828
lazy_static! {
29-
static ref GITHUB_TOKEN: Option<String> = match std::env::var("GITHUB_TOKEN") {
29+
static ref GITHUB_TOKEN: Option<String> = match env::var("GITHUB_TOKEN") {
3030
Ok(token) => {
3131
debug!("GITHUB_TOKEN environment variable found");
3232
Some(token)
@@ -200,7 +200,11 @@ impl GitHub {
200200
/// # Errors
201201
/// * If the asset is not found.
202202
#[instrument(level = "debug", skip(version, release))]
203-
fn get_asset(&self, version: &Version, release: &Release) -> Result<(Asset, Option<Asset>)> {
203+
fn get_asset(
204+
&self,
205+
version: &Version,
206+
release: &Release,
207+
) -> Result<(Asset, Option<Asset>, Option<HasherFn>)> {
204208
let matcher = matcher::registry::get(&self.url);
205209
let mut release_asset: Option<Asset> = None;
206210
for asset in &release.assets {
@@ -214,16 +218,26 @@ impl GitHub {
214218
return Err(AssetNotFound);
215219
};
216220

221+
// Attempt to find the asset hash for the asset.
217222
let mut asset_hash: Option<Asset> = None;
218-
let hash_name = format!("{}.sha256", asset.name);
223+
let mut asset_hasher_fn: Option<HasherFn> = None;
219224
for release_asset in &release.assets {
220-
if release_asset.name == hash_name {
225+
let release_asset_name = release_asset.name.as_str();
226+
if !release_asset_name.starts_with(&asset.name) {
227+
continue;
228+
}
229+
let extension = release_asset_name
230+
.strip_prefix(format!("{}.", asset.name.as_str()).as_str())
231+
.unwrap_or_default();
232+
233+
if let Some(hasher_fn) = hasher::registry::get(extension) {
221234
asset_hash = Some(release_asset.clone());
235+
asset_hasher_fn = Some(hasher_fn);
222236
break;
223237
}
224238
}
225239

226-
Ok((asset, asset_hash))
240+
Ok((asset, asset_hash, asset_hasher_fn))
227241
}
228242
}
229243

@@ -246,7 +260,7 @@ impl Repository for GitHub {
246260
async fn get_archive(&self, version_req: &VersionReq) -> Result<Archive> {
247261
let release = self.get_release(version_req).await?;
248262
let version = Self::get_version_from_tag_name(release.tag_name.as_str())?;
249-
let (asset, asset_hash) = self.get_asset(&version, &release)?;
263+
let (asset, asset_hash, asset_hasher_fn) = self.get_asset(&version, &release)?;
250264
let name = asset.name.clone();
251265

252266
let client = reqwest_client();
@@ -280,9 +294,10 @@ impl Repository for GitHub {
280294
human_bytes(text.len() as f64)
281295
);
282296

283-
let mut hasher = Sha256::new();
284-
hasher.update(&archive);
285-
let archive_hash = hex::encode(hasher.finalize());
297+
let archive_hash = match asset_hasher_fn {
298+
Some(hasher_fn) => hasher_fn(&bytes)?,
299+
None => String::new(),
300+
};
286301

287302
if archive_hash != hash {
288303
return Err(ArchiveHashMismatch { archive_hash, hash });

0 commit comments

Comments
 (0)