Skip to content

Commit

Permalink
[Arith] MLIR PresburgerSet compile fix mlir >= 160 (apache#15638)
Browse files Browse the repository at this point in the history
Hi folks,

Some fixes for MLIR based analyzer module introduced by apache#14690 .

---

* Make CMake at par with LLVM info:
```
{...}
-- Use llvm-config=llvm-config-64
-- LLVM libdir: /usr/lib64
-- Found MLIR
-- Build with MLIR
-- Set TVM_MLIR_VERSION=160
-- Found LLVM_INCLUDE_DIRS=/usr/include
{...}
--    USE_MKL                            : OFF
--    USE_MLIR                           : ON
--    USE_MSVC_MT                        : OFF
{...}
```

* Fix several compilation errors:
```
error: cannot convert 'llvm::SmallVector<long int>' to 'llvm::ArrayRef<mlir::presburger::MPInt>'
error: no matching function for call to 'tvm::IntImm::IntImm(tvm::runtime::DataType, mlir::presburger::MPInt&)'
note:   no known conversion for argument 2 from 'mlir::presburger::MPInt' to 'int64_t' {aka 'long int'}
```

Tested using: ```llvm/mlir 16.0.6```, ```llvm/mlir 15.0.7```,  ```llvm/mlir 17.0.0rc3```
  • Loading branch information
cbalint13 authored Aug 31, 2023
1 parent 79f9e57 commit 022299b
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 9 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ tvm_option(USE_HEXAGON_EXTERNAL_LIBS "Path to git repo containing external Hexag
tvm_option(USE_RPC "Build with RPC" ON)
tvm_option(USE_THREADS "Build with thread support" ON)
tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" OFF)
tvm_option(USE_MLIR "Build with MLIR support" OFF)
tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF)
tvm_option(USE_GRAPH_EXECUTOR "Build with tiny graph executor" ON)
tvm_option(USE_GRAPH_EXECUTOR_CUDA_GRAPH "Build with tiny graph executor with CUDA Graph for GPUs" OFF)
Expand Down
7 changes: 7 additions & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ function(add_lib_info src_file)
else()
string(STRIP ${TVM_INFO_LLVM_VERSION} TVM_INFO_LLVM_VERSION)
endif()
if (NOT DEFINED TVM_INFO_MLIR_VERSION)
set(TVM_INFO_MLIR_VERSION "NOT-FOUND")
else()
string(STRIP ${TVM_INFO_MLIR_VERSION} TVM_INFO_MLIR_VERSION)
endif()
if (NOT DEFINED CUDA_VERSION)
set(TVM_INFO_CUDA_VERSION "NOT-FOUND")
else()
Expand All @@ -47,6 +52,7 @@ function(add_lib_info src_file)
TVM_INFO_INDEX_DEFAULT_I64="${INDEX_DEFAULT_I64}"
TVM_INFO_INSTALL_DEV="${INSTALL_DEV}"
TVM_INFO_LLVM_VERSION="${TVM_INFO_LLVM_VERSION}"
TVM_INFO_MLIR_VERSION="${TVM_INFO_MLIR_VERSION}"
TVM_INFO_PICOJSON_PATH="${PICOJSON_PATH}"
TVM_INFO_RANG_PATH="${RANG_PATH}"
TVM_INFO_ROCM_PATH="${ROCM_PATH}"
Expand Down Expand Up @@ -86,6 +92,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_LIBBACKTRACE="${USE_LIBBACKTRACE}"
TVM_INFO_USE_LIBTORCH="${USE_LIBTORCH}"
TVM_INFO_USE_LLVM="${USE_LLVM}"
TVM_INFO_USE_MLIR="${USE_MLIR}"
TVM_INFO_USE_METAL="${USE_METAL}"
TVM_INFO_USE_MICRO_STANDALONE_RUNTIME="${USE_MICRO_STANDALONE_RUNTIME}"
TVM_INFO_USE_MICRO="${USE_MICRO}"
Expand Down
2 changes: 2 additions & 0 deletions cmake/utils/FindLLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ macro(find_llvm use_llvm)
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRPresburger.a")
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRSupport.a")
set(TVM_MLIR_VERSION ${TVM_LLVM_VERSION})
message(STATUS "Build with MLIR")
message(STATUS "Set TVM_MLIR_VERSION=" ${TVM_MLIR_VERSION})
endif()
endif()
endif()
Expand Down
51 changes: 42 additions & 9 deletions src/arith/presburger_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,38 +126,54 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const {
for (const IntegerRelation& disjunct : disjuncts) {
PrimExpr union_entry = Bool(1);
for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) {
PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
PrimExpr linear_eq = IntImm(DataType::Int(64), 0);
if (disjunct.getNumCols() > 1) {
for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
#if TVM_MLIR_VERSION >= 160
auto coeff = int64_t(disjunct.atEq(i, j));
#else
auto coeff = disjunct.atEq(i, j);
#endif
if (coeff >= 0 || is_zero(linear_eq)) {
linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j];
} else {
linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j];
linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * vars[j];
}
}
}
#if TVM_MLIR_VERSION >= 160
auto c0 = int64_t(disjunct.atEq(i, disjunct.getNumCols() - 1));
#else
auto c0 = disjunct.atEq(i, disjunct.getNumCols() - 1);
linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
#endif
linear_eq = linear_eq + IntImm(DataType::Int(64), c0);
union_entry = (union_entry && (linear_eq == 0));
}
for (unsigned i = 0, e = disjunct.getNumInequalities(); i < e; ++i) {
PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
PrimExpr linear_eq = IntImm(DataType::Int(64), 0);
if (disjunct.getNumCols() > 1) {
for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) {
#if TVM_MLIR_VERSION >= 160
auto coeff = int64_t(disjunct.atIneq(i, j));
#else
auto coeff = disjunct.atIneq(i, j);
#endif
if (coeff >= 0 || is_zero(linear_eq)) {
linear_eq = linear_eq + IntImm(DataType::Int(32), coeff) * vars[j];
linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j];
} else {
linear_eq = linear_eq - IntImm(DataType::Int(32), -coeff) * vars[j];
linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * vars[j];
}
}
}
#if TVM_MLIR_VERSION >= 160
auto c0 = int64_t(disjunct.atIneq(i, disjunct.getNumCols() - 1));
#else
auto c0 = disjunct.atIneq(i, disjunct.getNumCols() - 1);
#endif
if (c0 >= 0) {
linear_eq = linear_eq + IntImm(DataType::Int(32), c0);
linear_eq = linear_eq + IntImm(DataType::Int(64), c0);
} else {
linear_eq = linear_eq - IntImm(DataType::Int(32), -c0);
linear_eq = linear_eq - IntImm(DataType::Int(64), -c0);
}
union_entry = (union_entry && (linear_eq >= 0));
}
Expand Down Expand Up @@ -199,10 +215,19 @@ PresburgerSet Intersect(const Array<PresburgerSet>& sets) {

IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
Array<PrimExpr> tvm_coeffs = DetectLinearEquation(e, set->GetVars());
#if TVM_MLIR_VERSION >= 160
SmallVector<mlir::presburger::MPInt> coeffs;
#else
SmallVector<int64_t> coeffs;
#endif

coeffs.reserve(tvm_coeffs.size());
for (const PrimExpr& it : tvm_coeffs) {
#if TVM_MLIR_VERSION >= 160
coeffs.push_back(mlir::presburger::MPInt(*as_const_int(it)));
#else
coeffs.push_back(*as_const_int(it));
#endif
}

IntSet result = IntSet().Nothing();
Expand All @@ -211,9 +236,17 @@ IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
auto range = simplex.computeIntegerBounds(coeffs);
auto maxRoundedDown(simplex.computeOptimum(Simplex::Direction::Up, coeffs));
auto opt = range.first.getOptimumIfBounded();
#if TVM_MLIR_VERSION >= 160
auto min = opt.has_value() ? IntImm(DataType::Int(64), int64_t(opt.value())) : neg_inf();
#else
auto min = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : neg_inf();
#endif
opt = range.second.getOptimumIfBounded();
#if TVM_MLIR_VERSION >= 160
auto max = opt.has_value() ? IntImm(DataType::Int(64), int64_t(opt.value())) : pos_inf();
#else
auto max = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : pos_inf();
#endif
auto interval = IntervalSet(min, max);
result = Union({result, interval});
}
Expand Down
6 changes: 6 additions & 0 deletions src/support/libinfo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
#define TVM_INFO_LLVM_VERSION "NOT-FOUND"
#endif

#ifndef TVM_INFO_MLIR_VERSION
#define TVM_INFO_MLIR_VERSION "NOT-FOUND"
#endif

#ifndef TVM_INFO_USE_CUDA
#define TVM_INFO_USE_CUDA "NOT-FOUND"
#endif
Expand Down Expand Up @@ -271,6 +275,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
{"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64},
{"INSTALL_DEV", TVM_INFO_INSTALL_DEV},
{"LLVM_VERSION", TVM_INFO_LLVM_VERSION},
{"MLIR_VERSION", TVM_INFO_MLIR_VERSION},
{"PICOJSON_PATH", TVM_INFO_PICOJSON_PATH},
{"RANG_PATH", TVM_INFO_RANG_PATH},
{"ROCM_PATH", TVM_INFO_ROCM_PATH},
Expand Down Expand Up @@ -311,6 +316,7 @@ TVM_DLL Map<String, String> GetLibInfo() {
{"USE_LIBBACKTRACE", TVM_INFO_USE_LIBBACKTRACE},
{"USE_LIBTORCH", TVM_INFO_USE_LIBTORCH},
{"USE_LLVM", TVM_INFO_USE_LLVM},
{"USE_MLIR", TVM_INFO_USE_MLIR},
{"USE_METAL", TVM_INFO_USE_METAL},
{"USE_MICRO_STANDALONE_RUNTIME", TVM_INFO_USE_MICRO_STANDALONE_RUNTIME},
{"USE_MICRO", TVM_INFO_USE_MICRO},
Expand Down

0 comments on commit 022299b

Please sign in to comment.