Skip to content

Commit 2d46987

Browse files
authored
hlo_call: allow passing in numbers (#1152)
1 parent 08713d6 commit 2d46987

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

src/Ops.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,16 +1572,16 @@ julia> Reactant.@jit(
15721572
ftype_attr = MLIR.IR.attr(fn, "function_type")
15731573
ftype = MLIR.IR.Type(ftype_attr)
15741574

1575-
@assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "hlo_call: all inputs to hlo_call should be reactant arrays"
1575+
@assert all(Base.Fix2(isa, Union{TracedRArray,TracedRNumber}), args) "hlo_call: all inputs to hlo_call should be reactant arrays or numbers"
15761576
@assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name"
15771577

15781578
for (i, arg) in enumerate(args)
15791579
expected_type = MLIR.IR.input(ftype, i)
1580-
arg_type = MLIR.IR.type(arg.mlir_data)
1580+
arg_type = MLIR.IR.type(Reactant.TracedUtils.get_mlir_data(arg))
15811581
@assert expected_type == arg_type "hlo_call: argument #$i has the wrong type (expected $expected_type, got $arg_type)"
15821582
end
15831583

1584-
operands = [a.mlir_data for a in args]
1584+
operands = MLIR.IR.Value[Reactant.TracedUtils.get_mlir_data(a) for a in args]
15851585
call = MLIR.Dialects.func.call(
15861586
operands;
15871587
result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)],

test/ops.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,25 @@ end
981981
y_reactant,
982982
)
983983
)[1] x .+ y
984+
985+
@test Float32(
986+
only(
987+
Reactant.@jit(
988+
Ops.hlo_call(
989+
"""
990+
module {
991+
func.func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
992+
%0 = stablehlo.add %arg0, %arg1 : tensor<f32>
993+
return %0 : tensor<f32>
994+
}
995+
}
996+
""",
997+
Reactant.ConcreteRNumber(2.0f0),
998+
Reactant.ConcreteRNumber(2.0f0),
999+
)
1000+
)
1001+
),
1002+
) == 4.0f0
9841003
end
9851004

9861005
function f_repeat(x, y)

0 commit comments

Comments
 (0)