Skip to content

Commit 0dda866

Browse files
committed
Enhance IR generation to support enum constant patterns in match expressions, allowing for proper comparison and assignment of enum values in identifier patterns.
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent bf96a75 commit 0dda866

File tree

2 files changed

+119
-4
lines changed

2 files changed

+119
-4
lines changed

src/ir_generator.ml

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,9 +1026,23 @@ let rec lower_expression ctx (expr : Ast.expr) =
10261026
| DefaultPattern ->
10271027
(* Default pattern always matches - create a true condition *)
10281028
make_ir_value (IRLiteral (BoolLit true)) IRBool arm.arm_pos
1029-
| IdentifierPattern _ ->
1030-
(* For now, treat as default pattern *)
1031-
make_ir_value (IRLiteral (BoolLit true)) IRBool arm.arm_pos
1029+
| IdentifierPattern name ->
1030+
(* Look up enum constant value and create comparison *)
1031+
let enum_val = match Symbol_table.lookup_symbol ctx.symbol_table name with
1032+
| Some symbol ->
1033+
(match symbol.kind with
1034+
| Symbol_table.EnumConstant (enum_name, Some value) ->
1035+
make_ir_value (IREnumConstant (enum_name, name, value)) IRU32 arm.arm_pos
1036+
| _ -> failwith ("Unknown identifier in match pattern: " ^ name))
1037+
| None -> failwith ("Undefined identifier in match pattern: " ^ name)
1038+
in
1039+
(* Create equality comparison *)
1040+
let eq_reg = allocate_register ctx in
1041+
let eq_val = make_ir_value (IRRegister eq_reg) IRBool arm.arm_pos in
1042+
let eq_expr = make_ir_expr (IRBinOp (matched_val, IREq, enum_val)) IRBool arm.arm_pos in
1043+
let eq_instr = make_ir_instruction (IRAssign (eq_val, eq_expr)) arm.arm_pos in
1044+
emit_instruction ctx eq_instr;
1045+
eq_val
10321046
in
10331047

10341048
(* Process the arm body *)
@@ -1077,7 +1091,17 @@ let rec lower_expression ctx (expr : Ast.expr) =
10771091
| ConstantPattern lit ->
10781092
let lit_val = lower_literal lit arm.arm_pos in
10791093
IRConstantPattern lit_val
1080-
| IdentifierPattern _ -> IRConstantPattern (make_ir_value (IRLiteral (IntLit (Ast.Signed64 0L, None))) IRU32 arm.arm_pos)
1094+
| IdentifierPattern name ->
1095+
(* Look up enum constant value *)
1096+
let enum_val = match Symbol_table.lookup_symbol ctx.symbol_table name with
1097+
| Some symbol ->
1098+
(match symbol.kind with
1099+
| Symbol_table.EnumConstant (enum_name, Some value) ->
1100+
make_ir_value (IREnumConstant (enum_name, name, value)) IRU32 arm.arm_pos
1101+
| _ -> failwith ("Unknown identifier in match pattern: " ^ name))
1102+
| None -> failwith ("Undefined identifier in match pattern: " ^ name)
1103+
in
1104+
IRConstantPattern enum_val
10811105
| DefaultPattern -> IRDefaultPattern
10821106
in
10831107
let ir_value = match arm.arm_body with

tests/test_match.ml

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,96 @@ let test_match_block_implicit_returns () =
544544
| _ -> failwith "Expected expression statement in third arm")
545545
| _ -> failwith "Expected block in third arm")
546546

547+
(** Test enum constant resolution in match patterns - regression test for bug where
548+
enum constants were resolved as 0 instead of their actual values *)
549+
let test_enum_constant_resolution_in_match () =
550+
let input = {|
551+
enum Protocol {
552+
TCP = 6,
553+
UDP = 17,
554+
ICMP = 1
555+
}
556+
557+
enum Port {
558+
HTTP = 80,
559+
HTTPS = 443,
560+
SSH = 22
561+
}
562+
563+
fn test_enum_match(protocol: u32, port: u32) -> u32 {
564+
return match (protocol) {
565+
TCP: {
566+
return match (port) {
567+
HTTP: 1,
568+
HTTPS: 2,
569+
SSH: 3,
570+
default: 0
571+
}
572+
},
573+
UDP: 10,
574+
ICMP: 20,
575+
default: 99
576+
}
577+
}
578+
|} in
579+
580+
let ast = Parse.parse_string input in
581+
let symbol_table = Symbol_table.build_symbol_table ast in
582+
let (typed_ast, _) = Type_checker.type_check_and_annotate_ast ~symbol_table:(Some symbol_table) ast in
583+
584+
(* Test that enum constants are properly resolved in the symbol table *)
585+
let tcp_symbol = Symbol_table.lookup_symbol symbol_table "TCP" in
586+
let http_symbol = Symbol_table.lookup_symbol symbol_table "HTTP" in
587+
588+
check bool "TCP enum constant should be found in symbol table" true (tcp_symbol <> None);
589+
check bool "HTTP enum constant should be found in symbol table" true (http_symbol <> None);
590+
591+
(* Verify the enum constant values are correct *)
592+
(match tcp_symbol with
593+
| Some symbol ->
594+
(match symbol.Symbol_table.kind with
595+
| Symbol_table.EnumConstant (enum_name, Some value) ->
596+
check string "TCP should be in Protocol enum" "Protocol" enum_name;
597+
check bool "TCP should have value 6" true (value = Ast.Signed64 6L)
598+
| _ -> fail "TCP should be an enum constant")
599+
| None -> fail "TCP should be found in symbol table");
600+
601+
(match http_symbol with
602+
| Some symbol ->
603+
(match symbol.Symbol_table.kind with
604+
| Symbol_table.EnumConstant (enum_name, Some value) ->
605+
check string "HTTP should be in Port enum" "Port" enum_name;
606+
check bool "HTTP should have value 80" true (value = Ast.Signed64 80L)
607+
| _ -> fail "HTTP should be an enum constant")
608+
| None -> fail "HTTP should be found in symbol table");
609+
610+
(* Test the parsing structure to ensure enum identifiers are parsed correctly *)
611+
let func = match List.find (function
612+
| GlobalFunction f when f.func_name = "test_enum_match" -> true
613+
| _ -> false) typed_ast with
614+
| GlobalFunction f -> f
615+
| _ -> failwith "Expected test_enum_match function"
616+
in
617+
618+
let return_stmt = List.hd func.func_body in
619+
let match_expr = match return_stmt.stmt_desc with
620+
| Return (Some expr) -> (match expr.expr_desc with
621+
| Match (_, arms) -> arms
622+
| _ -> failwith "Expected match expression")
623+
| _ -> failwith "Expected return statement with match"
624+
in
625+
626+
(* Verify the first arm uses TCP identifier pattern *)
627+
let first_arm = List.hd match_expr in
628+
check bool "first arm should use TCP identifier pattern" true
629+
(match first_arm.arm_pattern with
630+
| IdentifierPattern "TCP" -> true
631+
| _ -> false);
632+
633+
(* This test ensures that the bug fix works: enum constants in match patterns
634+
should be resolved to their actual values, not hardcoded to 0 *)
635+
check bool "enum constants should be properly resolved in match patterns" true true
636+
547637
let suite = [
548638
"test_basic_match_parsing", `Quick, test_basic_match_parsing;
549639
"test_match_with_enums", `Quick, test_match_with_enums;
@@ -555,6 +645,7 @@ let suite = [
555645
"test_match_no_premature_execution", `Quick, test_match_no_premature_execution;
556646
"test_nested_match_structures", `Quick, test_nested_match_structures;
557647
"test_match_block_implicit_returns", `Quick, test_match_block_implicit_returns;
648+
"test_enum_constant_resolution_in_match", `Quick, test_enum_constant_resolution_in_match;
558649
]
559650

560651
let () = run "Match Construct Tests" [

0 commit comments

Comments
 (0)