Skip to content

Commit 1de0959

Browse files
committed
add test
1 parent de4fcdb commit 1de0959

File tree

1 file changed

+78
-5
lines changed

1 file changed

+78
-5
lines changed

mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
10+
#include "mlir/IR/BuiltinDialect.h"
911
#include "mlir/IR/BuiltinOps.h"
1012
#include "mlir/IR/MLIRContext.h"
1113
#include "mlir/Parser/Parser.h"
@@ -30,24 +32,46 @@ using namespace mlir;
3032
#define SKIP_WITHOUT_NATIVE(x) x
3133
#endif
3234

35+
namespace {
36+
// Dummy interface for testing.
37+
class TargetAttrImpl
38+
: public gpu::TargetAttrInterface::FallbackModel<TargetAttrImpl> {
39+
public:
40+
std::optional<SmallVector<char, 0>>
41+
serializeToObject(Attribute attribute, Operation *module,
42+
const gpu::TargetOptions &options) const;
43+
44+
Attribute createObject(Attribute attribute, Operation *module,
45+
const SmallVector<char, 0> &object,
46+
const gpu::TargetOptions &options) const;
47+
};
48+
} // namespace
49+
3350
class MLIRTargetLLVM : public ::testing::Test {
3451
protected:
3552
void SetUp() override {
3653
llvm::InitializeNativeTarget();
3754
llvm::InitializeNativeTargetAsmPrinter();
55+
registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
56+
IntegerAttr::attachInterface<TargetAttrImpl>(*ctx);
57+
});
58+
registerBuiltinDialectTranslation(registry);
59+
registerLLVMDialectTranslation(registry);
60+
registry.insert<gpu::GPUDialect>();
3861
}
39-
};
4062

41-
TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) {
63+
// Dialect registry.
64+
DialectRegistry registry;
65+
66+
// MLIR module used for the tests.
4267
std::string moduleStr = R"mlir(
4368
llvm.func @foo(%arg0 : i32) {
4469
llvm.return
4570
}
4671
)mlir";
72+
};
4773

48-
DialectRegistry registry;
49-
registerBuiltinDialectTranslation(registry);
50-
registerLLVMDialectTranslation(registry);
74+
TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) {
5175
MLIRContext context(registry);
5276

5377
OwningOpRef<ModuleOp> module =
@@ -74,3 +98,52 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) {
7498
// Check that it has a function named `foo`.
7599
ASSERT_TRUE((*llvmModule)->getFunction("foo") != nullptr);
76100
}
101+
102+
std::optional<SmallVector<char, 0>>
103+
TargetAttrImpl::serializeToObject(Attribute attribute, Operation *module,
104+
const gpu::TargetOptions &options) const {
105+
module->setAttr("serialize_attr", UnitAttr::get(module->getContext()));
106+
std::string targetTriple = llvm::sys::getProcessTriple();
107+
LLVM::ModuleToObject serializer(*module, targetTriple, "", "");
108+
return serializer.run();
109+
}
110+
111+
Attribute
112+
TargetAttrImpl::createObject(Attribute attribute, Operation *module,
113+
const SmallVector<char, 0> &object,
114+
const gpu::TargetOptions &options) const {
115+
return gpu::ObjectAttr::get(
116+
module->getContext(), attribute, gpu::CompilationTarget::Offload,
117+
StringAttr::get(module->getContext(),
118+
StringRef(object.data(), object.size())),
119+
module->getAttrDictionary());
120+
}
121+
122+
TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(TargetAttrAPI)) {
123+
MLIRContext context(registry);
124+
context.loadAllAvailableDialects();
125+
126+
OwningOpRef<ModuleOp> module =
127+
parseSourceString<ModuleOp>(moduleStr, &context);
128+
ASSERT_TRUE(!!module);
129+
Builder builder(&context);
130+
IntegerAttr target = builder.getI32IntegerAttr(0);
131+
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
132+
// Check the attribute holds the interface.
133+
ASSERT_TRUE(!!targetAttr);
134+
gpu::TargetOptions opts;
135+
std::optional<SmallVector<char, 0>> serializedBinary =
136+
targetAttr.serializeToObject(*module, opts);
137+
// Check the serialized string.
138+
ASSERT_TRUE(!!serializedBinary);
139+
ASSERT_TRUE(!serializedBinary->empty());
140+
// Create the object attribute.
141+
auto object = cast<gpu::ObjectAttr>(
142+
targetAttr.createObject(*module, *serializedBinary, opts));
143+
// Check the object has properties.
144+
DictionaryAttr properties = object.getProperties();
145+
ASSERT_TRUE(!!properties);
146+
// Check that it contains the attribute added to the module in
147+
// `serializeToObject`.
148+
ASSERT_TRUE(properties.contains("serialize_attr"));
149+
}

0 commit comments

Comments
 (0)