Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
343 changes: 214 additions & 129 deletions payjoin-cli/src/app/v2/mod.rs

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions payjoin-cli/src/db/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub(crate) enum Error {
Deserialize(serde_json::Error),
#[cfg(feature = "v2")]
NotFound(String),
#[cfg(feature = "v2")]
TryFromSlice(std::array::TryFromSliceError),
}

impl fmt::Display for Error {
Expand All @@ -27,6 +29,8 @@ impl fmt::Display for Error {
Error::Deserialize(e) => write!(f, "Deserialization failed: {e}"),
#[cfg(feature = "v2")]
Error::NotFound(key) => write!(f, "Key not found: {key}"),
#[cfg(feature = "v2")]
Error::TryFromSlice(e) => write!(f, "TryFromSlice failed: {e}"),
}
}
}
Expand Down
126 changes: 92 additions & 34 deletions payjoin-cli/src/db/v2.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,34 @@
use std::sync::Arc;
use std::time::SystemTime;

use bitcoincore_rpc::jsonrpc::serde_json;
use payjoin::persist::{Persister, Value};
use payjoin::receive::v2::{Receiver, ReceiverToken, WithContext};
use payjoin::bitcoin::hex::DisplayHex;
use payjoin::persist::{Persister, SessionPersister, Value};
use payjoin::receive::v2::SessionEvent;
use payjoin::send::v2::{Sender, SenderToken, WithReplyKey};
use serde::{Deserialize, Serialize};
use sled::Tree;
use url::Url;

use super::*;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct SessionWrapper<V> {
pub(crate) completed_at: Option<SystemTime>,
pub(crate) events: Vec<V>,
}

#[derive(Debug, Clone)]
pub struct SessionId([u8; 8]);

impl SessionId {
pub fn new(id: u64) -> Self { Self(id.to_be_bytes()) }
}

impl AsRef<[u8]> for SessionId {
fn as_ref(&self) -> &[u8] { self.0.as_ref() }
}

pub(crate) struct SenderPersister(Arc<Database>);
impl SenderPersister {
pub fn new(db: Arc<Database>) -> Self { Self(db) }
Expand Down Expand Up @@ -36,50 +56,88 @@ impl Persister<Sender<WithReplyKey>> for SenderPersister {
}
}

pub(crate) struct ReceiverPersister(Arc<Database>);
#[derive(Clone)]
pub(crate) struct ReceiverPersister {
db: Arc<Database>,
session_id: SessionId,
}
impl ReceiverPersister {
pub fn new(db: Arc<Database>) -> Self { Self(db) }
pub fn new(db: Arc<Database>) -> crate::db::Result<Self> {
let id = SessionId::new(db.0.generate_id()?);
let recv_tree = db.0.open_tree("recv_sessions")?;
let empty_session: SessionWrapper<SessionEvent> =
SessionWrapper { completed_at: None, events: vec![] };
let value = serde_json::to_vec(&empty_session).map_err(Error::Serialize)?;
recv_tree.insert(id.as_ref(), value.as_slice())?;
recv_tree.flush()?;

Ok(Self { db: db.clone(), session_id: id })
}

pub fn from_id(db: Arc<Database>, id: SessionId) -> crate::db::Result<Self> {
Ok(Self { db: db.clone(), session_id: id })
}
}

impl Persister<Receiver<WithContext>> for ReceiverPersister {
type Token = ReceiverToken;
type Error = crate::db::error::Error;
fn save(
&mut self,
value: Receiver<WithContext>,
) -> std::result::Result<ReceiverToken, Self::Error> {
let recv_tree = self.0 .0.open_tree("recv_sessions")?;
let key = value.key();
let value = serde_json::to_vec(&value).map_err(Error::Serialize)?;
recv_tree.insert(key.clone(), value.as_slice())?;
impl SessionPersister for ReceiverPersister {
type SessionEvent = SessionEvent;
type InternalStorageError = crate::db::error::Error;

fn save_event(
&self,
event: &SessionEvent,
) -> std::result::Result<(), Self::InternalStorageError> {
let recv_tree = self.db.0.open_tree("recv_sessions")?;
let key = self.session_id.as_ref();
let session =
recv_tree.get(key)?.ok_or(Error::NotFound(key.to_vec().to_lower_hex_string()))?;
let mut session_wrapper: SessionWrapper<SessionEvent> =
serde_json::from_slice(&session).map_err(Error::Deserialize)?;
session_wrapper.events.push(event.clone());
let value = serde_json::to_vec(&session_wrapper).map_err(Error::Serialize)?;
recv_tree.insert(key, value.as_slice())?;
recv_tree.flush()?;
Ok(key)
Ok(())
}
fn load(&self, key: ReceiverToken) -> std::result::Result<Receiver<WithContext>, Self::Error> {
let recv_tree = self.0 .0.open_tree("recv_sessions")?;
let value = recv_tree.get(key.as_ref())?.ok_or(Error::NotFound(key.to_string()))?;
serde_json::from_slice(&value).map_err(Error::Deserialize)

fn load(
&self,
) -> std::result::Result<Box<dyn Iterator<Item = SessionEvent>>, Self::InternalStorageError>
{
let recv_tree = self.db.0.open_tree("recv_sessions")?;
let session_wrapper = recv_tree.get(self.session_id.as_ref())?;
let value = session_wrapper.expect("key should exist");
let wrapper: SessionWrapper<SessionEvent> =
serde_json::from_slice(&value).map_err(Error::Deserialize)?;
Ok(Box::new(wrapper.events.into_iter()))
}

fn close(&self) -> std::result::Result<(), Self::InternalStorageError> {
let recv_tree = self.db.0.open_tree("recv_sessions")?;
let key = self.session_id.as_ref();
if let Some(existing) = recv_tree.get(key)? {
let mut wrapper: SessionWrapper<SessionEvent> =
serde_json::from_slice(&existing).map_err(Error::Deserialize)?;
wrapper.completed_at = Some(SystemTime::now());
let value = serde_json::to_vec(&wrapper).map_err(Error::Serialize)?;
recv_tree.insert(key, value.as_slice())?;
}
recv_tree.flush()?;
Ok(())
}
}

impl Database {
pub(crate) fn get_recv_sessions(&self) -> Result<Vec<Receiver<WithContext>>> {
pub(crate) fn get_recv_session_ids(&self) -> Result<Vec<SessionId>> {
let recv_tree = self.0.open_tree("recv_sessions")?;
let mut sessions = Vec::new();
let mut session_ids = Vec::new();
for item in recv_tree.iter() {
let (_, value) = item?;
let session: Receiver<WithContext> =
serde_json::from_slice(&value).map_err(Error::Deserialize)?;
sessions.push(session);
let (key, _) = item?;
session_ids.push(SessionId::new(u64::from_be_bytes(
key.as_ref().try_into().map_err(Error::TryFromSlice)?,
)));
}
Ok(sessions)
}

pub(crate) fn clear_recv_session(&self) -> Result<()> {
let recv_tree: Tree = self.0.open_tree("recv_sessions")?;
recv_tree.clear()?;
recv_tree.flush()?;
Ok(())
Ok(session_ids)
}

pub(crate) fn get_send_sessions(&self) -> Result<Vec<Sender<WithReplyKey>>> {
Expand Down
141 changes: 95 additions & 46 deletions payjoin-ffi/python/test/test_payjoin_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,20 @@

SelectParams("regtest")

class InMemoryReceiverPersister(ReceiverPersister):
def __init__(self):
super().__init__()
self.receivers = {}
class InMemoryReceiverSessionEventLog(JsonReceiverSessionPersister):
def __init__(self, id):
self.id = id
self.events = []
self.closed = False

def save(self, receiver: WithContext) -> ReceiverToken:
self.receivers[str(receiver.key())] = receiver.to_json()
def save(self, event: str):
self.events.append(event)

return receiver.key()
def load(self):
return self.events

def load(self, token: ReceiverToken) -> WithContext:
token = str(token)
if token not in self.receivers.keys():
raise ValueError(f"Token not found: {token}")
return WithContext.from_json(self.receivers[token])
def close(self):
self.closed = True

class InMemorySenderPersister(SenderPersister):
def __init__(self):
Expand All @@ -59,6 +58,79 @@ def setUpClass(cls):
cls.bitcoind = cls.env.get_bitcoind()
cls.receiver = cls.env.get_receiver()
cls.sender = cls.env.get_sender()

async def process_receiver_proposal(self, receiver: ReceiverTypeState, recv_persister: InMemoryReceiverSessionEventLog, ohttp_relay: Url) -> Optional[ReceiverTypeState]:
if receiver.is_WITH_CONTEXT():
res = await self.retrieve_receiver_proposal(receiver.inner, recv_persister, ohttp_relay)
if res is None:
return None
return res

if receiver.is_UNCHECKED_PROPOSAL():
return await self.process_unchecked_proposal(receiver.inner, recv_persister)
if receiver.is_MAYBE_INPUTS_OWNED():
return await self.process_maybe_inputs_owned(receiver.inner, recv_persister)
if receiver.is_MAYBE_INPUTS_SEEN():
return await self.process_maybe_inputs_seen(receiver.inner, recv_persister)
if receiver.is_OUTPUTS_UNKNOWN():
return await self.process_outputs_unknown(receiver.inner, recv_persister)
if receiver.is_WANTS_OUTPUTS():
return await self.process_wants_outputs(receiver.inner, recv_persister)
if receiver.is_WANTS_INPUTS():
return await self.process_wants_inputs(receiver.inner, recv_persister)
if receiver.is_PROVISIONAL_PROPOSAL():
return await self.process_provisional_proposal(receiver.inner, recv_persister)
if receiver.is_PAYJOIN_PROPOSAL():
return receiver

raise Exception(f"Unknown receiver state: {receiver}")


def create_receiver_context(self, receiver_address: bitcoinffi.Address, directory: Url, ohttp_keys: OhttpKeys, recv_persister: InMemoryReceiverSessionEventLog) -> WithContext:
receiver = UninitializedReceiver().create_session(address=receiver_address, directory=directory.as_string(), ohttp_keys=ohttp_keys, expire_after=None).save(recv_persister)
return receiver

async def retrieve_receiver_proposal(self, receiver: WithContext, recv_persister: InMemoryReceiverSessionEventLog, ohttp_relay: Url):
agent = httpx.AsyncClient()
request: RequestResponse = receiver.extract_req(ohttp_relay.as_string())
response = await agent.post(
url=request.request.url.as_string(),
headers={"Content-Type": request.request.content_type},
content=request.request.body
)
res = receiver.process_res(response.content, request.client_response).save(recv_persister)
if res.is_none():
return None
proposal = res.success()
return await self.process_unchecked_proposal(proposal, recv_persister)

async def process_unchecked_proposal(self, proposal: UncheckedProposal, recv_persister: InMemoryReceiverSessionEventLog) :
receiver = proposal.check_broadcast_suitability(None, MempoolAcceptanceCallback(self.receiver)).save(recv_persister)
return await self.process_maybe_inputs_owned(receiver, recv_persister)

async def process_maybe_inputs_owned(self, proposal: MaybeInputsOwned, recv_persister: InMemoryReceiverSessionEventLog):
maybe_inputs_owned = proposal.check_inputs_not_owned(IsScriptOwnedCallback(self.receiver)).save(recv_persister)
return await self.process_maybe_inputs_seen(maybe_inputs_owned, recv_persister)

async def process_maybe_inputs_seen(self, proposal: MaybeInputsSeen, recv_persister: InMemoryReceiverSessionEventLog):
outputs_unknown = proposal.check_no_inputs_seen_before(CheckInputsNotSeenCallback(self.receiver)).save(recv_persister)
return await self.process_outputs_unknown(outputs_unknown, recv_persister)

async def process_outputs_unknown(self, proposal: OutputsUnknown, recv_persister: InMemoryReceiverSessionEventLog):
wants_outputs = proposal.identify_receiver_outputs(IsScriptOwnedCallback(self.receiver)).save(recv_persister)
return await self.process_wants_outputs(wants_outputs, recv_persister)

async def process_wants_outputs(self, proposal: WantsOutputs, recv_persister: InMemoryReceiverSessionEventLog):
wants_inputs = proposal.commit_outputs().save(recv_persister)
return await self.process_wants_inputs(wants_inputs, recv_persister)

async def process_wants_inputs(self, proposal: WantsInputs, recv_persister: InMemoryReceiverSessionEventLog):
provisional_proposal = proposal.contribute_inputs(get_inputs(self.receiver)).commit_inputs().save(recv_persister)
return await self.process_provisional_proposal(provisional_proposal, recv_persister)

async def process_provisional_proposal(self, proposal: ProvisionalProposal, recv_persister: InMemoryReceiverSessionEventLog):
payjoin_proposal = proposal.finalize_proposal(ProcessPsbtCallback(self.receiver), 1, 10).save(recv_persister)
return ReceiverTypeState.PAYJOIN_PROPOSAL(payjoin_proposal)

async def test_integration_v2_to_v2(self):
try:
Expand All @@ -69,26 +141,16 @@ async def test_integration_v2_to_v2(self):
services.wait_for_services_ready()
directory = services.directory_url()
ohttp_keys = services.fetch_ohttp_keys()
ohttp_relay = services.ohttp_relay_url()
agent = httpx.AsyncClient()

# **********************
# Inside the Receiver:
new_receiver = NewReceiver(receiver_address, directory.as_string(), ohttp_keys, None)
persister = InMemoryReceiverPersister()
token = new_receiver.persist(persister)
session: WithContext = WithContext.load(token, persister)
recv_persister = InMemoryReceiverSessionEventLog(1)
session = self.create_receiver_context(receiver_address, directory, ohttp_keys, recv_persister)
process_response = await self.process_receiver_proposal(ReceiverTypeState.WITH_CONTEXT(session), recv_persister, ohttp_relay)
print(f"session: {session.to_json()}")
# Poll receive request
ohttp_relay = services.ohttp_relay_url()
request: RequestResponse = session.extract_req(ohttp_relay.as_string())
agent = httpx.AsyncClient()
response = await agent.post(
url=request.request.url.as_string(),
headers={"Content-Type": request.request.content_type},
content=request.request.body
)
response_body = session.process_res(response.content, request.client_response)
# No proposal yet since sender has not responded
self.assertIsNone(response_body)
self.assertIsNone(process_response)

# **********************
# Inside the Sender:
Expand All @@ -112,15 +174,11 @@ async def test_integration_v2_to_v2(self):
# Inside the Receiver:

# GET fallback psbt
request: RequestResponse = session.extract_req(ohttp_relay.as_string())
response = await agent.post(
url=request.request.url.as_string(),
headers={"Content-Type": request.request.content_type},
content=request.request.body
)
# POST payjoin
proposal = session.process_res(response.content, request.client_response)
payjoin_proposal = handle_directory_payjoin_proposal(self.receiver, proposal)
payjoin_proposal = await self.process_receiver_proposal(ReceiverTypeState.WITH_CONTEXT(session), recv_persister, ohttp_relay)
self.assertIsNotNone(payjoin_proposal)
self.assertEqual(payjoin_proposal.is_PAYJOIN_PROPOSAL(), True)

payjoin_proposal = payjoin_proposal.inner
request: RequestResponse = payjoin_proposal.extract_req(ohttp_relay.as_string())
response = await agent.post(
url=request.request.url.as_string(),
Expand Down Expand Up @@ -158,15 +216,6 @@ async def test_integration_v2_to_v2(self):
print("Caught:", e)
raise

def handle_directory_payjoin_proposal(receiver: Proxy, proposal: UncheckedProposal) -> PayjoinProposal:
maybe_inputs_owned = proposal.check_broadcast_suitability(None, MempoolAcceptanceCallback(receiver))
maybe_inputs_seen = maybe_inputs_owned.check_inputs_not_owned(IsScriptOwnedCallback(receiver))
outputs_unknown = maybe_inputs_seen.check_no_inputs_seen_before(CheckInputsNotSeenCallback(receiver))
wants_outputs = outputs_unknown.identify_receiver_outputs(IsScriptOwnedCallback(receiver))
wants_inputs = wants_outputs.commit_outputs()
provisional_proposal = wants_inputs.contribute_inputs(get_inputs(receiver)).commit_inputs()
return provisional_proposal.finalize_proposal(ProcessPsbtCallback(receiver), 1, 10)

def build_sweep_psbt(sender: Proxy, pj_uri: PjUri) -> bitcoinffi.Psbt:
outputs = {}
outputs[pj_uri.address()] = 50
Expand Down
Loading
Loading