Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .cargo/mutants.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ exclude_re = [
"Iterator",
".*Error",

# ---------------------Crate-specific exculsions---------------------
# ---------------------Crate-specific exclusions ---------------------
# Timeout loops
# src/receive/v1/mod.rs
"interleave_shuffle", # Replacing index += 1 with index *= 1 in a loop causes a timeout due to an infinite loop

# Trivial mutations
# These exlusions are allowing code blocks to run with artithmetic involving zero and as a result are no-ops
# These exclusions are allowing code blocks to run with arithmetic involving zero and as a result are no-ops
# payjoin/src/core/send/mod.rs
"replace < with <= in PsbtContext::check_outputs",
"replace > with >= in PsbtContext::check_fees",
Expand All @@ -26,8 +26,9 @@ exclude_re = [

# Async SystemTime comparison
# checking if the system time is equal to the expiry is difficult to reasonably test
# payjoin/src/core/receive/v2/session.rs and payjoin/src/core/send/v2/session.rs
"replace > with >= in replay_event_log",
# payjoin/src/core/receive/v2/mod.rs
"replace < with <= in Receiver<Initialized>::apply_unchecked_from_payload",
"replace > with >= in Receiver<Initialized>::create_poll_request",
"replace > with >= in extract_err_req",
# payjoin/src/core/send/v2/mod.rs
Expand Down
2 changes: 1 addition & 1 deletion payjoin/src/core/receive/v2/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::receive::ProtocolError;
/// This is currently opaque type because we aren't sure which variants will stay.
/// You can only display it.
#[derive(Debug)]
pub struct SessionError(InternalSessionError);
pub struct SessionError(pub(super) InternalSessionError);

impl From<InternalSessionError> for SessionError {
fn from(value: InternalSessionError) -> Self { SessionError(value) }
Expand Down
5 changes: 0 additions & 5 deletions payjoin/src/core/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,6 @@ impl Receiver<Initialized> {
event: OriginalPayload,
reply_key: Option<HpkePublicKey>,
) -> Result<ReceiveSession, InternalReplayError> {
if self.state.context.expiry < SystemTime::now() {
// Session is expired, close the session
return Err(InternalReplayError::SessionExpired(self.state.context.expiry));
}

let new_state = Receiver {
state: UncheckedOriginalPayload {
original: event,
Expand Down
61 changes: 30 additions & 31 deletions payjoin/src/core/receive/v2/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use super::{ReceiveSession, SessionContext};
use crate::output_substitution::OutputSubstitution;
use crate::persist::SessionPersister;
use crate::receive::v2::{extract_err_req, SessionError};
use crate::receive::v2::{extract_err_req, InternalSessionError, SessionError};
use crate::receive::{common, JsonReply, OriginalPayload, PsbtContext};
use crate::{ImplementationError, IntoUrl, PjUri, Request};

Expand All @@ -17,7 +17,6 @@ impl std::fmt::Display for ReplayError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use InternalReplayError::*;
match &self.0 {
SessionExpired(expiry) => write!(f, "Session expired at {expiry:?}"),
InvalidStateAndEvent(state, event) => write!(
f,
"Invalid combination of state ({state:?}) and event ({event:?}) during replay",
Expand All @@ -34,8 +33,6 @@ impl From<InternalReplayError> for ReplayError {

#[derive(Debug)]
pub(crate) enum InternalReplayError {
/// Session expired
SessionExpired(SystemTime),
/// Invalid combination of state and event
InvalidStateAndEvent(Box<ReceiveSession>, Box<SessionEvent>),
/// Application storage error
Expand All @@ -48,6 +45,7 @@ pub fn replay_event_log<P>(persister: &P) -> Result<(ReceiveSession, SessionHist
where
P: SessionPersister,
P::SessionEvent: Into<SessionEvent> + Clone,
P::SessionEvent: From<SessionEvent>,
{
let logs = persister
.load()
Expand All @@ -68,6 +66,21 @@ where
})?;
}

let ctx =
history.session_context().expect("Session context should be present after the first event");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this expect makes more sense after we remove uninitlized as a session state #1014

if SystemTime::now() > ctx.expiry {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect this needs a mutants exclusion too, see #1036 (review)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy that. I excluded this pattern: ""replace > with >= in replay_event_log","

We'll see if that works. Thanks

// Session has expired: close the session and persist a fatal error
let err = SessionError(InternalSessionError::Expired(ctx.expiry));
persister
.save_event(SessionEvent::SessionInvalid(err.to_string(), None).into())
.map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?;
persister
.close()
.map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?;

return Ok((ReceiveSession::TerminalFailure, history));
}

Ok((receiver, history))
}

Expand Down Expand Up @@ -193,6 +206,8 @@ pub enum SessionEvent {

#[cfg(test)]
mod tests {
use std::time::Duration;

use payjoin_test_utils::{BoxError, EXAMPLE_URL};

use super::*;
Expand Down Expand Up @@ -324,34 +339,26 @@ mod tests {
}

#[test]
fn test_replaying_unchecked_proposal() -> Result<(), BoxError> {
let session_context = SHARED_CONTEXT.clone();
let original = original_from_test_vector();
let reply_key = Some(crate::HpkeKeyPair::gen_keypair().1);

fn test_replaying_session_creation_with_expired_session() -> Result<(), BoxError> {
let session_context = SessionContext {
expiry: SystemTime::now() - Duration::from_secs(1),
..SHARED_CONTEXT.clone()
};
let test = SessionHistoryTest {
events: vec![
SessionEvent::Created(session_context.clone()),
SessionEvent::UncheckedOriginalPayload((original.clone(), reply_key.clone())),
],
events: vec![SessionEvent::Created(session_context.clone())],
expected_session_history: SessionHistoryExpectedOutcome {
psbt_with_fee_contributions: None,
fallback_tx: None,
},
expected_receiver_state: ReceiveSession::UncheckedOriginalPayload(Receiver {
state: UncheckedOriginalPayload {
original,
session_context: SessionContext { reply_key, ..session_context },
},
}),
expected_receiver_state: ReceiveSession::TerminalFailure,
};
// TODO: should check for the expired error message off the session history
run_session_history_test(test)
}

#[test]
fn test_replaying_unchecked_proposal_expiry() {
let now = SystemTime::now();
let session_context = SessionContext { expiry: now, ..SHARED_CONTEXT.clone() };
fn test_replaying_unchecked_proposal() -> Result<(), BoxError> {
let session_context = SHARED_CONTEXT.clone();
let original = original_from_test_vector();
let reply_key = Some(crate::HpkeKeyPair::gen_keypair().1);

Expand All @@ -371,15 +378,7 @@ mod tests {
},
}),
};
let session_history = run_session_history_test(test);

match session_history {
Err(error) => assert_eq!(
error.to_string(),
ReplayError::from(InternalReplayError::SessionExpired(now)).to_string()
),
Ok(_) => panic!("Expected session expiry error, got success"),
}
run_session_history_test(test)
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions payjoin/src/core/send/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ mod test {
}

#[test]
fn test_extract_v2_success() -> Result<(), BoxError> {
fn test_create_v2_post_request_success() -> Result<(), BoxError> {
let sender = create_sender_context(SystemTime::now() + Duration::from_secs(60))?;
let ohttp_relay = EXAMPLE_URL.clone();
let result = sender.create_v2_post_request(ohttp_relay);
Expand All @@ -580,7 +580,7 @@ mod test {
}

#[test]
fn test_extract_v2_fails_when_expired() -> Result<(), BoxError> {
fn test_create_v2_post_request_fails_when_expired() -> Result<(), BoxError> {
let expected_error = "session expired at SystemTime";
let sender = create_sender_context(SystemTime::now() - Duration::from_secs(60))?;
let ohttp_relay = EXAMPLE_URL.clone();
Expand Down
56 changes: 56 additions & 0 deletions payjoin/src/core/send/v2/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub fn replay_event_log<P>(persister: &P) -> Result<(SendSession, SessionHistory
where
P: SessionPersister + Clone,
P::SessionEvent: Into<SessionEvent> + Clone,
P::SessionEvent: From<SessionEvent>,
{
let logs = persister
.load()
Expand All @@ -59,6 +60,18 @@ where
}
}

let pj_param = history.pj_param().expect("pj_param should be present");
if std::time::SystemTime::now() > pj_param.expiration() {
// Session has expired: close the session and persist a fatal error
persister
.save_event(SessionEvent::SessionInvalid("Session expired".to_string()).into())
.map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?;
persister
.close()
.map_err(|e| InternalReplayError::PersistenceFailure(ImplementationError::new(e)))?;

return Ok((SendSession::TerminalFailure, history));
}
Ok((sender, history))
}

Expand Down Expand Up @@ -190,6 +203,49 @@ mod tests {
assert_eq!(session_history.pj_param().cloned(), test.expected_session_history.pj_param);
}

#[test]
fn test_sender_session_history_with_expired_session() {
// TODO(armins): how can we reduce the boilerplate for these tests?
let psbt = PARSED_ORIGINAL_PSBT.clone();
let sender = SenderBuilder::new(
psbt.clone(),
Uri::try_from(PJ_URI)
.expect("Valid uri")
.assume_checked()
.check_pj_supported()
.expect("Payjoin to be supported"),
)
.build_recommended(FeeRate::BROADCAST_MIN)
.unwrap();
let reply_key = HpkeKeyPair::gen_keypair();
let endpoint = sender.endpoint().clone();
let fallback_tx = sender.psbt_ctx.original_psbt.clone().extract_tx_unchecked_fee_rate();
let id = crate::uri::ShortId::try_from(&b"12345670"[..]).expect("valid short id");
let pj_param = crate::uri::v2::PjParam::new(
endpoint,
id,
std::time::SystemTime::now() - std::time::Duration::from_secs(1),
crate::OhttpKeys(
ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).expect("valid key config"),
),
reply_key.1,
);
let with_reply_key = WithReplyKey {
pj_param: pj_param.clone(),
psbt_ctx: sender.psbt_ctx.clone(),
reply_key: reply_key.0,
};
let test = SessionHistoryTest {
events: vec![SessionEvent::CreatedReplyKey(with_reply_key)],
expected_session_history: SessionHistoryExpectedOutcome {
fallback_tx: Some(fallback_tx),
pj_param: Some(pj_param),
},
expected_sender_state: SendSession::TerminalFailure,
};
run_session_history_test(test);
}

#[test]
fn test_sender_session_history_with_reply_key_event() {
let psbt = PARSED_ORIGINAL_PSBT.clone();
Expand Down
Loading