-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][gpu] Clean up prints in GPU dialect. NFC. #136250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Clean up printing code by switching to `llvm::interleaved` from llvm#135517. Also make some minor readability & performance fixes.
@llvm/pr-subscribers-mlir Author: Jakub Kuderski (kuhar) ChangesClean up printing code by switching to Full diff: https://github.com/llvm/llvm-project/pull/136250.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 976432ea37120..9391d2c4ec840 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -35,6 +35,8 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/StringSaver.h"
#include <cassert>
#include <numeric>
@@ -479,10 +481,11 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
if (values.empty())
return;
- p << ' ' << keyword << '(';
- llvm::interleaveComma(
- values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
- p << ')';
+ auto printBlockArg = [](BlockArgument v) {
+ return llvm::formatv("{} : {}", v, v.getType());
+ };
+ p << ' ' << keyword << '('
+ << llvm::interleaved(llvm::map_range(values, printBlockArg)) << ')';
}
/// Verifies a GPU function memory attribution.
@@ -1311,11 +1314,10 @@ static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
if (operands.empty())
return;
printer << "args(";
- llvm::interleaveComma(llvm::zip(operands, types), printer,
+ llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
[&](const auto &pair) {
- printer.printOperand(std::get<0>(pair));
- printer << " : ";
- printer.printType(std::get<1>(pair));
+ auto [operand, type] = pair;
+ printer << operand << " : " << type;
});
printer << ")";
}
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index cce477370a539..3970539db6675 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -40,6 +40,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/InterleavedRange.h"
#include <type_traits>
using namespace mlir;
@@ -450,20 +451,14 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
// Otherwise, we have a new insertion without a size -> use size 1.
tmpMappingSizes.push_back(1);
}
- LLVM_DEBUG(
- llvm::interleaveComma(
- tmpMappingSizes,
- DBGS() << "----tmpMappingSizes extracted from scf.forall op: ");
- llvm::dbgs() << "\n");
+ LDBG("----tmpMappingSizes extracted from scf.forall op: "
+ << llvm::interleaved(tmpMappingSizes));
// Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
- LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
- DBGS() << "----forallMappingSizes: ");
- llvm::dbgs() << "\n"; llvm::interleaveComma(
- forallMappingAttrs, DBGS() << "----forallMappingAttrs: ");
- llvm::dbgs() << "\n");
+ LDBG("----forallMappingSizes: " << llvm::interleaved(forallMappingSizes));
+ LDBG("----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs));
// Step 3. Generate the mappingIdOps using the provided generator.
Location loc = forallOp.getLoc();
@@ -501,17 +496,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
SmallVector<int64_t> availableMappingSizes =
builderResult.availableMappingSizes;
SmallVector<Value> activeIdOps = builderResult.activeIdOps;
- // clang-format off
- LLVM_DEBUG(
- llvm::interleaveComma(
- activeMappingSizes, DBGS() << "----activeMappingSizes: ");
- llvm::dbgs() << "\n";
- llvm::interleaveComma(
- availableMappingSizes, DBGS() << "----availableMappingSizes: ");
- llvm::dbgs() << "\n";
- llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: ");
- llvm::dbgs() << "\n");
- // clang-format on
+ LDBG("----activeMappingSizes: " << llvm::interleaved(activeMappingSizes));
+ LDBG("----availableMappingSizes: "
+ << llvm::interleaved(availableMappingSizes));
+ LDBG("----activeIdOps: " << llvm::interleaved(activeIdOps));
for (auto [activeId, activeMappingSize, availableMappingSize] :
llvm::zip_equal(activeIdOps, activeMappingSizes,
availableMappingSizes)) {
@@ -566,11 +554,9 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
// Step 8. Erase old op.
rewriter.eraseOp(forallOp);
- LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
- DBGS() << "----result forallMappingSizes: ");
- llvm::dbgs() << "\n"; llvm::interleaveComma(
- mappingIdOps, DBGS() << "----result mappingIdOps: ");
- llvm::dbgs() << "\n");
+ LDBG("----result forallMappingSizes: "
+ << llvm::interleaved(forallMappingSizes));
+ LDBG("----result mappingIdOps: " << llvm::interleaved(mappingIdOps));
result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
return DiagnosedSilenceableFailure::success();
@@ -740,7 +726,7 @@ static DiagnosedSilenceableFailure checkMappingSpec(
auto diag = definiteFailureHelper(
transformOp, forallOp,
Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
- std::to_string(factor));
+ Twine(factor));
return diag;
}
if (computeProduct(numParallelIterations) * factor >
@@ -749,9 +735,9 @@ static DiagnosedSilenceableFailure checkMappingSpec(
transformOp, forallOp,
Twine("the number of required parallel resources (blocks or "
"threads) ") +
- std::to_string(computeProduct(numParallelIterations) * factor) +
- std::string(" overflows the number of available resources ") +
- std::to_string(computeProduct(blockOrGridSizes)));
+ Twine(computeProduct(numParallelIterations) * factor) +
+ " overflows the number of available resources " +
+ Twine(computeProduct(blockOrGridSizes)));
return diag;
}
return DiagnosedSilenceableFailure::success();
|
@llvm/pr-subscribers-mlir-gpu Author: Jakub Kuderski (kuhar) ChangesClean up printing code by switching to Full diff: https://github.com/llvm/llvm-project/pull/136250.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 976432ea37120..9391d2c4ec840 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -35,6 +35,8 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/StringSaver.h"
#include <cassert>
#include <numeric>
@@ -479,10 +481,11 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
if (values.empty())
return;
- p << ' ' << keyword << '(';
- llvm::interleaveComma(
- values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
- p << ')';
+ auto printBlockArg = [](BlockArgument v) {
+ return llvm::formatv("{} : {}", v, v.getType());
+ };
+ p << ' ' << keyword << '('
+ << llvm::interleaved(llvm::map_range(values, printBlockArg)) << ')';
}
/// Verifies a GPU function memory attribution.
@@ -1311,11 +1314,10 @@ static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
if (operands.empty())
return;
printer << "args(";
- llvm::interleaveComma(llvm::zip(operands, types), printer,
+ llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
[&](const auto &pair) {
- printer.printOperand(std::get<0>(pair));
- printer << " : ";
- printer.printType(std::get<1>(pair));
+ auto [operand, type] = pair;
+ printer << operand << " : " << type;
});
printer << ")";
}
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index cce477370a539..3970539db6675 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -40,6 +40,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/InterleavedRange.h"
#include <type_traits>
using namespace mlir;
@@ -450,20 +451,14 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
// Otherwise, we have a new insertion without a size -> use size 1.
tmpMappingSizes.push_back(1);
}
- LLVM_DEBUG(
- llvm::interleaveComma(
- tmpMappingSizes,
- DBGS() << "----tmpMappingSizes extracted from scf.forall op: ");
- llvm::dbgs() << "\n");
+ LDBG("----tmpMappingSizes extracted from scf.forall op: "
+ << llvm::interleaved(tmpMappingSizes));
// Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey(
forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
- LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
- DBGS() << "----forallMappingSizes: ");
- llvm::dbgs() << "\n"; llvm::interleaveComma(
- forallMappingAttrs, DBGS() << "----forallMappingAttrs: ");
- llvm::dbgs() << "\n");
+ LDBG("----forallMappingSizes: " << llvm::interleaved(forallMappingSizes));
+ LDBG("----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs));
// Step 3. Generate the mappingIdOps using the provided generator.
Location loc = forallOp.getLoc();
@@ -501,17 +496,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
SmallVector<int64_t> availableMappingSizes =
builderResult.availableMappingSizes;
SmallVector<Value> activeIdOps = builderResult.activeIdOps;
- // clang-format off
- LLVM_DEBUG(
- llvm::interleaveComma(
- activeMappingSizes, DBGS() << "----activeMappingSizes: ");
- llvm::dbgs() << "\n";
- llvm::interleaveComma(
- availableMappingSizes, DBGS() << "----availableMappingSizes: ");
- llvm::dbgs() << "\n";
- llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: ");
- llvm::dbgs() << "\n");
- // clang-format on
+ LDBG("----activeMappingSizes: " << llvm::interleaved(activeMappingSizes));
+ LDBG("----availableMappingSizes: "
+ << llvm::interleaved(availableMappingSizes));
+ LDBG("----activeIdOps: " << llvm::interleaved(activeIdOps));
for (auto [activeId, activeMappingSize, availableMappingSize] :
llvm::zip_equal(activeIdOps, activeMappingSizes,
availableMappingSizes)) {
@@ -566,11 +554,9 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
// Step 8. Erase old op.
rewriter.eraseOp(forallOp);
- LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
- DBGS() << "----result forallMappingSizes: ");
- llvm::dbgs() << "\n"; llvm::interleaveComma(
- mappingIdOps, DBGS() << "----result mappingIdOps: ");
- llvm::dbgs() << "\n");
+ LDBG("----result forallMappingSizes: "
+ << llvm::interleaved(forallMappingSizes));
+ LDBG("----result mappingIdOps: " << llvm::interleaved(mappingIdOps));
result = ForallRewriteResult{forallMappingSizes, mappingIdOps};
return DiagnosedSilenceableFailure::success();
@@ -740,7 +726,7 @@ static DiagnosedSilenceableFailure checkMappingSpec(
auto diag = definiteFailureHelper(
transformOp, forallOp,
Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
- std::to_string(factor));
+ Twine(factor));
return diag;
}
if (computeProduct(numParallelIterations) * factor >
@@ -749,9 +735,9 @@ static DiagnosedSilenceableFailure checkMappingSpec(
transformOp, forallOp,
Twine("the number of required parallel resources (blocks or "
"threads) ") +
- std::to_string(computeProduct(numParallelIterations) * factor) +
- std::string(" overflows the number of available resources ") +
- std::to_string(computeProduct(blockOrGridSizes)));
+ Twine(computeProduct(numParallelIterations) * factor) +
+ " overflows the number of available resources " +
+ Twine(computeProduct(blockOrGridSizes)));
return diag;
}
return DiagnosedSilenceableFailure::success();
|
values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); }); | ||
p << ')'; | ||
auto printBlockArg = [](BlockArgument v) { | ||
return llvm::formatv("{} : {}", v, v.getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is technically more expensive as llvm::formatv
allocates a temp string, but I don't think it matters for printers much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack
Clean up printing code by switching to
llvm::interleaved
from #135517. Also make some minor readability & performance fixes.