Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion payjoin-cli/src/app/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ pub struct Config {
pub bitcoind: BitcoindConfig,
#[serde(skip)]
pub version: Option<VersionConfig>,
#[cfg(feature = "_danger-local-https")]
pub root_certificate: Option<PathBuf>,
#[cfg(feature = "_danger-local-https")]
pub certificate_key: Option<PathBuf>,
}

impl Config {
Expand Down Expand Up @@ -134,14 +138,26 @@ impl Config {
max_fee_rate: built_config.get("max_fee_rate").ok(),
bitcoind: built_config.get("bitcoind")?,
version: None,
#[cfg(feature = "_danger-local-https")]
root_certificate: built_config.get("root_certificate").ok(),
#[cfg(feature = "_danger-local-https")]
certificate_key: built_config.get("certificate_key").ok(),
};

match version {
Version::One => {
#[cfg(feature = "v1")]
{
match built_config.get::<V1Config>("v1") {
Ok(v1) => config.version = Some(VersionConfig::V1(v1)),
Ok(v1) => {
if v1.pj_endpoint.port().is_none() != (v1.port == 0) {
return Err(ConfigError::Message(
"If --port is 0, --pj-endpoint may not have a port".to_owned(),
));
}

config.version = Some(VersionConfig::V1(v1))
}
Err(e) =>
return Err(ConfigError::Message(format!(
"Valid V1 configuration is required for BIP78 mode: {e}"
Expand Down Expand Up @@ -266,6 +282,18 @@ fn add_v2_defaults(config: Builder, cli: &Cli) -> Result<Builder, ConfigError> {

/// Handles configuration overrides based on CLI subcommands
fn handle_subcommands(config: Builder, cli: &Cli) -> Result<Builder, ConfigError> {
#[cfg(feature = "_danger-local-https")]
let config = {
config
.set_override_option(
"root_certificate",
Some(cli.root_certificate.as_ref().map(|s| s.to_string_lossy().into_owned())),
)?
.set_override_option(
"certificate_key",
Some(cli.certificate_key.as_ref().map(|s| s.to_string_lossy().into_owned())),
)?
};
match &cli.command {
Commands::Send { .. } => Ok(config),
Commands::Receive {
Expand Down
37 changes: 15 additions & 22 deletions payjoin-cli/src/app/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ pub(crate) mod v1;
#[cfg(feature = "v2")]
pub(crate) mod v2;

#[cfg(feature = "_danger-local-https")]
pub const LOCAL_CERT_FILE: &str = "localhost.der";

#[async_trait::async_trait]
pub trait App: Send + Sync {
fn new(config: Config) -> Result<Self>
Expand Down Expand Up @@ -56,29 +53,25 @@ pub trait App: Send + Sync {
}

#[cfg(feature = "_danger-local-https")]
fn http_agent() -> Result<reqwest::Client> { Ok(http_agent_builder()?.build()?) }
fn http_agent(config: &Config) -> Result<reqwest::Client> {
Ok(http_agent_builder(config.root_certificate.as_ref())?.build()?)
}

#[cfg(not(feature = "_danger-local-https"))]
fn http_agent() -> Result<reqwest::Client> { Ok(reqwest::Client::new()) }
fn http_agent(_config: &Config) -> Result<reqwest::Client> { Ok(reqwest::Client::new()) }

#[cfg(feature = "_danger-local-https")]
fn http_agent_builder() -> Result<reqwest::ClientBuilder> {
use rustls::pki_types::CertificateDer;
use rustls::RootCertStore;

let cert_der = read_local_cert()?;
let mut root_cert_store = RootCertStore::empty();
root_cert_store.add(CertificateDer::from(cert_der.as_slice()))?;
Ok(reqwest::ClientBuilder::new()
.use_rustls_tls()
.add_root_certificate(reqwest::tls::Certificate::from_der(cert_der.as_slice())?))
}

#[cfg(feature = "_danger-local-https")]
fn read_local_cert() -> Result<Vec<u8>> {
let mut local_cert_path = std::env::temp_dir();
local_cert_path.push(LOCAL_CERT_FILE);
Ok(std::fs::read(local_cert_path)?)
fn http_agent_builder(
root_cert_path: Option<&std::path::PathBuf>,
) -> Result<reqwest::ClientBuilder> {
let mut builder = reqwest::ClientBuilder::new().use_rustls_tls();

if let Some(root_cert_path) = root_cert_path {
let cert_der = std::fs::read(root_cert_path)?;
builder =
builder.add_root_certificate(reqwest::tls::Certificate::from_der(cert_der.as_slice())?)
}
Ok(builder)
}

async fn handle_interrupt(tx: watch::Sender<()>) {
Expand Down
82 changes: 45 additions & 37 deletions payjoin-cli/src/app/v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use payjoin::bitcoin::FeeRate;
use payjoin::receive::v1::{PayjoinProposal, UncheckedProposal};
use payjoin::receive::ReplyableError::{self, Implementation, V1};
use payjoin::send::v1::SenderBuilder;
use payjoin::{ImplementationError, Uri, UriExt};
use payjoin::{ImplementationError, IntoUrl, Uri, UriExt};
use tokio::net::TcpListener;
use tokio::sync::watch;

Expand All @@ -26,8 +26,6 @@ use super::wallet::BitcoindWallet;
use super::App as AppTrait;
use crate::app::{handle_interrupt, http_agent};
use crate::db::Database;
#[cfg(feature = "_danger-local-https")]
pub const LOCAL_CERT_FILE: &str = "localhost.der";

struct Headers<'a>(&'a hyper::HeaderMap);
impl payjoin::receive::v1::Headers for Headers<'_> {
Expand Down Expand Up @@ -70,7 +68,7 @@ impl AppTrait for App {
.build_recommended(fee_rate)
.with_context(|| "Failed to build payjoin request")?
.create_v1_post_request();
let http = http_agent()?;
let http = http_agent(&self.config)?;
let body = String::from_utf8(req.body.clone()).unwrap();
println!("Sending fallback request to {}", &req.url);
let response = http
Expand Down Expand Up @@ -99,16 +97,9 @@ impl AppTrait for App {

#[allow(clippy::incompatible_msrv)]
async fn receive_payjoin(&self, amount: Amount) -> Result<()> {
let pj_uri_string = self.construct_payjoin_uri(amount, None)?;
println!(
"Listening at {}. Configured to accept payjoin at BIP 21 Payjoin Uri:",
self.config.v1()?.port
);
println!("{}", pj_uri_string);

let mut interrupt = self.interrupt.clone();
tokio::select! {
res = self.start_http_server() => { res?; }
res = self.start_http_server(amount) => { res?; }
_ = interrupt.changed() => {
println!("Interrupted.");
}
Expand All @@ -123,36 +114,45 @@ impl AppTrait for App {
}

impl App {
fn construct_payjoin_uri(
&self,
amount: Amount,
fallback_target: Option<&str>,
) -> Result<String> {
fn construct_payjoin_uri(&self, amount: Amount, endpoint: impl IntoUrl) -> Result<String> {
let pj_receiver_address = self.wallet.get_new_address()?;
let pj_part = match fallback_target {
Some(target) => target,
None => self.config.v1()?.pj_endpoint.as_str(),
};
let pj_part = payjoin::Url::parse(pj_part)
.map_err(|e| anyhow!("Failed to parse pj_endpoint: {}", e))?;

let mut pj_uri = payjoin::receive::v1::build_v1_pj_uri(
&pj_receiver_address,
&pj_part,
endpoint,
payjoin::OutputSubstitution::Enabled,
)?;
pj_uri.amount = Some(amount);

Ok(pj_uri.to_string())
}

async fn start_http_server(&self) -> Result<()> {
let addr = SocketAddr::from(([0, 0, 0, 0], self.config.v1()?.port));
async fn start_http_server(&self, amount: Amount) -> Result<()> {
let port = self.config.v1()?.port;
let addr = SocketAddr::from(([0, 0, 0, 0], port));
let listener = TcpListener::bind(addr).await?;

let mut endpoint = self.config.v1()?.pj_endpoint.clone();

// If --port 0 is specified, a free port is chosen, so we need to set it
// on the endpoint which must not have a port.
if port == 0 {
endpoint
.set_port(Some(listener.local_addr()?.port()))
.expect("setting port must succeed");
}

let pj_uri_string = self.construct_payjoin_uri(amount, endpoint)?;
println!(
"Listening at {}. Configured to accept payjoin at BIP 21 Payjoin Uri:",
listener.local_addr()?
);
println!("{}", pj_uri_string);

let app = self.clone();

#[cfg(feature = "_danger-local-https")]
let tls_acceptor = Self::init_tls_acceptor()?;
let tls_acceptor = self.init_tls_acceptor()?;
while let Ok((stream, _)) = listener.accept().await {
let app = app.clone();
#[cfg(feature = "_danger-local-https")]
Expand All @@ -179,28 +179,36 @@ impl App {
}

#[cfg(feature = "_danger-local-https")]
fn init_tls_acceptor() -> Result<tokio_rustls::TlsAcceptor> {
use std::io::Write;

fn init_tls_acceptor(&self) -> Result<tokio_rustls::TlsAcceptor> {
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::ServerConfig;
use tokio_rustls::TlsAcceptor;

let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])?;
let cert_der = cert.serialize_der()?;
let mut local_cert_path = std::env::temp_dir();
local_cert_path.push(LOCAL_CERT_FILE);
let mut file = std::fs::File::create(local_cert_path)?;
file.write_all(&cert_der)?;
let key = PrivateKeyDer::try_from(cert.serialize_private_key_der())
let key_der = std::fs::read(
self.config
.certificate_key
.as_ref()
.expect("certificate key is required if listening with tls"),
)?;
let key = PrivateKeyDer::try_from(key_der.clone())
.map_err(|e| anyhow::anyhow!("Could not parse key: {}", e))?;

let cert_der = std::fs::read(
self.config
.root_certificate
.as_ref()
.expect("certificate key is required if listening with tls"),
)?;
let certs = vec![CertificateDer::from(cert_der)];

let mut server_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| anyhow::anyhow!("TLS error: {}", e))?;

server_config.alpn_protocols =
vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()];

Ok(TlsAcceptor::from(Arc::new(server_config)))
}

Expand Down
Loading
Loading