Skip to content

Commit a9a4d94

Browse files
authored
Merge pull request #21 from jessekrubin/main
pool `conn_for_each`
2 parents 075cc3e + 66f8825 commit a9a4d94

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed

src/pool.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,4 +238,37 @@ impl Pool {
238238
let n = self.state.counter.fetch_add(1, Relaxed);
239239
&self.state.clients[n as usize % self.state.clients.len()]
240240
}
241+
242+
/// Runs a function on all connections in the pool asynchronously.
243+
///
244+
/// The function is executed on each connection concurrently.
245+
pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
246+
where
247+
F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
248+
T: Send + 'static,
249+
{
250+
let func = Arc::new(func);
251+
let futures = self.state.clients.iter().map(|client| {
252+
let func = func.clone();
253+
async move { client.conn(move |conn| func(conn)).await }
254+
});
255+
join_all(futures).await
256+
}
257+
258+
/// Runs a function on all connections in the pool, blocking the current thread.
259+
pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
260+
where
261+
F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
262+
T: Send + 'static,
263+
{
264+
let func = Arc::new(func);
265+
self.state
266+
.clients
267+
.iter()
268+
.map(|client| {
269+
let func = func.clone();
270+
client.conn_blocking(move |conn| func(conn))
271+
})
272+
.collect()
273+
}
241274
}

tests/tests.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ macro_rules! async_test {
8383
async_test!(test_journal_mode);
8484
async_test!(test_concurrency);
8585
async_test!(test_pool);
86+
async_test!(test_pool_conn_for_each);
8687

8788
async fn test_journal_mode() {
8889
let tmp_dir = tempfile::tempdir().unwrap();
@@ -166,3 +167,63 @@ async fn test_pool() {
166167
.collect::<Result<(), Error>>()
167168
.expect("collecting query results");
168169
}
170+
171+
async fn test_pool_conn_for_each() {
172+
// make dummy db
173+
let tmp_dir = tempfile::tempdir().unwrap();
174+
{
175+
let client = ClientBuilder::new()
176+
.journal_mode(JournalMode::Wal)
177+
.path(tmp_dir.path().join("sqlite.db"))
178+
.open_blocking()
179+
.expect("client unable to be opened");
180+
181+
client
182+
.conn_blocking(|conn| {
183+
conn.execute(
184+
"CREATE TABLE testing (id INTEGER PRIMARY KEY, val TEXT NOT NULL)",
185+
(),
186+
)?;
187+
conn.execute("INSERT INTO testing VALUES (1, ?)", ["value1"])
188+
})
189+
.expect("writing schema and seed data");
190+
}
191+
192+
let pool = PoolBuilder::new()
193+
.path(tmp_dir.path().join("another-sqlite.db"))
194+
.num_conns(2)
195+
.open()
196+
.await
197+
.expect("pool unable to be opened");
198+
199+
let dummy_db_path = tmp_dir.path().join("sqlite.db");
200+
let attach_fn = move |conn: &rusqlite::Connection| {
201+
conn.execute(
202+
"ATTACH DATABASE ? AS dummy",
203+
[dummy_db_path.to_str().unwrap()],
204+
)
205+
};
206+
// attach to the dummy db via conn_for_each
207+
pool.conn_for_each(attach_fn).await;
208+
209+
// check that the dummy db is attached
210+
fn check_fn(conn: &rusqlite::Connection) -> Result<Vec<String>, rusqlite::Error> {
211+
let mut stmt = conn
212+
.prepare_cached("SELECT name FROM dummy.sqlite_master WHERE type='table'")
213+
.unwrap();
214+
let names = stmt
215+
.query_map([], |row| row.get(0))
216+
.unwrap()
217+
.map(|r| r.unwrap())
218+
.collect::<Vec<String>>();
219+
220+
Ok(names)
221+
}
222+
let res = pool.conn_for_each(check_fn).await;
223+
for r in res {
224+
assert_eq!(r.unwrap(), vec!["testing"]);
225+
}
226+
227+
// cleanup
228+
pool.close().await.expect("closing client conn");
229+
}

0 commit comments

Comments
 (0)