Skip to content

feat!: add configurable extractors #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/archive_async/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use postgresql_archive::{

#[tokio::main]
async fn main() -> Result<()> {
let url = THESEUS_POSTGRESQL_BINARIES_URL;
let version_req = VersionReq::STAR;
let (archive_version, archive) =
get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &version_req).await?;
let (archive_version, archive) = get_archive(url, &version_req).await?;
let out_dir = tempfile::tempdir()?.into_path();
extract(&archive, &out_dir).await?;
extract(url, &archive, &out_dir).await?;
println!(
"PostgreSQL {} extracted to {}",
archive_version,
Expand Down
5 changes: 3 additions & 2 deletions examples/archive_sync/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use postgresql_archive::blocking::{extract, get_archive};
use postgresql_archive::{Result, VersionReq, THESEUS_POSTGRESQL_BINARIES_URL};

fn main() -> Result<()> {
let url = THESEUS_POSTGRESQL_BINARIES_URL;
let version_req = VersionReq::STAR;
let (archive_version, archive) = get_archive(THESEUS_POSTGRESQL_BINARIES_URL, &version_req)?;
let (archive_version, archive) = get_archive(url, &version_req)?;
let out_dir = tempfile::tempdir()?.into_path();
extract(&archive, &out_dir)?;
extract(url, &archive, &out_dir)?;
println!(
"PostgreSQL {} extracted to {}",
archive_version,
Expand Down
2 changes: 1 addition & 1 deletion postgresql_archive/benches/archive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn bench_extract(criterion: &mut Criterion) -> Result<()> {
fn extract_archive(archive: &Vec<u8>) -> Result<()> {
let out_dir = tempfile::tempdir()?.path().to_path_buf();
create_dir_all(&out_dir)?;
extract(archive, &out_dir)?;
extract(THESEUS_POSTGRESQL_BINARIES_URL, archive, &out_dir)?;
remove_dir_all(&out_dir)?;
Ok(())
}
Expand Down
170 changes: 6 additions & 164 deletions postgresql_archive/src/archive.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
//! Manage PostgreSQL archives
#![allow(dead_code)]

use crate::error::Error::Unexpected;
use crate::error::Result;
use crate::repository;
use flate2::bufread::GzDecoder;
use human_bytes::human_bytes;
use num_format::{Locale, ToFormattedString};
use crate::{extractor, repository};
use semver::{Version, VersionReq};
use std::fs::{create_dir_all, remove_dir_all, remove_file, rename, File};
use std::io::{copy, BufReader, Cursor};
use std::path::{Path, PathBuf};
use std::thread::sleep;
use std::time::Duration;
use tar::Archive;
use tracing::{debug, instrument, warn};
use std::path::Path;
use tracing::instrument;

pub const THESEUS_POSTGRESQL_BINARIES_URL: &str =
"https://github.com/theseus-rs/postgresql-binaries";
Expand Down Expand Up @@ -47,164 +38,15 @@ pub async fn get_archive(url: &str, version_req: &VersionReq) -> Result<(Version
Ok((version, bytes))
}

/// Acquires a lock file in the [out_dir](Path) to prevent multiple processes from extracting the
/// archive at the same time.
///
/// # Errors
/// * If the lock file cannot be acquired.
#[instrument(level = "debug")]
fn acquire_lock(out_dir: &Path) -> Result<PathBuf> {
let lock_file = out_dir.join("postgresql-archive.lock");

if lock_file.is_file() {
let metadata = lock_file.metadata()?;
let created = metadata.created()?;

if created.elapsed()?.as_secs() > 300 {
warn!(
"Stale lock file detected; removing file to attempt process recovery: {}",
lock_file.to_string_lossy()
);
remove_file(&lock_file)?;
}
}

debug!(
"Attempting to acquire lock: {}",
lock_file.to_string_lossy()
);

for _ in 0..30 {
let lock = std::fs::OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.open(&lock_file);

match lock {
Ok(_) => {
debug!("Lock acquired: {}", lock_file.to_string_lossy());
return Ok(lock_file);
}
Err(error) => {
warn!("unable to acquire lock: {error}");
sleep(Duration::from_secs(1));
}
}
}

Err(Unexpected("Failed to acquire lock".to_string()))
}

/// Extracts the compressed tar `bytes` to the [out_dir](Path).
///
/// # Errors
/// Returns an error if the extraction fails.
#[allow(clippy::cast_precision_loss)]
#[instrument(skip(bytes))]
pub async fn extract(bytes: &Vec<u8>, out_dir: &Path) -> Result<()> {
let input = BufReader::new(Cursor::new(bytes));
let decoder = GzDecoder::new(input);
let mut archive = Archive::new(decoder);
let mut files = 0;
let mut extracted_bytes = 0;

let parent_dir = if let Some(parent) = out_dir.parent() {
parent
} else {
debug!("No parent directory for {}", out_dir.to_string_lossy());
out_dir
};

create_dir_all(parent_dir)?;

let lock_file = acquire_lock(parent_dir)?;
// If the directory already exists, then the archive has already been
// extracted by another process.
if out_dir.exists() {
debug!(
"Directory already exists {}; skipping extraction: ",
out_dir.to_string_lossy()
);
remove_file(&lock_file)?;
return Ok(());
}

let extract_dir = tempfile::tempdir_in(parent_dir)?.into_path();
debug!("Extracting archive to {}", extract_dir.to_string_lossy());

for archive_entry in archive.entries()? {
let mut entry = archive_entry?;
let entry_header = entry.header();
let entry_type = entry_header.entry_type();
let entry_size = entry_header.size()?;
#[cfg(unix)]
let file_mode = entry_header.mode()?;

let entry_header_path = entry_header.path()?.to_path_buf();
let prefix = match entry_header_path.components().next() {
Some(component) => component.as_os_str().to_str().unwrap_or_default(),
None => {
return Err(Unexpected(
"Failed to get file header path prefix".to_string(),
));
}
};
let stripped_entry_header_path = entry_header_path.strip_prefix(prefix)?.to_path_buf();
let mut entry_name = extract_dir.clone();
entry_name.push(stripped_entry_header_path);

if entry_type.is_dir() || entry_name.is_dir() {
create_dir_all(&entry_name)?;
} else if entry_type.is_file() {
let mut output_file = File::create(&entry_name)?;
copy(&mut entry, &mut output_file)?;

files += 1;
extracted_bytes += entry_size;

#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
output_file.set_permissions(std::fs::Permissions::from_mode(file_mode))?;
}
} else if entry_type.is_symlink() {
#[cfg(unix)]
if let Some(symlink_target) = entry.link_name()? {
let symlink_path = entry_name;
std::os::unix::fs::symlink(symlink_target.as_ref(), symlink_path)?;
}
}
}

if out_dir.exists() {
debug!(
"Directory already exists {}; skipping rename and removing extraction directory: {}",
out_dir.to_string_lossy(),
extract_dir.to_string_lossy()
);
remove_dir_all(&extract_dir)?;
} else {
debug!(
"Renaming {} to {}",
extract_dir.to_string_lossy(),
out_dir.to_string_lossy()
);
rename(extract_dir, out_dir)?;
}

if lock_file.is_file() {
debug!("Removing lock file: {}", lock_file.to_string_lossy());
remove_file(lock_file)?;
}

debug!(
"Extracting {} files totalling {}",
files.to_formatted_string(&Locale::en),
human_bytes(extracted_bytes as f64)
);

Ok(())
pub async fn extract(url: &str, bytes: &Vec<u8>, out_dir: &Path) -> Result<()> {
let extractor_fn = extractor::registry::get(url)?;
extractor_fn(bytes, out_dir)
}

#[cfg(test)]
Expand Down
4 changes: 2 additions & 2 deletions postgresql_archive/src/blocking/archive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ pub fn get_archive(url: &str, version_req: &VersionReq) -> crate::Result<(Versio
///
/// # Errors
/// Returns an error if the extraction fails.
pub fn extract(bytes: &Vec<u8>, out_dir: &Path) -> crate::Result<()> {
pub fn extract(url: &str, bytes: &Vec<u8>, out_dir: &Path) -> crate::Result<()> {
RUNTIME
.handle()
.block_on(async move { crate::extract(bytes, out_dir).await })
.block_on(async move { crate::extract(url, bytes, out_dir).await })
}
3 changes: 3 additions & 0 deletions postgresql_archive/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ pub enum Error {
/// Unexpected error
#[error("{0}")]
Unexpected(String),
/// Unsupported extractor
#[error("unsupported extractor for '{0}'")]
UnsupportedExtractor(String),
/// Unsupported hasher
#[error("unsupported hasher for '{0}'")]
UnsupportedHasher(String),
Expand Down
2 changes: 2 additions & 0 deletions postgresql_archive/src/extractor/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod registry;
pub mod theseus_postgresql_binary;
121 changes: 121 additions & 0 deletions postgresql_archive/src/extractor/registry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use crate::extractor::theseus_postgresql_binary;
use crate::Error::{PoisonedLock, UnsupportedExtractor};
use crate::{Result, THESEUS_POSTGRESQL_BINARIES_URL};
use lazy_static::lazy_static;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};

lazy_static! {
static ref REGISTRY: Arc<Mutex<RepositoryRegistry>> =
Arc::new(Mutex::new(RepositoryRegistry::default()));
}

type SupportsFn = fn(&str) -> Result<bool>;
type ExtractFn = fn(&Vec<u8>, &Path) -> Result<()>;

/// Singleton struct to store extractors
#[allow(clippy::type_complexity)]
struct RepositoryRegistry {
extractors: Vec<(Arc<RwLock<SupportsFn>>, Arc<RwLock<ExtractFn>>)>,
}

impl RepositoryRegistry {
/// Creates a new extractor registry.
fn new() -> Self {
Self {
extractors: Vec::new(),
}
}

/// Registers an extractor. Newly registered extractors take precedence over existing ones.
fn register(&mut self, supports_fn: SupportsFn, extract_fn: ExtractFn) {
self.extractors.insert(
0,
(
Arc::new(RwLock::new(supports_fn)),
Arc::new(RwLock::new(extract_fn)),
),
);
}

/// Gets an extractor that supports the specified URL
///
/// # Errors
/// * If the URL is not supported.
fn get(&self, url: &str) -> Result<ExtractFn> {
for (supports_fn, extractor_fn) in &self.extractors {
let supports_function = supports_fn
.read()
.map_err(|error| PoisonedLock(error.to_string()))?;
if supports_function(url)? {
let extractor_function = extractor_fn
.read()
.map_err(|error| PoisonedLock(error.to_string()))?;
return Ok(*extractor_function);
}
}

Err(UnsupportedExtractor(url.to_string()))
}
}

impl Default for RepositoryRegistry {
/// Creates a new repository registry with the default repositories registered.
fn default() -> Self {
let mut registry = Self::new();
registry.register(
|url| Ok(url.starts_with(THESEUS_POSTGRESQL_BINARIES_URL)),
theseus_postgresql_binary::extract,
);
registry
}
}

/// Registers an extractor. Newly registered extractors take precedence over existing ones.
///
/// # Errors
/// * If the registry is poisoned.
#[allow(dead_code)]
pub fn register(supports_fn: SupportsFn, extractor_fn: ExtractFn) -> Result<()> {
let mut registry = REGISTRY
.lock()
.map_err(|error| PoisonedLock(error.to_string()))?;
registry.register(supports_fn, extractor_fn);
Ok(())
}

/// Gets an extractor that supports the specified URL
///
/// # Errors
/// * If the URL is not supported.
pub fn get(url: &str) -> Result<ExtractFn> {
let registry = REGISTRY
.lock()
.map_err(|error| PoisonedLock(error.to_string()))?;
registry.get(url)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_register() -> Result<()> {
register(|url| Ok(url == "https://foo.com"), |_, _| Ok(()))?;
let url = "https://foo.com";
let extractor = get(url)?;
assert!(extractor(&Vec::new(), Path::new("foo")).is_ok());
Ok(())
}

#[test]
fn test_get_error() {
let error = get("foo").unwrap_err();
assert_eq!("unsupported extractor for 'foo'", error.to_string());
}

#[test]
fn test_get_theseus_postgresql_binaries() {
assert!(get(THESEUS_POSTGRESQL_BINARIES_URL).is_ok());
}
}
Loading
Loading