From a7de02e05216db9a04e438703ddf1b6b12f3fbef Mon Sep 17 00:00:00 2001 From: David Rubin <87927264+Rexicon226@users.noreply.github.com> Date: Wed, 22 May 2024 08:51:16 -0700 Subject: [PATCH] implement `@expect` builtin (#19658) * implement `@expect` * add docs * add a second arg for expected bool * fix typo * move `expect` to use BinOp * update to newer langref format --- doc/langref.html.in | 8 ++++ doc/langref/expect_if.zig | 15 ++++++++ lib/std/zig/AstGen.zig | 10 ++++- lib/std/zig/AstRlAnnotate.zig | 5 +++ lib/std/zig/BuiltinFn.zig | 8 ++++ lib/std/zig/Zir.zig | 3 ++ lib/zig.h | 6 +++ src/Air.zig | 7 ++++ src/Liveness.zig | 2 + src/Liveness/Verify.zig | 1 + src/Module.zig | 1 + src/Sema.zig | 29 +++++++++++++++ src/arch/aarch64/CodeGen.zig | 2 + src/arch/arm/CodeGen.zig | 2 + src/arch/riscv64/CodeGen.zig | 2 + src/arch/sparc64/CodeGen.zig | 2 + src/arch/wasm/CodeGen.zig | 2 + src/arch/x86_64/CodeGen.zig | 2 + src/codegen/c.zig | 23 ++++++++++++ src/codegen/llvm.zig | 22 +++++++++++ src/print_air.zig | 1 + src/print_zir.zig | 1 + src/target.zig | 1 + test/behavior/expect.zig | 37 +++++++++++++++++++ .../cases/compile_errors/@expect_non_bool.zig | 11 ++++++ 25 files changed, 202 insertions(+), 1 deletion(-) create mode 100644 doc/langref/expect_if.zig create mode 100644 test/behavior/expect.zig create mode 100644 test/cases/compile_errors/@expect_non_bool.zig diff --git a/doc/langref.html.in b/doc/langref.html.in index 28577560e285..d8b70a10e1ab 100644 --- a/doc/langref.html.in +++ b/doc/langref.html.in @@ -4799,6 +4799,14 @@ fn cmpxchgWeakButNotAtomic(comptime T: type, ptr: *T, expected_value: T, new_val {#see_also|@export#} {#header_close#} + {#header_open|@expect#} +
{#syntax#}@expect(operand: bool, expected: bool) bool{#endsyntax#}+
+ Informs the optimizer that {#syntax#}operand{#endsyntax#} will likely be {#syntax#}expected{#endsyntax#}, which influences branch compilation to prefer generating the true branch first. +
+ {#code|expect_if.zig#} + {#header_close#} + {#header_open|@fence#}{#syntax#}@fence(order: AtomicOrder) void{#endsyntax#}
diff --git a/doc/langref/expect_if.zig b/doc/langref/expect_if.zig new file mode 100644 index 000000000000..e971db4e5a56 --- /dev/null +++ b/doc/langref/expect_if.zig @@ -0,0 +1,15 @@ +pub fn a(x: u32) void { + if (@expect(x == 0, false)) { + // condition check falls through at code generation + return; + } else { + // condition is branched to at code generation + return; + } +} + +test "expect" { + a(10); +} + +// test diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig index ee14d7dee477..9843423677a2 100644 --- a/lib/std/zig/AstGen.zig +++ b/lib/std/zig/AstGen.zig @@ -2823,6 +2823,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .set_float_mode, .set_align_stack, .set_cold, + .expect, => break :b true, else => break :b false, }, @@ -9292,7 +9293,14 @@ fn builtinCall( }); return rvalue(gz, ri, .void_value, node); }, - + .expect => { + const val = try gz.addExtendedPayload(.expect, Zir.Inst.BinNode{ + .node = gz.nodeIndexToRelative(node), + .lhs = try expr(gz, scope, .{ .rl = .{ .ty = .bool_type } }, params[0]), + .rhs = try expr(gz, scope, .{ .rl = .{ .ty = .bool_type } }, params[1]), + }); + return rvalue(gz, ri, val, node); + }, .src => { const token_starts = tree.tokens.items(.start); const node_start = token_starts[tree.firstToken(node)]; diff --git a/lib/std/zig/AstRlAnnotate.zig b/lib/std/zig/AstRlAnnotate.zig index 4a1203ca09fc..0e80d0687249 100644 --- a/lib/std/zig/AstRlAnnotate.zig +++ b/lib/std/zig/AstRlAnnotate.zig @@ -1100,5 +1100,10 @@ fn builtinCall(astrl: *AstRlAnnotate, block: ?*Block, ri: ResultInfo, node: Ast. _ = try astrl.expr(args[4], block, ResultInfo.type_only); return false; }, + .expect => { + _ = try astrl.expr(args[0], block, ResultInfo.none); + _ = try astrl.expr(args[1], block, ResultInfo.none); + return false; + }, } } diff --git a/lib/std/zig/BuiltinFn.zig b/lib/std/zig/BuiltinFn.zig index 4bea0278fa5d..b6aea0e84eb8 100644 --- a/lib/std/zig/BuiltinFn.zig +++ b/lib/std/zig/BuiltinFn.zig @@ -82,6 +82,7 @@ pub const Tag = enum { select, set_align_stack, set_cold, + expect, set_eval_branch_quota, set_float_mode, set_runtime_safety, @@ -743,6 +744,13 @@ pub const list = list: { .illegal_outside_function = true, }, }, + .{ + "@expect", + .{ + .tag = .expect, + .param_count = 2, + }, + }, .{ "@setEvalBranchQuota", .{ diff --git a/lib/std/zig/Zir.zig b/lib/std/zig/Zir.zig index 9453790fcf86..29d6634cfa6d 100644 --- a/lib/std/zig/Zir.zig +++ b/lib/std/zig/Zir.zig @@ -2060,6 +2060,9 @@ pub const Inst = struct { /// Guaranteed to not have the `ptr_cast` flag. /// Uses the `pl_node` union field with payload `FieldParentPtr`. field_parent_ptr, + /// Implements the `@expect` builtin. + /// `operand` is BinOp + expect, pub const InstData = struct { opcode: Extended, diff --git a/lib/zig.h b/lib/zig.h index 1e9da15a3fac..13240b45ddd7 100644 --- a/lib/zig.h +++ b/lib/zig.h @@ -318,6 +318,12 @@ typedef char bool; #define zig_noreturn #endif +#if defined(__GNUC__) || defined(__clang__) +#define zig_expect(op, exp) __builtin_expect(op, exp) +#else +#define zig_expect(op, exp) (op) +#endif + #define zig_bitSizeOf(T) (CHAR_BIT * sizeof(T)) #define zig_compiler_rt_abbrev_uint32_t si diff --git a/src/Air.zig b/src/Air.zig index 9d7016155e89..82a1715666ed 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -848,6 +848,10 @@ pub const Inst = struct { /// Operand is unused and set to Ref.none work_group_id, + /// Implements @expect builtin. + /// Uses the `bin_op` field. + expect, + pub fn fromCmpOp(op: std.math.CompareOperator, optimized: bool) Tag { switch (op) { .lt => return if (optimized) .cmp_lt_optimized else .cmp_lt, @@ -1517,6 +1521,8 @@ pub fn typeOfIndex(air: *const Air, inst: Air.Inst.Index, ip: *const InternPool) .work_group_id, => return Type.u32, + .expect => return Type.bool, + .inferred_alloc => unreachable, .inferred_alloc_comptime => unreachable, } @@ -1634,6 +1640,7 @@ pub fn mustLower(air: Air, inst: Air.Inst.Index, ip: *const InternPool) bool { .add_safe, .sub_safe, .mul_safe, + .expect, => true, .add, diff --git a/src/Liveness.zig b/src/Liveness.zig index 4ca28758e222..e66859ff4f5a 100644 --- a/src/Liveness.zig +++ b/src/Liveness.zig @@ -286,6 +286,7 @@ pub fn categorizeOperand( .cmp_gte_optimized, .cmp_gt_optimized, .cmp_neq_optimized, + .expect, => { const o = air_datas[@intFromEnum(inst)].bin_op; if (o.lhs == operand_ref) return matchOperandSmallIndex(l, inst, 0, .none); @@ -955,6 +956,7 @@ fn analyzeInst( .memset, .memset_safe, .memcpy, + .expect, => { const o = inst_datas[@intFromEnum(inst)].bin_op; return analyzeOperands(a, pass, data, inst, .{ o.lhs, o.rhs, .none }); diff --git a/src/Liveness/Verify.zig b/src/Liveness/Verify.zig index 4392f25e101d..31442f25b232 100644 --- a/src/Liveness/Verify.zig +++ b/src/Liveness/Verify.zig @@ -257,6 +257,7 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void { .memset, .memset_safe, .memcpy, + .expect, => { const bin_op = data[@intFromEnum(inst)].bin_op; try self.verifyInstOperands(inst, .{ bin_op.lhs, bin_op.rhs, .none }); diff --git a/src/Module.zig b/src/Module.zig index c571c851fe57..9fc038ecd917 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -5546,6 +5546,7 @@ pub const Feature = enum { /// to generate better machine code in the backends. All backends should migrate to /// enabling this feature. safety_checked_instructions, + can_expect, }; pub fn backendSupportsFeature(zcu: Module, feature: Feature) bool { diff --git a/src/Sema.zig b/src/Sema.zig index ca6d562eef51..fcf82b1153a5 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -1258,6 +1258,7 @@ fn analyzeBodyInner( .work_group_size => try sema.zirWorkItem( block, extended, extended.opcode), .work_group_id => try sema.zirWorkItem( block, extended, extended.opcode), .in_comptime => try sema.zirInComptime( block), + .expect => try sema.zirExpect( block, extended), .closure_get => try sema.zirClosureGet( block, extended), // zig fmt: on @@ -17553,6 +17554,34 @@ fn zirThis( return sema.analyzeDeclVal(block, src, this_decl_index); } +fn zirExpect(sema: *Sema, block: *Block, inst: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref { + const bin_op = sema.code.extraData(Zir.Inst.BinNode, inst.operand).data; + const operand = try sema.resolveInst(bin_op.lhs); + const expected = try sema.resolveInst(bin_op.rhs); + + const expected_src = LazySrcLoc{ .node_offset_builtin_call_arg1 = bin_op.node }; + + if (!try sema.isComptimeKnown(expected)) { + return sema.fail(block, expected_src, "@expect 'expected' must be comptime-known", .{}); + } + + if (try sema.resolveValue(operand)) |op| { + return Air.internedToRef(op.toIntern()); + } + + if (sema.mod.backendSupportsFeature(.can_expect) and sema.mod.optimizeMode() != .Debug) { + return try block.addInst(.{ + .tag = .expect, + .data = .{ .bin_op = .{ + .lhs = operand, + .rhs = expected, + } }, + }); + } else { + return operand; + } +} + fn zirClosureGet(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref { const mod = sema.mod; const ip = &mod.intern_pool; diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig index ddde72345efe..dc7e8c4dd580 100644 --- a/src/arch/aarch64/CodeGen.zig +++ b/src/arch/aarch64/CodeGen.zig @@ -803,6 +803,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .@"try" => try self.airTry(inst), .try_ptr => try self.airTryPtr(inst), + .expect => unreachable, + .dbg_stmt => try self.airDbgStmt(inst), .dbg_inline_block => try self.airDbgInlineBlock(inst), .dbg_var_ptr, diff --git a/src/arch/arm/CodeGen.zig b/src/arch/arm/CodeGen.zig index 86d4e8f7fdd6..7f001e5fa9c7 100644 --- a/src/arch/arm/CodeGen.zig +++ b/src/arch/arm/CodeGen.zig @@ -844,6 +844,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .wrap_errunion_payload => try self.airWrapErrUnionPayload(inst), .wrap_errunion_err => try self.airWrapErrUnionErr(inst), + .expect => unreachable, + .add_optimized, .sub_optimized, .mul_optimized, diff --git a/src/arch/riscv64/CodeGen.zig b/src/arch/riscv64/CodeGen.zig index 762251bc44b0..c8c5041a7d55 100644 --- a/src/arch/riscv64/CodeGen.zig +++ b/src/arch/riscv64/CodeGen.zig @@ -1200,6 +1200,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .@"try" => try self.airTry(inst), .try_ptr => return self.fail("TODO: try_ptr", .{}), + .expect => unreachable, + .dbg_var_ptr, .dbg_var_val, => try self.airDbgVar(inst), diff --git a/src/arch/sparc64/CodeGen.zig b/src/arch/sparc64/CodeGen.zig index 19c18ec4a6b0..b78935d5b205 100644 --- a/src/arch/sparc64/CodeGen.zig +++ b/src/arch/sparc64/CodeGen.zig @@ -636,6 +636,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .@"try" => try self.airTry(inst), .try_ptr => @panic("TODO try self.airTryPtr(inst)"), + .expect => unreachable, + .dbg_stmt => try self.airDbgStmt(inst), .dbg_inline_block => try self.airDbgInlineBlock(inst), .dbg_var_ptr, diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index fe94c061365f..a2d97326fc86 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -2016,6 +2016,8 @@ fn genInst(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { .c_va_start, => |tag| return func.fail("TODO: Implement wasm inst: {s}", .{@tagName(tag)}), + .expect => unreachable, + .atomic_load => func.airAtomicLoad(inst), .atomic_store_unordered, .atomic_store_monotonic, diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index cc6e01408072..4ecb39611f5b 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -2014,6 +2014,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .abs => try self.airAbs(inst), + .expect => unreachable, + .add_with_overflow => try self.airAddSubWithOverflow(inst), .sub_with_overflow => try self.airAddSubWithOverflow(inst), .mul_with_overflow => try self.airMulWithOverflow(inst), diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 9514b826eaa8..f3ff66680228 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -3343,6 +3343,8 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, .@"try" => try airTry(f, inst), .try_ptr => try airTryPtr(f, inst), + + .expect => try airExpect(f, inst), .dbg_stmt => try airDbgStmt(f, inst), .dbg_inline_block => try airDbgInlineBlock(f, inst), @@ -4704,6 +4706,27 @@ fn airTryPtr(f: *Function, inst: Air.Inst.Index) !CValue { return lowerTry(f, inst, extra.data.ptr, body, err_union_ty, true); } +fn airExpect(f: *Function, inst: Air.Inst.Index) !CValue { + const bin_op = f.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; + const operand = try f.resolveInst(bin_op.lhs); + const expected = try f.resolveInst(bin_op.rhs); + + const writer = f.object.writer(); + const local = try f.allocLocal(inst, Type.bool); + const a = try Assignment.start(f, writer, CType.bool); + try f.writeCValue(writer, local, .Other); + try a.assign(f, writer); + + try writer.writeAll("zig_expect("); + try f.writeCValue(writer, operand, .FunctionArgument); + try writer.writeAll(", "); + try f.writeCValue(writer, expected, .FunctionArgument); + try writer.writeAll(")"); + + try a.end(f, writer); + return local; +} + fn lowerTry( f: *Function, inst: Air.Inst.Index, diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 9e51417ab623..c9f9384e91f5 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -5038,6 +5038,8 @@ pub const FuncGen = struct { .slice_ptr => try self.airSliceField(inst, 0), .slice_len => try self.airSliceField(inst, 1), + .expect => try self.airExpect(inst), + .call => try self.airCall(inst, .auto), .call_always_tail => try self.airCall(inst, .always_tail), .call_never_tail => try self.airCall(inst, .never_tail), @@ -6365,6 +6367,26 @@ pub const FuncGen = struct { return result; } + // Note that the LowerExpectPass only runs in Release modes + fn airExpect(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { + const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op; + + const operand = try self.resolveInst(bin_op.lhs); + const expected = try self.resolveInst(bin_op.rhs); + + return try self.wip.callIntrinsic( + .normal, + .none, + .expect, + &.{operand.typeOfWip(&self.wip)}, + &.{ + operand, + expected, + }, + "", + ); + } + fn sliceOrArrayPtr(fg: *FuncGen, ptr: Builder.Value, ty: Type) Allocator.Error!Builder.Value { const o = fg.dg.object; const mod = o.module; diff --git a/src/print_air.zig b/src/print_air.zig index e61ae9fff004..d0607fe1035f 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -162,6 +162,7 @@ const Writer = struct { .memcpy, .memset, .memset_safe, + .expect, => try w.writeBinOp(s, inst), .is_null, diff --git a/src/print_zir.zig b/src/print_zir.zig index dfe94d397097..bc95d41b28df 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -591,6 +591,7 @@ const Writer = struct { .wasm_memory_grow, .prefetch, .c_va_arg, + .expect, => { const inst_data = self.code.extraData(Zir.Inst.BinNode, extended.operand).data; const src = LazySrcLoc.nodeOffset(inst_data.node); diff --git a/src/target.zig b/src/target.zig index 6af301e00101..0c1cbe3f9227 100644 --- a/src/target.zig +++ b/src/target.zig @@ -535,5 +535,6 @@ pub fn backendSupportsFeature( .error_set_has_value => use_llvm or cpu_arch.isWasm(), .field_reordering => ofmt == .c or use_llvm, .safety_checked_instructions => use_llvm, + .can_expect => use_llvm or ofmt == .c, }; } diff --git a/test/behavior/expect.zig b/test/behavior/expect.zig new file mode 100644 index 000000000000..7d759a4d0447 --- /dev/null +++ b/test/behavior/expect.zig @@ -0,0 +1,37 @@ +const std = @import("std"); +const expect = std.testing.expect; + +test "@expect if-statement" { + const x: u32 = 10; + _ = &x; + if (@expect(x == 20, true)) {} +} + +test "@expect runtime if-statement" { + var x: u32 = 10; + var y: u32 = 20; + _ = &x; + _ = &y; + if (@expect(x != y, false)) {} +} + +test "@expect bool input/output" { + const b: bool = true; + try expect(@TypeOf(@expect(b, false)) == bool); +} + +test "@expect bool is transitive" { + const a: bool = true; + const b = @expect(a, false); + + const c = @intFromBool(!b); + std.mem.doNotOptimizeAway(c); + + try expect(c == 0); + try expect(@expect(c != 0, false) == false); +} + +test "@expect at comptime" { + const a: bool = true; + comptime try expect(@expect(a, true) == true); +} diff --git a/test/cases/compile_errors/@expect_non_bool.zig b/test/cases/compile_errors/@expect_non_bool.zig new file mode 100644 index 000000000000..da49b30d126a --- /dev/null +++ b/test/cases/compile_errors/@expect_non_bool.zig @@ -0,0 +1,11 @@ +export fn a() void { + var x: u32 = 10; + _ = &x; + _ = @expect(x, true); +} + +// error +// backend=stage2 +// target=native +// +// :4:17: error: expected type 'bool', found 'u32'