Skip to content

feat(postgres): add PgAdvisoryLockGuardOwned #3442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 123 additions & 25 deletions sqlx-postgres/src/advisory_lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use hkdf::Hkdf;
use once_cell::sync::OnceCell;
use sha2::Sha256;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;

/// A mutex-like type utilizing [Postgres advisory locks].
///
Expand Down Expand Up @@ -82,6 +83,11 @@ pub struct PgAdvisoryLockGuard<'lock, C: AsMut<PgConnection>> {
conn: Option<C>,
}

pub struct PgAdvisoryLockGuardOwned<C: AsMut<PgConnection>> {
lock: Arc<PgAdvisoryLock>,
conn: Option<C>,
}

impl PgAdvisoryLock {
/// Construct a `PgAdvisoryLock` using the given string as a key.
///
Expand Down Expand Up @@ -203,22 +209,7 @@ impl PgAdvisoryLock {
&self,
mut conn: C,
) -> Result<PgAdvisoryLockGuard<'_, C>> {
match &self.key {
PgAdvisoryLockKey::BigInt(key) => {
crate::query::query("SELECT pg_advisory_lock($1)")
.bind(key)
.execute(conn.as_mut())
.await?;
}
PgAdvisoryLockKey::IntPair(key1, key2) => {
crate::query::query("SELECT pg_advisory_lock($1, $2)")
.bind(key1)
.bind(key2)
.execute(conn.as_mut())
.await?;
}
}

self.execute_acquire(conn.as_mut()).await?;
Ok(PgAdvisoryLockGuard::new(self, conn))
}

Expand Down Expand Up @@ -246,26 +237,68 @@ impl PgAdvisoryLock {
&self,
mut conn: C,
) -> Result<Either<PgAdvisoryLockGuard<'_, C>, C>> {
let locked: bool = match &self.key {
let locked = self.execute_try_acquire(conn.as_mut()).await?;
if locked {
Ok(Either::Left(PgAdvisoryLockGuard::new(self, conn)))
} else {
Ok(Either::Right(conn))
}
}

pub async fn acquire_owned<C: AsMut<PgConnection>>(
self: Arc<Self>,
mut conn: C,
) -> Result<PgAdvisoryLockGuardOwned<C>> {
self.execute_acquire(conn.as_mut()).await?;
Ok(PgAdvisoryLockGuardOwned::new(self, conn))
}

pub async fn try_acquire_owned<C: AsMut<PgConnection>>(
self: Arc<Self>,
mut conn: C,
) -> Result<Either<PgAdvisoryLockGuardOwned<C>, C>> {
let locked = self.execute_try_acquire(conn.as_mut()).await?;
if locked {
Ok(Either::Left(PgAdvisoryLockGuardOwned::new(self, conn)))
} else {
Ok(Either::Right(conn))
}
}

async fn execute_acquire(&self, conn: &mut PgConnection) -> Result<(), sqlx_core::Error> {
match &self.key {
PgAdvisoryLockKey::BigInt(key) => {
crate::query::query("SELECT pg_advisory_lock($1)")
.bind(key)
.execute(conn.as_mut())
.await?;
}
PgAdvisoryLockKey::IntPair(key1, key2) => {
crate::query::query("SELECT pg_advisory_lock($1, $2)")
.bind(key1)
.bind(key2)
.execute(conn.as_mut())
.await?;
}
}
Ok(())
}

async fn execute_try_acquire(&self, conn: &mut PgConnection) -> Result<bool, sqlx_core::Error> {
match &self.key {
PgAdvisoryLockKey::BigInt(key) => {
crate::query_scalar::query_scalar("SELECT pg_try_advisory_lock($1)")
.bind(key)
.fetch_one(conn.as_mut())
.await?
.await
}
PgAdvisoryLockKey::IntPair(key1, key2) => {
crate::query_scalar::query_scalar("SELECT pg_try_advisory_lock($1, $2)")
.bind(key1)
.bind(key2)
.fetch_one(conn.as_mut())
.await?
.await
}
};

if locked {
Ok(Either::Left(PgAdvisoryLockGuard::new(self, conn)))
} else {
Ok(Either::Right(conn))
}
}

Expand Down Expand Up @@ -419,3 +452,68 @@ impl<'lock, C: AsMut<PgConnection>> Drop for PgAdvisoryLockGuard<'lock, C> {
}
}
}

impl<C: AsMut<PgConnection>> PgAdvisoryLockGuardOwned<C> {
fn new(lock: Arc<PgAdvisoryLock>, conn: C) -> Self {
Self {
lock,
conn: Some(conn),
}
}

pub fn leak(mut self) -> C {
self.conn.take().expect(NONE_ERR)
}

pub async fn release_now(mut self) -> Result<C> {
let (conn, released) = self
.lock
.force_release(self.conn.take().expect(NONE_ERR))
.await?;

if !released {
tracing::warn!(
lock = ?self.lock.key,
"PgAdvisoryLockGuard: advisory lock was not held by the contained connection",
);
}

Ok(conn)
}
}

impl<C: AsMut<PgConnection>> Drop for PgAdvisoryLockGuardOwned<C> {
fn drop(&mut self) {
if let Some(mut conn) = self.conn.take() {
conn.as_mut()
.queue_simple_query(self.lock.get_release_query());
}
}
}

impl<C: AsRef<PgConnection> + AsMut<PgConnection>> Deref for PgAdvisoryLockGuardOwned<C> {
type Target = PgConnection;

fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
impl<C: AsMut<PgConnection> + AsRef<PgConnection>> DerefMut for PgAdvisoryLockGuardOwned<C> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut()
}
}

impl<C: AsMut<PgConnection> + AsRef<PgConnection>> AsRef<PgConnection>
for PgAdvisoryLockGuardOwned<C>
{
fn as_ref(&self) -> &PgConnection {
self.conn.as_ref().expect(NONE_ERR).as_ref()
}
}

impl<C: AsMut<PgConnection>> AsMut<PgConnection> for PgAdvisoryLockGuardOwned<C> {
fn as_mut(&mut self) -> &mut PgConnection {
self.conn.as_mut().expect(NONE_ERR).as_mut()
}
}
66 changes: 66 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1833,6 +1833,72 @@ async fn test_advisory_locks() -> anyhow::Result<()> {
Ok(())
}

#[sqlx_macros::test]
async fn test_advisory_locks_with_owned_guards() -> anyhow::Result<()> {
let pool = PgPoolOptions::new()
.max_connections(2)
.connect(&dotenvy::var("DATABASE_URL")?)
.await?;

let lock1 = Arc::new(PgAdvisoryLock::new("sqlx-postgres-tests-1"));
let lock2 = Arc::new(PgAdvisoryLock::new("sqlx-postgres-tests-2"));

let conn1 = pool.acquire().await?;
let mut conn1_lock1 = lock1.clone().acquire_owned(conn1).await?;

// try acquiring a recursive lock through a mutable reference then dropping
drop(lock1.clone().acquire_owned(&mut conn1_lock1).await?);

let conn2 = pool.acquire().await?;
let conn2_lock2 = lock2.clone().acquire_owned(conn2).await?;

sqlx_core::rt::spawn({
let lock1 = lock1.clone();
let lock2 = lock2.clone();

async move {
let conn2_lock2 = lock1
.clone()
.try_acquire_owned(conn2_lock2)
.await?
.right_or_else(|_| {
panic!(
"acquired lock but wasn't supposed to! Key: {:?}",
lock1.key()
)
});

let (conn2, released) = lock2.force_release(conn2_lock2).await?;
assert!(released);

// acquire both locks but let the pool release them
let conn2_lock1 = lock1.acquire_owned(conn2).await?;
let _conn2_lock1and2 = lock2.acquire_owned(conn2_lock1).await?;

anyhow::Ok(())
}
});

// acquire lock2 on conn1, we leak the lock1 guard so we can manually release it before lock2
let conn1_lock1and2 = lock2.clone().acquire_owned(conn1_lock1.leak()).await?;

// release lock1 while holding lock2
let (conn1_lock2, released) = lock1.force_release(conn1_lock1and2).await?;
assert!(released);

let conn1 = conn1_lock2.release_now().await?;

// acquire both locks to be sure they were released
{
let conn1_lock1 = lock1.acquire_owned(conn1).await?;
let _conn1_lock1and2 = lock2.acquire_owned(conn1_lock1).await?;
}

pool.close().await;

Ok(())
}

#[sqlx_macros::test]
async fn test_postgres_bytea_hex_deserialization_errors() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
Expand Down