Skip to content

Commit 68da4b3

Browse files
committed
Introduce DatabaseFactory trait
This is part of #486 to add multi-descriptor wallet support to BDK.
1 parent 9165fae commit 68da4b3

File tree

6 files changed

+225
-2
lines changed

6 files changed

+225
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88

9+
- Introduce `DatabaseFactory` trait.
910

1011
## [v0.20.0] - [v0.19.0]
1112

src/database/any.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,43 @@ impl ConfigurableDatabase for AnyDatabase {
425425
impl_from!((), AnyDatabaseConfig, Memory,);
426426
impl_from!(SledDbConfiguration, AnyDatabaseConfig, Sled, #[cfg(feature = "key-value-db")]);
427427
impl_from!(SqliteDbConfiguration, AnyDatabaseConfig, Sqlite, #[cfg(feature = "sqlite")]);
428+
429+
/// Type that implements [`DatabaseFactory`] that builds [`AnyDatabase`].
430+
pub enum AnyDatabaseFactory {
431+
/// Memory database factory
432+
Memory(memory::MemoryDatabaseFactory),
433+
#[cfg(feature = "key-value-db")]
434+
#[cfg_attr(docsrs, doc(cfg(feature = "key-value-db")))]
435+
/// Key-value database factory
436+
Sled(sled::Db),
437+
#[cfg(feature = "sqlite")]
438+
#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))]
439+
/// Sqlite database factory
440+
Sqlite(sqlite::SqliteDatabaseFactory<String>),
441+
}
442+
443+
impl DatabaseFactory for AnyDatabaseFactory {
444+
type Inner = AnyDatabase;
445+
446+
fn build_with_change(
447+
&self,
448+
descriptor: ExtendedDescriptor,
449+
change_descriptor: Option<ExtendedDescriptor>,
450+
network: Network,
451+
secp: &SecpCtx,
452+
) -> Result<Self::Inner, Error> {
453+
match self {
454+
AnyDatabaseFactory::Memory(f) => f
455+
.build_with_change(descriptor, change_descriptor, network, secp)
456+
.map(Self::Inner::Memory),
457+
#[cfg(feature = "key-value-db")]
458+
AnyDatabaseFactory::Sled(f) => f
459+
.build_with_change(descriptor, change_descriptor, network, secp)
460+
.map(Self::Inner::Sled),
461+
#[cfg(feature = "sqlite")]
462+
AnyDatabaseFactory::Sqlite(f) => f
463+
.build_with_change(descriptor, change_descriptor, network, secp)
464+
.map(Self::Inner::Sqlite),
465+
}
466+
}
467+
}

src/database/keyvalue.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ use crate::database::memory::MapKey;
2121
use crate::database::{BatchDatabase, BatchOperations, Database, SyncTime};
2222
use crate::error::Error;
2323
use crate::types::*;
24+
use crate::wallet::wallet_name_from_descriptor;
25+
26+
use super::DatabaseFactory;
2427

2528
macro_rules! impl_batch_operations {
2629
( { $($after_insert:tt)* }, $process_delete:ident ) => {
@@ -402,6 +405,22 @@ impl BatchDatabase for Tree {
402405
}
403406
}
404407

408+
/// A [`DatabaseFactory`] implementation that builds [`Tree`]
409+
impl DatabaseFactory for sled::Db {
410+
type Inner = sled::Tree;
411+
412+
fn build_with_change(
413+
&self,
414+
descriptor: crate::descriptor::ExtendedDescriptor,
415+
change_descriptor: Option<crate::descriptor::ExtendedDescriptor>,
416+
network: bitcoin::Network,
417+
secp: &crate::wallet::utils::SecpCtx,
418+
) -> Result<Self::Inner, Error> {
419+
let name = wallet_name_from_descriptor(descriptor, change_descriptor, network, secp)?;
420+
self.open_tree(&name).map_err(Error::Sled)
421+
}
422+
}
423+
405424
#[cfg(test)]
406425
mod test {
407426
use lazy_static::lazy_static;
@@ -492,4 +511,16 @@ mod test {
492511
fn test_sync_time() {
493512
crate::database::test::test_sync_time(get_tree());
494513
}
514+
515+
#[test]
516+
fn test_factory() {
517+
let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
518+
let mut dir = std::env::temp_dir();
519+
dir.push(format!("bdk_{}", time.as_nanos()));
520+
521+
let fac = sled::open(&dir).unwrap();
522+
crate::database::test::test_factory(&fac);
523+
524+
std::fs::remove_dir_all(&dir).unwrap();
525+
}
495526
}

src/database/memory.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ use crate::database::{BatchDatabase, BatchOperations, ConfigurableDatabase, Data
2626
use crate::error::Error;
2727
use crate::types::*;
2828

29+
use super::DatabaseFactory;
30+
2931
// path -> script p{i,e}<path> -> script
3032
// script -> path s<script> -> {i,e}<path>
3133
// outpoint u<outpoint> -> txout
@@ -475,6 +477,24 @@ impl ConfigurableDatabase for MemoryDatabase {
475477
}
476478
}
477479

480+
/// A [`DatabaseFactory`] implementation that builds [`MemoryDatabase`].
481+
#[derive(Debug, Default)]
482+
pub struct MemoryDatabaseFactory;
483+
484+
impl DatabaseFactory for MemoryDatabaseFactory {
485+
type Inner = MemoryDatabase;
486+
487+
fn build_with_change(
488+
&self,
489+
_descriptor: crate::descriptor::ExtendedDescriptor,
490+
_change_descriptor: Option<crate::descriptor::ExtendedDescriptor>,
491+
_network: bitcoin::Network,
492+
_secp: &crate::wallet::utils::SecpCtx,
493+
) -> Result<Self::Inner, Error> {
494+
Ok(MemoryDatabase::default())
495+
}
496+
}
497+
478498
#[macro_export]
479499
#[doc(hidden)]
480500
/// Artificially insert a tx in the database, as if we had found it with a `sync`. This is a hidden
@@ -579,7 +599,7 @@ macro_rules! doctest_wallet {
579599

580600
#[cfg(test)]
581601
mod test {
582-
use super::MemoryDatabase;
602+
use super::{MemoryDatabase, MemoryDatabaseFactory};
583603

584604
fn get_tree() -> MemoryDatabase {
585605
MemoryDatabase::new()
@@ -629,4 +649,10 @@ mod test {
629649
fn test_sync_time() {
630650
crate::database::test::test_sync_time(get_tree());
631651
}
652+
653+
#[test]
654+
fn test_factory() {
655+
let fac = MemoryDatabaseFactory;
656+
crate::database::test::test_factory(&fac);
657+
}
632658
}

src/database/mod.rs

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
use serde::{Deserialize, Serialize};
2828

2929
use bitcoin::hash_types::Txid;
30-
use bitcoin::{OutPoint, Script, Transaction, TxOut};
30+
use bitcoin::{Network, OutPoint, Script, Transaction, TxOut};
3131

32+
use crate::descriptor::ExtendedDescriptor;
3233
use crate::error::Error;
3334
use crate::types::*;
35+
use crate::wallet::utils::SecpCtx;
3436

3537
pub mod any;
3638
pub use any::{AnyDatabase, AnyDatabaseConfig};
@@ -212,12 +214,39 @@ pub(crate) trait DatabaseUtils: Database {
212214

213215
impl<T: Database> DatabaseUtils for T {}
214216

217+
/// A factory trait that builds databases which share underlying configurations and/or storage
218+
/// paths.
219+
pub trait DatabaseFactory: Sized {
220+
/// Inner type to build
221+
type Inner: BatchDatabase;
222+
223+
/// Builds the defined [`DatabaseFactory::Inner`] type.
224+
fn build(
225+
&self,
226+
descriptor: ExtendedDescriptor,
227+
network: Network,
228+
secp: &SecpCtx,
229+
) -> Result<Self::Inner, Error> {
230+
self.build_with_change(descriptor, None, network, secp)
231+
}
232+
233+
/// Builds the defined [`DatabaseFactory::Inner`] type with the addition of a change descriptor.
234+
fn build_with_change(
235+
&self,
236+
descriptor: ExtendedDescriptor,
237+
change_descriptor: Option<ExtendedDescriptor>,
238+
network: Network,
239+
secp: &SecpCtx,
240+
) -> Result<Self::Inner, Error>;
241+
}
242+
215243
#[cfg(test)]
216244
pub mod test {
217245
use std::str::FromStr;
218246

219247
use bitcoin::consensus::encode::deserialize;
220248
use bitcoin::hashes::hex::*;
249+
use bitcoin::util::bip32::{self, DerivationPath, ExtendedPubKey};
221250
use bitcoin::*;
222251

223252
use super::*;
@@ -441,5 +470,42 @@ pub mod test {
441470
assert!(tree.get_sync_time().unwrap().is_none());
442471
}
443472

473+
pub fn test_factory<F: DatabaseFactory>(fac: &F) {
474+
let secp = SecpCtx::new();
475+
let network = Network::Regtest;
476+
let master_privkey = bip32::ExtendedPrivKey::from_str("tprv8ZgxMBicQKsPdowxEXJxXqPYd7i7WN3jG8NTVsq9MYVaR7qnLgi5xo1KZq4z1T89GfGs7BwQTVrtVKWozxwuQLgFNcd3snADMeivux1Y5u5").unwrap();
477+
478+
let descriptor = |acc: usize| -> ExtendedDescriptor {
479+
let path = DerivationPath::from_str(&format!("m/84h/1h/{}h", acc)).unwrap();
480+
let sk = master_privkey.derive_priv(&secp, &path).unwrap();
481+
let pk = ExtendedPubKey::from_priv(&secp, &sk);
482+
ExtendedDescriptor::from_str(&format!(
483+
"wpkh([{}/84h/1h/{}h]{}/0/*)",
484+
master_privkey.fingerprint(&secp),
485+
acc,
486+
pk
487+
))
488+
.unwrap()
489+
};
490+
491+
let mut acc_index = 0_usize;
492+
let mut database = || {
493+
let db = fac.build(descriptor(acc_index), network, &secp).unwrap();
494+
acc_index += 1;
495+
db
496+
};
497+
498+
test_script_pubkey(database());
499+
test_batch_script_pubkey(database());
500+
test_iter_script_pubkey(database());
501+
test_del_script_pubkey(database());
502+
test_utxo(database());
503+
test_raw_tx(database());
504+
test_tx(database());
505+
test_list_transaction(database());
506+
test_last_index(database());
507+
test_sync_time(database());
508+
}
509+
444510
// TODO: more tests...
445511
}

src/database/sqlite.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@
99
// You may not use this file except in accordance with one or both of these
1010
// licenses.
1111

12+
use std::path::Path;
13+
1214
use bitcoin::consensus::encode::{deserialize, serialize};
1315
use bitcoin::hash_types::Txid;
1416
use bitcoin::{OutPoint, Script, Transaction, TxOut};
1517

1618
use crate::database::{BatchDatabase, BatchOperations, Database, SyncTime};
1719
use crate::error::Error;
1820
use crate::types::*;
21+
use crate::wallet::wallet_name_from_descriptor;
1922

2023
use rusqlite::{named_params, Connection};
2124

25+
use super::DatabaseFactory;
26+
2227
static MIGRATIONS: &[&str] = &[
2328
"CREATE TABLE version (version INTEGER)",
2429
"INSERT INTO version VALUES (1)",
@@ -970,11 +975,49 @@ pub fn migrate(conn: &Connection) -> rusqlite::Result<()> {
970975
Ok(())
971976
}
972977

978+
/// A [`DatabaseFactory`] implementation that builds [`SqliteDatabase`].
979+
///
980+
/// Each built database is stored in path of format: `<path_root>_<hash>.<path_ext>`
981+
/// Where `hash` contains identifying data derived with inputs provided by
982+
/// [`DatabaseFactory::build`] or [`DatabaseFactory::build_with_change`] calls.
983+
pub struct SqliteDatabaseFactory<P> {
984+
pub dir: P,
985+
pub ext: String,
986+
}
987+
988+
impl<P: AsRef<Path>> DatabaseFactory for SqliteDatabaseFactory<P> {
989+
type Inner = SqliteDatabase;
990+
991+
fn build_with_change(
992+
&self,
993+
descriptor: crate::descriptor::ExtendedDescriptor,
994+
change_descriptor: Option<crate::descriptor::ExtendedDescriptor>,
995+
network: bitcoin::Network,
996+
secp: &crate::wallet::utils::SecpCtx,
997+
) -> Result<Self::Inner, Error> {
998+
// ensure dir exists
999+
std::fs::create_dir_all(&self.dir).map_err(|e| Error::Generic(e.to_string()))?;
1000+
1001+
let name = wallet_name_from_descriptor(descriptor, change_descriptor, network, secp)?;
1002+
let ext = self.ext.trim_start_matches('.');
1003+
1004+
let mut path = std::path::PathBuf::new();
1005+
path.push(&self.dir);
1006+
path.push(name);
1007+
path.set_extension(ext);
1008+
1009+
// TODO: This is stupid, fix this
1010+
Ok(Self::Inner::new(path.to_str().unwrap().to_string()))
1011+
}
1012+
}
1013+
9731014
#[cfg(test)]
9741015
pub mod test {
9751016
use crate::database::SqliteDatabase;
9761017
use std::time::{SystemTime, UNIX_EPOCH};
9771018

1019+
use super::SqliteDatabaseFactory;
1020+
9781021
fn get_database() -> SqliteDatabase {
9791022
let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
9801023
let mut dir = std::env::temp_dir();
@@ -1031,4 +1074,20 @@ pub mod test {
10311074
fn test_txs() {
10321075
crate::database::test::test_list_transaction(get_database());
10331076
}
1077+
1078+
#[test]
1079+
fn test_factory() {
1080+
let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
1081+
let mut dir = std::env::temp_dir();
1082+
dir.push(format!("bdk_{}", time.as_nanos()));
1083+
1084+
let fac = SqliteDatabaseFactory {
1085+
dir: dir.clone(),
1086+
ext: "db".to_string(),
1087+
};
1088+
1089+
crate::database::test::test_factory(&fac);
1090+
1091+
std::fs::remove_dir_all(&dir).unwrap();
1092+
}
10341093
}

0 commit comments

Comments
 (0)