diff --git a/src/overdub.jl b/src/overdub.jl index 51635cd..f615810 100644 --- a/src/overdub.jl +++ b/src/overdub.jl @@ -2,7 +2,7 @@ using Cassette Cassette.@context CounterCtx; -const ops = ( +const binops = ( (:add32, Core.Intrinsics.add_float, Float32), (:sub32, Core.Intrinsics.sub_float, Float32), (:mul32, Core.Intrinsics.mul_float, Float32), @@ -13,6 +13,13 @@ const ops = ( (:div64, Core.Intrinsics.div_float, Float64), ) +const unops = ( + (:sqrt32, Core.Intrinsics.sqrt_llvm, Float32), + (:sqrt64, Core.Intrinsics.sqrt_llvm, Float64), +) + +const ops = Iterators.flatten((binops, unops)) |> collect + @eval mutable struct Counter $((:($(op[1]) ::Int) for op in ops)...) Counter() = new($((0 for _ in 1:length(ops))...)) @@ -24,7 +31,22 @@ for typ1 in (Float32, Float64) ::$typ1, ::$typ1) $(Expr(:block, - (map(ops) do (name, op, typ2) + (map(binops) do (name, op, typ2) + typ1 == typ2 || return :nothing + quote + if op == $op + ctx.metadata.$name += 1 + return + end + end + end)...)) + end + + @eval function Cassette.prehook(ctx::CounterCtx, + op::Core.IntrinsicFunction, + ::$typ1) + $(Expr(:block, + (map(unops) do (name, op, typ2) typ1 == typ2 || return :nothing quote if op == $op diff --git a/test/runtests.jl b/test/runtests.jl index ad087f2..cb8e4b7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,8 @@ Flop Counter: sub64: 0 mul64: 0 div64: 0 + sqrt32: 0 + sqrt64: 0 """ end end @@ -79,6 +81,16 @@ Flop Counter: @test cnt.mul64 == N*N @test GFlops.flop(cnt) == 2*N*N end + + let cnt = @count_ops sqrt(4.2) + @test cnt.sqrt64 == 1 + @test GFlops.flop(cnt) == 1 + end + + let cnt = @count_ops sqrt(4.2f0) + @test cnt.sqrt32 == 1 + @test GFlops.flop(cnt) == 1 + end end @testset "@gflops" begin