6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
10
+ #include " mlir/IR/BuiltinDialect.h"
9
11
#include " mlir/IR/BuiltinOps.h"
10
12
#include " mlir/IR/MLIRContext.h"
11
13
#include " mlir/Parser/Parser.h"
@@ -30,24 +32,46 @@ using namespace mlir;
30
32
#define SKIP_WITHOUT_NATIVE (x ) x
31
33
#endif
32
34
35
+ namespace {
36
+ // Dummy interface for testing.
37
+ class TargetAttrImpl
38
+ : public gpu::TargetAttrInterface::FallbackModel<TargetAttrImpl> {
39
+ public:
40
+ std::optional<SmallVector<char , 0 >>
41
+ serializeToObject (Attribute attribute, Operation *module ,
42
+ const gpu::TargetOptions &options) const ;
43
+
44
+ Attribute createObject (Attribute attribute, Operation *module ,
45
+ const SmallVector<char , 0 > &object,
46
+ const gpu::TargetOptions &options) const ;
47
+ };
48
+ } // namespace
49
+
33
50
class MLIRTargetLLVM : public ::testing::Test {
34
51
protected:
35
52
void SetUp () override {
36
53
llvm::InitializeNativeTarget ();
37
54
llvm::InitializeNativeTargetAsmPrinter ();
55
+ registry.addExtension (+[](MLIRContext *ctx, BuiltinDialect *dialect) {
56
+ IntegerAttr::attachInterface<TargetAttrImpl>(*ctx);
57
+ });
58
+ registerBuiltinDialectTranslation (registry);
59
+ registerLLVMDialectTranslation (registry);
60
+ registry.insert <gpu::GPUDialect>();
38
61
}
39
- };
40
62
41
- TEST_F (MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) {
63
+ // Dialect registry.
64
+ DialectRegistry registry;
65
+
66
+ // MLIR module used for the tests.
42
67
std::string moduleStr = R"mlir(
43
68
llvm.func @foo(%arg0 : i32) {
44
69
llvm.return
45
70
}
46
71
)mlir" ;
72
+ };
47
73
48
- DialectRegistry registry;
49
- registerBuiltinDialectTranslation (registry);
50
- registerLLVMDialectTranslation (registry);
74
+ TEST_F (MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) {
51
75
MLIRContext context (registry);
52
76
53
77
OwningOpRef<ModuleOp> module =
@@ -74,3 +98,52 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) {
74
98
// Check that it has a function named `foo`.
75
99
ASSERT_TRUE ((*llvmModule)->getFunction (" foo" ) != nullptr );
76
100
}
101
+
102
+ std::optional<SmallVector<char , 0 >>
103
+ TargetAttrImpl::serializeToObject (Attribute attribute, Operation *module ,
104
+ const gpu::TargetOptions &options) const {
105
+ module ->setAttr (" serialize_attr" , UnitAttr::get (module ->getContext ()));
106
+ std::string targetTriple = llvm::sys::getProcessTriple ();
107
+ LLVM::ModuleToObject serializer (*module , targetTriple, " " , " " );
108
+ return serializer.run ();
109
+ }
110
+
111
+ Attribute
112
+ TargetAttrImpl::createObject (Attribute attribute, Operation *module ,
113
+ const SmallVector<char , 0 > &object,
114
+ const gpu::TargetOptions &options) const {
115
+ return gpu::ObjectAttr::get (
116
+ module ->getContext (), attribute, gpu::CompilationTarget::Offload,
117
+ StringAttr::get (module ->getContext (),
118
+ StringRef (object.data (), object.size ())),
119
+ module ->getAttrDictionary ());
120
+ }
121
+
122
+ TEST_F (MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(TargetAttrAPI)) {
123
+ MLIRContext context (registry);
124
+ context.loadAllAvailableDialects ();
125
+
126
+ OwningOpRef<ModuleOp> module =
127
+ parseSourceString<ModuleOp>(moduleStr, &context);
128
+ ASSERT_TRUE (!!module );
129
+ Builder builder (&context);
130
+ IntegerAttr target = builder.getI32IntegerAttr (0 );
131
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
132
+ // Check the attribute holds the interface.
133
+ ASSERT_TRUE (!!targetAttr);
134
+ gpu::TargetOptions opts;
135
+ std::optional<SmallVector<char , 0 >> serializedBinary =
136
+ targetAttr.serializeToObject (*module , opts);
137
+ // Check the serialized string.
138
+ ASSERT_TRUE (!!serializedBinary);
139
+ ASSERT_TRUE (!serializedBinary->empty ());
140
+ // Create the object attribute.
141
+ auto object = cast<gpu::ObjectAttr>(
142
+ targetAttr.createObject (*module , *serializedBinary, opts));
143
+ // Check the object has properties.
144
+ DictionaryAttr properties = object.getProperties ();
145
+ ASSERT_TRUE (!!properties);
146
+ // Check that it contains the attribute added to the module in
147
+ // `serializeToObject`.
148
+ ASSERT_TRUE (properties.contains (" serialize_attr" ));
149
+ }
0 commit comments