Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ thiserror = "2.0.12"
lazy_static = "1.5"
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
pg_walstream = "0.4.0"
pg_walstream = "0.4.1"
tiberius = { version = "0.12.3", features = ["tds73", "sql-browser-tokio", "bigdecimal", "rust_decimal", "time", "chrono"] }
sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "mysql", "sqlite", "chrono", "uuid"] }
flate2 = "1.1.5"
Expand Down
39 changes: 26 additions & 13 deletions pg2any-lib/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ use crate::destinations::{DestinationFactory, DestinationHandler};
use crate::error::{CdcError, Result};
use crate::lsn_tracker::{LsnTracker, SharedLsnFeedback};
use crate::monitoring::{MetricsCollector, MetricsCollectorTrait};
use crate::pg_replication::{ReplicationManager, ReplicationStream};
use crate::transaction_manager::{
PendingTransactionFile, TransactionFileMetadata, TransactionManager,
};
use crate::types::{EventType, Lsn};
use chrono::{DateTime, Utc};
use pg_walstream::{LogicalReplicationStream, ReplicationStreamConfig};
use std::collections::{BinaryHeap, HashMap};
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot, Mutex};
Expand Down Expand Up @@ -78,7 +78,7 @@ pub struct CdcClient {
/// Transaction file manager for file-based workflow
transaction_file_manager: Arc<TransactionManager>,
/// Replication stream for PostgreSQL connection
replication_stream: Arc<Mutex<ReplicationStream>>,
replication_stream: Arc<Mutex<LogicalReplicationStream>>,
}

impl CdcClient {
Expand Down Expand Up @@ -124,8 +124,9 @@ impl CdcClient {

// Create replication stream
info!("Creating replication stream");
let replication_manager = ReplicationManager::new(config.clone());
let replication_stream = replication_manager.create_stream_async().await?;
let stream_config = ReplicationStreamConfig::from(&config);
let replication_stream =
LogicalReplicationStream::new(&config.source_connection_string, stream_config).await?;

let client = Self {
config,
Expand Down Expand Up @@ -173,10 +174,11 @@ impl CdcClient {

// Start the replication stream
{
let start_xlog = start_lsn.map(|lsn| lsn.0);
self.replication_stream
.lock()
.await
.start(start_lsn)
.start(start_xlog)
.await?;
}

Expand Down Expand Up @@ -209,7 +211,7 @@ impl CdcClient {
// Get shared_lsn_feedback from stored replication_stream
let shared_lsn_feedback = {
let stream_guard = self.replication_stream.as_ref().lock().await;
stream_guard.shared_lsn_feedback().clone()
stream_guard.shared_lsn_feedback.clone()
};

if let Some(ref mut handler) = self.destination_handler {
Expand Down Expand Up @@ -456,9 +458,9 @@ impl CdcClient {
/// ## Shutdown Coordination
///
/// During graceful shutdown, the producer simply exits gracefully without transferring
/// ownership. The ReplicationStream remains stored in CdcClient. The main thread's stop()
/// ownership. The LogicalReplicationStream remains stored in CdcClient. The main thread's stop()
/// function waits for both producer and consumer to complete, ensuring all transactions
/// are committed, then calls stop() on the stored ReplicationStream to send final ACK.
/// are committed, then calls stop() on the stored LogicalReplicationStream to send final ACK.
/// This ensures the ACK includes all transactions successfully applied by the consumer,
/// preventing re-download of already applied transactions on restart.
///
Expand All @@ -470,7 +472,7 @@ impl CdcClient {
/// 3. Consumer receives None from mpsc (channel closed) and processes remaining queue
/// 4. Consumer then exits after draining all pending transactions
async fn run_producer(
replication_stream: Arc<Mutex<ReplicationStream>>,
replication_stream: Arc<Mutex<LogicalReplicationStream>>,
cancellation_token: CancellationToken,
start_lsn: Lsn,
metrics_collector: Arc<MetricsCollector>,
Expand Down Expand Up @@ -538,7 +540,7 @@ impl CdcClient {
// Get the next event from the replication stream (lock for the duration of the call)
let event_result = {
let mut stream = replication_stream.lock().await;
stream.next_event(&cancellation_token).await
stream.next_event_with_retry(&cancellation_token).await
};

match event_result {
Expand Down Expand Up @@ -677,7 +679,7 @@ impl CdcClient {
info!("Producer: StreamCommit for transaction {}", transaction_id);

// Move streaming transaction file to pending and notify consumer
if let Some(_) = streaming_txs.remove(transaction_id) {
if streaming_txs.remove(transaction_id).is_some() {
// Use helper function to handle commit logic
if let Err(e) = Self::handle_transaction_commit(
*transaction_id,
Expand Down Expand Up @@ -1231,12 +1233,23 @@ impl CdcClient {
info!("Sending final ACK to PostgreSQL before shutdown");
let mut stream = self.replication_stream.as_ref().lock().await;
stream
.shared_lsn_feedback()
.shared_lsn_feedback
.log_state("Final shutdown - LSN state before ACK");

if let Err(e) = stream.send_feedback() {
warn!("Failed to send final feedback: {}", e);
}

info!(
"Stopping logical replication stream (last received LSN: {})",
pg_walstream::format_lsn(stream.current_lsn())
);

if let Err(e) = stream.stop().await {
error!("Failed to stop replication stream: {}", e);
return Err(e);
return Err(CdcError::from(e));
}

info!("Final ACK sent successfully to PostgreSQL");
}

Expand Down
72 changes: 72 additions & 0 deletions pg2any-lib/src/destinations/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
/// - Maps source schemas to destination databases/schemas
/// - Essential for cross-database replication
use std::collections::HashMap;

/// Map a source schema to destination schema using provided mappings
pub fn map_schema(schema_mappings: &HashMap<String, String>, source_schema: &str) -> String {
schema_mappings
Expand All @@ -14,6 +15,77 @@ pub fn map_schema(schema_mappings: &HashMap<String, String>, source_schema: &str
.unwrap_or_else(|| source_schema.to_string())
}

/// Execute a batch of SQL commands within a single sqlx transaction with optional pre-commit hook.
///
/// This is the shared implementation used by both MySQL and SQLite destinations.
/// All commands are executed atomically — if any command fails, the entire batch is rolled back.
///
/// # Arguments
/// * `pool` - sqlx connection pool for the target database
/// * `commands` - SQL commands to execute within a single transaction
/// * `pre_commit_hook` - Optional async callback invoked BEFORE COMMIT (rolled back on failure)
/// * `db_name` - Database name for error messages (e.g., "MySQL", "SQLite")
#[cfg(any(feature = "mysql", feature = "sqlite"))]
pub(crate) async fn execute_sqlx_batch_with_hook<DB>(
pool: &sqlx::Pool<DB>,
commands: &[String],
pre_commit_hook: Option<super::destination_factory::PreCommitHook>,
db_name: &str,
) -> crate::error::Result<()>
where
DB: sqlx::Database,
for<'c> &'c mut <DB as sqlx::Database>::Connection: sqlx::Executor<'c, Database = DB>,
for<'q> <DB as sqlx::Database>::Arguments<'q>: sqlx::IntoArguments<'q, DB>,
{
// Begin a transaction
let mut tx = pool.begin().await.map_err(|e| {
crate::error::CdcError::generic(format!("{db_name} BEGIN transaction failed: {e}"))
})?;

// Execute all commands in the transaction
for (idx, sql) in commands.iter().enumerate() {
if let Err(e) = sqlx::query(sql).execute(&mut *tx).await {
// Rollback on error
if let Err(rollback_err) = tx.rollback().await {
tracing::error!(
"{db_name} ROLLBACK failed after execution error: {}",
rollback_err
);
}
return Err(crate::error::CdcError::generic(format!(
"{db_name} execute_sql_batch failed at command {}/{}: {}",
idx + 1,
commands.len(),
e
)));
}
}

// Execute pre-commit hook BEFORE transaction COMMIT
if let Some(hook) = pre_commit_hook {
if let Err(e) = hook().await {
// Rollback transaction if hook fails
if let Err(rollback_err) = tx.rollback().await {
tracing::error!(
"{db_name} ROLLBACK failed after pre-commit hook error: {}",
rollback_err
);
}
return Err(crate::error::CdcError::generic(format!(
"{db_name} pre-commit hook failed, transaction rolled back: {}",
e
)));
}
}

// Commit the transaction
tx.commit().await.map_err(|e| {
crate::error::CdcError::generic(format!("{db_name} COMMIT transaction failed: {e}"))
})?;

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
63 changes: 1 addition & 62 deletions pg2any-lib/src/destinations/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,58 +82,7 @@ impl DestinationHandler for MySQLDestination {
.as_ref()
.ok_or_else(|| CdcError::generic("MySQL pool not initialized"))?;

// Begin a transaction
let mut tx = pool
.begin()
.await
.map_err(|e| CdcError::generic(format!("MySQL BEGIN transaction failed: {e}")))?;

// Execute all commands in the transaction
for (idx, sql) in commands.iter().enumerate() {
if let Err(e) = sqlx::query(sql).execute(&mut *tx).await {
// Rollback on error
if let Err(rollback_err) = tx.rollback().await {
tracing::error!(
"MySQL ROLLBACK failed after execution error: {}",
rollback_err
);
}
return Err(CdcError::generic(format!(
"MySQL execute_sql_batch failed at command {}/{}: {}",
idx + 1,
commands.len(),
e
)));
}
}

// CRITICAL: Execute pre-commit hook BEFORE transaction COMMIT
// This ensures checkpoint updates are atomic with data changes:
// - If hook fails: transaction rolls back, checkpoint not updated
// - If COMMIT fails: transaction rolls back, checkpoint not persisted
// - If crash before COMMIT: transaction rolls back, checkpoint file not written
// - If crash after COMMIT: both data and checkpoint are durable
if let Some(hook) = pre_commit_hook {
if let Err(e) = hook().await {
// Rollback transaction if hook fails
if let Err(rollback_err) = tx.rollback().await {
tracing::error!(
"MySQL ROLLBACK failed after pre-commit hook error: {}",
rollback_err
);
}
return Err(CdcError::generic(format!(
"MySQL pre-commit hook failed, transaction rolled back: {}",
e
)));
}
}

tx.commit()
.await
.map_err(|e| CdcError::generic(format!("MySQL COMMIT transaction failed: {e}")))?;

Ok(())
super::common::execute_sqlx_batch_with_hook(pool, commands, pre_commit_hook, "MySQL").await
}

async fn close(&mut self) -> Result<()> {
Expand All @@ -152,21 +101,11 @@ impl DestinationHandler for MySQLDestination {

#[cfg(test)]
mod tests {
use super::super::common;
use super::*;

#[test]
fn test_mysql_destination_creation() {
let destination = MySQLDestination::new();
assert!(destination.pool.is_none());
}

#[test]
fn test_map_schema() {
let mut mappings = HashMap::new();
mappings.insert("public".to_string(), "cdc_db".to_string());

assert_eq!(common::map_schema(&mappings, "public"), "cdc_db");
assert_eq!(common::map_schema(&mappings, "other"), "other");
}
}
49 changes: 1 addition & 48 deletions pg2any-lib/src/destinations/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,54 +113,7 @@ impl DestinationHandler for SQLiteDestination {
.as_ref()
.ok_or_else(|| CdcError::generic("SQLite pool not initialized"))?;

// Begin a transaction
let mut tx = pool
.begin()
.await
.map_err(|e| CdcError::generic(format!("SQLite BEGIN transaction failed: {e}")))?;

// Execute all commands in the transaction
for (idx, sql) in commands.iter().enumerate() {
if let Err(e) = sqlx::query(sql).execute(&mut *tx).await {
// Rollback on error
if let Err(rollback_err) = tx.rollback().await {
tracing::error!(
"SQLite ROLLBACK failed after execution error: {}",
rollback_err
);
}
return Err(CdcError::generic(format!(
"SQLite execute_sql_batch failed at command {}/{}: {}",
idx + 1,
commands.len(),
e
)));
}
}

// Execute pre-commit hook BEFORE transaction COMMIT
if let Some(hook) = pre_commit_hook {
if let Err(e) = hook().await {
// Rollback transaction if hook fails
if let Err(rollback_err) = tx.rollback().await {
tracing::error!(
"SQLite ROLLBACK failed after pre-commit hook error: {}",
rollback_err
);
}
return Err(CdcError::generic(format!(
"SQLite pre-commit hook failed, transaction rolled back: {}",
e
)));
}
}

// Commit the transaction
tx.commit()
.await
.map_err(|e| CdcError::generic(format!("SQLite COMMIT transaction failed: {e}")))?;

Ok(())
super::common::execute_sqlx_batch_with_hook(pool, commands, pre_commit_hook, "SQLite").await
}

async fn close(&mut self) -> Result<()> {
Expand Down
7 changes: 0 additions & 7 deletions pg2any-lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ pub mod error;
// Destination handlers
pub mod types;

// Low-level PostgreSQL replication using libpq-sys
pub mod pg_replication;

pub mod lsn_tracker;

// High-level client interface
Expand All @@ -88,7 +85,6 @@ pub use config::{Config, ConfigBuilder};
pub use env::load_config_from_env;
pub use error::CdcError;
pub use lsn_tracker::{create_lsn_tracker_with_load, LsnTracker};
pub use pg_replication::{PgReplicationConnection, ReplicationConnectionRetry, RetryConfig};
pub type CdcResult<T> = Result<T, CdcError>;

pub mod destinations;
Expand Down Expand Up @@ -130,9 +126,6 @@ pub use pg_walstream::{
PG_EPOCH_OFFSET_SECS,
};

// Re-export PgResult from pg_replication (pg2any-lib's version with libpq)
pub use pg_replication::PgResult;

// Re-export SharedLsnFeedback from lsn_tracker (pg2any-lib's version with log_status method)
pub use lsn_tracker::SharedLsnFeedback;

Expand Down
Loading