Skip to content

[mlir][gpu] Clean up prints in GPU dialect. NFC. #136250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/StringSaver.h"
#include <cassert>
#include <numeric>
Expand Down Expand Up @@ -479,10 +481,11 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
if (values.empty())
return;

p << ' ' << keyword << '(';
llvm::interleaveComma(
values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
p << ')';
auto printBlockArg = [](BlockArgument v) {
return llvm::formatv("{} : {}", v, v.getType());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is technically more expensive as llvm::formatv allocates a temp string, but I don't think it matters for printers much.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack

};
p << ' ' << keyword << '('
<< llvm::interleaved(llvm::map_range(values, printBlockArg)) << ')';
}

/// Verifies a GPU function memory attribution.
Expand Down Expand Up @@ -1311,11 +1314,10 @@ static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
if (operands.empty())
return;
printer << "args(";
llvm::interleaveComma(llvm::zip(operands, types), printer,
llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
[&](const auto &pair) {
printer.printOperand(std::get<0>(pair));
printer << " : ";
printer.printType(std::get<1>(pair));
auto [operand, type] = pair;
printer << operand << " : " << type;
});
printer << ")";
}
Expand Down
46 changes: 16 additions & 30 deletions mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InterleavedRange.h"
#include <type_traits>

using namespace mlir;
Expand Down Expand Up @@ -450,20 +451,14 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
// Otherwise, we have a new insertion without a size -> use size 1.
tmpMappingSizes.push_back(1);
}
LLVM_DEBUG(
llvm::interleaveComma(
tmpMappingSizes,
DBGS() << "----tmpMappingSizes extracted from scf.forall op: ");
llvm::dbgs() << "\n");
LDBG("----tmpMappingSizes extracted from scf.forall op: "
<< llvm::interleaved(tmpMappingSizes));

// Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
DBGS() << "----forallMappingSizes: ");
llvm::dbgs() << "\n"; llvm::interleaveComma(
forallMappingAttrs, DBGS() << "----forallMappingAttrs: ");
llvm::dbgs() << "\n");
LDBG("----forallMappingSizes: " << llvm::interleaved(forallMappingSizes));
LDBG("----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs));

// Step 3. Generate the mappingIdOps using the provided generator.
Location loc = forallOp.getLoc();
Expand Down Expand Up @@ -501,17 +496,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
SmallVector<int64_t> availableMappingSizes =
builderResult.availableMappingSizes;
SmallVector<Value> activeIdOps = builderResult.activeIdOps;
// clang-format off
LLVM_DEBUG(
llvm::interleaveComma(
activeMappingSizes, DBGS() << "----activeMappingSizes: ");
llvm::dbgs() << "\n";
llvm::interleaveComma(
availableMappingSizes, DBGS() << "----availableMappingSizes: ");
llvm::dbgs() << "\n";
llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: ");
llvm::dbgs() << "\n");
// clang-format on
LDBG("----activeMappingSizes: " << llvm::interleaved(activeMappingSizes));
LDBG("----availableMappingSizes: "
<< llvm::interleaved(availableMappingSizes));
LDBG("----activeIdOps: " << llvm::interleaved(activeIdOps));
for (auto [activeId, activeMappingSize, availableMappingSize] :
llvm::zip_equal(activeIdOps, activeMappingSizes,
availableMappingSizes)) {
Expand Down Expand Up @@ -566,11 +554,9 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
// Step 8. Erase old op.
rewriter.eraseOp(forallOp);

LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
DBGS() << "----result forallMappingSizes: ");
llvm::dbgs() << "\n"; llvm::interleaveComma(
mappingIdOps, DBGS() << "----result mappingIdOps: ");
llvm::dbgs() << "\n");
LDBG("----result forallMappingSizes: "
<< llvm::interleaved(forallMappingSizes));
LDBG("----result mappingIdOps: " << llvm::interleaved(mappingIdOps));

result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
return DiagnosedSilenceableFailure::success();
Expand Down Expand Up @@ -740,7 +726,7 @@ static DiagnosedSilenceableFailure checkMappingSpec(
auto diag = definiteFailureHelper(
transformOp, forallOp,
Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
std::to_string(factor));
Twine(factor));
return diag;
}
if (computeProduct(numParallelIterations) * factor >
Expand All @@ -749,9 +735,9 @@ static DiagnosedSilenceableFailure checkMappingSpec(
transformOp, forallOp,
Twine("the number of required parallel resources (blocks or "
"threads) ") +
std::to_string(computeProduct(numParallelIterations) * factor) +
std::string(" overflows the number of available resources ") +
std::to_string(computeProduct(blockOrGridSizes)));
Twine(computeProduct(numParallelIterations) * factor) +
" overflows the number of available resources " +
Twine(computeProduct(blockOrGridSizes)));
return diag;
}
return DiagnosedSilenceableFailure::success();
Expand Down
Loading