Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 109 additions & 115 deletions payjoin-cli/src/app/config.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::path::PathBuf;

use anyhow::Result;
use clap::ArgMatches;
use config::builder::DefaultState;
use config::{ConfigError, File, FileFormat};
use payjoin::bitcoin::FeeRate;
use payjoin::core::version::Version;
use serde::Deserialize;
use url::Url;

use crate::cli::{Cli, Commands};
use crate::db;

type Builder = config::builder::ConfigBuilder<DefaultState>;
Expand Down Expand Up @@ -58,78 +59,75 @@ pub struct Config {
}

impl Config {
/// Version flags in order of precedence (newest to oldest)
const VERSION_FLAGS: &'static [(&'static str, u8)] = &[("bip77", 2), ("bip78", 1)];

/// Check for multiple version flags and return the highest precedence version
fn determine_version(matches: &ArgMatches) -> Result<u8, ConfigError> {
fn determine_version(cli: &Cli) -> Result<Version, ConfigError> {
let mut selected_version = None;
for &(flag, version) in Self::VERSION_FLAGS.iter() {
if matches.get_flag(flag) {
if selected_version.is_some() {
return Err(ConfigError::Message(format!(
"Multiple version flags specified. Please use only one of: {}",
Self::VERSION_FLAGS
.iter()
.map(|(flag, _)| format!("--{flag}"))
.collect::<Vec<_>>()
.join(", ")
)));
}
selected_version = Some(version);

// Check for BIP77 (v2)
if cli.flags.bip77.unwrap_or(false) {
selected_version = Some(Version::Two);
}

// Check for BIP78 (v1)
if cli.flags.bip78.unwrap_or(false) {
if selected_version.is_some() {
return Err(ConfigError::Message(
"Multiple version flags specified. Please use only one of: --bip77, --bip78"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I aimed to not change any language or behavior here, but the difference between the v1/v2 features and bip77/bip78 would be helpful. Is bip77 a "version" flag? If not, I can update the language in another PR.

.to_string(),
));
}
selected_version = Some(Version::One);
}

if let Some(version) = selected_version {
return Ok(version);
}
};

// If no version explicitly selected, use default based on available features
#[cfg(feature = "v2")]
return Ok(2);
return Ok(Version::Two);
#[cfg(all(feature = "v1", not(feature = "v2")))]
return Ok(1);

return Ok(Version::One);
#[cfg(not(any(feature = "v1", feature = "v2")))]
return Err(ConfigError::Message(
"No valid version available - must compile with v1 or v2 feature".to_string(),
));
}

pub(crate) fn new(matches: &ArgMatches) -> Result<Self, ConfigError> {
let mut builder = config::Config::builder();
builder = add_bitcoind_defaults(builder, matches)?;
builder = add_common_defaults(builder, matches)?;
pub(crate) fn new(cli: &Cli) -> Result<Self, ConfigError> {
let mut config = config::Config::builder();
config = add_bitcoind_defaults(config, cli)?;
config = add_common_defaults(config, cli)?;

let version = Self::determine_version(matches)?;
let version = Self::determine_version(cli)?;

match version {
1 => {
Version::One => {
#[cfg(feature = "v1")]
{
builder = add_v1_defaults(builder)?;
config = add_v1_defaults(config, cli)?;
}
#[cfg(not(feature = "v1"))]
return Err(ConfigError::Message(
"BIP78 (v1) selected but v1 feature not enabled".to_string(),
));
}
2 => {
Version::Two => {
#[cfg(feature = "v2")]
{
builder = add_v2_defaults(builder, matches)?;
config = add_v2_defaults(config, cli)?;
}
#[cfg(not(feature = "v2"))]
return Err(ConfigError::Message(
"BIP77 (v2) selected but v2 feature not enabled".to_string(),
));
}
_ => unreachable!("determine_version() should only return 1 or 2"),
}

builder = handle_subcommands(builder, matches)?;
builder = builder.add_source(File::new("config.toml", FileFormat::Toml).required(false));
config = handle_subcommands(config, cli)?;
config = config.add_source(File::new("config.toml", FileFormat::Toml).required(false));

let built_config = builder.build()?;
let built_config = config.build()?;

let mut config = Config {
db_path: built_config.get("db_path")?,
Expand All @@ -139,7 +137,7 @@ impl Config {
};

match version {
1 => {
Version::One => {
#[cfg(feature = "v1")]
{
match built_config.get::<V1Config>("v1") {
Expand All @@ -155,7 +153,7 @@ impl Config {
"BIP78 (v1) selected but v1 feature not enabled".to_string(),
));
}
2 => {
Version::Two => {
#[cfg(feature = "v2")]
{
match built_config.get::<V2Config>("v2") {
Expand All @@ -171,7 +169,6 @@ impl Config {
"BIP77 (v2) selected but v2 feature not enabled".to_string(),
));
}
_ => unreachable!("determine_version() should only return 1 or 2"),
}

if config.version.is_none() {
Expand Down Expand Up @@ -204,105 +201,102 @@ impl Config {
}

/// Set up default values and CLI overrides for Bitcoin RPC connection settings
fn add_bitcoind_defaults(builder: Builder, matches: &ArgMatches) -> Result<Builder, ConfigError> {
builder
fn add_bitcoind_defaults(config: Builder, cli: &Cli) -> Result<Builder, ConfigError> {
// Set default values
let config = config
.set_default("bitcoind.rpchost", "http://localhost:18443")?
.set_override_option(
"bitcoind.rpchost",
matches.get_one::<Url>("rpchost").map(|s| s.as_str()),
)?
.set_default("bitcoind.cookie", None::<String>)?
.set_override_option(
"bitcoind.cookie",
matches.get_one::<String>("cookie_file").map(|s| s.as_str()),
)?
.set_default("bitcoind.rpcuser", "bitcoin")?
.set_override_option(
"bitcoind.rpcuser",
matches.get_one::<String>("rpcuser").map(|s| s.as_str()),
)?
.set_default("bitcoind.rpcpassword", "")?
.set_override_option(
"bitcoind.rpcpassword",
matches.get_one::<String>("rpcpassword").map(|s| s.as_str()),
)
.set_default("bitcoind.rpcpassword", "")?;
Comment on lines +205 to +210
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The missing defaults were what was causing the test failures before, so the tests caught a behavioral change I made. It didn't really make sense to me to provide defaults for these since rpcuser and rpcpassword don't have sensible defaults the way rpchost does, so it felt weird to give one here. I can take a look at refactoring these out of the tests later if we want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good issue to follow up with. The way you solved it for now works since it doesn't regress.


// Override config values with command line arguments if applicable
let rpchost = cli.rpchost.as_ref().map(|s| s.as_str());
let cookie_file = cli.cookie_file.as_ref().map(|p| p.to_string_lossy().into_owned());
let rpcuser = cli.rpcuser.as_deref();
let rpcpassword = cli.rpcpassword.as_deref();

config
.set_override_option("bitcoind.rpchost", rpchost)?
.set_override_option("bitcoind.cookie", cookie_file)?
.set_override_option("bitcoind.rpcuser", rpcuser)?
.set_override_option("bitcoind.rpcpassword", rpcpassword)
}

/// Set up default values and CLI overrides for common settings shared between v1 and v2
fn add_common_defaults(builder: Builder, matches: &ArgMatches) -> Result<Builder, ConfigError> {
builder
.set_default("db_path", db::DB_PATH)?
.set_override_option("db_path", matches.get_one::<String>("db_path").map(|s| s.as_str()))
fn add_common_defaults(config: Builder, cli: &Cli) -> Result<Builder, ConfigError> {
let db_path = cli.db_path.as_ref().map(|p| p.to_string_lossy().into_owned());
config.set_default("db_path", db::DB_PATH)?.set_override_option("db_path", db_path)
}

/// Set up default values for v1-specific settings when v2 is not enabled
#[cfg(feature = "v1")]
fn add_v1_defaults(builder: Builder) -> Result<Builder, ConfigError> {
builder
fn add_v1_defaults(config: Builder, cli: &Cli) -> Result<Builder, ConfigError> {
// Set default values
let config = config
.set_default("v1.port", 3000_u16)?
.set_default("v1.pj_endpoint", "https://localhost:3000")
.set_default("v1.pj_endpoint", "https://localhost:3000")?;

// Override config values with command line arguments if applicable
let pj_endpoint = cli.pj_endpoint.as_ref().map(|s| s.as_str());

config
.set_override_option("v1.port", cli.port)?
.set_override_option("v1.pj_endpoint", pj_endpoint)
}

/// Set up default values and CLI overrides for v2-specific settings
#[cfg(feature = "v2")]
fn add_v2_defaults(builder: Builder, matches: &ArgMatches) -> Result<Builder, ConfigError> {
builder
.set_override_option(
"v2.ohttp_relays",
matches
.get_many::<Url>("ohttp_relays")
.map(|val| val.map(|s| s.as_str()).collect::<Vec<_>>()),
)?
fn add_v2_defaults(config: Builder, cli: &Cli) -> Result<Builder, ConfigError> {
// Set default values
let config = config
.set_default("v2.pj_directory", "https://payjo.in")?
.set_default("v2.ohttp_keys", None::<String>)
.set_default("v2.ohttp_keys", None::<String>)?;

// Override config values with command line arguments if applicable
let pj_directory = cli.pj_directory.as_ref().map(|s| s.as_str());
let ohttp_keys = cli.ohttp_keys.as_ref().map(|p| p.to_string_lossy().into_owned());
let ohttp_relays = cli
.ohttp_relays
.as_ref()
.map(|urls| urls.iter().map(|url| url.as_str()).collect::<Vec<_>>());

config
.set_override_option("v2.pj_directory", pj_directory)?
.set_override_option("v2.ohttp_keys", ohttp_keys)?
.set_override_option("v2.ohttp_relays", ohttp_relays)
}

/// Handles configuration overrides based on CLI subcommands
fn handle_subcommands(builder: Builder, matches: &ArgMatches) -> Result<Builder, ConfigError> {
match matches.subcommand() {
Some(("send", _)) => Ok(builder),
Some(("receive", matches)) => {
let builder = handle_receive_command(builder, matches)?;
let max_fee_rate = matches.get_one::<FeeRate>("max_fee_rate");
builder.set_override_option("max_fee_rate", max_fee_rate.map(|f| f.to_string()))
fn handle_subcommands(config: Builder, cli: &Cli) -> Result<Builder, ConfigError> {
match &cli.command {
Commands::Send { .. } => Ok(config),
Commands::Receive {
#[cfg(feature = "v1")]
port,
#[cfg(feature = "v1")]
pj_endpoint,
#[cfg(feature = "v2")]
pj_directory,
#[cfg(feature = "v2")]
ohttp_keys,
..
} => {
#[cfg(feature = "v1")]
let config = config
.set_override_option("v1.port", port.map(|p| p.to_string()))?
.set_override_option("v1.pj_endpoint", pj_endpoint.as_ref().map(|s| s.as_str()))?;
#[cfg(feature = "v2")]
let config = config
.set_override_option("v2.pj_directory", pj_directory.as_ref().map(|s| s.as_str()))?
.set_override_option(
"v2.ohttp_keys",
ohttp_keys.as_ref().map(|s| s.to_string_lossy().into_owned()),
)?;
Ok(config)
}
#[cfg(feature = "v2")]
Some(("resume", _)) => Ok(builder),
_ => unreachable!(), // If all subcommands are defined above, anything else is unreachabe!()
Commands::Resume => Ok(config),
}
}

/// Handle configuration overrides specific to the receive command
fn handle_receive_command(builder: Builder, matches: &ArgMatches) -> Result<Builder, ConfigError> {
#[cfg(feature = "v1")]
let builder = {
let port = matches
.get_one::<String>("port")
.map(|port| port.parse::<u16>())
.transpose()
.map_err(|_| ConfigError::Message("\"port\" must be a valid number".to_string()))?;
builder.set_override_option("v1.port", port)?.set_override_option(
"v1.pj_endpoint",
matches.get_one::<Url>("pj_endpoint").map(|s| s.as_str()),
)?
};

#[cfg(feature = "v2")]
let builder = {
builder
.set_override_option(
"v2.pj_directory",
matches.get_one::<Url>("pj_directory").map(|s| s.as_str()),
)?
.set_override_option(
"v2.ohttp_keys",
matches.get_one::<String>("ohttp_keys").map(|s| s.as_str()),
)?
};

Ok(builder)
}

#[cfg(feature = "v2")]
fn deserialize_ohttp_keys_from_path<'de, D>(
deserializer: D,
Expand Down
Loading