Skip to content

feat!: add configurable hashers #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions postgresql_archive/src/hasher/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod registry;
pub mod sha2_256;
134 changes: 134 additions & 0 deletions postgresql_archive/src/hasher/registry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use crate::hasher::sha2_256;
use crate::Result;
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};

lazy_static! {
static ref REGISTRY: Arc<Mutex<HasherRegistry>> =
Arc::new(Mutex::new(HasherRegistry::default()));
}

pub type HasherFn = fn(&Vec<u8>) -> Result<String>;

/// Singleton struct to store hashers
struct HasherRegistry {
hashers: HashMap<String, Arc<RwLock<HasherFn>>>,
}

impl HasherRegistry {
/// Creates a new hasher registry.
///
/// # Returns
/// * The hasher registry.
fn new() -> Self {
Self {
hashers: HashMap::new(),
}
}

/// Registers a hasher for an extension. Newly registered hashers with the same extension will
/// override existing ones.
///
/// # Arguments
/// * `extension` - The extension to register the hasher for.
/// * `hasher_fn` - The hasher function to register.
fn register<S: AsRef<str>>(&mut self, extension: S, hasher_fn: HasherFn) {
let extension = extension.as_ref().to_string();
self.hashers
.insert(extension, Arc::new(RwLock::new(hasher_fn)));
}

/// Get a hasher for the specified extension.
///
/// # Arguments
/// * `extension` - The extension to locate a hasher for.
///
/// # Returns
/// * The hasher for the extension or [None] if not found.
fn get<S: AsRef<str>>(&self, extension: S) -> Option<HasherFn> {
let extension = extension.as_ref().to_string();
if let Some(hasher) = self.hashers.get(&extension) {
return Some(*hasher.read().unwrap());
}

None
}
}

impl Default for HasherRegistry {
fn default() -> Self {
let mut registry = Self::new();
registry.register("sha256", sha2_256::hash);
registry
}
}

/// Registers a hasher for an extension. Newly registered hashers with the same extension will
/// override existing ones.
///
/// # Arguments
/// * `extension` - The extension to register the hasher for.
/// * `hasher_fn` - The hasher function to register.
///
/// # Panics
/// * If the registry is poisoned.
#[allow(dead_code)]
pub fn register<S: AsRef<str>>(extension: S, hasher_fn: HasherFn) {
let mut registry = REGISTRY.lock().unwrap();
registry.register(extension, hasher_fn);
}

/// Get a hasher for the specified extension.
///
/// # Arguments
/// * `extension` - The extension to locate a hasher for.
///
/// # Returns
/// * The hasher for the extension or [None] if not found.
///
/// # Panics
/// * If the registry is poisoned.
pub fn get<S: AsRef<str>>(extension: S) -> Option<HasherFn> {
let registry = REGISTRY.lock().unwrap();
registry.get(extension)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_register() -> Result<()> {
let extension = "sha256";
let hashers = REGISTRY.lock().unwrap().hashers.len();
assert!(!REGISTRY.lock().unwrap().hashers.is_empty());
REGISTRY.lock().unwrap().hashers.remove(extension);
assert_ne!(hashers, REGISTRY.lock().unwrap().hashers.len());
register(extension, sha2_256::hash);
assert_eq!(hashers, REGISTRY.lock().unwrap().hashers.len());

let hasher = get(extension).unwrap();
let data = vec![1, 2, 3];
let hash = hasher(&data)?;

assert_eq!(
"039058c6f2c0cb492c533b0a4d14ef77cc0f78abccced5287d84a1a2011cfb81",
hash
);
Ok(())
}

#[test]
fn test_sha2_256() -> Result<()> {
let hasher = get("sha256").unwrap();
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0];
let hash = hasher(&data)?;

assert_eq!(
"9a89c68c4c5e28b8c4a5567673d462fff515db46116f9900624d09c474f593fb",
hash
);
Ok(())
}
}
35 changes: 35 additions & 0 deletions postgresql_archive/src/hasher/sha2_256.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use crate::Result;
use sha2::{Digest, Sha256};

/// Hashes the data using SHA2-256.
///
/// # Arguments
/// * `data` - The data to hash.
///
/// # Returns
/// * The hash of the data.
///
/// # Errors
/// * If the data cannot be hashed.
pub fn hash(data: &Vec<u8>) -> Result<String> {
let mut hasher = Sha256::new();
hasher.update(data);
let hash = hex::encode(hasher.finalize());
Ok(hash)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_hash() -> Result<()> {
let data = vec![4, 2];
let hash = hash(&data)?;
assert_eq!(
"b7586d310e5efb1b7d10a917ba5af403adbf54f4f77fe7fdcb4880a95dac7e7e",
hash
);
Ok(())
}
}
1 change: 1 addition & 0 deletions postgresql_archive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ mod archive;
#[cfg(feature = "blocking")]
pub mod blocking;
mod error;
pub mod hasher;
pub mod matcher;
pub mod repository;
mod version;
Expand Down
6 changes: 3 additions & 3 deletions postgresql_archive/src/matcher/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
mod default;
mod postgresql_binaries;
pub(crate) mod registry;
pub mod default;
pub mod postgresql_binaries;
pub mod registry;
16 changes: 11 additions & 5 deletions postgresql_archive/src/matcher/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ lazy_static! {
Arc::new(Mutex::new(MatchersRegistry::default()));
}

type MatcherFn = fn(&str, &Version) -> Result<bool>;
pub type MatcherFn = fn(&str, &Version) -> Result<bool>;

/// Singleton struct to store matchers
struct MatchersRegistry {
Expand Down Expand Up @@ -75,6 +75,9 @@ impl Default for MatchersRegistry {
/// # Arguments
/// * `url` - The URL to register the matcher for; [None] to register the default.
/// * `matcher_fn` - The matcher function to register.
///
/// # Panics
/// * If the registry is poisoned.
#[allow(dead_code)]
pub fn register<S: AsRef<str>>(url: Option<S>, matcher_fn: MatcherFn) {
let mut registry = REGISTRY.lock().unwrap();
Expand All @@ -89,6 +92,9 @@ pub fn register<S: AsRef<str>>(url: Option<S>, matcher_fn: MatcherFn) {
///
/// # Returns
/// * The matcher for the URL, or the default matcher.
///
/// # Panics
/// * If the registry is poisoned.
pub fn get<S: AsRef<str>>(url: S) -> MatcherFn {
let registry = REGISTRY.lock().unwrap();
registry.get(url)
Expand All @@ -99,8 +105,8 @@ mod tests {
use super::*;
use std::env;

#[tokio::test]
async fn test_register() -> Result<()> {
#[test]
fn test_register() -> Result<()> {
let matchers = REGISTRY.lock().unwrap().matchers.len();
assert!(!REGISTRY.lock().unwrap().matchers.is_empty());
REGISTRY.lock().unwrap().matchers.remove(&None::<String>);
Expand All @@ -117,8 +123,8 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_default_matcher() -> Result<()> {
#[test]
fn test_default_matcher() -> Result<()> {
let matcher = get("https://foo.com");
let version = Version::new(16, 3, 0);
let os = env::consts::OS;
Expand Down
37 changes: 26 additions & 11 deletions postgresql_archive/src/repository/github/repository.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::hasher::registry::HasherFn;
use crate::repository::github::models::{Asset, Release};
use crate::repository::model::Repository;
use crate::repository::Archive;
use crate::Error::{
ArchiveHashMismatch, AssetHashNotFound, AssetNotFound, RepositoryFailure, VersionNotFound,
};
use crate::{matcher, Result};
use crate::{hasher, matcher, Result};
use async_trait::async_trait;
use bytes::Bytes;
use http::{header, Extensions};
Expand All @@ -16,7 +17,6 @@
use reqwest_retry::RetryTransientMiddleware;
use reqwest_tracing::TracingMiddleware;
use semver::{Version, VersionReq};
use sha2::{Digest, Sha256};
use std::env;
use std::str::FromStr;
use tracing::{debug, instrument, warn};
Expand All @@ -26,7 +26,7 @@
const GITHUB_API_VERSION: &str = "2022-11-28";

lazy_static! {
static ref GITHUB_TOKEN: Option<String> = match std::env::var("GITHUB_TOKEN") {
static ref GITHUB_TOKEN: Option<String> = match env::var("GITHUB_TOKEN") {
Ok(token) => {
debug!("GITHUB_TOKEN environment variable found");
Some(token)
Expand Down Expand Up @@ -200,7 +200,11 @@
/// # Errors
/// * If the asset is not found.
#[instrument(level = "debug", skip(version, release))]
fn get_asset(&self, version: &Version, release: &Release) -> Result<(Asset, Option<Asset>)> {
fn get_asset(
&self,
version: &Version,
release: &Release,
) -> Result<(Asset, Option<Asset>, Option<HasherFn>)> {
let matcher = matcher::registry::get(&self.url);
let mut release_asset: Option<Asset> = None;
for asset in &release.assets {
Expand All @@ -214,16 +218,26 @@
return Err(AssetNotFound);
};

// Attempt to find the asset hash for the asset.
let mut asset_hash: Option<Asset> = None;
let hash_name = format!("{}.sha256", asset.name);
let mut asset_hasher_fn: Option<HasherFn> = None;
for release_asset in &release.assets {
if release_asset.name == hash_name {
let release_asset_name = release_asset.name.as_str();
if !release_asset_name.starts_with(&asset.name) {
continue;
}
let extension = release_asset_name
.strip_prefix(format!("{}.", asset.name.as_str()).as_str())
.unwrap_or_default();

if let Some(hasher_fn) = hasher::registry::get(extension) {
asset_hash = Some(release_asset.clone());
asset_hasher_fn = Some(hasher_fn);
break;
}
}

Ok((asset, asset_hash))
Ok((asset, asset_hash, asset_hasher_fn))
}
}

Expand All @@ -246,7 +260,7 @@
async fn get_archive(&self, version_req: &VersionReq) -> Result<Archive> {
let release = self.get_release(version_req).await?;
let version = Self::get_version_from_tag_name(release.tag_name.as_str())?;
let (asset, asset_hash) = self.get_asset(&version, &release)?;
let (asset, asset_hash, asset_hasher_fn) = self.get_asset(&version, &release)?;
let name = asset.name.clone();

let client = reqwest_client();
Expand Down Expand Up @@ -280,9 +294,10 @@
human_bytes(text.len() as f64)
);

let mut hasher = Sha256::new();
hasher.update(&archive);
let archive_hash = hex::encode(hasher.finalize());
let archive_hash = match asset_hasher_fn {
Some(hasher_fn) => hasher_fn(&bytes)?,
None => String::new(),

Check warning on line 299 in postgresql_archive/src/repository/github/repository.rs

View check run for this annotation

Codecov / codecov/patch

postgresql_archive/src/repository/github/repository.rs#L299

Added line #L299 was not covered by tests
};

if archive_hash != hash {
return Err(ArchiveHashMismatch { archive_hash, hash });
Expand Down
Loading