Skip to content

[mlir][gpu] Pass GPU module to TargetAttrInterface::createObject. #94910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def GPUTargetAttrInterface : AttrInterface<"TargetAttrInterface"> {
meant to be used for passing additional options that are not in the
attribute.
}], "::mlir::Attribute", "createObject",
(ins "const ::llvm::SmallVector<char, 0>&":$object,
"const ::mlir::gpu::TargetOptions&":$options)>
(ins "::mlir::Operation *":$module,
Copy link
Collaborator

@joker-eph joker-eph Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're missing documentation here!

(also should it be GPUModuleOp instead of Operation*)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GPUModuleOp vs Operation* goes back to the original introduction of these attributes. If you see, all other methods also use Operation*. Back then, there was a build dependency that required the interfaces to be generated before the dialect, hence the types were not available (I'm not sure this is needed anymore, I'll check). Then a second reason, was to make the header independent of GPUDialect.h

"const ::llvm::SmallVector<char, 0> &":$object,
"const ::mlir::gpu::TargetOptions &":$options)>
];
}

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ LogicalResult moduleSerializer(GPUModuleOp op,
return failure();
}

Attribute object = target.createObject(*serializedModule, targetOptions);
Attribute object =
target.createObject(op, *serializedModule, targetOptions);
if (!object) {
op.emitError("An error happened while creating the object.");
return failure();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Target/LLVM/NVVM/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class NVVMTargetAttrImpl
serializeToObject(Attribute attribute, Operation *module,
const gpu::TargetOptions &options) const;

Attribute createObject(Attribute attribute,
Attribute createObject(Attribute attribute, Operation *module,
const SmallVector<char, 0> &object,
const gpu::TargetOptions &options) const;
};
Expand Down Expand Up @@ -591,7 +591,7 @@ NVVMTargetAttrImpl::serializeToObject(Attribute attribute, Operation *module,
}

Attribute
NVVMTargetAttrImpl::createObject(Attribute attribute,
NVVMTargetAttrImpl::createObject(Attribute attribute, Operation *module,
const SmallVector<char, 0> &object,
const gpu::TargetOptions &options) const {
auto target = cast<NVVMTargetAttr>(attribute);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Target/LLVM/ROCDL/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ROCDLTargetAttrImpl
serializeToObject(Attribute attribute, Operation *module,
const gpu::TargetOptions &options) const;

Attribute createObject(Attribute attribute,
Attribute createObject(Attribute attribute, Operation *module,
const SmallVector<char, 0> &object,
const gpu::TargetOptions &options) const;
};
Expand Down Expand Up @@ -500,7 +500,7 @@ std::optional<SmallVector<char, 0>> ROCDLTargetAttrImpl::serializeToObject(
}

Attribute
ROCDLTargetAttrImpl::createObject(Attribute attribute,
ROCDLTargetAttrImpl::createObject(Attribute attribute, Operation *module,
const SmallVector<char, 0> &object,
const gpu::TargetOptions &options) const {
gpu::CompilationTarget format = options.getCompilationTarget();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Target/SPIRV/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class SPIRVTargetAttrImpl
serializeToObject(Attribute attribute, Operation *module,
const gpu::TargetOptions &options) const;

Attribute createObject(Attribute attribute,
Attribute createObject(Attribute attribute, Operation *module,
const SmallVector<char, 0> &object,
const gpu::TargetOptions &options) const;
};
Expand Down Expand Up @@ -89,7 +89,7 @@ std::optional<SmallVector<char, 0>> SPIRVTargetAttrImpl::serializeToObject(

// Prepare Attribute for gpu.binary with serialized kernel object
Attribute
SPIRVTargetAttrImpl::createObject(Attribute attribute,
SPIRVTargetAttrImpl::createObject(Attribute attribute, Operation *module,
const SmallVector<char, 0> &object,
const gpu::TargetOptions &options) const {
gpu::CompilationTarget format = options.getCompilationTarget();
Expand Down
83 changes: 78 additions & 5 deletions mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser/Parser.h"
Expand All @@ -30,24 +32,46 @@ using namespace mlir;
#define SKIP_WITHOUT_NATIVE(x) x
#endif

namespace {
// Dummy interface for testing.
class TargetAttrImpl
: public gpu::TargetAttrInterface::FallbackModel<TargetAttrImpl> {
public:
std::optional<SmallVector<char, 0>>
serializeToObject(Attribute attribute, Operation *module,
const gpu::TargetOptions &options) const;

Attribute createObject(Attribute attribute, Operation *module,
const SmallVector<char, 0> &object,
const gpu::TargetOptions &options) const;
};
} // namespace

class MLIRTargetLLVM : public ::testing::Test {
protected:
void SetUp() override {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
IntegerAttr::attachInterface<TargetAttrImpl>(*ctx);
});
registerBuiltinDialectTranslation(registry);
registerLLVMDialectTranslation(registry);
registry.insert<gpu::GPUDialect>();
}
};

TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) {
// Dialect registry.
DialectRegistry registry;

// MLIR module used for the tests.
std::string moduleStr = R"mlir(
llvm.func @foo(%arg0 : i32) {
llvm.return
}
)mlir";
};

DialectRegistry registry;
registerBuiltinDialectTranslation(registry);
registerLLVMDialectTranslation(registry);
TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) {
MLIRContext context(registry);

OwningOpRef<ModuleOp> module =
Expand All @@ -74,3 +98,52 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) {
// Check that it has a function named `foo`.
ASSERT_TRUE((*llvmModule)->getFunction("foo") != nullptr);
}

std::optional<SmallVector<char, 0>>
TargetAttrImpl::serializeToObject(Attribute attribute, Operation *module,
const gpu::TargetOptions &options) const {
module->setAttr("serialize_attr", UnitAttr::get(module->getContext()));
std::string targetTriple = llvm::sys::getProcessTriple();
LLVM::ModuleToObject serializer(*module, targetTriple, "", "");
return serializer.run();
}

Attribute
TargetAttrImpl::createObject(Attribute attribute, Operation *module,
const SmallVector<char, 0> &object,
const gpu::TargetOptions &options) const {
return gpu::ObjectAttr::get(
module->getContext(), attribute, gpu::CompilationTarget::Offload,
StringAttr::get(module->getContext(),
StringRef(object.data(), object.size())),
module->getAttrDictionary());
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add more documentation to the tests, out of the context of this PR it wouldn't be straightforward to understand the intent of all this I believe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, wrt docs everywhere, would you prefer a fresh patch, or can I bundle the changes into this patch #95292 ? (they are related)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td isn't touched at all in #95292 ? I wouldn't add it to the changeset there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I'm touching CompilationAttrs.td. I'll merge that other one, create a patch with the docs and then revisit if we can change Operation* to GPUModuleOp.


TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(TargetAttrAPI)) {
MLIRContext context(registry);
context.loadAllAvailableDialects();

OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);
Builder builder(&context);
IntegerAttr target = builder.getI32IntegerAttr(0);
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
// Check the attribute holds the interface.
ASSERT_TRUE(!!targetAttr);
gpu::TargetOptions opts;
std::optional<SmallVector<char, 0>> serializedBinary =
targetAttr.serializeToObject(*module, opts);
// Check the serialized string.
ASSERT_TRUE(!!serializedBinary);
ASSERT_TRUE(!serializedBinary->empty());
// Create the object attribute.
auto object = cast<gpu::ObjectAttr>(
targetAttr.createObject(*module, *serializedBinary, opts));
// Check the object has properties.
DictionaryAttr properties = object.getProperties();
ASSERT_TRUE(!!properties);
// Check that it contains the attribute added to the module in
// `serializeToObject`.
ASSERT_TRUE(properties.contains("serialize_attr"));
}
Loading