Skip to content

Commit

Permalink
[move][move-2024] Match typing fix (MystenLabs#20584)
Browse files Browse the repository at this point in the history
## Description 

This fixes a bug where some match arms' final type was computed based on
the inner arm, instead of the expected output type.

## Test plan 

New tests past

---

## Release notes

Check each box that your changes affect. If none of the boxes relate to
your changes, release notes aren't required.

For each box you select, include information after the relevant heading
that describes the impact of your changes that a user might notice and
any actions they must take to implement updates.

- [ ] Protocol: 
- [ ] Nodes (Validators and Full nodes): 
- [ ] Indexer: 
- [ ] JSON-RPC: 
- [ ] GraphQL: 
- [ ] CLI: 
- [ ] Rust SDK:
- [ ] REST API:
  • Loading branch information
cgswords authored Dec 10, 2024
1 parent 6fc12e7 commit b48e421
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 38 deletions.
33 changes: 31 additions & 2 deletions external-crates/move/crates/move-compiler/src/cfgir/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
hlir::ast::{self as H, BlockLabel, Label, Value, Value_, Var},
ice_assert,
parser::ast::{ConstantName, FunctionName},
shared::{program_info::TypingProgramInfo, unique_map::UniqueMap, CompilationEnv},
shared::{program_info::TypingProgramInfo, unique_map::UniqueMap, AstDebug, CompilationEnv},
FullyCompiledProgram,
};
use cfgir::ast::LoopInfo;
Expand Down Expand Up @@ -43,6 +43,13 @@ enum NamedBlockType {
Named,
}

pub(super) struct CFGIRDebugFlags {
#[allow(dead_code)]
pub(super) print_blocks: bool,
#[allow(dead_code)]
pub(super) print_optimized_blocks: bool,
}

struct Context<'env> {
env: &'env CompilationEnv,
info: &'env TypingProgramInfo,
Expand All @@ -52,6 +59,7 @@ struct Context<'env> {
named_blocks: UniqueMap<BlockLabel, (Label, Label)>,
// Used for populating block_info
loop_bounds: BTreeMap<Label, G::LoopInfo>,
debug: CFGIRDebugFlags,
}

impl<'env> Context<'env> {
Expand All @@ -65,6 +73,10 @@ impl<'env> Context<'env> {
label_count: 0,
named_blocks: UniqueMap::new(),
loop_bounds: BTreeMap::new(),
debug: CFGIRDebugFlags {
print_blocks: false,
print_optimized_blocks: false,
},
}
}

Expand Down Expand Up @@ -654,7 +666,15 @@ fn function_body(
let (start, mut blocks, block_info) = finalize_blocks(context, blocks);
context.clear_block_state();
let binfo = block_info.iter().map(destructure_tuple);

if context.debug.print_blocks {
for (lbl, block) in &blocks {
println!("{lbl}:");
for cmd in block {
print!(" ");
cmd.print_verbose();
}
}
}
let (mut cfg, infinite_loop_starts, diags) =
MutForwardCFG::new(start, &mut blocks, binfo);
context.add_diags(diags);
Expand Down Expand Up @@ -685,6 +705,15 @@ fn function_body(
&UniqueMap::new(),
&mut cfg,
);
if context.debug.print_optimized_blocks {
for (lbl, block) in &blocks {
println!("{lbl}:");
for cmd in block {
print!(" ");
cmd.print_verbose();
}
}
}
}
let block_info = block_info
.into_iter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ pub(super) fn compile_match(
let match_tree = build_match_tree(context, VecDeque::from([match_subject]), pattern_matrix);
debug_print!(
context.debug.match_translation,
("match tree" => match_tree; sdbg)
("match tree" => match_tree; sdbg),
("result type" => result_type)
);
let mut resolution_context = ResolutionContext {
hlir_context: context,
Expand Down Expand Up @@ -419,6 +420,7 @@ fn match_tree_to_exp(
if let Some((unpack_fields, next)) = arms.remove(&v) {
let rest_result = match_tree_to_exp(context, init_subject, *next);
let unpack_block = make_match_variant_unpack(
context,
m,
e,
v,
Expand All @@ -436,7 +438,7 @@ fn match_tree_to_exp(
}
let out_exp = T::UnannotatedExp_::VariantMatch(make_var_ref(subject), (m, e), blocks);
let body_exp = T::exp(context.output_type(), sp(context.arms_loc(), out_exp));
make_copy_bindings(bindings, body_exp)
make_copy_bindings(context, bindings, body_exp)
}
MatchTree::StructUnpack {
subject,
Expand All @@ -460,6 +462,7 @@ fn match_tree_to_exp(
StructUnpack::Unpack(unpack_fields, next) => {
let rest_result = match_tree_to_exp(context, init_subject, *next);
make_match_struct_unpack(
context,
m,
s,
tyargs.clone(),
Expand All @@ -469,7 +472,7 @@ fn match_tree_to_exp(
)
}
};
make_copy_bindings(bindings, unpack_exp)
make_copy_bindings(context, bindings, unpack_exp)
}
MatchTree::LiteralSwitch {
subject,
Expand Down Expand Up @@ -497,11 +500,11 @@ fn match_tree_to_exp(

let true_arm = match_tree_to_exp(context, init_subject, *true_arm);
let false_arm = match_tree_to_exp(context, init_subject, *false_arm);
let result_ty = context.output_type().clone();

make_copy_bindings(
context,
bindings,
make_if_else(lit_subject, true_arm, false_arm, result_ty),
make_if_else_arm(context, lit_subject, true_arm, false_arm),
)
}
MatchTree::LiteralSwitch {
Expand All @@ -526,10 +529,9 @@ fn match_tree_to_exp(
for (key, next_tree) in entries.into_iter().rev() {
let match_arm = match_tree_to_exp(context, init_subject, *next_tree);
let test_exp = make_lit_test(lit_subject.clone(), key);
let result_ty = context.output_type().clone();
out_exp = make_if_else(test_exp, match_arm, out_exp, result_ty);
out_exp = make_if_else_arm(context, test_exp, match_arm, out_exp);
}
make_copy_bindings(bindings, out_exp)
make_copy_bindings(context, bindings, out_exp)
}
}
}
Expand All @@ -549,7 +551,8 @@ fn make_leaf(
last.guard.unwrap().exp.loc,
"Must have a non-guarded leaf"
);
return make_copy_bindings(last.bindings, make_arm(context, subject.clone(), last.arm));
let arm = make_arm(context, subject.clone(), last.arm);
return make_copy_bindings(context, last.bindings, arm);
}

let last = leaf.pop().unwrap();
Expand All @@ -559,17 +562,16 @@ fn make_leaf(
last.guard.unwrap().exp.loc,
"Must have a non-guarded leaf"
);
let mut out_exp =
make_copy_bindings(last.bindings, make_arm(context, subject.clone(), last.arm));
let out_ty = out_exp.ty.clone();
let arm = make_arm(context, subject.clone(), last.arm);
let mut out_exp = make_copy_bindings(context, last.bindings, arm);
while let Some(arm) = leaf.pop() {
ice_assert!(
context.hlir_context.reporter,
arm.guard.is_some(),
arm.loc,
"Expected a guard"
);
out_exp = make_guard_exp(context, subject, arm, out_exp, out_ty.clone());
out_exp = make_guard_exp(context, subject, arm, out_exp);
}
out_exp
}
Expand All @@ -579,7 +581,6 @@ fn make_guard_exp(
subject: &FringeEntry,
arm: ArmResult,
cur_exp: T::Exp,
result_ty: Type,
) -> T::Exp {
let ArmResult {
loc: _,
Expand All @@ -593,8 +594,8 @@ fn make_guard_exp(
.map(|(x, (_mut, entry))| (x, (Mutability::Imm, entry)))
.collect();
let guard_arm = make_arm(context, subject.clone(), arm);
let body = make_if_else(*guard.unwrap(), guard_arm, cur_exp, result_ty);
make_copy_bindings(bindings, body)
let body = make_if_else_arm(context, *guard.unwrap(), guard_arm, cur_exp);
make_copy_bindings(context, bindings, body)
}

fn make_arm(context: &mut ResolutionContext, subject: FringeEntry, arm: Arm) -> T::Exp {
Expand Down Expand Up @@ -715,11 +716,10 @@ fn make_arm_unpack(
}

let nloc = next.exp.loc;
let out_type = next.ty.clone();
seq.push_back(sp(nloc, T::SequenceItem_::Seq(Box::new(next))));

let body = T::UnannotatedExp_::Block((UseFuns::new(0), seq));
T::exp(out_type, sp(ploc, body))
T::exp(context.output_type(), sp(ploc, body))
}

fn match_pattern_has_binders(pat: &T::MatchPattern, rhs_binders: &BTreeSet<Var>) -> bool {
Expand Down Expand Up @@ -937,6 +937,7 @@ fn make_var_ref(subject: FringeEntry) -> Box<T::Exp> {

// Performs an unpack for the purpose of matching, where we are matching against an imm. ref.
fn make_match_variant_unpack(
context: &ResolutionContext,
mident: ModuleIdent,
enum_: DatatypeName,
variant: VariantName,
Expand Down Expand Up @@ -970,16 +971,16 @@ fn make_match_variant_unpack(
let binder = T::SequenceItem_::Bind(sp(rhs_loc, vec![unpack_lvalue]), vec![Some(ty)], rhs);
seq.push_back(sp(rhs_loc, binder));

let result_type = next.ty.clone();
let eloc = next.exp.loc;
seq.push_back(sp(eloc, T::SequenceItem_::Seq(Box::new(next))));

let exp_value = sp(eloc, T::UnannotatedExp_::Block((UseFuns::new(0), seq)));
T::exp(result_type, exp_value)
T::exp(context.output_type(), exp_value)
}

// Performs a struct unpack for the purpose of matching, where we are matching against an imm. ref.
fn make_match_struct_unpack(
context: &ResolutionContext,
mident: ModuleIdent,
struct_: DatatypeName,
tyargs: Vec<Type>,
Expand Down Expand Up @@ -1012,12 +1013,11 @@ fn make_match_struct_unpack(
let binder = T::SequenceItem_::Bind(sp(rhs_loc, vec![unpack_lvalue]), vec![Some(ty)], rhs);
seq.push_back(sp(rhs_loc, binder));

let result_type = next.ty.clone();
let eloc = next.exp.loc;
seq.push_back(sp(eloc, T::SequenceItem_::Seq(Box::new(next))));

let exp_value = sp(eloc, T::UnannotatedExp_::Block((UseFuns::new(0), seq)));
T::exp(result_type, exp_value)
T::exp(context.output_type(), exp_value)
}

fn make_arm_variant_unpack_stmt(
Expand Down Expand Up @@ -1111,11 +1111,16 @@ fn make_match_lit(subject: FringeEntry) -> T::Exp {
}
}

fn make_copy_bindings(bindings: PatBindings, next: T::Exp) -> T::Exp {
make_bindings(bindings, next, true)
fn make_copy_bindings(context: &ResolutionContext, bindings: PatBindings, next: T::Exp) -> T::Exp {
make_bindings(context, bindings, next, true)
}

fn make_bindings(bindings: PatBindings, next: T::Exp, as_copy: bool) -> T::Exp {
fn make_bindings(
context: &ResolutionContext,
bindings: PatBindings,
next: T::Exp,
as_copy: bool,
) -> T::Exp {
let eloc = next.exp.loc;
let mut seq = VecDeque::new();
for (lhs, (mut_, rhs)) in bindings {
Expand All @@ -1126,10 +1131,9 @@ fn make_bindings(bindings: PatBindings, next: T::Exp, as_copy: bool) -> T::Exp {
};
seq.push_back(binding);
}
let result_type = next.ty.clone();
seq.push_back(sp(eloc, T::SequenceItem_::Seq(Box::new(next))));
let exp_value = sp(eloc, T::UnannotatedExp_::Block((UseFuns::new(0), seq)));
T::exp(result_type, exp_value)
T::exp(context.output_type(), exp_value)
}

fn make_lvalue(lhs: Var, mut_: Mutability, ty: Type) -> T::LValue {
Expand Down Expand Up @@ -1174,11 +1178,16 @@ fn make_lit_test(lit_exp: T::Exp, value: Value) -> T::Exp {
make_eq_test(loc, lit_exp, value_exp)
}

fn make_if_else(test: T::Exp, conseq: T::Exp, alt: T::Exp, result_ty: Type) -> T::Exp {
fn make_if_else_arm(
context: &ResolutionContext,
test: T::Exp,
conseq: T::Exp,
alt: T::Exp,
) -> T::Exp {
// FIXME: this span is woefully wrong
let loc = test.exp.loc;
T::exp(
result_ty,
context.output_type(),
sp(
loc,
T::UnannotatedExp_::IfElse(Box::new(test), Box::new(conseq), Some(Box::new(alt))),
Expand Down
13 changes: 6 additions & 7 deletions external-crates/move/crates/move-compiler/src/hlir/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ pub(super) struct HLIRDebugFlags {
#[allow(dead_code)]
pub(super) match_specialization: bool,
#[allow(dead_code)]
pub(super) match_work_queue: bool,
#[allow(dead_code)]
pub(super) function_translation: bool,
#[allow(dead_code)]
pub(super) eval_order: bool,
Expand Down Expand Up @@ -161,7 +159,6 @@ impl<'env> Context<'env> {
eval_order: false,
match_translation: false,
match_specialization: false,
match_work_queue: false,
};
let reporter = env.diagnostic_reporter_at_top_level();
Context {
Expand Down Expand Up @@ -479,7 +476,7 @@ fn function_body(
let (locals, body) = function_body_defined(context, sig, loc, seq);
debug_print!(context.debug.function_translation,
(msg "--------"),
(lines "body" => &body));
(lines "body" => &body; verbose));
HB::Defined { locals, body }
}
TB::Macro => unreachable!("ICE macros filtered above"),
Expand Down Expand Up @@ -871,10 +868,11 @@ fn tail(
E::Match(subject, arms) => {
debug_print!(context.debug.match_translation,
("subject" => subject),
("type" => in_type),
(lines "arms" => &arms.value)
);
let compiled = match_compilation::compile_match(context, in_type, *subject, arms);
debug_print!(context.debug.match_translation, ("compiled" => compiled));
debug_print!(context.debug.match_translation, ("compiled" => compiled; verbose));
let result = tail(context, block, expected_type, compiled);
debug_print!(context.debug.match_variant_translation,
(lines "block" => block; verbose),
Expand Down Expand Up @@ -1203,10 +1201,11 @@ fn value(
E::Match(subject, arms) => {
debug_print!(context.debug.match_translation,
("subject" => subject),
("type" => in_type),
(lines "arms" => &arms.value)
);
let compiled = match_compilation::compile_match(context, in_type, *subject, arms);
debug_print!(context.debug.match_translation, ("compiled" => compiled));
debug_print!(context.debug.match_translation, ("compiled" => compiled; verbose));
let result = value(context, block, None, compiled);
debug_print!(context.debug.match_variant_translation, ("result" => &result));
result
Expand Down Expand Up @@ -1848,7 +1847,7 @@ fn statement(context: &mut Context, block: &mut Block, e: T::Exp) {
);
let subject_type = subject.ty.clone();
let compiled = match_compilation::compile_match(context, &subject_type, *subject, arms);
debug_print!(context.debug.match_translation, ("compiled" => compiled));
debug_print!(context.debug.match_translation, ("compiled" => compiled; verbose));
statement(context, block, compiled);
debug_print!(context.debug.match_variant_translation, (lines "block" => block));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module 0x0::variable_undefined;

const V_1: u8 = 0;
const V_2: u8 = 1;

public fun get_v1(a: u8): u8 {
match (a) {
V_1 => 10,
V_2 => 20,
_ => abort,
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module 0x0::variable_undefined;

const V_1: bool = true;
const V_2: bool = false;

public fun get_v1(a: bool): bool {
match (a) {
V_1 => true,
V_2 => false,
_ => abort,
}
}

0 comments on commit b48e421

Please sign in to comment.