Skip to content

Commit 98b74c4

Browse files
committed
Introduce DatabaseFactory trait
This is part of #486 to add multi-descriptor wallet support to BDK.
1 parent 9165fae commit 98b74c4

File tree

6 files changed

+210
-2
lines changed

6 files changed

+210
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88

9+
- Introduce `DatabaseFactory` trait.
910

1011
## [v0.20.0] - [v0.19.0]
1112

src/database/any.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,42 @@ impl ConfigurableDatabase for AnyDatabase {
425425
impl_from!((), AnyDatabaseConfig, Memory,);
426426
impl_from!(SledDbConfiguration, AnyDatabaseConfig, Sled, #[cfg(feature = "key-value-db")]);
427427
impl_from!(SqliteDbConfiguration, AnyDatabaseConfig, Sqlite, #[cfg(feature = "sqlite")]);
428+
429+
/// Type that implements [`DatabaseFactory`] that builds [`AnyDatabase`].
430+
pub enum AnyDatabaseFactory {
431+
/// Memory database factory
432+
Memory(memory::MemoryDatabaseFactory),
433+
#[cfg(feature = "key-value-db")]
434+
#[cfg_attr(docsrs, doc(cfg(feature = "key-value-db")))]
435+
/// Key-value database factory
436+
Sled(sled::Db),
437+
#[cfg(feature = "sqlite")]
438+
#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))]
439+
/// Sqlite database factory
440+
Sqlite(sqlite::SqliteDatabaseFactory<String>),
441+
}
442+
443+
impl DatabaseFactory for AnyDatabaseFactory {
444+
type Inner = AnyDatabase;
445+
446+
fn build(
447+
&self,
448+
descriptor: &ExtendedDescriptor,
449+
network: Network,
450+
secp: &SecpCtx,
451+
) -> Result<Self::Inner, Error> {
452+
match self {
453+
AnyDatabaseFactory::Memory(f) => {
454+
f.build(descriptor, network, secp).map(Self::Inner::Memory)
455+
}
456+
#[cfg(feature = "key-value-db")]
457+
AnyDatabaseFactory::Sled(f) => {
458+
f.build(descriptor, network, secp).map(Self::Inner::Sled)
459+
}
460+
#[cfg(feature = "sqlite")]
461+
AnyDatabaseFactory::Sqlite(f) => {
462+
f.build(descriptor, network, secp).map(Self::Inner::Sqlite)
463+
}
464+
}
465+
}
466+
}

src/database/keyvalue.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ use crate::database::memory::MapKey;
2121
use crate::database::{BatchDatabase, BatchOperations, Database, SyncTime};
2222
use crate::error::Error;
2323
use crate::types::*;
24+
use crate::wallet::wallet_name_from_descriptor;
25+
26+
use super::DatabaseFactory;
2427

2528
macro_rules! impl_batch_operations {
2629
( { $($after_insert:tt)* }, $process_delete:ident ) => {
@@ -402,6 +405,21 @@ impl BatchDatabase for Tree {
402405
}
403406
}
404407

408+
/// A [`DatabaseFactory`] implementation that builds [`Tree`]
409+
impl DatabaseFactory for sled::Db {
410+
type Inner = sled::Tree;
411+
412+
fn build(
413+
&self,
414+
descriptor: &crate::descriptor::ExtendedDescriptor,
415+
network: bitcoin::Network,
416+
secp: &crate::wallet::utils::SecpCtx,
417+
) -> Result<Self::Inner, Error> {
418+
let name = wallet_name_from_descriptor(descriptor.clone(), None, network, secp)?;
419+
self.open_tree(&name).map_err(Error::Sled)
420+
}
421+
}
422+
405423
#[cfg(test)]
406424
mod test {
407425
use lazy_static::lazy_static;
@@ -492,4 +510,16 @@ mod test {
492510
fn test_sync_time() {
493511
crate::database::test::test_sync_time(get_tree());
494512
}
513+
514+
#[test]
515+
fn test_factory() {
516+
let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
517+
let mut dir = std::env::temp_dir();
518+
dir.push(format!("bdk_{}", time.as_nanos()));
519+
520+
let fac = sled::open(&dir).unwrap();
521+
crate::database::test::test_factory(&fac);
522+
523+
std::fs::remove_dir_all(&dir).unwrap();
524+
}
495525
}

src/database/memory.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ use crate::database::{BatchDatabase, BatchOperations, ConfigurableDatabase, Data
2626
use crate::error::Error;
2727
use crate::types::*;
2828

29+
use super::DatabaseFactory;
30+
2931
// path -> script p{i,e}<path> -> script
3032
// script -> path s<script> -> {i,e}<path>
3133
// outpoint u<outpoint> -> txout
@@ -475,6 +477,23 @@ impl ConfigurableDatabase for MemoryDatabase {
475477
}
476478
}
477479

480+
/// A [`DatabaseFactory`] implementation that builds [`MemoryDatabase`].
481+
#[derive(Debug, Default)]
482+
pub struct MemoryDatabaseFactory;
483+
484+
impl DatabaseFactory for MemoryDatabaseFactory {
485+
type Inner = MemoryDatabase;
486+
487+
fn build(
488+
&self,
489+
_descriptor: &crate::descriptor::ExtendedDescriptor,
490+
_network: bitcoin::Network,
491+
_secp: &crate::wallet::utils::SecpCtx,
492+
) -> Result<Self::Inner, Error> {
493+
Ok(MemoryDatabase::default())
494+
}
495+
}
496+
478497
#[macro_export]
479498
#[doc(hidden)]
480499
/// Artificially insert a tx in the database, as if we had found it with a `sync`. This is a hidden
@@ -579,7 +598,7 @@ macro_rules! doctest_wallet {
579598

580599
#[cfg(test)]
581600
mod test {
582-
use super::MemoryDatabase;
601+
use super::{MemoryDatabase, MemoryDatabaseFactory};
583602

584603
fn get_tree() -> MemoryDatabase {
585604
MemoryDatabase::new()
@@ -629,4 +648,10 @@ mod test {
629648
fn test_sync_time() {
630649
crate::database::test::test_sync_time(get_tree());
631650
}
651+
652+
#[test]
653+
fn test_factory() {
654+
let fac = MemoryDatabaseFactory;
655+
crate::database::test::test_factory(&fac);
656+
}
632657
}

src/database/mod.rs

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
use serde::{Deserialize, Serialize};
2828

2929
use bitcoin::hash_types::Txid;
30-
use bitcoin::{OutPoint, Script, Transaction, TxOut};
30+
use bitcoin::{Network, OutPoint, Script, Transaction, TxOut};
3131

32+
use crate::descriptor::ExtendedDescriptor;
3233
use crate::error::Error;
3334
use crate::types::*;
35+
use crate::wallet::utils::SecpCtx;
3436

3537
pub mod any;
3638
pub use any::{AnyDatabase, AnyDatabaseConfig};
@@ -212,12 +214,28 @@ pub(crate) trait DatabaseUtils: Database {
212214

213215
impl<T: Database> DatabaseUtils for T {}
214216

217+
/// A factory trait that builds databases which share underlying configurations and/or storage
218+
/// paths.
219+
pub trait DatabaseFactory: Sized {
220+
/// Inner type to build
221+
type Inner: BatchDatabase;
222+
223+
/// Builds the defined [`DatabaseFactory::Inner`] type.
224+
fn build(
225+
&self,
226+
descriptor: &ExtendedDescriptor,
227+
network: Network,
228+
secp: &SecpCtx,
229+
) -> Result<Self::Inner, Error>;
230+
}
231+
215232
#[cfg(test)]
216233
pub mod test {
217234
use std::str::FromStr;
218235

219236
use bitcoin::consensus::encode::deserialize;
220237
use bitcoin::hashes::hex::*;
238+
use bitcoin::util::bip32::{self, DerivationPath, ExtendedPubKey};
221239
use bitcoin::*;
222240

223241
use super::*;
@@ -441,5 +459,42 @@ pub mod test {
441459
assert!(tree.get_sync_time().unwrap().is_none());
442460
}
443461

462+
pub fn test_factory<F: DatabaseFactory>(fac: &F) {
463+
let secp = SecpCtx::new();
464+
let network = Network::Regtest;
465+
let master_privkey = bip32::ExtendedPrivKey::from_str("tprv8ZgxMBicQKsPdowxEXJxXqPYd7i7WN3jG8NTVsq9MYVaR7qnLgi5xo1KZq4z1T89GfGs7BwQTVrtVKWozxwuQLgFNcd3snADMeivux1Y5u5").unwrap();
466+
467+
let descriptor = |acc: usize| -> ExtendedDescriptor {
468+
let path = DerivationPath::from_str(&format!("m/84h/1h/{}h", acc)).unwrap();
469+
let sk = master_privkey.derive_priv(&secp, &path).unwrap();
470+
let pk = ExtendedPubKey::from_priv(&secp, &sk);
471+
ExtendedDescriptor::from_str(&format!(
472+
"wpkh([{}/84h/1h/{}h]{}/0/*)",
473+
master_privkey.fingerprint(&secp),
474+
acc,
475+
pk
476+
))
477+
.unwrap()
478+
};
479+
480+
let mut acc_index = 0_usize;
481+
let mut database = || {
482+
let db = fac.build(&descriptor(acc_index), network, &secp).unwrap();
483+
acc_index += 1;
484+
db
485+
};
486+
487+
test_script_pubkey(database());
488+
test_batch_script_pubkey(database());
489+
test_iter_script_pubkey(database());
490+
test_del_script_pubkey(database());
491+
test_utxo(database());
492+
test_raw_tx(database());
493+
test_tx(database());
494+
test_list_transaction(database());
495+
test_last_index(database());
496+
test_sync_time(database());
497+
}
498+
444499
// TODO: more tests...
445500
}

src/database/sqlite.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@
99
// You may not use this file except in accordance with one or both of these
1010
// licenses.
1111

12+
use std::path::Path;
13+
1214
use bitcoin::consensus::encode::{deserialize, serialize};
1315
use bitcoin::hash_types::Txid;
1416
use bitcoin::{OutPoint, Script, Transaction, TxOut};
1517

1618
use crate::database::{BatchDatabase, BatchOperations, Database, SyncTime};
1719
use crate::error::Error;
1820
use crate::types::*;
21+
use crate::wallet::wallet_name_from_descriptor;
1922

2023
use rusqlite::{named_params, Connection};
2124

25+
use super::DatabaseFactory;
26+
2227
static MIGRATIONS: &[&str] = &[
2328
"CREATE TABLE version (version INTEGER)",
2429
"INSERT INTO version VALUES (1)",
@@ -970,11 +975,48 @@ pub fn migrate(conn: &Connection) -> rusqlite::Result<()> {
970975
Ok(())
971976
}
972977

978+
/// A [`DatabaseFactory`] implementation that builds [`SqliteDatabase`].
979+
///
980+
/// Each built database is stored in path of format: `<path_root>_<hash>.<path_ext>`
981+
/// Where `hash` contains identifying data derived with inputs provided by
982+
/// [`DatabaseFactory::build`] or [`DatabaseFactory::build_with_change`] calls.
983+
pub struct SqliteDatabaseFactory<P> {
984+
pub dir: P,
985+
pub ext: String,
986+
}
987+
988+
impl<P: AsRef<Path>> DatabaseFactory for SqliteDatabaseFactory<P> {
989+
type Inner = SqliteDatabase;
990+
991+
fn build(
992+
&self,
993+
descriptor: &crate::descriptor::ExtendedDescriptor,
994+
network: bitcoin::Network,
995+
secp: &crate::wallet::utils::SecpCtx,
996+
) -> Result<Self::Inner, Error> {
997+
// ensure dir exists
998+
std::fs::create_dir_all(&self.dir).map_err(|e| Error::Generic(e.to_string()))?;
999+
1000+
let name = wallet_name_from_descriptor(descriptor.clone(), None, network, secp)?;
1001+
let ext = self.ext.trim_start_matches('.');
1002+
1003+
let mut path = std::path::PathBuf::new();
1004+
path.push(&self.dir);
1005+
path.push(name);
1006+
path.set_extension(ext);
1007+
1008+
// TODO: This is stupid, fix this
1009+
Ok(Self::Inner::new(path.to_str().unwrap().to_string()))
1010+
}
1011+
}
1012+
9731013
#[cfg(test)]
9741014
pub mod test {
9751015
use crate::database::SqliteDatabase;
9761016
use std::time::{SystemTime, UNIX_EPOCH};
9771017

1018+
use super::SqliteDatabaseFactory;
1019+
9781020
fn get_database() -> SqliteDatabase {
9791021
let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
9801022
let mut dir = std::env::temp_dir();
@@ -1031,4 +1073,20 @@ pub mod test {
10311073
fn test_txs() {
10321074
crate::database::test::test_list_transaction(get_database());
10331075
}
1076+
1077+
#[test]
1078+
fn test_factory() {
1079+
let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
1080+
let mut dir = std::env::temp_dir();
1081+
dir.push(format!("bdk_{}", time.as_nanos()));
1082+
1083+
let fac = SqliteDatabaseFactory {
1084+
dir: dir.clone(),
1085+
ext: "db".to_string(),
1086+
};
1087+
1088+
crate::database::test::test_factory(&fac);
1089+
1090+
std::fs::remove_dir_all(&dir).unwrap();
1091+
}
10341092
}

0 commit comments

Comments
 (0)