Skip to content

Commit 1a6249c

Browse files
authored
Merge branch 'main' into ap/save_sharded_array
2 parents 12691ff + 2d46987 commit 1a6249c

File tree

7 files changed

+41
-13
lines changed

7 files changed

+41
-13
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
4-
version = "0.2.64"
4+
version = "0.2.66"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -87,7 +87,7 @@ PythonCall = "0.9"
8787
Random = "1.10"
8888
Random123 = "1.7"
8989
ReactantCore = "0.1.9"
90-
Reactant_jll = "0.0.123"
90+
Reactant_jll = "0.0.128"
9191
ScopedValues = "1.3.0"
9292
Scratch = "1.2"
9393
Sockets = "1.10"

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "fbb729606494eaa4c8f8f140817d26375e3c79f6"
12+
ENZYMEXLA_COMMIT = "3f1054ecec69f60575370f873fe88bb1296a18ba"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(

src/Compiler.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ end
459459

460460
const WHILE_CONCAT = Ref(false)
461461
const DUS_TO_CONCAT = Ref(false)
462+
const SUM_TO_CONV = Ref(false)
462463

463464
# Optimization passes via transform dialect
464465
function optimization_passes(;
@@ -654,8 +655,19 @@ function optimization_passes(;
654655
"elementwise_licm(0)",
655656
"concatenate_licm(0)",
656657
"slice_broadcast",
658+
"while_pad_induction_reduction",
659+
"while_licm<1>(1)",
660+
"associative_common_mul_op_reordering",
661+
"slice_select_to_select_slice",
662+
"pad_concat_to_concat_pad",
663+
"slice_if",
664+
"dus_to_i32",
657665
]
658666

667+
if SUM_TO_CONV[]
668+
push!(transform_passes_list, "sum_to_conv")
669+
end
670+
659671
if WHILE_CONCAT[]
660672
push!(transform_passes_list, "while_concat")
661673
end
@@ -701,7 +713,7 @@ function optimization_passes(;
701713
[
702714
"transpose_while",
703715
"transpose_slice",
704-
"transpose_elementwise",
716+
"transpose_elementwise(0)",
705717
"transpose_concat",
706718
"transpose_iota",
707719
"transpose_reduce",

src/ConcreteRArray.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ end
3333
function Base.deepcopy_internal(
3434
x::Union{AbstractConcreteArray,AbstractConcreteNumber}, stackdict::IdDict
3535
)
36-
if haskey(stackdict, x)
37-
return stackdict[x]::typeof(x)
38-
end
36+
haskey(stackdict, x) && return stackdict[x]
3937
return deepcopy(x)
4038
end
4139

src/Ops.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,16 +1572,16 @@ julia> Reactant.@jit(
15721572
ftype_attr = MLIR.IR.attr(fn, "function_type")
15731573
ftype = MLIR.IR.Type(ftype_attr)
15741574

1575-
@assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "hlo_call: all inputs to hlo_call should be reactant arrays"
1575+
@assert all(Base.Fix2(isa, Union{TracedRArray,TracedRNumber}), args) "hlo_call: all inputs to hlo_call should be reactant arrays or numbers"
15761576
@assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name"
15771577

15781578
for (i, arg) in enumerate(args)
15791579
expected_type = MLIR.IR.input(ftype, i)
1580-
arg_type = MLIR.IR.type(arg.mlir_data)
1580+
arg_type = MLIR.IR.type(Reactant.TracedUtils.get_mlir_data(arg))
15811581
@assert expected_type == arg_type "hlo_call: argument #$i has the wrong type (expected $expected_type, got $arg_type)"
15821582
end
15831583

1584-
operands = [a.mlir_data for a in args]
1584+
operands = MLIR.IR.Value[Reactant.TracedUtils.get_mlir_data(a) for a in args]
15851585
call = MLIR.Dialects.func.call(
15861586
operands;
15871587
result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)],

src/mlir/IR/Pass.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,8 @@ end
145145
Run the provided `passManager` on the given `module`.
146146
"""
147147
function run!(pm::PassManager, mod::Module, key::String="")
148-
# Dump MLIR before running the pass manager. We set `pm` to nothing because
149-
# the pass manager isn't called yet here.
150-
DUMP_MLIR_ALWAYS[] && dump_mlir(mod, nothing, isempty(key) ? "pre_pm" : "pre_$(key)_pm")
148+
# Dump MLIR before running the pass manager, but also print the list of passes that will be called later.
149+
DUMP_MLIR_ALWAYS[] && dump_mlir(mod, pm, isempty(key) ? "pre_pm" : "pre_$(key)_pm")
151150
status = LogicalResult(@static if isdefined(API, :mlirPassManagerRunOnOp)
152151
API.mlirPassManagerRunOnOp(pm, Operation(mod))
153152
else

test/ops.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,25 @@ end
981981
y_reactant,
982982
)
983983
)[1] x .+ y
984+
985+
@test Float32(
986+
only(
987+
Reactant.@jit(
988+
Ops.hlo_call(
989+
"""
990+
module {
991+
func.func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
992+
%0 = stablehlo.add %arg0, %arg1 : tensor<f32>
993+
return %0 : tensor<f32>
994+
}
995+
}
996+
""",
997+
Reactant.ConcreteRNumber(2.0f0),
998+
Reactant.ConcreteRNumber(2.0f0),
999+
)
1000+
)
1001+
),
1002+
) == 4.0f0
9841003
end
9851004

9861005
function f_repeat(x, y)

0 commit comments

Comments
 (0)