@@ -130,18 +130,27 @@ impl<'a, S: Deref> ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: Scor
130
130
131
131
impl < ' a , S : Deref > ScoreLookUp for ScorerAccountingForInFlightHtlcs < ' a , S > where S :: Target : ScoreLookUp {
132
132
type ScoreParams = <S :: Target as ScoreLookUp >:: ScoreParams ;
133
- fn channel_penalty_msat ( & self , short_channel_id : u64 , source : & NodeId , target : & NodeId , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
133
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , usage : ChannelUsage , score_params : & Self :: ScoreParams ) -> u64 {
134
+ let target = match candidate. target ( ) {
135
+ Some ( target) => target,
136
+ None => return self . scorer . channel_penalty_msat ( candidate, usage, score_params) ,
137
+ } ;
138
+ let short_channel_id = match candidate. short_channel_id ( ) {
139
+ Some ( short_channel_id) => short_channel_id,
140
+ None => return self . scorer . channel_penalty_msat ( candidate, usage, score_params) ,
141
+ } ;
142
+ let source = candidate. source ( ) ;
134
143
if let Some ( used_liquidity) = self . inflight_htlcs . used_liquidity_msat (
135
- source, target, short_channel_id
144
+ & source, & target, short_channel_id
136
145
) {
137
146
let usage = ChannelUsage {
138
147
inflight_htlc_msat : usage. inflight_htlc_msat . saturating_add ( used_liquidity) ,
139
148
..usage
140
149
} ;
141
150
142
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
151
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
143
152
} else {
144
- self . scorer . channel_penalty_msat ( short_channel_id , source , target , usage, score_params)
153
+ self . scorer . channel_penalty_msat ( candidate , usage, score_params)
145
154
}
146
155
}
147
156
}
@@ -1062,7 +1071,7 @@ impl<'a> CandidateRouteHop<'a> {
1062
1071
/// For `Blinded` and `OneHopBlinded` we return `None` because next hop is not known.
1063
1072
pub fn short_channel_id ( & self ) -> Option < u64 > {
1064
1073
match self {
1065
- CandidateRouteHop :: FirstHop { details, .. } => Some ( details. get_outbound_payment_scid ( ) . unwrap ( ) ) ,
1074
+ CandidateRouteHop :: FirstHop { details, .. } => details. get_outbound_payment_scid ( ) ,
1066
1075
CandidateRouteHop :: PublicHop { short_channel_id, .. } => Some ( * short_channel_id) ,
1067
1076
CandidateRouteHop :: PrivateHop { hint, .. } => Some ( hint. short_channel_id ) ,
1068
1077
CandidateRouteHop :: Blinded { .. } => None ,
@@ -1168,7 +1177,7 @@ impl<'a> CandidateRouteHop<'a> {
1168
1177
CandidateRouteHop :: PublicHop { source_node_id, .. } => * source_node_id,
1169
1178
CandidateRouteHop :: PrivateHop { hint, .. } => hint. src_node_id . into ( ) ,
1170
1179
CandidateRouteHop :: Blinded { hint, .. } => hint. 1 . introduction_node_id . into ( ) ,
1171
- CandidateRouteHop :: OneHopBlinded { hint, .. } => hint. 1 . introduction_node_id . into ( )
1180
+ CandidateRouteHop :: OneHopBlinded { hint, .. } => hint. 1 . introduction_node_id . into ( ) ,
1172
1181
}
1173
1182
}
1174
1183
/// Returns the target node id of this hop, if known.
@@ -1795,7 +1804,7 @@ where L::Target: Logger {
1795
1804
let mut num_ignored_htlc_minimum_msat_limit: u32 = 0 ;
1796
1805
1797
1806
macro_rules! add_entry {
1798
- // Adds entry which goes from $src_node_id to $dest_node_id over the $candidate hop.
1807
+ // Adds entry which goes from candidate.source() to candiadte.target() over the $candidate hop.
1799
1808
// $next_hops_fee_msat represents the fees paid for using all the channels *after* this one,
1800
1809
// since that value has to be transferred over this channel.
1801
1810
// Returns the contribution amount of $candidate if the channel caused an update to `targets`.
@@ -1811,7 +1820,7 @@ where L::Target: Logger {
1811
1820
// - for first and last hops early in get_route
1812
1821
let src_node_id = $candidate. source( ) ;
1813
1822
let dest_node_id = $candidate. target( ) . unwrap_or( maybe_dummy_payee_node_id) ;
1814
- if src_node_id != dest_node_id {
1823
+ if Some ( $candidate . source ( ) ) != $candidate . target ( ) {
1815
1824
let scid_opt = $candidate. short_channel_id( ) ;
1816
1825
let effective_capacity = $candidate. effective_capacity( ) ;
1817
1826
let htlc_maximum_msat = max_htlc_from_capacity( effective_capacity, channel_saturation_pow_half) ;
@@ -1976,9 +1985,10 @@ where L::Target: Logger {
1976
1985
inflight_htlc_msat: used_liquidity_msat,
1977
1986
effective_capacity,
1978
1987
} ;
1979
- let channel_penalty_msat = scid_opt. map_or( 0 ,
1980
- |scid| scorer. channel_penalty_msat( scid, & src_node_id, & dest_node_id,
1981
- channel_usage, score_params) ) ;
1988
+ let channel_penalty_msat =
1989
+ scorer. channel_penalty_msat( $candidate,
1990
+ channel_usage,
1991
+ score_params) ;
1982
1992
let path_penalty_msat = $next_hops_path_penalty_msat
1983
1993
. saturating_add( channel_penalty_msat) ;
1984
1994
let new_graph_node = RouteGraphNode {
@@ -1991,7 +2001,7 @@ where L::Target: Logger {
1991
2001
path_length_to_node,
1992
2002
} ;
1993
2003
1994
- // Update the way of reaching $src_node_id with the given short_channel_id (from $dest_node_id ),
2004
+ // Update the way of reaching $candidate.source() with the given short_channel_id (from $candidate.target() ),
1995
2005
// if this way is cheaper than the already known
1996
2006
// (considering the cost to "reach" this channel from the route destination,
1997
2007
// the cost of using this channel,
@@ -2285,7 +2295,7 @@ where L::Target: Logger {
2285
2295
effective_capacity : candidate. effective_capacity ( ) ,
2286
2296
} ;
2287
2297
let channel_penalty_msat = scorer. channel_penalty_msat (
2288
- hop . short_channel_id , & source , & target , channel_usage, score_params
2298
+ & candidate , channel_usage, score_params
2289
2299
) ;
2290
2300
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
2291
2301
. saturating_add ( channel_penalty_msat) ;
@@ -2649,7 +2659,6 @@ where L::Target: Logger {
2649
2659
let mut paths = Vec :: new ( ) ;
2650
2660
for payment_path in selected_route {
2651
2661
let mut hops = Vec :: with_capacity ( payment_path. hops . len ( ) ) ;
2652
- let mut prev_hop_node_id = our_node_id;
2653
2662
for ( hop, node_features) in payment_path. hops . iter ( )
2654
2663
. filter ( |( h, _) | h. candidate . short_channel_id ( ) . is_some ( ) )
2655
2664
{
@@ -2666,7 +2675,7 @@ where L::Target: Logger {
2666
2675
// an alias, in which case we don't take any chances here.
2667
2676
network_graph. node ( & hop. node_id ) . map_or ( false , |hop_node|
2668
2677
hop_node. channels . iter ( ) . any ( |scid| network_graph. channel ( * scid)
2669
- . map_or ( false , |c| c. as_directed_from ( & prev_hop_node_id ) . is_some ( ) ) )
2678
+ . map_or ( false , |c| c. as_directed_from ( & hop . candidate . source ( ) ) . is_some ( ) ) )
2670
2679
)
2671
2680
} ;
2672
2681
@@ -2679,8 +2688,6 @@ where L::Target: Logger {
2679
2688
cltv_expiry_delta : hop. candidate . cltv_expiry_delta ( ) ,
2680
2689
maybe_announced_channel,
2681
2690
} ) ;
2682
-
2683
- prev_hop_node_id = hop. node_id ;
2684
2691
}
2685
2692
let mut final_cltv_delta = final_cltv_expiry_delta;
2686
2693
let blinded_tail = payment_path. hops . last ( ) . and_then ( |( h, _) | {
@@ -2843,13 +2850,13 @@ fn build_route_from_hops_internal<L: Deref>(
2843
2850
2844
2851
impl ScoreLookUp for HopScorer {
2845
2852
type ScoreParams = ( ) ;
2846
- fn channel_penalty_msat ( & self , _short_channel_id : u64 , source : & NodeId , target : & NodeId ,
2853
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop ,
2847
2854
_usage : ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64
2848
2855
{
2849
2856
let mut cur_id = self . our_node_id ;
2850
2857
for i in 0 ..self . hop_ids . len ( ) {
2851
2858
if let Some ( next_id) = self . hop_ids [ i] {
2852
- if cur_id == * source && next_id == * target {
2859
+ if cur_id == candidate . source ( ) && Some ( next_id) == candidate . target ( ) {
2853
2860
return 0 ;
2854
2861
}
2855
2862
cur_id = next_id;
@@ -2925,6 +2932,8 @@ mod tests {
2925
2932
2926
2933
use core:: convert:: TryInto ;
2927
2934
2935
+ use super :: CandidateRouteHop ;
2936
+
2928
2937
fn get_channel_details ( short_channel_id : Option < u64 > , node_id : PublicKey ,
2929
2938
features : InitFeatures , outbound_capacity_msat : u64 ) -> channelmanager:: ChannelDetails {
2930
2939
channelmanager:: ChannelDetails {
@@ -6197,8 +6206,8 @@ mod tests {
6197
6206
}
6198
6207
impl ScoreLookUp for BadChannelScorer {
6199
6208
type ScoreParams = ( ) ;
6200
- fn channel_penalty_msat ( & self , short_channel_id : u64 , _ : & NodeId , _ : & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6201
- if short_channel_id == self . short_channel_id { u64:: max_value ( ) } else { 0 }
6209
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6210
+ if candidate . short_channel_id ( ) == Some ( self . short_channel_id ) { u64:: max_value ( ) } else { 0 }
6202
6211
}
6203
6212
}
6204
6213
@@ -6213,8 +6222,8 @@ mod tests {
6213
6222
6214
6223
impl ScoreLookUp for BadNodeScorer {
6215
6224
type ScoreParams = ( ) ;
6216
- fn channel_penalty_msat ( & self , _ : u64 , _ : & NodeId , target : & NodeId , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6217
- if * target == self . node_id { u64:: max_value ( ) } else { 0 }
6225
+ fn channel_penalty_msat ( & self , candidate : & CandidateRouteHop , _: ChannelUsage , _score_params : & Self :: ScoreParams ) -> u64 {
6226
+ if candidate . target ( ) == Some ( self . node_id ) { u64:: max_value ( ) } else { 0 }
6218
6227
}
6219
6228
}
6220
6229
@@ -6702,26 +6711,34 @@ mod tests {
6702
6711
} ;
6703
6712
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) , 123 ) ;
6704
6713
scorer_params. set_manual_penalty ( & NodeId :: from_pubkey ( & nodes[ 4 ] ) , 456 ) ;
6705
- assert_eq ! ( scorer. channel_penalty_msat( 42 , & NodeId :: from_pubkey( & nodes[ 3 ] ) , & NodeId :: from_pubkey( & nodes[ 4 ] ) , usage, & scorer_params) , 456 ) ;
6714
+ let network_graph = network_graph. read_only ( ) ;
6715
+ let channels = network_graph. channels ( ) ;
6716
+ let channel = channels. get ( & 5 ) . unwrap ( ) ;
6717
+ let info = channel. as_directed_from ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) . unwrap ( ) ;
6718
+ let candidate: CandidateRouteHop = CandidateRouteHop :: PublicHop {
6719
+ info : info. 0 ,
6720
+ short_channel_id : 5 ,
6721
+ source_node_id : NodeId :: from_pubkey ( & nodes[ 3 ] ) ,
6722
+ target_node_id : NodeId :: from_pubkey ( & nodes[ 4 ] ) ,
6723
+ } ;
6724
+ assert_eq ! ( scorer. channel_penalty_msat( & candidate, usage, & scorer_params) , 456 ) ;
6706
6725
6707
6726
// Then check we can get a normal route
6708
6727
let payment_params = PaymentParameters :: from_node_id ( nodes[ 10 ] , 42 ) ;
6709
6728
let route_params = RouteParameters :: from_payment_params_and_value (
6710
6729
payment_params, 100 ) ;
6711
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6730
+ let route = get_route ( & our_id, & route_params, & network_graph, None ,
6712
6731
Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6713
6732
assert ! ( route. is_ok( ) ) ;
6714
6733
6715
6734
// Then check that we can't get a route if we ban an intermediate node.
6716
6735
scorer_params. add_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6717
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6718
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6736
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6719
6737
assert ! ( route. is_err( ) ) ;
6720
6738
6721
6739
// Finally make sure we can route again, when we remove the ban.
6722
6740
scorer_params. remove_banned ( & NodeId :: from_pubkey ( & nodes[ 3 ] ) ) ;
6723
- let route = get_route ( & our_id, & route_params, & network_graph. read_only ( ) , None ,
6724
- Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6741
+ let route = get_route ( & our_id, & route_params, & network_graph, None , Arc :: clone ( & logger) , & scorer, & scorer_params, & random_seed_bytes) ;
6725
6742
assert ! ( route. is_ok( ) ) ;
6726
6743
}
6727
6744
0 commit comments