diff --git a/chain/chain/src/runtime/mod.rs b/chain/chain/src/runtime/mod.rs index 477d7e886ac..1f684dd1c11 100644 --- a/chain/chain/src/runtime/mod.rs +++ b/chain/chain/src/runtime/mod.rs @@ -643,9 +643,25 @@ impl RuntimeAdapter for NightshadeRuntime { verify_signature: bool, epoch_id: &EpochId, current_protocol_version: ProtocolVersion, + receiver_congestion_info: Option, ) -> Result, Error> { let runtime_config = self.runtime_config_store.get_config(current_protocol_version); + if let Some(congestion_info) = receiver_congestion_info { + let congestion_control = CongestionControl::new( + runtime_config.congestion_control_config, + congestion_info.congestion_info, + congestion_info.missed_chunks_count, + ); + if !congestion_control.shard_accepts_transactions() { + let receiver_shard = + self.account_id_to_shard_uid(transaction.transaction.receiver_id(), epoch_id)?; + return Ok(Some(InvalidTxError::ShardCongested { + shard_id: receiver_shard.shard_id, + })); + } + } + if let Some(state_root) = state_root { let shard_uid = self.account_id_to_shard_uid(transaction.transaction.signer_id(), epoch_id)?; diff --git a/chain/chain/src/test_utils/kv_runtime.rs b/chain/chain/src/test_utils/kv_runtime.rs index 6b267550e81..7458b669a00 100644 --- a/chain/chain/src/test_utils/kv_runtime.rs +++ b/chain/chain/src/test_utils/kv_runtime.rs @@ -18,7 +18,7 @@ use near_primitives::account::{AccessKey, Account}; use near_primitives::apply::ApplyChunkReason; use near_primitives::block::Tip; use near_primitives::block_header::{Approval, ApprovalInner}; -use near_primitives::congestion_info::CongestionInfo; +use near_primitives::congestion_info::{CongestionInfo, ExtendedCongestionInfo}; use near_primitives::epoch_manager::block_info::BlockInfo; use near_primitives::epoch_manager::epoch_info::EpochInfo; use near_primitives::epoch_manager::EpochConfig; @@ -1089,6 +1089,7 @@ impl RuntimeAdapter for KeyValueRuntime { _verify_signature: bool, _epoch_id: &EpochId, _current_protocol_version: ProtocolVersion, + _receiver_congestion_info: Option, ) -> Result, Error> { Ok(None) } diff --git a/chain/chain/src/types.rs b/chain/chain/src/types.rs index 736b3279e33..2b7a3f4a7ff 100644 --- a/chain/chain/src/types.rs +++ b/chain/chain/src/types.rs @@ -417,6 +417,7 @@ pub trait RuntimeAdapter: Send + Sync { verify_signature: bool, epoch_id: &EpochId, current_protocol_version: ProtocolVersion, + receiver_congestion_info: Option, ) -> Result, Error>; /// Returns an ordered list of valid transactions from the pool up the given limits. diff --git a/chain/client/src/client.rs b/chain/client/src/client.rs index 7bac83771c5..94dca0f82f4 100644 --- a/chain/client/src/client.rs +++ b/chain/client/src/client.rs @@ -2270,7 +2270,8 @@ impl Client { ) -> Result { let head = self.chain.head()?; let me = self.validator_signer.as_ref().map(|vs| vs.validator_id()); - let cur_block_header = self.chain.head_header()?; + let cur_block = self.chain.get_head_block()?; + let cur_block_header = cur_block.header(); let transaction_validity_period = self.chain.transaction_validity_period; // here it is fine to use `cur_block_header` as it is a best effort estimate. If the transaction // were to be included, the block that the chunk points to will have height >= height of @@ -2285,12 +2286,23 @@ impl Client { } let gas_price = cur_block_header.next_gas_price(); let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(&head.last_block_hash)?; - + let receiver_shard = + self.epoch_manager.account_id_to_shard_id(tx.transaction.receiver_id(), &epoch_id)?; + let receiver_congestion_info = + cur_block.shards_congestion_info().get(&receiver_shard).copied(); let protocol_version = self.epoch_manager.get_epoch_protocol_version(&epoch_id)?; if let Some(err) = self .runtime_adapter - .validate_tx(gas_price, None, tx, true, &epoch_id, protocol_version) + .validate_tx( + gas_price, + None, + tx, + true, + &epoch_id, + protocol_version, + receiver_congestion_info, + ) .expect("no storage errors") { debug!(target: "client", tx_hash = ?tx.get_hash(), ?err, "Invalid tx during basic validation"); @@ -2322,7 +2334,15 @@ impl Client { }; if let Some(err) = self .runtime_adapter - .validate_tx(gas_price, Some(state_root), tx, false, &epoch_id, protocol_version) + .validate_tx( + gas_price, + Some(state_root), + tx, + false, + &epoch_id, + protocol_version, + receiver_congestion_info, + ) .expect("no storage errors") { debug!(target: "client", ?err, "Invalid tx"); diff --git a/chain/jsonrpc/res/rpc_errors_schema.json b/chain/jsonrpc/res/rpc_errors_schema.json index 93706ee0ada..627c84c208a 100644 --- a/chain/jsonrpc/res/rpc_errors_schema.json +++ b/chain/jsonrpc/res/rpc_errors_schema.json @@ -570,7 +570,8 @@ "ActionsValidation", "TransactionSizeExceeded", "InvalidTransactionVersion", - "StorageError" + "StorageError", + "ShardCongested" ], "props": {} }, @@ -773,6 +774,13 @@ "subtypes": [], "props": {} }, + "ShardCongested": { + "name": "ShardCongested", + "subtypes": [], + "props": { + "shard_id": "" + } + }, "SignerDoesNotExist": { "name": "SignerDoesNotExist", "subtypes": [], diff --git a/core/primitives/src/errors.rs b/core/primitives/src/errors.rs index 5fcd5a1e92b..fef4d3ec97e 100644 --- a/core/primitives/src/errors.rs +++ b/core/primitives/src/errors.rs @@ -213,6 +213,11 @@ pub enum InvalidTxError { InvalidTransactionVersion, // Error occurred during storage access StorageError(StorageError), + /// The receiver shard of the transaction is too congested to accept new + /// transactions at the moment. + ShardCongested { + shard_id: u32, + }, } impl From for InvalidTxError { @@ -620,6 +625,9 @@ impl Display for InvalidTxError { InvalidTxError::StorageError(error) => { write!(f, "Storage error: {}", error) } + InvalidTxError::ShardCongested { shard_id } => { + write!(f, "Shard {shard_id} is currently congested and rejects new transactions.") + } } } } diff --git a/integration-tests/src/tests/client/features/congestion_control.rs b/integration-tests/src/tests/client/features/congestion_control.rs index 65a29fa3670..5afa18fed25 100644 --- a/integration-tests/src/tests/client/features/congestion_control.rs +++ b/integration-tests/src/tests/client/features/congestion_control.rs @@ -1,3 +1,4 @@ +use assert_matches::assert_matches; use near_chain_configs::Genesis; use near_client::test_utils::TestEnv; use near_client::ProcessTxResponse; @@ -6,7 +7,9 @@ use near_o11y::testonly::init_test_logger; use near_parameters::{RuntimeConfig, RuntimeConfigStore}; use near_primitives::account::id::AccountId; use near_primitives::congestion_info::{CongestionControl, CongestionInfo}; -use near_primitives::errors::{ActionErrorKind, FunctionCallError, TxExecutionError}; +use near_primitives::errors::{ + ActionErrorKind, FunctionCallError, InvalidTxError, TxExecutionError, +}; use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::ShardLayout; use near_primitives::sharding::{ShardChunk, ShardChunkHeader}; @@ -292,7 +295,11 @@ fn submit_n_100tgas_fns(env: &mut TestEnv, n: u32, nonce: &mut u64, signer: &InM let fn_tx = new_fn_call_100tgas(nonce, signer, *block.hash()); // this only adds the tx to the pool, no chain progress is made let response = env.clients[0].process_tx(fn_tx, false, false); - assert_eq!(response, ProcessTxResponse::ValidTx); + match response { + ProcessTxResponse::ValidTx + | ProcessTxResponse::InvalidTx(InvalidTxError::ShardCongested { .. }) => (), + other => panic!("unexpected result from submitting tx: {other:?}"), + } } } @@ -565,3 +572,53 @@ fn measure_tx_limit( local_tx_included_with_congestion, ) } + +/// Test that RPC clients stop accepting transactions when the receiver is +/// congested. +#[test] +fn test_rpc_client_rejection() { + let sender_id: AccountId = "test0".parse().unwrap(); + let mut env = setup_runtime(sender_id.clone(), PROTOCOL_VERSION); + + // prepare a contract to call + setup_contract(&mut env); + + let signer = InMemorySigner::from_seed(sender_id.clone(), KeyType::ED25519, sender_id.as_str()); + let mut nonce = 10; + + // Check we can send transactions at the start. + let fn_tx = new_fn_call_100tgas( + &mut nonce, + &signer, + *env.clients[0].chain.head_header().unwrap().hash(), + ); + let response = env.clients[0].process_tx(fn_tx, false, false); + assert_eq!(response, ProcessTxResponse::ValidTx); + + // Congest the network with a burst of 100 PGas. + submit_n_100tgas_fns(&mut env, 1_000, &mut nonce, &signer); + + // Allow transactions to enter the chain and enough receipts to arrive at + // the receiver shard for it to become congested. + let tip = env.clients[0].chain.head().unwrap(); + for i in 1..10 { + env.produce_block(0, tip.height + i); + } + + // Check that congestion control rejects new transactions. + let fn_tx = new_fn_call_100tgas( + &mut nonce, + &signer, + *env.clients[0].chain.head_header().unwrap().hash(), + ); + let response = env.clients[0].process_tx(fn_tx, false, false); + + if ProtocolFeature::CongestionControl.enabled(PROTOCOL_VERSION) { + assert_matches!( + response, + ProcessTxResponse::InvalidTx(InvalidTxError::ShardCongested { .. }) + ); + } else { + assert_eq!(response, ProcessTxResponse::ValidTx); + } +} diff --git a/integration-tests/src/tests/client/resharding.rs b/integration-tests/src/tests/client/resharding.rs index 9bdf874c8ca..bffff1f1124 100644 --- a/integration-tests/src/tests/client/resharding.rs +++ b/integration-tests/src/tests/client/resharding.rs @@ -239,7 +239,7 @@ impl TestReshardingEnv { .validator_seats(num_validators) .real_stores() .epoch_managers_with_test_overrides(epoch_config_test_overrides) - .nightshade_runtimes(&genesis) + .nightshade_runtimes_congestion_control_disabled(&genesis) .track_all_shards() .build(); assert_eq!(env.validators.len(), num_validators); diff --git a/integration-tests/src/tests/client/sync_state_nodes.rs b/integration-tests/src/tests/client/sync_state_nodes.rs index 9b60c2098fb..d42b028d7ed 100644 --- a/integration-tests/src/tests/client/sync_state_nodes.rs +++ b/integration-tests/src/tests/client/sync_state_nodes.rs @@ -570,7 +570,7 @@ fn test_dump_epoch_missing_chunk_in_last_block() { .clients_count(2) .use_state_snapshots() .real_stores() - .nightshade_runtimes(&genesis) + .nightshade_runtimes_congestion_control_disabled(&genesis) .build(); let genesis_block = env.clients[0].chain.get_block_by_height(0).unwrap();