Skip to content

Commit bb26a7a

Browse files
committed
address comments and cleanup
1 parent 2c37270 commit bb26a7a

File tree

2 files changed

+51
-54
lines changed

2 files changed

+51
-54
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2380,7 +2380,9 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
23802380
containing the specified address is brought.
23812381

23822382
`uniform` can be specified after the `cacheLevel` to indicate that the
2383-
prefetch is performed to the specified uniform cache level. If `uniform` is specified, `addr` must be a generic address pointer and no operation is performed if `addr` maps to a `const`, `local`, or `shared` memory location.
2383+
prefetch is performed to the specified uniform cache level. If `uniform` is
2384+
specified, `addr` must be a generic address pointer and no operation is
2385+
performed if `addr` maps to a `const`, `local`, or `shared` memory location.
23842386

23852387
The `evictPriority` attribute is optional and specifies the cache eviction
23862388
priority when `cacheLevel` is L2.

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 48 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,35 +1206,37 @@ LogicalResult NVVM::VoteSyncOp::verify() {
12061206
}
12071207

12081208
LogicalResult NVVM::PrefetchOp::verify() {
1209-
unsigned addressSpace =
1209+
using MemSpace = NVVM::NVVMMemorySpace;
1210+
using CacheLevel = NVVM::PrefetchCacheLevel;
1211+
1212+
unsigned as =
12101213
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1211-
auto evictPriority = getEvictPriority();
1214+
std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
12121215

12131216
if (getUniform()) {
1214-
if (!(getCacheLevel() == NVVM::PrefetchCacheLevel::L1)) {
1217+
if (getCacheLevel() != CacheLevel::L1)
12151218
return emitOpError("unsupported cache level, the only supported uniform "
12161219
"cache level is L1");
1217-
}
1218-
if (addressSpace != NVVM::NVVMMemorySpace::kGenericMemorySpace) {
1220+
1221+
if (as != MemSpace::kGenericMemorySpace)
12191222
return emitOpError(
12201223
"prefetch to uniform cache requires a generic pointer");
1221-
}
12221224
}
12231225

1224-
if (evictPriority && getCacheLevel() != NVVM::PrefetchCacheLevel::L2)
1225-
return emitOpError(
1226-
"cache eviction priority supported only for cache level L2");
1226+
if (evictPriority) {
1227+
if (getCacheLevel() != CacheLevel::L2)
1228+
return emitOpError(
1229+
"cache eviction priority supported only for cache level L2");
12271230

1228-
if (evictPriority &&
1229-
(addressSpace != NVVM::NVVMMemorySpace::kGlobalMemorySpace))
1230-
return emitOpError("cache eviction priority requires a global pointer");
1231+
if (as != MemSpace::kGlobalMemorySpace)
1232+
return emitOpError("cache eviction priority requires a global pointer");
12311233

1232-
if (evictPriority &&
1233-
*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1234-
*evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1235-
return emitOpError(
1236-
"unsupported cache eviction priority, only evict_last and "
1237-
"evict_normal are supported");
1234+
if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1235+
*evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1236+
return emitOpError(
1237+
"unsupported cache eviction priority, only evict_last and "
1238+
"evict_normal are supported");
1239+
}
12381240

12391241
return success();
12401242
}
@@ -1769,52 +1771,45 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
17691771
}
17701772

17711773
llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(Operation &op) {
1774+
using MemSpace = NVVM::NVVMMemorySpace;
1775+
using CacheLevel = NVVM::PrefetchCacheLevel;
1776+
17721777
auto curOp = llvm::cast<NVVM::PrefetchOp>(op);
1773-
NVVM::PrefetchCacheLevel cacheLevel = curOp.getCacheLevel();
1778+
NVVM::PrefetchCacheLevel cl = curOp.getCacheLevel();
17741779
std::optional<NVVM::CacheEvictionPriority> evictPriority =
17751780
curOp.getEvictPriority();
17761781
unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
17771782
.getAddressSpace();
17781783

1779-
if (curOp.getUniform()) {
1780-
if (cacheLevel == NVVM::PrefetchCacheLevel::L1)
1781-
return llvm::Intrinsic::nvvm_prefetchu_L1;
1782-
else
1783-
llvm_unreachable("Invalid uniform cache level");
1784-
}
1784+
if (curOp.getUniform() && cl == CacheLevel::L1)
1785+
return llvm::Intrinsic::nvvm_prefetchu_L1;
17851786

1786-
if (cacheLevel == NVVM::PrefetchCacheLevel::L1) {
1787-
switch (as) {
1788-
case NVVM::NVVMMemorySpace::kGenericMemorySpace:
1789-
return llvm::Intrinsic::nvvm_prefetch_L1;
1790-
case NVVM::NVVMMemorySpace::kGlobalMemorySpace:
1791-
return llvm::Intrinsic::nvvm_prefetch_global_L1;
1792-
case NVVM::NVVMMemorySpace::kLocalMemorySpace:
1793-
return llvm::Intrinsic::nvvm_prefetch_local_L1;
1794-
default:
1795-
llvm_unreachable("Invalid pointer address space");
1796-
}
1797-
} else if (cacheLevel == NVVM::PrefetchCacheLevel::L2) {
1798-
switch (as) {
1799-
case NVVM::NVVMMemorySpace::kGenericMemorySpace:
1800-
return llvm::Intrinsic::nvvm_prefetch_L2;
1801-
case NVVM::NVVMMemorySpace::kGlobalMemorySpace:
1802-
if (evictPriority) {
1803-
if (*evictPriority == NVVM::CacheEvictionPriority::EvictLast)
1804-
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
1805-
else if (*evictPriority == NVVM::CacheEvictionPriority::EvictNormal)
1806-
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
1807-
else
1808-
llvm_unreachable("Invalid cache eviction priority");
1809-
}
1810-
return llvm::Intrinsic::nvvm_prefetch_global_L2;
1811-
case NVVM::NVVMMemorySpace::kLocalMemorySpace:
1787+
if (evictPriority && cl == CacheLevel::L2) {
1788+
switch (*evictPriority) {
1789+
case NVVM::CacheEvictionPriority::EvictLast:
1790+
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
1791+
case NVVM::CacheEvictionPriority::EvictNormal:
18121792
return llvm::Intrinsic::nvvm_prefetch_local_L2;
18131793
default:
1814-
llvm_unreachable("Invalid pointer address space");
1794+
llvm_unreachable("Invalid cache eviction priority");
18151795
}
18161796
}
1817-
llvm_unreachable("Invalid cache level");
1797+
1798+
switch (as) {
1799+
case MemSpace::kGenericMemorySpace:
1800+
return cl == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
1801+
: llvm::Intrinsic::nvvm_prefetch_L2;
1802+
case MemSpace::kGlobalMemorySpace:
1803+
return cl == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_global_L1
1804+
: llvm::Intrinsic::nvvm_prefetch_global_L2;
1805+
case MemSpace::kLocalMemorySpace:
1806+
return cl == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_local_L1
1807+
: llvm::Intrinsic::nvvm_prefetch_local_L2;
1808+
default:
1809+
llvm_unreachable("Invalid pointer address space");
1810+
}
1811+
1812+
llvm_unreachable("Invalid parameters for prefetch");
18181813
}
18191814

18201815
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)