Skip to content

Commit

Permalink
Merge pull request #853 from yihuang/fix-overflow
Browse files Browse the repository at this point in the history
Add methods to UInt128 for explicit overflow control
  • Loading branch information
webmaster128 authored Apr 1, 2021
2 parents 21adf0d + b441203 commit 53a7833
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 104 deletions.
26 changes: 13 additions & 13 deletions contracts/staking/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ pub fn transfer(
let sender_raw = deps.api.canonical_address(&info.sender)?;

let mut accounts = balances(deps.storage);
accounts.update(&sender_raw, |balance: Option<Uint128>| {
balance.unwrap_or_default() - send
accounts.update(&sender_raw, |balance: Option<Uint128>| -> StdResult<_> {
Ok(balance.unwrap_or_default().checked_sub(send)?)
})?;
accounts.update(&rcpt_raw, |balance: Option<Uint128>| -> StdResult<_> {
Ok(balance.unwrap_or_default() + send)
Expand Down Expand Up @@ -209,7 +209,7 @@ pub fn unbond(deps: DepsMut, env: Env, info: MessageInfo, amount: Uint128) -> St
// deduct all from the account
let mut accounts = balances(deps.storage);
accounts.update(&sender_raw, |balance| -> StdResult<_> {
balance.unwrap_or_default() - amount
Ok(balance.unwrap_or_default().checked_sub(amount)?)
})?;
if tax > Uint128(0) {
// add tax to the owner
Expand All @@ -223,14 +223,14 @@ pub fn unbond(deps: DepsMut, env: Env, info: MessageInfo, amount: Uint128) -> St
let bonded = get_bonded(&deps.querier, &env.contract.address)?;

// calculate how many native tokens this is worth and update supply
let remainder = (amount - tax)?;
let remainder = amount.checked_sub(tax)?;
let mut totals = total_supply(deps.storage);
let mut supply = totals.load()?;
// TODO: this is just temporary check - we should use dynamic query or have a way to recover
assert_bonds(&supply, bonded)?;
let unbond = remainder.multiply_ratio(bonded, supply.issued);
supply.bonded = (bonded - unbond)?;
supply.issued = (supply.issued - remainder)?;
supply.bonded = bonded.checked_sub(unbond)?;
supply.issued = supply.issued.checked_sub(remainder)?;
supply.claims += unbond;
totals.save(&supply)?;

Expand Down Expand Up @@ -273,15 +273,15 @@ pub fn claim(deps: DepsMut, env: Env, info: MessageInfo) -> StdResult<Response>
// check how much to send - min(balance, claims[sender]), and reduce the claim
let sender_raw = deps.api.canonical_address(&info.sender)?;
let mut to_send = balance.amount;
claims(deps.storage).update(sender_raw.as_slice(), |claim| {
claims(deps.storage).update(sender_raw.as_slice(), |claim| -> StdResult<_> {
let claim = claim.ok_or_else(|| StdError::generic_err("no claim for this address"))?;
to_send = to_send.min(claim);
claim - to_send
Ok(claim.checked_sub(to_send)?)
})?;

// update total supply (lower claim)
total_supply(deps.storage).update(|mut supply| -> StdResult<_> {
supply.claims = (supply.claims - to_send)?;
supply.claims = supply.claims.checked_sub(to_send)?;
Ok(supply)
})?;

Expand Down Expand Up @@ -353,15 +353,15 @@ pub fn _bond_all_tokens(
// we deduct pending claims from our account balance before reinvesting.
// if there is not enough funds, we just return a no-op
match total_supply(deps.storage).update(|mut supply| {
balance.amount = (balance.amount - supply.claims)?;
balance.amount = balance.amount.checked_sub(supply.claims)?;
// this just triggers the "no op" case if we don't have min_withdrawal left to reinvest
(balance.amount - invest.min_withdrawal)?;
balance.amount.checked_sub(invest.min_withdrawal)?;
supply.bonded += balance.amount;
Ok(supply)
}) {
Ok(_) => {}
// if it is below the minimum, we do a no-op (do not revert other state from withdrawal)
Err(StdError::Underflow { .. }) => return Ok(Response::default()),
Err(StdError::Overflow(_)) => return Ok(Response::default()),
Err(e) => return Err(e.into()),
}

Expand Down Expand Up @@ -739,7 +739,7 @@ mod tests {
let res = execute(deps.as_mut(), mock_env(), info, unbond_msg);
match res.unwrap_err() {
StakingError::Std {
original: StdError::Underflow { .. },
original: StdError::Overflow(_),
} => {}
err => panic!("Unexpected error: {:?}", err),
}
Expand Down
2 changes: 1 addition & 1 deletion packages/std/src/errors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ mod system_error;
mod verification_error;

pub use recover_pubkey_error::RecoverPubkeyError;
pub use std_error::{StdError, StdResult};
pub use std_error::{DivideByZeroError, OverflowError, OverflowOperation, StdError, StdResult};
pub use system_error::SystemError;
pub use verification_error::VerificationError;
143 changes: 92 additions & 51 deletions packages/std/src/errors/std_error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#[cfg(feature = "backtraces")]
use std::backtrace::Backtrace;
use std::fmt;
use thiserror::Error;

use crate::errors::{RecoverPubkeyError, VerificationError};
Expand Down Expand Up @@ -82,13 +83,10 @@ pub enum StdError {
#[cfg(feature = "backtraces")]
backtrace: Backtrace,
},
#[error("Cannot subtract {subtrahend} from {minuend}")]
Underflow {
minuend: String,
subtrahend: String,
#[cfg(feature = "backtraces")]
backtrace: Backtrace,
},
#[error(transparent)]
Overflow(#[from] OverflowError),
#[error(transparent)]
DivideByZero(#[from] DivideByZeroError),
}

impl StdError {
Expand Down Expand Up @@ -167,15 +165,6 @@ impl StdError {
backtrace: Backtrace::capture(),
}
}

pub fn underflow<U: ToString>(minuend: U, subtrahend: U) -> Self {
StdError::Underflow {
minuend: minuend.to_string(),
subtrahend: subtrahend.to_string(),
#[cfg(feature = "backtraces")]
backtrace: Backtrace::capture(),
}
}
}

impl PartialEq<StdError> for StdError {
Expand Down Expand Up @@ -331,20 +320,16 @@ impl PartialEq<StdError> for StdError {
false
}
}
StdError::Underflow {
minuend,
subtrahend,
#[cfg(feature = "backtraces")]
backtrace: _,
} => {
if let StdError::Underflow {
minuend: rhs_minuend,
subtrahend: rhs_subtrahend,
#[cfg(feature = "backtraces")]
backtrace: _,
} = rhs
{
minuend == rhs_minuend && subtrahend == rhs_subtrahend
StdError::Overflow(err) => {
if let StdError::Overflow(rhs_err) = rhs {
err == rhs_err
} else {
false
}
}
StdError::DivideByZero(err) => {
if let StdError::DivideByZero(rhs_err) = rhs {
err == rhs_err
} else {
false
}
Expand Down Expand Up @@ -384,6 +369,60 @@ impl From<RecoverPubkeyError> for StdError {
/// result/error type in cosmwasm-std.
pub type StdResult<T> = core::result::Result<T, StdError>;

#[derive(Error, Debug, PartialEq, Eq)]
pub enum OverflowOperation {
Add,
Sub,
Mul,
Pow,
}

impl fmt::Display for OverflowOperation {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self)
}
}

#[derive(Error, Debug, PartialEq, Eq)]
#[error("Cannot {operation} with {operand1} and {operand2}")]
pub struct OverflowError {
pub operation: OverflowOperation,
pub operand1: String,
pub operand2: String,
#[cfg(feature = "backtraces")]
backtrace: Backtrace,
}

impl OverflowError {
pub fn new<U: ToString>(operation: OverflowOperation, operand1: U, operand2: U) -> Self {
Self {
operation,
operand1: operand1.to_string(),
operand2: operand2.to_string(),
#[cfg(feature = "backtraces")]
backtrace: Backtrace::capture(),
}
}
}

#[derive(Error, Debug, PartialEq, Eq)]
#[error("Cannot devide {operand} by zero")]
pub struct DivideByZeroError {
pub operand: String,
#[cfg(feature = "backtraces")]
backtrace: Backtrace,
}

impl DivideByZeroError {
pub fn new<U: ToString>(operand: U) -> Self {
Self {
operand: operand.to_string(),
#[cfg(feature = "backtraces")]
backtrace: Backtrace::capture(),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -513,58 +552,60 @@ mod tests {

#[test]
fn underflow_works_for_u128() {
let error = StdError::underflow(123u128, 456u128);
let error = StdError::from(OverflowError::new(OverflowOperation::Sub, 123u128, 456u128));
match error {
StdError::Underflow {
minuend,
subtrahend,
StdError::Overflow(OverflowError {
operation: OverflowOperation::Sub,
operand1,
operand2,
..
} => {
assert_eq!(minuend, "123");
assert_eq!(subtrahend, "456");
}) => {
assert_eq!(operand1, "123");
assert_eq!(operand2, "456");
}
_ => panic!("expect different error"),
}
}

#[test]
fn underflow_works_for_i64() {
let error = StdError::underflow(777i64, 1234i64);
let error = StdError::from(OverflowError::new(OverflowOperation::Sub, 777i64, 1234i64));
match error {
StdError::Underflow {
minuend,
subtrahend,
StdError::Overflow(OverflowError {
operation: OverflowOperation::Sub,
operand1,
operand2,
..
} => {
assert_eq!(minuend, "777");
assert_eq!(subtrahend, "1234");
}) => {
assert_eq!(operand1, "777");
assert_eq!(operand2, "1234");
}
_ => panic!("expect different error"),
}
}

#[test]
fn implements_debug() {
let error: StdError = StdError::underflow(3, 5);
let error: StdError = StdError::from(OverflowError::new(OverflowOperation::Sub, 3, 5));
let embedded = format!("Debug message: {:?}", error);
assert_eq!(
embedded,
r#"Debug message: Underflow { minuend: "3", subtrahend: "5" }"#
r#"Debug message: Overflow(OverflowError { operation: Sub, operand1: "3", operand2: "5" })"#
);
}

#[test]
fn implements_display() {
let error: StdError = StdError::underflow(3, 5);
let error: StdError = StdError::from(OverflowError::new(OverflowOperation::Sub, 3, 5));
let embedded = format!("Display message: {}", error);
assert_eq!(embedded, "Display message: Cannot subtract 5 from 3");
assert_eq!(embedded, "Display message: Cannot Sub with 3 and 5");
}

#[test]
fn implements_partial_eq() {
let u1 = StdError::underflow(3, 5);
let u2 = StdError::underflow(3, 5);
let u3 = StdError::underflow(3, 7);
let u1 = StdError::from(OverflowError::new(OverflowOperation::Sub, 3, 5));
let u2 = StdError::from(OverflowError::new(OverflowOperation::Sub, 3, 5));
let u3 = StdError::from(OverflowError::new(OverflowOperation::Sub, 3, 7));
let s1 = StdError::serialize_err("Book", "Content too long");
let s2 = StdError::serialize_err("Book", "Content too long");
let s3 = StdError::serialize_err("Book", "Title too long");
Expand Down
5 changes: 4 additions & 1 deletion packages/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ pub use crate::addresses::{CanonicalAddr, HumanAddr};
pub use crate::binary::Binary;
pub use crate::coins::{coin, coins, has_coins, Coin};
pub use crate::deps::{Deps, DepsMut, OwnedDeps};
pub use crate::errors::{RecoverPubkeyError, StdError, StdResult, SystemError, VerificationError};
pub use crate::errors::{
OverflowError, OverflowOperation, RecoverPubkeyError, StdError, StdResult, SystemError,
VerificationError,
};
#[cfg(feature = "stargate")]
pub use crate::ibc::{
ChannelResponse, IbcAcknowledgement, IbcBasicResponse, IbcChannel, IbcEndpoint, IbcMsg,
Expand Down
Loading

0 comments on commit 53a7833

Please sign in to comment.