From 116e527a2ff41849dff4a4937d2a57546da4e263 Mon Sep 17 00:00:00 2001 From: William Smith Date: Wed, 12 Jun 2024 09:08:37 -0700 Subject: [PATCH] [Traffic Control] Make it play nicely with proxy infra (#17854) ## Description We received reports from a community validator of error logs from failure to parse `x-forwarded-for` header value to a `SocketAddr`. Investigation revealed that this is because the validator is running HAProxy, which attempts to insert the header itself (and in this case the value inserted was `-`). We are not using this value anyway, and even if we were, we would not want to be reading the value inserted by HAProxy as this would correspond to the fullnode rather than the client. Note that this also should obviate the fact that running TrafficController or related on a validator that is running HAProxy can have bad unintended side effects - namely that a spammy fullnode may cause a validator running HAProxy to block its HAProxy instance rather than the fullnode itself. To fix all of these issues, this PR introduces configuration of `client-id-source`, which by default will select the `socket-addr` (as today) to be treated as the "connection". Alternatively, the node operator can select `x-forwarded-for`, in which case we will search for this header key and use its contents to determine the client connection. Note that In the x-forwarded-for case, we have seen that this can be written as a domain name rather than an IP by load balancers such as HAProxy, however resolving this to IP address from our side would be very expensive, so any node that configures `x-forwarded-for` source type must configure their proxy to also fully resolve the domain name and write this to the header, otherwise we will skip traffic control. Note that in addition to infra proxies such as HAProxy, traffic controller also has the concept of proxy, which from the perspective of a validator would be a fullnode, as it is forwarding a client request to the entire committee. This is a different issue entirely, and as of yet is unsupported. Nevertheless, to reduce the chance for confusion, there is some liberal renaming in this PR. TODOS * Handle `client-id-source: x-forwarded-for` on json rpc side. Currently we say "unsupported" and skip traffic control. * Add tests for when `client-id-source: x-forwarded-for` is set (via both IP and domain name) * Unit tests for domain name parsing ## Test plan Existing tests. More unit tests to come for the x-forwarded-for header case. --- ## Release notes Check each box that your changes affect. If none of the boxes relate to your changes, release notes aren't required. For each box you select, include information after the relevant heading that describes the impact of your changes that a user might notice and any actions they must take to implement updates. - [ ] Protocol: - [ ] Nodes (Validators and Full nodes): - [ ] Indexer: - [ ] JSON-RPC: - [ ] GraphQL: - [ ] CLI: - [ ] Rust SDK: --- crates/sui-core/src/authority_server.rs | 135 +++++++++++------- crates/sui-core/src/traffic_controller/mod.rs | 95 ++++++------ .../src/traffic_controller/policies.rs | 125 ++++++++-------- .../tests/traffic_control_tests.rs | 34 ++--- crates/sui-json-rpc/src/axum_router.rs | 33 +++-- crates/sui-types/src/traffic_control.rs | 49 +++++-- 6 files changed, 265 insertions(+), 206 deletions(-) diff --git a/crates/sui-core/src/authority_server.rs b/crates/sui-core/src/authority_server.rs index f4a4c82c64428..387898b4bf272 100644 --- a/crates/sui-core/src/authority_server.rs +++ b/crates/sui-core/src/authority_server.rs @@ -11,7 +11,12 @@ use prometheus::{ register_int_counter_vec_with_registry, register_int_counter_with_registry, IntCounter, IntCounterVec, Registry, }; -use std::{io, net::SocketAddr, sync::Arc, time::SystemTime}; +use std::{ + io, + net::{IpAddr, SocketAddr}, + sync::Arc, + time::SystemTime, +}; use sui_network::{ api::{Validator, ValidatorServer}, tonic, @@ -25,7 +30,7 @@ use sui_types::messages_grpc::{ }; use sui_types::multiaddr::Multiaddr; use sui_types::sui_system_state::SuiSystemState; -use sui_types::traffic_control::{PolicyConfig, RemoteFirewallConfig, Weight}; +use sui_types::traffic_control::{ClientIdSource, PolicyConfig, RemoteFirewallConfig, Weight}; use sui_types::{error::*, transaction::*}; use sui_types::{ fp_ensure, @@ -135,6 +140,7 @@ impl AuthorityServer { consensus_adapter: self.consensus_adapter, metrics: self.metrics.clone(), traffic_controller: None, + client_id_source: None, })) .bind(&address) .await @@ -167,6 +173,7 @@ pub struct ValidatorServiceMetrics { connection_ip_not_found: IntCounter, forwarded_header_parse_error: IntCounter, forwarded_header_invalid: IntCounter, + forwarded_header_not_included: IntCounter, } impl ValidatorServiceMetrics { @@ -257,6 +264,12 @@ impl ValidatorServiceMetrics { registry, ) .unwrap(), + forwarded_header_not_included: register_int_counter_with_registry!( + "validator_service_forwarded_header_not_included", + "Number of times x-forwarded-for header was (unexpectedly) not included in request", + registry, + ) + .unwrap(), } } @@ -272,6 +285,7 @@ pub struct ValidatorService { consensus_adapter: Arc, metrics: Arc, traffic_controller: Option>, + client_id_source: Option, } impl ValidatorService { @@ -287,13 +301,14 @@ impl ValidatorService { state, consensus_adapter, metrics: validator_metrics, - traffic_controller: policy_config.map(|policy| { + traffic_controller: policy_config.clone().map(|policy| { Arc::new(TrafficController::spawn( policy, traffic_controller_metrics, firewall_config, )) }), + client_id_source: policy_config.map(|policy| policy.client_id_source), } } @@ -326,6 +341,7 @@ impl ValidatorService { consensus_adapter, metrics, traffic_controller: _, + client_id_source: _, } = self.clone(); let transaction = request.into_inner(); let epoch_store = state.load_epoch_store_one_call_per_task(); @@ -707,15 +723,9 @@ impl ValidatorService { Ok(tonic::Response::new(response)) } - async fn handle_traffic_req( - &self, - connection_ip: Option, - proxy_ip: Option, - ) -> Result<(), tonic::Status> { + async fn handle_traffic_req(&self, client: Option) -> Result<(), tonic::Status> { if let Some(traffic_controller) = &self.traffic_controller { - let connection = connection_ip.map(|ip| ip.ip()); - let proxy = proxy_ip.map(|ip| ip.ip()); - if !traffic_controller.check(connection, proxy).await { + if !traffic_controller.check(&client, &None).await { // Entity in blocklist Err(tonic::Status::from_error(SuiError::TooManyRequests.into())) } else { @@ -728,8 +738,7 @@ impl ValidatorService { fn handle_traffic_resp( &self, - connection_ip: Option, - proxy_ip: Option, + client: Option, response: &Result, tonic::Status>, ) { let error: Option = if let Err(status) = response { @@ -740,8 +749,8 @@ impl ValidatorService { if let Some(traffic_controller) = self.traffic_controller.clone() { traffic_controller.tally(TrafficTally { - connection_ip: connection_ip.map(|ip| ip.ip()), - proxy_ip: proxy_ip.map(|ip| ip.ip()), + direct: client, + through_fullnode: None, error_weight: error.map(normalize).unwrap_or(Weight::zero()), timestamp: SystemTime::now(), }) @@ -781,57 +790,75 @@ fn normalize(err: SuiError) -> Weight { #[macro_export] macro_rules! handle_with_decoration { ($self:ident, $func_name:ident, $request:ident) => {{ - // extract IP info. Note that in addition to extracting the client IP from - // the request header, we also get the remote address in case we need to - // throttle a fullnode, or an end user is running a local quorum driver. - let connection_ip: Option = $request.remote_addr(); - - // We will hit this case if the IO type used does not - // implement Connected or when using a unix domain socket. - // TODO: once we have confirmed that no legitimate traffic - // is hitting this case, we should reject such requests that - // hit this case. - if connection_ip.is_none() { - if cfg!(msim) { - // Ignore the error from simtests. - } else if cfg!(test) { - panic!("Failed to get remote address from request"); - } else { - $self.metrics.connection_ip_not_found.inc(); - error!("Failed to get remote address from request"); - } + if $self.client_id_source.is_none() { + return $self.$func_name($request).await; } - let proxy_ip: Option = - if let Some(op) = $request.metadata().get("x-forwarded-for") { - match op.to_str() { - Ok(ip) => match ip.parse() { - Ok(ret) => Some(ret), + let client = match $self.client_id_source.as_ref().unwrap() { + ClientIdSource::SocketAddr => { + let socket_addr: Option = $request.remote_addr(); + + // We will hit this case if the IO type used does not + // implement Connected or when using a unix domain socket. + // TODO: once we have confirmed that no legitimate traffic + // is hitting this case, we should reject such requests that + // hit this case. + if let Some(socket_addr) = socket_addr { + Some(socket_addr.ip()) + } else { + if cfg!(msim) { + // Ignore the error from simtests. + } else if cfg!(test) { + panic!("Failed to get remote address from request"); + } else { + $self.metrics.connection_ip_not_found.inc(); + error!("Failed to get remote address from request"); + } + None + } + } + ClientIdSource::XForwardedFor => { + if let Some(op) = $request.metadata().get("x-forwarded-for") { + match op.to_str() { + Ok(header_val) => { + match header_val.parse::() { + Ok(socket_addr) => Some(socket_addr.ip()), + Err(err) => { + $self.metrics.forwarded_header_parse_error.inc(); + error!( + "Failed to parse x-forwarded-for header value of {:?} to ip address: {:?}. \ + Please ensure that your proxy is configured to resolve client domains to an \ + IP address before writing header", + header_val, + err, + ); + None + } + } + } Err(e) => { - $self.metrics.forwarded_header_parse_error.inc(); - error!("Failed to parse x-forwarded-for header value to SocketAddr: {:?}", e); + // TODO: once we have confirmed that no legitimate traffic + // is hitting this case, we should reject such requests that + // hit this case. + $self.metrics.forwarded_header_invalid.inc(); + error!("Invalid UTF-8 in x-forwarded-for header: {:?}", e); None } - }, - Err(e) => { - // TODO: once we have confirmed that no legitimate traffic - // is hitting this case, we should reject such requests that - // hit this case. - $self.metrics.forwarded_header_invalid.inc(); - error!("Invalid UTF-8 in x-forwarded-for header: {:?}", e); - None } + } else { + $self.metrics.forwarded_header_not_included.inc(); + error!("x-forwarded-header not present for request despite node configuring XForwardedFor tracking type"); + None } - } else { - None - }; + } + }; // check if either IP is blocked, in which case return early - $self.handle_traffic_req(connection_ip, proxy_ip).await?; + $self.handle_traffic_req(client.clone()).await?; // handle request let response = $self.$func_name($request).await; // handle response tallying - $self.handle_traffic_resp(connection_ip, proxy_ip, &response); + $self.handle_traffic_resp(client, &response); response }}; } diff --git a/crates/sui-core/src/traffic_controller/mod.rs b/crates/sui-core/src/traffic_controller/mod.rs index 11c9d6c8d5bde..933bb1b68f6c8 100644 --- a/crates/sui-core/src/traffic_controller/mod.rs +++ b/crates/sui-core/src/traffic_controller/mod.rs @@ -28,12 +28,12 @@ use tokio::sync::mpsc; use tokio::sync::mpsc::error::TrySendError; use tracing::{debug, error, info, warn}; -type BlocklistT = Arc>; +type Blocklist = Arc>; #[derive(Clone)] struct Blocklists { - connection_ips: BlocklistT, - proxy_ips: BlocklistT, + clients: Blocklist, + proxied_clients: Blocklist, } #[derive(Clone)] @@ -84,8 +84,8 @@ impl TrafficController { let ret = Self { tally_channel: tx, blocklists: Blocklists { - connection_ips: Arc::new(DashMap::new()), - proxy_ips: Arc::new(DashMap::new()), + clients: Arc::new(DashMap::new()), + proxied_clients: Arc::new(DashMap::new()), }, metrics: metrics.clone(), dry_run_mode: policy_config.dry_run, @@ -132,9 +132,9 @@ impl TrafficController { } /// Handle check with dry-run mode considered - pub async fn check(&self, connection_ip: Option, proxy_ip: Option) -> bool { + pub async fn check(&self, client: &Option, proxied_client: &Option) -> bool { match ( - self.check_impl(connection_ip, proxy_ip).await, + self.check_impl(client, proxied_client).await, self.dry_run_mode(), ) { // check succeeded @@ -142,8 +142,8 @@ impl TrafficController { // check failed while in dry-run mode (false, true) => { debug!( - "Dry run mode: Blocked request from connection IP {:?}, proxy IP: {:?}", - connection_ip, proxy_ip + "Dry run mode: Blocked request from client {:?}, proxied client: {:?}", + client, proxied_client ); self.metrics.num_dry_run_blocked_requests.inc(); true @@ -156,21 +156,22 @@ impl TrafficController { /// Returns true if the connection is allowed, false if it is blocked pub async fn check_impl( &self, - connection_ip: Option, - proxy_ip: Option, + client: &Option, + proxied_client: &Option, ) -> bool { - let connection_check = self.check_and_clear_blocklist( - connection_ip, - self.blocklists.connection_ips.clone(), + let client_check = self.check_and_clear_blocklist( + client, + self.blocklists.clients.clone(), &self.metrics.connection_ip_blocklist_len, ); - let proxy_check = self.check_and_clear_blocklist( - proxy_ip, - self.blocklists.proxy_ips.clone(), + let proxied_client_check = self.check_and_clear_blocklist( + proxied_client, + self.blocklists.proxied_clients.clone(), &self.metrics.proxy_ip_blocklist_len, ); - let (conn_check, proxy_check) = futures::future::join(connection_check, proxy_check).await; - conn_check && proxy_check + let (client_check, proxied_client_check) = + futures::future::join(client_check, proxied_client_check).await; + client_check && proxied_client_check } pub fn dry_run_mode(&self) -> bool { @@ -179,19 +180,19 @@ impl TrafficController { async fn check_and_clear_blocklist( &self, - ip: Option, - blocklist: BlocklistT, + client: &Option, + blocklist: Blocklist, blocklist_len_gauge: &IntGauge, ) -> bool { - let ip = match ip { - Some(ip) => ip, + let client = match client { + Some(client) => client, None => return true, }; let now = SystemTime::now(); // the below two blocks cannot be nested, otherwise we will deadlock // due to aquiring the lock on get, then holding across the remove let (should_block, should_remove) = { - match blocklist.get(&ip) { + match blocklist.get(client) { Some(expiration) if now >= *expiration => (false, true), None => (false, false), _ => (true, false), @@ -199,7 +200,7 @@ impl TrafficController { }; if should_remove { blocklist_len_gauge.dec(); - blocklist.remove(&ip); + blocklist.remove(client); } !should_block } @@ -360,39 +361,39 @@ async fn handle_policy_response( metrics: Arc, ) { let PolicyResponse { - block_connection_ip, - block_proxy_ip, + block_client, + block_proxied_client, } = response; let PolicyConfig { connection_blocklist_ttl_sec, proxy_blocklist_ttl_sec, .. } = policy_config; - if let Some(ip) = block_connection_ip { + if let Some(client) = block_client { if blocklists - .connection_ips + .clients .insert( - ip, + client, SystemTime::now() + Duration::from_secs(*connection_blocklist_ttl_sec), ) .is_none() { - // Only increment the metric if the IP was not already blocked - debug!("Blocking connection IP"); + // Only increment the metric if the client was not already blocked + debug!("Blocking client: {:?}", client); metrics.connection_ip_blocklist_len.inc(); } } - if let Some(ip) = block_proxy_ip { + if let Some(client) = block_proxied_client { if blocklists - .proxy_ips + .proxied_clients .insert( - ip, + client, SystemTime::now() + Duration::from_secs(*proxy_blocklist_ttl_sec), ) .is_none() { - // Only increment the metric if the IP was not already blocked - debug!("Blocking proxy IP"); + // Only increment the metric if the client was not already blocked + debug!("Blocking proxied client: {:?}", client); metrics.proxy_ip_blocklist_len.inc(); } } @@ -406,8 +407,8 @@ async fn delegate_policy_response( metrics: Arc, ) -> Result<(), reqwest::Error> { let PolicyResponse { - block_connection_ip, - block_proxy_ip, + block_client, + block_proxied_client, } = response; let PolicyConfig { connection_blocklist_ttl_sec, @@ -415,16 +416,16 @@ async fn delegate_policy_response( .. } = policy_config; let mut addresses = vec![]; - if let Some(ip) = block_connection_ip { - debug!("Delegating connection IP blocking to firewall"); + if let Some(client_id) = block_client { + debug!("Delegating client blocking to firewall"); addresses.push(BlockAddress { - source_address: ip.to_string(), + source_address: client_id.to_string(), destination_port, ttl: *connection_blocklist_ttl_sec, }); } - if let Some(ip) = block_proxy_ip { - debug!("Delegating proxy IP blocking to firewall"); + if let Some(ip) = block_proxied_client { + debug!("Delegating proxied client blocking to firewall"); addresses.push(BlockAddress { source_address: ip.to_string(), destination_port, @@ -603,15 +604,15 @@ impl TrafficSim { let start = Instant::now(); while start.elapsed() < duration { - let connection_ip = Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, task_num))); - let allowed = controller.check(connection_ip, None).await; + let client = Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, task_num))); + let allowed = controller.check(&client, &None).await; if allowed { if currently_blocked { total_time_blocked += time_blocked_start.elapsed(); currently_blocked = false; } controller.tally(TrafficTally::new( - connection_ip, + client, // TODO add proxy IP for testing None, // TODO add weight adjustment diff --git a/crates/sui-core/src/traffic_controller/policies.rs b/crates/sui-core/src/traffic_controller/policies.rs index 03700fd1c8c46..64ebfe8cbb59b 100644 --- a/crates/sui-core/src/traffic_controller/policies.rs +++ b/crates/sui-core/src/traffic_controller/policies.rs @@ -15,14 +15,15 @@ use std::time::{Instant, SystemTime}; use sui_types::traffic_control::{FreqThresholdConfig, PolicyConfig, PolicyType, Weight}; use tracing::info; +/// The type of request client. #[derive(Hash, Eq, PartialEq, Debug)] -enum IpType { - Connection, - Proxy, +enum ClientType { + Direct, + ThroughFullnode, } #[derive(Hash, Eq, PartialEq, Debug)] -struct SketchKey(IpAddr, IpType); +struct SketchKey(IpAddr, ClientType); pub struct TrafficSketch { /// Circular buffer Count Min Sketches representing a sliding window @@ -133,21 +134,21 @@ impl TrafficSketch { #[derive(Clone, Debug)] pub struct TrafficTally { - pub connection_ip: Option, - pub proxy_ip: Option, + pub direct: Option, + pub through_fullnode: Option, pub error_weight: Weight, pub timestamp: SystemTime, } impl TrafficTally { pub fn new( - connection_ip: Option, - proxy_ip: Option, + direct: Option, + through_fullnode: Option, error_weight: Weight, ) -> Self { Self { - connection_ip, - proxy_ip, + direct, + through_fullnode, error_weight, timestamp: SystemTime::now(), } @@ -156,8 +157,8 @@ impl TrafficTally { #[derive(Clone, Debug, Default)] pub struct PolicyResponse { - pub block_connection_ip: Option, - pub block_proxy_ip: Option, + pub block_client: Option, + pub block_proxied_client: Option, } pub trait Policy { @@ -225,16 +226,16 @@ impl TrafficControlPolicy { pub struct FreqThresholdPolicy { config: PolicyConfig, sketch: TrafficSketch, - connection_threshold: u64, - proxy_threshold: u64, + client_threshold: u64, + proxied_client_threshold: u64, } impl FreqThresholdPolicy { pub fn new( config: PolicyConfig, FreqThresholdConfig { - connection_threshold, - proxy_threshold, + client_threshold, + proxied_client_threshold, window_size_secs, update_interval_secs, sketch_capacity, @@ -252,28 +253,28 @@ impl FreqThresholdPolicy { Self { config, sketch, - connection_threshold, - proxy_threshold, + client_threshold, + proxied_client_threshold, } } fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse { - let block_connection_ip = if let Some(ip) = tally.connection_ip { - let key = SketchKey(ip, IpType::Connection); + let block_client = if let Some(source) = tally.direct { + let key = SketchKey(source, ClientType::Direct); self.sketch.increment_count(&key); - if self.sketch.get_request_rate(&key) >= self.connection_threshold as f64 { - Some(ip) + if self.sketch.get_request_rate(&key) >= self.client_threshold as f64 { + Some(source) } else { None } } else { None }; - let block_proxy_ip = if let Some(ip) = tally.proxy_ip { - let key = SketchKey(ip, IpType::Proxy); + let block_proxied_client = if let Some(source) = tally.through_fullnode { + let key = SketchKey(source, ClientType::ThroughFullnode); self.sketch.increment_count(&key); - if self.sketch.get_request_rate(&key) >= self.proxy_threshold as f64 { - Some(ip) + if self.sketch.get_request_rate(&key) >= self.proxied_client_threshold as f64 { + Some(source) } else { None } @@ -281,8 +282,8 @@ impl FreqThresholdPolicy { None }; PolicyResponse { - block_connection_ip, - block_proxy_ip, + block_client, + block_proxied_client, } } @@ -335,23 +336,23 @@ impl TestNConnIPPolicy { } fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse { - let ip = if let Some(ip) = tally.connection_ip { - ip + let client = if let Some(client) = tally.direct { + client } else { return PolicyResponse::default(); }; // increment the count for the IP let mut frequencies = self.frequencies.write(); - let count = frequencies.entry(tally.connection_ip.unwrap()).or_insert(0); + let count = frequencies.entry(client).or_insert(0); *count += 1; PolicyResponse { - block_connection_ip: if *count >= self.threshold { - Some(ip) + block_client: if *count >= self.threshold { + Some(client) } else { None }, - block_proxy_ip: None, + block_proxied_client: None, } } @@ -403,8 +404,8 @@ mod tests { let mut policy = TrafficControlPolicy::FreqThreshold(FreqThresholdPolicy::new( PolicyConfig::default(), FreqThresholdConfig { - connection_threshold: 5, - proxy_threshold: 2, + client_threshold: 5, + proxied_client_threshold: 2, window_size_secs: 5, update_interval_secs: 1, ..Default::default() @@ -414,20 +415,20 @@ mod tests { // same fullnode, thus have the same connection IP on // validator, but different proxy IPs let alice = TrafficTally { - connection_ip: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))), - proxy_ip: Some(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))), + direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))), + through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))), error_weight: Weight::zero(), timestamp: SystemTime::now(), }; let bob = TrafficTally { - connection_ip: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))), - proxy_ip: Some(IpAddr::V4(Ipv4Addr::new(4, 3, 2, 1))), + direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))), + through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(4, 3, 2, 1))), error_weight: Weight::zero(), timestamp: SystemTime::now(), }; let charlie = TrafficTally { - connection_ip: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))), - proxy_ip: Some(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))), + direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))), + through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))), error_weight: Weight::zero(), timestamp: SystemTime::now(), }; @@ -435,19 +436,19 @@ mod tests { // initial 2 tallies for alice, should not block for _ in 0..2 { let response = policy.handle_tally(alice.clone()); - assert_eq!(response.block_proxy_ip, None); - assert_eq!(response.block_connection_ip, None); + assert_eq!(response.block_proxied_client, None); + assert_eq!(response.block_client, None); } // meanwhile bob spams 10 requests at once and is blocked for _ in 0..9 { let response = policy.handle_tally(bob.clone()); - assert_eq!(response.block_connection_ip, None); - assert_eq!(response.block_proxy_ip, None); + assert_eq!(response.block_client, None); + assert_eq!(response.block_proxied_client, None); } let response = policy.handle_tally(bob.clone()); - assert_eq!(response.block_connection_ip, None); - assert_eq!(response.block_proxy_ip, bob.proxy_ip); + assert_eq!(response.block_client, None); + assert_eq!(response.block_proxied_client, bob.through_fullnode); // 2 more tallies, so far we are above 2 tallies // per second, but over the average window of 5 seconds @@ -455,40 +456,40 @@ mod tests { tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; for _ in 0..2 { let response = policy.handle_tally(alice.clone()); - assert_eq!(response.block_connection_ip, None); - assert_eq!(response.block_proxy_ip, None); + assert_eq!(response.block_client, None); + assert_eq!(response.block_proxied_client, None); } // bob is no longer blocked, as we moved past the bursty traffic // in the sliding window let _ = policy.handle_tally(bob.clone()); let response = policy.handle_tally(bob.clone()); - assert_eq!(response.block_connection_ip, None); - assert_eq!(response.block_proxy_ip, bob.proxy_ip); + assert_eq!(response.block_client, None); + assert_eq!(response.block_proxied_client, bob.through_fullnode); // close to threshold for alice, but still below tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; for i in 0..5 { let response = policy.handle_tally(alice.clone()); - assert_eq!(response.block_connection_ip, None, "Blocked at i = {}", i); - assert_eq!(response.block_proxy_ip, None); + assert_eq!(response.block_client, None, "Blocked at i = {}", i); + assert_eq!(response.block_proxied_client, None); } // should block alice now tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; let response = policy.handle_tally(alice.clone()); - assert_eq!(response.block_connection_ip, None); - assert_eq!(response.block_proxy_ip, alice.proxy_ip); + assert_eq!(response.block_client, None); + assert_eq!(response.block_proxied_client, alice.through_fullnode); // spam through charlie to block connection for i in 0..2 { let response = policy.handle_tally(charlie.clone()); - assert_eq!(response.block_connection_ip, None, "Blocked at i = {}", i); - assert_eq!(response.block_proxy_ip, None); + assert_eq!(response.block_client, None, "Blocked at i = {}", i); + assert_eq!(response.block_proxied_client, None); } // Now we block connection IP let response = policy.handle_tally(charlie.clone()); - assert_eq!(response.block_proxy_ip, None); - assert_eq!(response.block_connection_ip, charlie.connection_ip); + assert_eq!(response.block_proxied_client, None); + assert_eq!(response.block_client, charlie.direct); // Ensure that if we wait another second, we are no longer blocked // as the bursty first second has finally rotated out of the sliding @@ -496,8 +497,8 @@ mod tests { tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; for i in 0..3 { let response = policy.handle_tally(charlie.clone()); - assert_eq!(response.block_connection_ip, None, "Blocked at i = {}", i); - assert_eq!(response.block_proxy_ip, None); + assert_eq!(response.block_client, None, "Blocked at i = {}", i); + assert_eq!(response.block_proxied_client, None); } } diff --git a/crates/sui-e2e-tests/tests/traffic_control_tests.rs b/crates/sui-e2e-tests/tests/traffic_control_tests.rs index e54489fe4ccb4..df12c95bee009 100644 --- a/crates/sui-e2e-tests/tests/traffic_control_tests.rs +++ b/crates/sui-e2e-tests/tests/traffic_control_tests.rs @@ -32,13 +32,12 @@ async fn test_validator_traffic_control_noop() -> Result<(), anyhow::Error> { let policy_config = PolicyConfig { connection_blocklist_ttl_sec: 1, proxy_blocklist_ttl_sec: 5, - spam_policy_type: PolicyType::NoOp, // This should never be invoked when set as an error policy // as we are not sending requests that error error_policy_type: PolicyType::TestPanicOnInvocation, - channel_capacity: 100, dry_run: false, spam_sample_rate: Weight::one(), + ..Default::default() }; let network_config = ConfigBuilder::new_with_temp_dir() .with_policy_config(Some(policy_config)) @@ -56,13 +55,12 @@ async fn test_fullnode_traffic_control_noop() -> Result<(), anyhow::Error> { let policy_config = PolicyConfig { connection_blocklist_ttl_sec: 1, proxy_blocklist_ttl_sec: 5, - spam_policy_type: PolicyType::NoOp, // This should never be invoked when set as an error policy // as we are not sending requests that error error_policy_type: PolicyType::TestPanicOnInvocation, - channel_capacity: 100, spam_sample_rate: Weight::one(), dry_run: false, + ..Default::default() }; let test_cluster = TestClusterBuilder::new() .with_fullnode_policy_config(Some(policy_config)) @@ -80,9 +78,9 @@ async fn test_validator_traffic_control_ok() -> Result<(), anyhow::Error> { // This should never be invoked when set as an error policy // as we are not sending requests that error error_policy_type: PolicyType::TestPanicOnInvocation, - channel_capacity: 100, dry_run: false, spam_sample_rate: Weight::one(), + ..Default::default() }; let network_config = ConfigBuilder::new_with_temp_dir() .with_policy_config(Some(policy_config)) @@ -104,9 +102,9 @@ async fn test_fullnode_traffic_control_ok() -> Result<(), anyhow::Error> { // This should never be invoked when set as an error policy // as we are not sending requests that error error_policy_type: PolicyType::TestPanicOnInvocation, - channel_capacity: 100, spam_sample_rate: Weight::one(), dry_run: false, + ..Default::default() }; let test_cluster = TestClusterBuilder::new() .with_fullnode_policy_config(Some(policy_config)) @@ -126,8 +124,8 @@ async fn test_validator_traffic_control_dry_run() -> Result<(), anyhow::Error> { // This should never be invoked when set as an error policy // as we are not sending requests that error error_policy_type: PolicyType::TestPanicOnInvocation, - channel_capacity: 100, dry_run: true, + ..Default::default() }; let network_config = ConfigBuilder::new_with_temp_dir() .with_policy_config(Some(policy_config)) @@ -151,8 +149,8 @@ async fn test_fullnode_traffic_control_dry_run() -> Result<(), anyhow::Error> { // This should never be invoked when set as an error policy // as we are not sending requests that error error_policy_type: PolicyType::TestPanicOnInvocation, - channel_capacity: 100, dry_run: true, + ..Default::default() }; let test_cluster = TestClusterBuilder::new() .with_fullnode_policy_config(Some(policy_config)) @@ -169,7 +167,6 @@ async fn test_validator_traffic_control_spam_blocked() -> Result<(), anyhow::Err // Test that any N requests will cause an IP to be added to the blocklist. spam_policy_type: PolicyType::TestNConnIP(n - 1), spam_sample_rate: Weight::one(), - channel_capacity: 100, dry_run: false, ..Default::default() }; @@ -191,7 +188,6 @@ async fn test_fullnode_traffic_control_spam_blocked() -> Result<(), anyhow::Erro // Test that any N requests will cause an IP to be added to the blocklist. spam_policy_type: PolicyType::TestNConnIP(n - 1), spam_sample_rate: Weight::one(), - channel_capacity: 100, dry_run: false, ..Default::default() }; @@ -212,7 +208,6 @@ async fn test_validator_traffic_control_spam_delegated() -> Result<(), anyhow::E // Test that any N - 1 requests will cause an IP to be added to the blocklist. spam_policy_type: PolicyType::TestNConnIP(n - 1), spam_sample_rate: Weight::one(), - channel_capacity: 100, dry_run: false, ..Default::default() }; @@ -246,7 +241,6 @@ async fn test_fullnode_traffic_control_spam_delegated() -> Result<(), anyhow::Er // Test that any N - 1 requests will cause an IP to be added to the blocklist. spam_policy_type: PolicyType::TestNConnIP(n - 1), spam_sample_rate: Weight::one(), - channel_capacity: 100, dry_run: false, ..Default::default() }; @@ -273,7 +267,6 @@ async fn test_traffic_control_dead_mans_switch() -> Result<(), anyhow::Error> { connection_blocklist_ttl_sec: 3, spam_policy_type: PolicyType::TestNConnIP(10), spam_sample_rate: Weight::one(), - channel_capacity: 100, dry_run: false, ..Default::default() }; @@ -338,8 +331,8 @@ async fn test_traffic_control_manual_set_dead_mans_switch() -> Result<(), anyhow #[sim_test] async fn test_traffic_sketch_no_blocks() { let sketch_config = FreqThresholdConfig { - connection_threshold: 10_100, - proxy_threshold: 10_100, + client_threshold: 10_100, + proxied_client_threshold: 10_100, window_size_secs: 4, update_interval_secs: 1, ..Default::default() @@ -377,8 +370,8 @@ async fn test_traffic_sketch_no_blocks() { #[sim_test] async fn test_traffic_sketch_with_slow_blocks() { let sketch_config = FreqThresholdConfig { - connection_threshold: 9_900, - proxy_threshold: 9_900, + client_threshold: 9_900, + proxied_client_threshold: 9_900, window_size_secs: 4, update_interval_secs: 1, ..Default::default() @@ -416,8 +409,8 @@ async fn test_traffic_sketch_with_slow_blocks() { #[sim_test] async fn test_traffic_sketch_with_sampled_spam() { let sketch_config = FreqThresholdConfig { - connection_threshold: 4_500, - proxy_threshold: 4_500, + client_threshold: 4_500, + proxied_client_threshold: 4_500, window_size_secs: 4, update_interval_secs: 1, ..Default::default() @@ -426,10 +419,9 @@ async fn test_traffic_sketch_with_sampled_spam() { connection_blocklist_ttl_sec: 1, proxy_blocklist_ttl_sec: 1, spam_policy_type: PolicyType::FreqThreshold(sketch_config), - error_policy_type: PolicyType::NoOp, spam_sample_rate: Weight::new(0.5).unwrap(), - channel_capacity: 100, dry_run: false, + ..Default::default() }; let metrics = TrafficSim::run( policy, diff --git a/crates/sui-json-rpc/src/axum_router.rs b/crates/sui-json-rpc/src/axum_router.rs index a4bac3bc23c58..5700f708f37e9 100644 --- a/crates/sui-json-rpc/src/axum_router.rs +++ b/crates/sui-json-rpc/src/axum_router.rs @@ -1,6 +1,7 @@ // Copyright (c) Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 +use std::net::IpAddr; use std::time::SystemTime; use std::{net::SocketAddr, sync::Arc}; use sui_types::traffic_control::RemoteFirewallConfig; @@ -21,7 +22,9 @@ use serde_json::value::RawValue; use sui_core::traffic_controller::{ metrics::TrafficControllerMetrics, policies::TrafficTally, TrafficController, }; +use sui_types::traffic_control::ClientIdSource; use sui_types::traffic_control::{PolicyConfig, Weight}; +use tracing::error; use crate::routing_layer::RpcRouter; use sui_json_rpc_api::CLIENT_TARGET_API_VERSION_HEADER; @@ -39,6 +42,7 @@ pub struct JsonRpcService { methods: Methods, rpc_router: RpcRouter, traffic_controller: Option>, + client_id_source: Option, } impl JsonRpcService { @@ -55,13 +59,14 @@ impl JsonRpcService { rpc_router, logger, id_provider: Arc::new(RandomIntegerIdProvider), - traffic_controller: policy_config.map(|policy| { + traffic_controller: policy_config.clone().map(|policy| { Arc::new(TrafficController::spawn( policy, traffic_controller_metrics, remote_fw_config, )) }), + client_id_source: policy_config.map(|policy| policy.client_id_source), } } } @@ -139,11 +144,23 @@ async fn process_raw_request( raw_request: &str, client_addr: SocketAddr, ) -> MethodResponse { + let client = match service.client_id_source { + Some(ClientIdSource::SocketAddr) => Some(client_addr.ip()), + Some(ClientIdSource::XForwardedFor) => { + // TODO - implement this later. Will need to read header at axum layer. + error!( + "X-Forwarded-For client ID source not yet supported on json \ + rpc servers. Skipping traffic controller request handling.", + ); + None + } + None => None, + }; if let Ok(request) = serde_json::from_str::(raw_request) { // check if either IP is blocked, in which case return early if let Some(traffic_controller) = &service.traffic_controller { if let Err(blocked_response) = - handle_traffic_req(traffic_controller.clone(), client_addr).await + handle_traffic_req(traffic_controller.clone(), &client).await { return blocked_response; } @@ -152,7 +169,7 @@ async fn process_raw_request( // handle response tallying if let Some(traffic_controller) = &service.traffic_controller { - handle_traffic_resp(traffic_controller.clone(), client_addr, &response); + handle_traffic_resp(traffic_controller.clone(), client, &response); } response } else if let Ok(_batch) = serde_json::from_str::>(raw_request) { @@ -168,9 +185,9 @@ async fn process_raw_request( async fn handle_traffic_req( traffic_controller: Arc, - client_ip: SocketAddr, + client: &Option, ) -> Result<(), MethodResponse> { - if !traffic_controller.check(Some(client_ip.ip()), None).await { + if !traffic_controller.check(client, &None).await { // Entity in blocklist let err_obj = ErrorObject::borrowed(ErrorCode::ServerIsBusy.code(), &TOO_MANY_REQUESTS_MSG, None); @@ -182,13 +199,13 @@ async fn handle_traffic_req( fn handle_traffic_resp( traffic_controller: Arc, - client_ip: SocketAddr, + client: Option, response: &MethodResponse, ) { let error = response.error_code.map(ErrorCode::from); traffic_controller.tally(TrafficTally { - connection_ip: Some(client_ip.ip()), - proxy_ip: None, + direct: client, + through_fullnode: None, error_weight: error.map(normalize).unwrap_or(Weight::zero()), timestamp: SystemTime::now(), }); diff --git a/crates/sui-types/src/traffic_control.rs b/crates/sui-types/src/traffic_control.rs index 4014096aebdb7..c5d5d41116f0a 100644 --- a/crates/sui-types/src/traffic_control.rs +++ b/crates/sui-types/src/traffic_control.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use serde_with::serde_as; -use std::{fmt::Debug, path::PathBuf}; +use std::path::PathBuf; // These values set to loosely attempt to limit // memory usage for a single sketch to ~20MB @@ -16,6 +16,19 @@ use rand::distributions::Distribution; const TRAFFIC_SINK_TIMEOUT_SEC: u64 = 300; +/// The source that should be used to identify the client's +/// IP address. To be used to configure cases where a node has +/// infra running in front of the node that is separate from the +/// protocol, such as a load balancer. Note that this is not the +/// same as the client type (e.g a direct client vs a proxy client, +/// as in the case of a fullnode driving requests from many clients). +#[derive(Clone, Debug, Deserialize, Serialize, Default)] +pub enum ClientIdSource { + #[default] + SocketAddr, + XForwardedFor, +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct Weight(f32); @@ -83,10 +96,10 @@ fn default_drain_timeout() -> u64 { #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "kebab-case")] pub struct FreqThresholdConfig { - #[serde(default = "default_connection_threshold")] - pub connection_threshold: u64, - #[serde(default = "default_proxy_threshold")] - pub proxy_threshold: u64, + #[serde(default = "default_client_threshold")] + pub client_threshold: u64, + #[serde(default = "default_proxied_client_threshold")] + pub proxied_client_threshold: u64, #[serde(default = "default_window_size_secs")] pub window_size_secs: u64, #[serde(default = "default_update_interval_secs")] @@ -102,8 +115,8 @@ pub struct FreqThresholdConfig { impl Default for FreqThresholdConfig { fn default() -> Self { Self { - connection_threshold: default_connection_threshold(), - proxy_threshold: default_proxy_threshold(), + client_threshold: default_client_threshold(), + proxied_client_threshold: default_proxied_client_threshold(), window_size_secs: default_window_size_secs(), update_interval_secs: default_update_interval_secs(), sketch_capacity: default_sketch_capacity(), @@ -113,16 +126,17 @@ impl Default for FreqThresholdConfig { } } -fn default_connection_threshold() -> u64 { - // by default only block connection IPs with unreasonably - // high qps, as a single fullnode could be routing the vast - // majority of all traffic in normal operations. If used as a - // spam policy, all requests would count against this threshold - // within the window time. In practice this should always be set +fn default_client_threshold() -> u64 { + // by default only block client with unreasonably + // high qps, as a client could be a single fullnode proxying + // the majority of traffic from many behaving clients in normal + // operations. If used as a spam policy, all requests would + // count against this threshold within the window time. In + // practice this should always be set 1_000_000 } -fn default_proxy_threshold() -> u64 { +fn default_proxied_client_threshold() -> u64 { 10 } @@ -174,6 +188,8 @@ pub enum PolicyType { #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "kebab-case")] pub struct PolicyConfig { + #[serde(default = "default_client_id_source")] + pub client_id_source: ClientIdSource, #[serde(default = "default_connection_blocklist_ttl_sec")] pub connection_blocklist_ttl_sec: u64, #[serde(default)] @@ -193,6 +209,7 @@ pub struct PolicyConfig { impl Default for PolicyConfig { fn default() -> Self { Self { + client_id_source: default_client_id_source(), connection_blocklist_ttl_sec: 0, proxy_blocklist_ttl_sec: 0, spam_policy_type: PolicyType::NoOp, @@ -204,6 +221,10 @@ impl Default for PolicyConfig { } } +pub fn default_client_id_source() -> ClientIdSource { + ClientIdSource::SocketAddr +} + pub fn default_connection_blocklist_ttl_sec() -> u64 { 60 }