Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

features/bladedisc rebase 20220830 #20

Closed
wants to merge 200 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
200 commits
Select commit Hold shift + click to select a range
247dd64
Change to notifyMatchFailure (#1073)
jpienaar Jul 18, 2022
4c25878
[MLIR][TORCH] Add canonicalization pattern for prim.ListUnpack op
vivekkhandelwal1 Jul 15, 2022
df0b1e7
[MLIR][TORCH] Add negative dim support for aten.cat and aten.slice op
vivekkhandelwal1 Jul 15, 2022
7f08169
bump llvm tag to 3580daa (#1078)
ashay Jul 18, 2022
c73a39e
Add support for index.Tensor on dimensions other than the first
qedawkins Jul 15, 2022
21f905a
Emit underscore version of aten.sqrt (#1072)
kkiningh Jul 19, 2022
e06ee08
torch: [nfc] use `WalkResult::isInterrupted()` instead of booleans (#…
ashay Jul 19, 2022
647e75e
Allow expanding and collapsing in aten::view (#1082)
qedawkins Jul 20, 2022
c61c99e
[MHLO] Init MHLO integration. (#1083)
ZihengJiang Jul 20, 2022
ad283c1
python: trim registration and loading of dialects and passes (#1084)
ashay Jul 21, 2022
72dd04c
Revert "python: trim registration and loading of dialects and passes"…
ashay Jul 21, 2022
c0ef192
Improve error message
silvasean Jul 21, 2022
31fd812
Add linux and macOS source builds in CI (#1070)
powderluv Jul 21, 2022
f271e6a
Add verifiers for ToBuiltinTensorOp and FromBuiltinTensorOp (#1089)
ramiro050 Jul 21, 2022
a02dbb2
[MHLO] Init MHLO slice like op patterns (#1091)
Jul 22, 2022
b80ce79
[MHLO] Init MHLO view like op patterns (#1090)
Jul 22, 2022
f424930
Add option to expose custom PyTorch repo/branch (#1103)
powderluv Jul 25, 2022
f50d701
[MHLO] Add [un]squeeze op patterns (#1099)
Jul 25, 2022
44ead68
[MHLO] Init MHLO gather op patterns (#1104)
Jul 25, 2022
e8f327c
Add lowering to linalg for softplus and log1p
kkiningh Jul 17, 2022
e23fbc8
Update development.md with source builds (#1105)
powderluv Jul 25, 2022
052d2f8
[MHLO] Init MHLO basic op conversion (#1092)
Vremold Jul 27, 2022
3c9addf
Add e2e support for aten.expm1
qedawkins Jul 27, 2022
b36a17c
README: Add op office hours
silvasean Jul 28, 2022
11a8901
[MLIR][TORCH] Add support for multiple indexing tensors for aten.inde…
qedawkins Jul 28, 2022
7247c6a
[MLIR][TORCH] Add E2E support for aten.ge.int op
vivekkhandelwal1 Jul 22, 2022
d386b8f
[MLIR][TORCH] Add decomposition for aten.var.correction op
vivekkhandelwal1 Jul 22, 2022
c681c34
[MLIR][TORCH} Fix empty dim cases for the .dim ops
vivekkhandelwal1 Jul 28, 2022
b389053
Add mhlo to bazel build (#1120)
asaadaldien Jul 29, 2022
9a1203c
Fix CI failure due to upstream PyTorch change in aten.mean.dim op
vivekkhandelwal1 Jul 29, 2022
db4a699
buildAndTest.yml for matrix builds (#1098)
powderluv Jul 29, 2022
8b5631d
[MLIR][TORCH] Add decomposition for aten.std.dim Op
PhaneeshB Jul 26, 2022
2f22e2e
Add initial LTC backend (#610)
antoniojkim Feb 18, 2022
58338f7
Torch-MLIR LTC Backend Lowering Codegen (#621)
antoniojkim Feb 26, 2022
c3b20e4
Got LTC working until compile (#689)
antoniojkim Mar 24, 2022
65cf146
Fix Torch-MLIR LTC Backend based off latest PyTorch master (#723)
antoniojkim Apr 13, 2022
3e9b1cb
Added JIT to MLIR lowering (#724)
henrytwo Apr 14, 2022
a605fe2
Add example Torch MLIR LTC Backend (#725)
henrytwo Apr 14, 2022
615ff1d
Generate MLIR with shape information via LTC frontend (#742)
antoniojkim May 26, 2022
cca9fe1
Enable support for LTC Input/Output Mapping (#764)
henrytwo Apr 27, 2022
1bde00c
Fix LTC Decoupling (#815)
antoniojkim May 3, 2022
406d1e7
Use JIT GraphExecutor for execution in example backend (#830)
henrytwo May 5, 2022
de5b380
Bert example and relevant shape inference functions (#831)
henrytwo May 10, 2022
0c35e60
Add static shape for scalar tensors (#833)
henrytwo May 12, 2022
d9aee0d
E2E HuggingFace Bert using LTC Backend (#912)
antoniojkim Jun 7, 2022
8312fa5
Refactor Node Lowering (#914)
antoniojkim Jun 9, 2022
dfcc265
Added e2e LTC tests (#916)
henrytwo Jun 9, 2022
a62d608
Refactor autogen (#925)
antoniojkim Jun 10, 2022
0cee0dc
Only import the LTC backend that's used (#939)
henrytwo Jun 14, 2022
1510eae
Upstream native_batch_norm and native_batch_norm_backward shape infer…
henrytwo Jun 24, 2022
fb21c9e
Integrate Functionalization Pass (#998)
antoniojkim Jun 30, 2022
a3244b6
Update README.md to reflect LTC landing in main branch (#1000)
henrytwo Jun 30, 2022
61db88c
LTC Documentation (#1021)
henrytwo Jul 7, 2022
9de06f3
Update Torch MLIR readme
henrytwo Jul 7, 2022
f5acad8
Prune xfail e2e LTC tests & fix bugs from functionalization pass (#1044)
henrytwo Jul 12, 2022
47bb38d
Reference Lazy Backend (#1045)
henrytwo Jul 12, 2022
cec74b8
Blacklist _convolution op (#1048)
henrytwo Jul 13, 2022
de6c135
Fix LTC autogen for CI with nightly PyTorch
antoniojkim Jul 19, 2022
e37891b
Default Device Ordinal API (#1079)
antoniojkim Jul 19, 2022
0d16a91
Add support for lift_fresh op (#1101)
antoniojkim Jul 25, 2022
70395de
Resolve CI testing failure for Lazy Tensor Core (#1088)
henrytwo Jul 25, 2022
3689632
Export LTC Headers (#1108)
antoniojkim Jul 27, 2022
4253622
Clean up Autogen (#1112)
antoniojkim Jul 27, 2022
2c3b360
Resolve remaining LTC CI failures (#1110)
henrytwo Jul 30, 2022
55bbbec
llvm: update tag to 02b3a35 (#1124)
ashay Jul 31, 2022
30017af
Disable LTC on Release builds (#1125)
powderluv Aug 1, 2022
554570f
Implemented a decomposition of aten::narrow
JakopinA Aug 1, 2022
ed13ebf
E2E support for AtenEmbeddingBagPaddingIdxOp SUM Mode (#1066)
vidsinghal Aug 1, 2022
fe3c9f5
Add embedding Bag e2e case in xfail set (#1130)
vidsinghal Aug 1, 2022
3772e0b
[NFC][MHLO] move util funcs to MhloLegalizeUtils.h/cpp (#1128)
Aug 2, 2022
76c9766
[MHLO] Support for dynamic shape in basic op conversion by introducin…
Vremold Aug 2, 2022
704efdc
[MHLO] add aten::gelu op pattern (#1127)
Yancey1989 Aug 2, 2022
38d8498
add e2e support for aten.atan2 (#1117)
qedawkins Aug 2, 2022
a7af1fd
Add support for `dim=None` to `AtenMeanDimOp` (#1129)
ramiro050 Aug 2, 2022
82af44d
Fix mhlo bazel rule name (#1136)
asaadaldien Aug 2, 2022
0b23af2
[MHLO] support non-constant torch scalar in BasicOps (#1134)
Aug 3, 2022
636f5ac
[MHLO] Init MHLO reduce-like op conversion (#1133)
Vremold Aug 3, 2022
f2a0e32
[MLIR][TORCH] Fix CI failure
vivekkhandelwal1 Aug 3, 2022
0d25b6f
Fix cache-suffix name bug (#1138)
powderluv Aug 3, 2022
37a229c
Update buildAndTest.yml (#1145)
powderluv Aug 3, 2022
48ec300
[Fix bazel] Add recently added torch->mhlo conversion pass to bazel (…
asaadaldien Aug 3, 2022
f0a24f5
[MHLO] Init MHLO linear op patterns (#1132)
Aug 4, 2022
d030591
[MHLO] Init MHLO pooling-like op conversion (#1141)
Vremold Aug 4, 2022
08fc2d8
Add non-unit groups support to aten.convolution (#858)
gpetters94 Aug 4, 2022
c94431f
[MHLO] Add convolution op pattern (#1152)
Vremold Aug 4, 2022
6484776
Make numerical stability test more perverse
silvasean Aug 3, 2022
31727f8
torch_mlir.compile: Allow ignoring traced shapes
silvasean Aug 3, 2022
c129a6d
[MLIR][TORCH] Add support for dim=None to Aten[Var|Std]DimOp
vivekkhandelwal1 Aug 5, 2022
8ce5d3f
E2E framework: Report tensor dtype in summary
silvasean Aug 3, 2022
e322f6a
Update LTC CMake hack documentation (#1155)
henrytwo Aug 5, 2022
1fdaf2f
development.md: How to enable ASan
silvasean Aug 5, 2022
5618890
development.md: Avoid name collisions with PYTORCH_ variables
silvasean Aug 5, 2022
1ee8659
[MHLO] fix tensor mode aten.div op pattern (#1160)
Aug 6, 2022
290d775
importer: add initial support for loading Float16 tensors (#1169)
Aug 8, 2022
f85ae9c
Reenable LTC in out-of-tree build (#1177)
henrytwo Aug 8, 2022
b70548e
Add decomposition and E2E support for Aten_EmbeddingBag (#1137)
vidsinghal Aug 8, 2022
34e207e
E2E support for AtenRemainderScalarOp (#1119)
vidsinghal Aug 9, 2022
504de5e
Rework how global slot initializers work.
silvasean Jul 13, 2022
351f154
[MHLO] Add transposed convolution conversion pattern (#1171)
Vremold Aug 9, 2022
3e97a33
Revert "Reenable LTC in out-of-tree build (#1177)" (#1183)
henrytwo Aug 9, 2022
bb47c16
llvm: update tag to 061e0189 (#1180)
ashay Aug 9, 2022
e55fc4d
Revert "E2E support for AtenRemainderScalarOp (#1119)" (#1190)
powderluv Aug 9, 2022
f83a905
[MHLO]fix lowering failed on reduction op with i32 shape (#1185)
Yancey1989 Aug 9, 2022
202076c
Add CMake dep to Func dialect (#1196)
marbre Aug 9, 2022
7473561
Don't set MLIR_TABLEGEN_EXE (#1197)
marbre Aug 9, 2022
e75c7c5
Flip to C++17 (#1198)
jpienaar Aug 9, 2022
b696362
Enable OOT builds in CI (#1188)
sjain-stanford Aug 9, 2022
d41c7be
[Bazel] Allow workflow_dispatch manual trigger on bazel workflow (#1203)
sjain-stanford Aug 9, 2022
072c2b5
[Bazel] Add EraseModuleInitializer to TorchMLIRTorchPasses library (#…
sjain-stanford Aug 9, 2022
9cf0b6e
Disable out-of-tree and PyTorch binary (#1206)
powderluv Aug 10, 2022
8756277
[MHLO] Add AtenCatOp conversion pattern to MHLO (#1208)
Vremold Aug 10, 2022
2342456
mac m1 cross compile (#1204)
powderluv Aug 10, 2022
738f4fe
Rename TorchToStd pass as TorchToArith (#1163)
Aug 10, 2022
71c240a
Add note about MLIR compiled outputs in dev docs (#1195)
rengolin Aug 10, 2022
b8d51a7
Update TorchToStd to TorchtoArith in bazel files too. (#1210)
Aug 10, 2022
79b9cf9
Add lowering for aten.to.device (#1107)
gpetters94 Aug 10, 2022
dd2da5a
E2E support for AtenRemainderScalarOp (#1200)
vidsinghal Aug 11, 2022
d96ec64
remove torch dialect from legal list (#1192)
Yancey1989 Aug 11, 2022
b1a5066
Add decomposition of `aten.masked.tensor` op.
Aug 8, 2022
51bfe25
Add PyYaml to requirements.txt (#1174)
rengolin Aug 11, 2022
f00ca91
Simplify matrix configuration for CI workflows (#1213)
sjain-stanford Aug 11, 2022
b8bd0a4
use pytorch binary for macos-arm64 builds (#1215)
sjain-stanford Aug 12, 2022
aed0ec3
Merge matrix runs to fail fast globally (#1216)
sjain-stanford Aug 12, 2022
1581d6a
build: fix typo in path (#1218)
ashay Aug 12, 2022
34478ab
[Build] Add concurrency groups to address long queue times (#1219)
sjain-stanford Aug 13, 2022
606f4d2
build: streamline options for enabling LTC and MHLO (#1221)
ashay Aug 13, 2022
41aa562
s/external/externals/g (#1222)
sjain-stanford Aug 13, 2022
c935795
add native_dropout and related ops pattern (#1211)
Yancey1989 Aug 15, 2022
9d6ee48
Fix unused-variables warnings about EmbeddingBag ops (#1220)
ramiro050 Aug 15, 2022
3b3cb99
Generalize canonicalization pattern for more aten.sub/div/mul/add op …
Vremold Aug 16, 2022
84d345c
build: update llvm tag to 2dde4ba6 (#1229)
ashay Aug 16, 2022
0af5578
Propagate device data names (#1157)
antoniojkim Aug 16, 2022
fde390c
Re-enable custom op support
nithinsubbiah Aug 2, 2022
11a5b5a
[MHLO] Add AtenRSubScalarOp conversion pattern to MHLO (#1233)
Vremold Aug 17, 2022
85f383c
Bump the shape lib to match the upstream functions currently in PyTor…
qedawkins Aug 17, 2022
9be8997
Revert "add native_dropout and related ops pattern (#1211)" (#1230)
Yancey1989 Aug 17, 2022
9c8b962
Dockerize and Cache Bazel {Local, CI} Builds (#1240)
sjain-stanford Aug 17, 2022
57681f7
Iteratively run the main simplification pipeline.
silvasean Aug 4, 2022
114f48e
[Bazel] Check cache directory exists before changing owners (#1241)
sjain-stanford Aug 18, 2022
7d4a0d0
[Bazel] Add LowerToBackendContract.cpp to TorchMLIRTorchPasses bazel …
sjain-stanford Aug 18, 2022
f07f7d2
Clean up shape functions that use `sum_mean_dim` (#1217)
ramiro050 Aug 18, 2022
283e0f1
Add a concept of "backend legal ops".
silvasean Aug 17, 2022
1a7fc39
[docs] Add architecture doc.
silvasean Aug 5, 2022
9bc606c
Add support for returning more than one copy of the same tensor (#1228)
ramiro050 Aug 18, 2022
f601435
Add white background to diagram.
silvasean Aug 18, 2022
0d1aa43
Drop Python 3.7x from the nightly binary builds (#1246)
powderluv Aug 18, 2022
7bd173a
[MHLO] Eliminate explicit dynamic output shape generating in converti…
Vremold Aug 19, 2022
1e1759c
[Bazel] Run buildifier (#1250)
sjain-stanford Aug 19, 2022
65d811e
[MLIR][TORCH] Fix dynamic cases for aten.index.Tensor
vivekkhandelwal1 Aug 16, 2022
ba17a4d
Reenable LTC in out-of-tree build (for real this time) (#1205)
henrytwo Aug 19, 2022
99fb4c8
Add folder for ToF64Op and FromF64Op (#1257)
Vremold Aug 22, 2022
c38308f
Add lowering for _convolution.deprecated (#1259)
alextsao1999 Aug 22, 2022
3815cfa
[MLIR][TORCH] Fix CI failure due to failing tests
vivekkhandelwal1 Aug 22, 2022
1f1abda
Don't explicitly set MLIR_PDLL_TABLEGEN_EXE (#1262)
marbre Aug 22, 2022
ef89dad
Update Torch-MLIR Architecture Diagram (#1254)
powderluv Aug 22, 2022
01290d1
Add a way for backends to control which ops are legal for them.
silvasean Aug 19, 2022
9176b5e
Add decomposition for aten.flatten.using_ints (#1161)
Aug 23, 2022
8cad02f
[MLIR][TORCH] Add torch.Device type to backend contract scalar types
vivekkhandelwal1 Aug 19, 2022
2374098
[MHLO] Init end to end unit tests (#1223)
Aug 23, 2022
1106b9a
[MHLO] bert-tiny and resnet18 example from torchscript to mhlo (#1266)
Vremold Aug 23, 2022
3d0e18b
Add decomposition for aten.roll (#1170)
Aug 24, 2022
d7d6797
[cleanup] Change OutputType enum values to strings
silvasean Aug 19, 2022
f012279
Add transposed case for at::convolution (#917)
gpetters94 Aug 24, 2022
e2f862c
Fix LTC build warnings (#1272)
henrytwo Aug 24, 2022
1d9d925
mlir: fix replacement of `OpaqueElementsAttr` (#1274)
ashay Aug 24, 2022
233fd12
doc: fix instructions for LLVM and MHLO updates (#1273)
ashay Aug 24, 2022
a1ace06
Revert updating mlir_native_functions.cpp signature (#1281)
henrytwo Aug 25, 2022
e869e68
Fix LTC lib_torch_mlir_ltc.so import error (#1283)
henrytwo Aug 25, 2022
e153694
Add TestUtils.randint + replace torch.randint with tu.randint (#1276)
ramiro050 Aug 26, 2022
b1fa7a2
Fix a few build warnings
silvasean Aug 25, 2022
0e3ddba
Remove VerifyInvariantsBeforeBackendLowering
silvasean Aug 25, 2022
8e880a2
Fix symint related functionalization ops (#1289)
antoniojkim Aug 26, 2022
883c6b4
Add LTC architecture diagram (#1291)
henrytwo Aug 26, 2022
493f45f
Add documentation for adding E2E tests (#1269)
ramiro050 Aug 26, 2022
a507ae4
Avoid cascading failures when compiler crashes
silvasean Aug 25, 2022
f245613
Add reference to slides and presentation.
Aug 25, 2022
c0630da
Disable LTC by default until upstream revert relands (#1303)
powderluv Aug 29, 2022
2623185
Rename an outdated class name
silvasean Aug 27, 2022
bcccf41
Add CI for generated files.
silvasean Aug 25, 2022
15fca6e
Update MHLO xfails.
silvasean Aug 26, 2022
f402eb2
Move development.md to docs/ for consistency
silvasean Aug 27, 2022
e16b43e
Remove "torchscript" association from the e2e framework.
silvasean Aug 29, 2022
079bff3
Sort tests before anything else.
silvasean Aug 29, 2022
0f40d98
Ensure that tests have unique names
silvasean Aug 29, 2022
9f061ea
Dockerize CI + Release builds (#1234)
powderluv Aug 30, 2022
51ef1b1
Add some missing dependencies.
silvasean Aug 27, 2022
e52e886
build: update llvm tag to 00d648bd (#1307)
ashay Aug 30, 2022
928c815
Add shapelib and Torch ODS gen tests (#1318)
powderluv Aug 31, 2022
3704363
Use pre-compiled headers for PyTorch Source builds (#1327)
powderluv Aug 31, 2022
7769eb8
Set ccache logging to verbose temporarily (#1326)
powderluv Aug 31, 2022
29cafdb
[MHLO] refactor pass configurations (#1315)
Sep 1, 2022
a924de3
Slightly tweak generated file checks
silvasean Aug 31, 2022
f5a8a93
aicompiler rebase 20220830
Aug 12, 2022
1465814
To comply with the old pytorch versions
Jun 11, 2022
5d5504d
add relu6 op
Yancey1989 Jun 21, 2022
5d6b854
Add decomposition for aten::native_layer_norm (#13)
Jul 20, 2022
3bc47c3
Add native_dropout_backward & native_layer_norm_backward decompositio…
Aug 8, 2022
c1b556a
add native_dropout and related ops pattern (#1211)
Yancey1989 Aug 15, 2022
207c7ad
Fix dot product result type setting (#19)
Sep 20, 2022
8f29de9
fix BatchNormInference
Sep 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[MHLO] Init MHLO view like op patterns (llvm#1090)
* [MHLO] Init MHLO view like op patterns

See RFC: llvm#999

Co-authored-by: Bairen Yi yibairen.byron@bytedance.com
Co-authored-by: Jiawei Wu xremold@gmail.com
Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com
Co-authored-by: Xu Yan yancey.yx@alibaba-inc.com
Co-authored-by: Ziheng Jiang ziheng.jiang@bytedance.com

* update filecheck test cases

* rebase, remove chlo and clang-format
  • Loading branch information
Tanyo Kwok authored Jul 22, 2022
commit b80ce79b9fd26247e935899ce65f5cc11ec3e346
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ endmacro()
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
# The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
# One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
# the range of i32(4GiB)
option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
"Enable truncate dimension size from i64 to i32(unsafely)" OFF)
if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
endif()
endif()

torch_mlir_add_llvm_external_project(
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_mlir_conversion_library(TorchMLIRTorchToMhlo
TorchToMhlo.cpp
BasicOp.cpp
SliceLikeOps.cpp
ViewLikeOps.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
Expand Down
7 changes: 3 additions & 4 deletions lib/Conversion/TorchToMhlo/PopulatePatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ namespace torch_to_mhlo {
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
void populateSliceLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);

void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);

} // namespace torch_to_mhlo
} // namespace torch
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToMhlo/TorchToMhlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {

torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
target);
torch_to_mhlo::populateSliceLikeOpPatternsAndLegality(typeConverter, patterns,
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, patterns,
target);

if (failed(applyPartialConversion(getOperation(), target,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include <numeric>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
static constexpr size_t kMhloDimSizeBits = 32;
Expand All @@ -31,10 +36,8 @@ static constexpr size_t kMhloDimSizeBits = 64;

namespace {

SmallVector<Value, 4> getDimSizesOfTensor(
PatternRewriter& rewriter,
Operation* op,
Value value) {
SmallVector<Value, 4> getDimSizesOfTensor(PatternRewriter &rewriter,
Operation *op, Value value) {
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
if (!valueTy) {
op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor");
Expand All @@ -51,20 +54,16 @@ SmallVector<Value, 4> getDimSizesOfTensor(
auto loc = op->getLoc();
for (auto d = 0; d < rank; ++d) {
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
loc,
rewriter.getIntegerType(kMhloDimSizeBits),
loc, rewriter.getIntegerType(kMhloDimSizeBits),
rewriter.create<tensor::DimOp>(loc, value, d)));
}
return dimSizes;
}

// A dimension index from torch.dialect might outside the range [0, dimSize].
// The function is used to normalize the input index into the range.
Value getNormalizedDimSizeInternal(
PatternRewriter& rewriter,
Operation* op,
Value index,
Value dimSize) {
Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
Value index, Value dimSize) {
auto loc = op->getLoc();
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0));
Expand All @@ -79,19 +78,14 @@ Value getNormalizedDimSizeInternal(
auto indexPositive = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, index, zero);
// get positive index: (index >=0) ? index: index + dimSize
return rewriter.create<arith::SelectOp>(
loc, indexPositive, index, dimSizePlusIndex);
return rewriter.create<arith::SelectOp>(loc, indexPositive, index,
dimSizePlusIndex);
}

Value getDynamicSliceInternal(
PatternRewriter& rewriter,
Operation* op,
Value input,
Value startIndex,
Value endIndex,
Value step,
size_t dimIndex,
ArrayRef<Value> dimSizes) {
Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
Value input, Value startIndex, Value endIndex,
Value step, size_t dimIndex,
ArrayRef<Value> dimSizes) {
auto loc = op->getLoc();
// startIndex & endIndex has been normailized into range [0, dSize]
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
Expand All @@ -112,8 +106,8 @@ Value getDynamicSliceInternal(

auto endIndexIsZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, endIndex, zero);
endIndex = rewriter.create<arith::SelectOp>(
loc, endIndexIsZero, dimSizes[dimIndex], endIndex);
endIndex = rewriter.create<arith::SelectOp>(loc, endIndexIsZero,
dimSizes[dimIndex], endIndex);

for (size_t r = 0; r < rank; ++r) {
if (r == dimIndex) {
Expand Down Expand Up @@ -143,51 +137,47 @@ Value getDynamicSliceInternal(
loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor);
}

// Get a dynamic slice of the tensor from startIndex to endIndex with stride step
// on the specifed dimension. The input startIndex(default to 0),
// Get a dynamic slice of the tensor from startIndex to endIndex with stride
// step on the specifed dimension. The input startIndex(default to 0),
// endIndex(default to dimSize), and step(default to 1) can be optional.
Value getDynamicSlice(
PatternRewriter& rewriter,
Operation* op,
Value input,
llvm::Optional<Value> startIndexOpt,
llvm::Optional<Value> endIndexOpt,
llvm::Optional<Value> stepOpt,
int64_t dim) {
Value getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input,
llvm::Optional<Value> startIndexOpt,
llvm::Optional<Value> endIndexOpt,
llvm::Optional<Value> stepOpt, int64_t dim) {
auto loc = op->getLoc();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
auto rank = inputTy.getRank();

dim = (dim + rank) % rank;
Value dimSize = rewriter.create<arith::IndexCastOp>(
loc,
rewriter.getI64Type(),
loc, rewriter.getI64Type(),
rewriter.create<tensor::DimOp>(loc, input, dim));

Value normStartIndex = startIndexOpt
? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize)
: rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0));
Value normEndIndex = endIndexOpt
? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize)
: dimSize;
Value step = stepOpt
? *stepOpt
: rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
Value normStartIndex =
startIndexOpt
? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize)
: rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0));
Value normEndIndex =
endIndexOpt
? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize)
: dimSize;
Value step =
stepOpt ? *stepOpt
: rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
auto i32Type = rewriter.getIntegerType(kMhloDimSizeBits);
normStartIndex =
rewriter.create<arith::TruncIOp>(loc, i32Type, normStartIndex);
normEndIndex =
rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex);
normEndIndex = rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex);
step = rewriter.create<arith::TruncIOp>(loc, i32Type, step);
#endif
auto dimSizes = getDimSizesOfTensor(rewriter, op, input);

return getDynamicSliceInternal(
rewriter, op, input, normStartIndex, normEndIndex, step, dim, dimSizes);
return getDynamicSliceInternal(rewriter, op, input, normStartIndex,
normEndIndex, step, dim, dimSizes);
}

template <typename AtenOpT>
Expand All @@ -202,9 +192,8 @@ class ConvertAtenOp : public OpConversionPattern<AtenOpT> {

template <>
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
AtenSliceTensorOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
Expand All @@ -226,16 +215,110 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
llvm::Optional<Value> step = getOptionalVal(adaptor.step());

Value sliced =
getDynamicSlice(rewriter, op, self, start, end, step, dim);
Value sliced = getDynamicSlice(rewriter, op, self, start, end, step, dim);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), sliced);

return success();
}

// This defines a template to construct ops whose legalizations are
// specialized.
template <typename AtenOpT>
class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

LogicalResult matchAndRewrite(
AtenOpT op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto rankType =
adaptor.self().getType().template dyn_cast<RankedTensorType>();
if (!rankType)
return op.emitError("Only ranked tensor types are currently supported");

SmallVector<Value, 4> dimSizes;
if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) {
return op.emitError("Dims size must be a list of Scalar");
}

auto loc = op.getLoc();
auto newRank = dimSizes.size();
if (newRank == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self());
return success();
}

std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
dSize = rewriter.create<ToI64Op>(loc, dSize).getResult();
return dSize;
});

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
// The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
// One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
// the range of i32(4GiB)
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
// dimSize: cast i64 -> i32
dSize = rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
return dSize;
});
#endif

Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
for (auto d : dimSizes) {
numel = rewriter.create<arith::MulIOp>(loc, numel, d);
}
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
numel);

Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
loc, mhloShape.getType(), numel, mhloShape);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self(), computedShape);
return success();
}

bool getAtenViewOpSizes(
AtenOpT op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const;
};

template <>
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
AtenViewOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const {
return getListConstructElements(adaptor.size(), dimSizes);
}

template <>
bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
AtenReshapeOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const {
return getListConstructElements(adaptor.shape(), dimSizes);
}

} // namespace

void mlir::torch::torch_to_mhlo::populateSliceLikeOpPatternsAndLegality(
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
Expand All @@ -246,4 +329,10 @@ void mlir::torch::torch_to_mhlo::populateSliceLikeOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context);
INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
#undef INSERT_VIEW_OP_PATTERN
}
Loading