Skip to content

Commit 4be84a1

Browse files
authored
[mlir][gpu] Clean up prints in GPU dialect. NFC. (#136250)
Clean up printing code by switching to `llvm::interleaved` from #135517. Also make some minor readability & performance fixes.
1 parent d0dd697 commit 4be84a1

File tree

2 files changed

+26
-38
lines changed

2 files changed

+26
-38
lines changed

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

+10-8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include "llvm/ADT/TypeSwitch.h"
3636
#include "llvm/Support/CommandLine.h"
3737
#include "llvm/Support/ErrorHandling.h"
38+
#include "llvm/Support/FormatVariadic.h"
39+
#include "llvm/Support/InterleavedRange.h"
3840
#include "llvm/Support/StringSaver.h"
3941
#include <cassert>
4042
#include <numeric>
@@ -479,10 +481,11 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
479481
if (values.empty())
480482
return;
481483

482-
p << ' ' << keyword << '(';
483-
llvm::interleaveComma(
484-
values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
485-
p << ')';
484+
auto printBlockArg = [](BlockArgument v) {
485+
return llvm::formatv("{} : {}", v, v.getType());
486+
};
487+
p << ' ' << keyword << '('
488+
<< llvm::interleaved(llvm::map_range(values, printBlockArg)) << ')';
486489
}
487490

488491
/// Verifies a GPU function memory attribution.
@@ -1311,11 +1314,10 @@ static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
13111314
if (operands.empty())
13121315
return;
13131316
printer << "args(";
1314-
llvm::interleaveComma(llvm::zip(operands, types), printer,
1317+
llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
13151318
[&](const auto &pair) {
1316-
printer.printOperand(std::get<0>(pair));
1317-
printer << " : ";
1318-
printer.printType(std::get<1>(pair));
1319+
auto [operand, type] = pair;
1320+
printer << operand << " : " << type;
13191321
});
13201322
printer << ")";
13211323
}

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

+16-30
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "llvm/ADT/TypeSwitch.h"
4141
#include "llvm/Support/Debug.h"
4242
#include "llvm/Support/ErrorHandling.h"
43+
#include "llvm/Support/InterleavedRange.h"
4344
#include <type_traits>
4445

4546
using namespace mlir;
@@ -450,20 +451,14 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
450451
// Otherwise, we have a new insertion without a size -> use size 1.
451452
tmpMappingSizes.push_back(1);
452453
}
453-
LLVM_DEBUG(
454-
llvm::interleaveComma(
455-
tmpMappingSizes,
456-
DBGS() << "----tmpMappingSizes extracted from scf.forall op: ");
457-
llvm::dbgs() << "\n");
454+
LDBG("----tmpMappingSizes extracted from scf.forall op: "
455+
<< llvm::interleaved(tmpMappingSizes));
458456

459457
// Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
460458
SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
461459
forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
462-
LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
463-
DBGS() << "----forallMappingSizes: ");
464-
llvm::dbgs() << "\n"; llvm::interleaveComma(
465-
forallMappingAttrs, DBGS() << "----forallMappingAttrs: ");
466-
llvm::dbgs() << "\n");
460+
LDBG("----forallMappingSizes: " << llvm::interleaved(forallMappingSizes));
461+
LDBG("----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs));
467462

468463
// Step 3. Generate the mappingIdOps using the provided generator.
469464
Location loc = forallOp.getLoc();
@@ -501,17 +496,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
501496
SmallVector<int64_t> availableMappingSizes =
502497
builderResult.availableMappingSizes;
503498
SmallVector<Value> activeIdOps = builderResult.activeIdOps;
504-
// clang-format off
505-
LLVM_DEBUG(
506-
llvm::interleaveComma(
507-
activeMappingSizes, DBGS() << "----activeMappingSizes: ");
508-
llvm::dbgs() << "\n";
509-
llvm::interleaveComma(
510-
availableMappingSizes, DBGS() << "----availableMappingSizes: ");
511-
llvm::dbgs() << "\n";
512-
llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: ");
513-
llvm::dbgs() << "\n");
514-
// clang-format on
499+
LDBG("----activeMappingSizes: " << llvm::interleaved(activeMappingSizes));
500+
LDBG("----availableMappingSizes: "
501+
<< llvm::interleaved(availableMappingSizes));
502+
LDBG("----activeIdOps: " << llvm::interleaved(activeIdOps));
515503
for (auto [activeId, activeMappingSize, availableMappingSize] :
516504
llvm::zip_equal(activeIdOps, activeMappingSizes,
517505
availableMappingSizes)) {
@@ -566,11 +554,9 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
566554
// Step 8. Erase old op.
567555
rewriter.eraseOp(forallOp);
568556

569-
LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
570-
DBGS() << "----result forallMappingSizes: ");
571-
llvm::dbgs() << "\n"; llvm::interleaveComma(
572-
mappingIdOps, DBGS() << "----result mappingIdOps: ");
573-
llvm::dbgs() << "\n");
557+
LDBG("----result forallMappingSizes: "
558+
<< llvm::interleaved(forallMappingSizes));
559+
LDBG("----result mappingIdOps: " << llvm::interleaved(mappingIdOps));
574560

575561
result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
576562
return DiagnosedSilenceableFailure::success();
@@ -740,7 +726,7 @@ static DiagnosedSilenceableFailure checkMappingSpec(
740726
auto diag = definiteFailureHelper(
741727
transformOp, forallOp,
742728
Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
743-
std::to_string(factor));
729+
Twine(factor));
744730
return diag;
745731
}
746732
if (computeProduct(numParallelIterations) * factor >
@@ -749,9 +735,9 @@ static DiagnosedSilenceableFailure checkMappingSpec(
749735
transformOp, forallOp,
750736
Twine("the number of required parallel resources (blocks or "
751737
"threads) ") +
752-
std::to_string(computeProduct(numParallelIterations) * factor) +
753-
std::string(" overflows the number of available resources ") +
754-
std::to_string(computeProduct(blockOrGridSizes)));
738+
Twine(computeProduct(numParallelIterations) * factor) +
739+
" overflows the number of available resources " +
740+
Twine(computeProduct(blockOrGridSizes)));
755741
return diag;
756742
}
757743
return DiagnosedSilenceableFailure::success();

0 commit comments

Comments
 (0)