Skip to content

Commit

Permalink
Try fix reentrancy-2 detector
Browse files Browse the repository at this point in the history
  • Loading branch information
jgcrosta committed Nov 7, 2024
1 parent 9fe8b66 commit 011eb42
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 148 deletions.
281 changes: 136 additions & 145 deletions detectors/reentrancy-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,190 +103,181 @@ const INSERT: &str = "insert";
const MAPPING: &str = "Mapping";
const ACCOUNT_ID: &str = "AccountId";
const U128: &str = "u128";
const CALL_FLAGS: &str = "call_flags";
const ALLOW_REENTRY: &str = "ALLOW_REENTRY";

impl<'tcx> LateLintPass<'tcx> for Reentrancy2 {
fn check_fn(
&mut self,
cx: &LateContext<'tcx>,
_: FnKind<'tcx>,
_: &'tcx FnDecl<'_>,
body: &'tcx Body<'_>,
_: Span,
_: LocalDefId,
) {
struct ReentrancyVisitor<'a, 'tcx: 'a> {
cx: &'a LateContext<'tcx>,
contracts_tainted_for_reentrancy: HashSet<Symbol>,
current_method_call: Option<Symbol>,
bool_var_values: HashMap<HirId, bool>,
reentrancy_spans: Vec<Span>,
should_look_for_insert: bool,
has_insert_operation: bool,
}
struct ReentrancyVisitor<'a, 'tcx> {
cx: &'a LateContext<'tcx>,
tainted_contracts: HashSet<Symbol>,
current_method: Option<Symbol>,
bool_values: HashMap<HirId, bool>,
reentrancy_spans: Vec<Span>,
looking_for_insert: bool,
found_insert: bool,
}

// This function is called whenever a contract is identified as potentially susceptible to reentrancy.
fn set_tainted_contract(visitor: &mut ReentrancyVisitor) {
if let Some(method_calls) = &visitor.current_method_call {
visitor
.contracts_tainted_for_reentrancy
.insert(*method_calls);
visitor.current_method_call = None;
}
impl<'a, 'tcx> ReentrancyVisitor<'a, 'tcx> {
fn mark_current_as_tainted(&mut self) {
if let Some(method) = self.current_method.take() {
self.tainted_contracts.insert(method);
}
}

fn handle_set_allow_reentry(visitor: &mut ReentrancyVisitor, args: &&[Expr<'_>]) {
match &args[0].kind {
ExprKind::Lit(lit) => {
// If the argument is a boolean literal and it's true, call set_tainted_contract
if let LitKind::Bool(value) = lit.node {
if value {
set_tainted_contract(visitor);
}
}
}
ExprKind::Path(qpath) => {
// If the argument is a local variable, check if it's a boolean and if it's true
if_chain! {
if let res = visitor.cx.qpath_res(qpath, args[0].hir_id);
if let Res::Local(_) = res;
if let QPath::Resolved(_, path) = qpath;
then {
for path_segment in path.segments {
// If the argument is a known boolean variable, check if it's true
if let Res::Local(hir_id) = path_segment.res {
if visitor.bool_var_values.get(&hir_id).map_or(true, |v| *v) {
set_tainted_contract(visitor);
}
}
fn handle_set_allow_reentry(&mut self, args: &[Expr<'_>]) {
let is_reentry_enabled = match &args[0].kind {
ExprKind::Lit(lit) => matches!(lit.node, LitKind::Bool(true)),
ExprKind::Path(qpath) => {
if_chain! {
if let res = self.cx.qpath_res(qpath, args[0].hir_id);
if let Res::Local(_) = res;
if let QPath::Resolved(_, path) = qpath;
then {
path.segments.iter().any(|segment| {
if let Res::Local(hir_id) = segment.res {
self.bool_values.get(&hir_id).copied().unwrap_or(true)
} else {
false
}
}
})
} else {
false
}
}
_ => (),
}
_ => false,
};

if is_reentry_enabled {
self.mark_current_as_tainted();
}
}

fn handle_invoke_contract(
visitor: &mut ReentrancyVisitor,
args: &&[Expr<'_>],
expr: &Expr<'_>,
) {
if_chain! {
if let ExprKind::AddrOf(_, _, invoke_expr) = &args[0].kind;
if let ExprKind::Path(qpath) = &invoke_expr.kind;
if let QPath::Resolved(_, path) = qpath;
then{
for path_segment in path.segments {
// If the argument is a tainted contract, add the span of this expression to the span vector
if visitor.contracts_tainted_for_reentrancy.contains(&path_segment.ident.name) {
visitor.should_look_for_insert = true;
visitor.reentrancy_spans.push(expr.span);
}
fn handle_invoke_contract(&mut self, args: &[Expr<'_>], expr: &Expr<'_>) {
if_chain! {
if let ExprKind::AddrOf(_, _, invoke_expr) = &args[0].kind;
if let ExprKind::Path(qpath) = &invoke_expr.kind;
if let QPath::Resolved(_, path) = qpath;
then {
for segment in path.segments.iter() {
if self.tainted_contracts.contains(&segment.ident.name) {
self.looking_for_insert = true;
self.reentrancy_spans.push(expr.span);
}
}
}
}
}

fn handle_insert(visitor: &mut ReentrancyVisitor, expr: &Expr<'_>) {
if_chain! {
if let ExprKind::MethodCall(_, expr1, _, _) = &expr.kind;
if let object_type = visitor.cx.typeck_results().expr_ty(expr1);
if let TyKind::Adt(adt_def, substs) = object_type.kind();
if let Some(variant) = adt_def.variants().get(VariantIdx::from_u32(0));
if variant.name.as_str() == MAPPING;
if let mut has_account_id = false;
if let mut has_u128 = false;
then{
substs.types().for_each(|inner_type| {
let str_inner_type = inner_type.to_string();
if str_inner_type.contains(ACCOUNT_ID) {
has_account_id = true;
} else if str_inner_type.contains(U128) {
has_u128 = true;
}
});
visitor.has_insert_operation = has_account_id && has_u128;
}
fn handle_call_flags(&mut self, args: &[Expr<'_>]) {
if_chain! {
if let ExprKind::Path(qpath) = &args[0].kind;
if let QPath::TypeRelative(_, segment) = qpath;
if segment.ident.name.as_str() == ALLOW_REENTRY;
then {
self.mark_current_as_tainted();
}
}
}

impl<'a, 'tcx> Visitor<'tcx> for ReentrancyVisitor<'a, 'tcx> {
fn visit_local(&mut self, local: &'tcx Local<'tcx>) {
if let Some(init) = &local.init {
if let PatKind::Binding(_, _, ident, _) = &local.pat.kind {
match &init.kind {
// Check if the variable being declared is a boolean, if so, add it to the bool_declarations hashmap
ExprKind::Lit(lit) => {
if let LitKind::Bool(value) = lit.node {
self.bool_var_values.insert(local.pat.hir_id, value);
}
}
ExprKind::MethodCall(_, _, _, _) => {
self.current_method_call = Some(ident.name);
}
// Check if the variable being declared is a boolean, if so, add it to the bool_declarations hashmap
ExprKind::Path(QPath::Resolved(_, path)) => {
if let Some(segment) = path.segments.last() {
if let Res::Local(hir_id) = segment.res {
if let Some(value) = self.bool_var_values.get(&hir_id) {
self.bool_var_values.insert(local.pat.hir_id, *value);
}
}
fn handle_insert(&mut self, expr: &Expr<'_>) {
if_chain! {
if let ExprKind::MethodCall(_, receiver, _, _) = &expr.kind;
if let object_type = self.cx.typeck_results().expr_ty(receiver);
if let TyKind::Adt(adt_def, substs) = object_type.kind();
if let Some(variant) = adt_def.variants().get(VariantIdx::from_u32(0));
if variant.name.as_str() == MAPPING;
then {
let mut has_account_id = false;
let mut has_u128 = false;

substs.types().for_each(|ty| {
let type_str = ty.to_string();
has_account_id |= type_str.contains(ACCOUNT_ID);
has_u128 |= type_str.contains(U128);
});

self.found_insert = has_account_id && has_u128;
}
}
}
}

impl<'a, 'tcx> Visitor<'tcx> for ReentrancyVisitor<'a, 'tcx> {
fn visit_local(&mut self, local: &'tcx Local<'tcx>) {
if let Some(init) = &local.init {
if let PatKind::Binding(_, _, ident, _) = &local.pat.kind {
match &init.kind {
ExprKind::Lit(lit) => {
if let LitKind::Bool(value) = lit.node {
self.bool_values.insert(local.pat.hir_id, value);
}
}
ExprKind::MethodCall(_, _, _, _) => {
self.current_method = Some(ident.name);
}
ExprKind::Path(QPath::Resolved(_, path)) => {
if let Some(segment) = path.segments.last() {
if let Res::Local(hir_id) = segment.res {
if let Some(&value) = self.bool_values.get(&hir_id) {
self.bool_values.insert(local.pat.hir_id, value);
}
}
_ => (),
}
}
walk_local(self, local);
_ => (),
}
}
}
walk_local(self, local);
}

// This method is called for every expression.
fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
if let ExprKind::MethodCall(func, _, args, _) = &expr.kind {
let function_name = func.ident.name.as_str();
match function_name {
// The function "set_allow_reentry" is being called
SET_ALLOW_REENTRY => handle_set_allow_reentry(self, args),
// The function "invoke_contract" is being called
INVOKE_CONTRACT => handle_invoke_contract(self, args, expr),
// The function "insert" is being called
INSERT => {
if self.should_look_for_insert {
handle_insert(self, expr)
}
}
_ => (),
}
}
walk_expr(self, expr)
fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
if let ExprKind::MethodCall(func, _, args, _) = &expr.kind {
match func.ident.name.as_str() {
SET_ALLOW_REENTRY => self.handle_set_allow_reentry(args),
CALL_FLAGS => self.handle_call_flags(args),
INVOKE_CONTRACT => self.handle_invoke_contract(args, expr),
INSERT if self.looking_for_insert => self.handle_insert(expr),
_ => (),
}
}
walk_expr(self, expr);
}
}

// The main function where we start the visitor to traverse the AST.
let mut reentrancy_visitor = ReentrancyVisitor {
impl<'tcx> LateLintPass<'tcx> for Reentrancy2 {
fn check_fn(
&mut self,
cx: &LateContext<'tcx>,
_: FnKind<'tcx>,
_: &'tcx FnDecl<'_>,
body: &'tcx Body<'_>,
_: Span,
_: LocalDefId,
) {
let mut visitor = ReentrancyVisitor {
cx,
contracts_tainted_for_reentrancy: HashSet::new(),
current_method_call: None,
bool_var_values: HashMap::new(),
tainted_contracts: HashSet::new(),
current_method: None,
bool_values: HashMap::new(),
reentrancy_spans: Vec::new(),
has_insert_operation: false,
should_look_for_insert: false,
looking_for_insert: false,
found_insert: false,
};
walk_expr(&mut reentrancy_visitor, body.value);
walk_expr(&mut visitor, body.value);

// Iterate over all potential reentrancy spans and emit a warning for each.
if reentrancy_visitor.has_insert_operation {
reentrancy_visitor.reentrancy_spans.into_iter().for_each(|span| {
if visitor.found_insert {
for span in visitor.reentrancy_spans {
span_lint_and_help(
cx,
REENTRANCY_2,
span,
LINT_MESSAGE,
None,
"This statement seems to call another contract after the flag set_allow_reentry was enabled [todo: check state changes after this statement]"
"This statement seems to call another contract after the flag \
set_allow_reentry was enabled [todo: check state changes after this statement]",
);
})
}
}
}
}
8 changes: 5 additions & 3 deletions test-cases/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ members = ["*/*/*-example"]
resolver = "2"

[workspace.dependencies]
getrandom = { version = "0.2" }
ink = { version = "5.0.0", default-features = false }
scale = { package = "parity-scale-codec", version = "3", default-features = false, features = ["derive"] }
scale-info = { version = "2.6", default-features = false, features = ["derive"] }
ink_e2e = { version = "=5.0.0" }
getrandom = { version = "0.2" }
scale = { package = "parity-scale-codec", version = "3", default-features = false, features = [
"derive",
] }
scale-info = { version = "2.6", default-features = false, features = ["derive"] }

[profile.release]
codegen-units = 1
Expand Down

0 comments on commit 011eb42

Please sign in to comment.