Skip to content

Commit a52cb73

Browse files
sstipanovchuravy
andauthored
Value traits support (#121)
Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
1 parent 726df05 commit a52cb73

File tree

11 files changed

+375
-71
lines changed

11 files changed

+375
-71
lines changed

example/ExampleDialect.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def SizeOfOp : ExampleOp<"sizeof", [Memory<[]>, NoUnwind, WillReturn]> {
152152
let results = (outs I64:$result);
153153
let arguments = (ins type:$sizeof_type);
154154

155+
let value_traits = [
156+
(NoCapture $sizeof_type),
157+
];
158+
155159
let summary = "size of a given type";
156160
let description = [{
157161
Returns the store size of the given type in bytes.
@@ -235,6 +239,10 @@ def IExtOp : ExampleOp<"iext", [Memory<[]>, NoUnwind, WillReturn]> {
235239
def StreamReduceOp : OpClass<ExampleDialect> {
236240
let arguments = (ins Ptr:$ptr, I64:$count, value:$initial);
237241

242+
let value_traits = [
243+
(NoCapture $ptr)
244+
];
245+
238246
let summary = "family of operations that reduce some array in memory";
239247
let description = [{
240248
Illustrate the use of the OpClass feature.
@@ -336,3 +344,20 @@ def NoDescriptionOp : Op<ExampleDialect, "no.description.op", [WillReturn]> {
336344

337345
let summary = "Some summary";
338346
}
347+
348+
def BufferCompareOp : Op<ExampleDialect, "buffer.compare.op", [WillReturn, NoUnwind]> {
349+
let results = (outs I32:$ret);
350+
let arguments = (ins Ptr:$lhs, Ptr:$rhs);
351+
352+
let value_traits = [
353+
(NoCapture $lhs),
354+
(NoCapture $rhs),
355+
(NoUndef $lhs),
356+
(NoUndef $ret),
357+
];
358+
359+
let summary = "demonstrate how multiple parameter attributes are added";
360+
let description = [{
361+
Both arguments get a parameter attribute, as well as return value
362+
}];
363+
}

include/llvm-dialects/Dialect/Dialect.td

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -322,29 +322,42 @@ multiclass AttrEnum<string cppType_> {
322322
/// Traits generally map to llvm::Attributes.
323323
// ============================================================================
324324

325-
class Trait;
325+
class TraitProperty;
326+
def FnTrait : TraitProperty;
327+
def ParamTrait : TraitProperty;
328+
def RetTrait : TraitProperty;
326329

327-
class LlvmEnumAttributeTrait<string llvmEnum_> : Trait {
330+
class Trait<list<TraitProperty> P> {
331+
list<TraitProperty> Properties = P;
332+
}
333+
334+
class LlvmEnumAttributeTrait<string llvmEnum_, list<TraitProperty> P> : Trait<P> {
328335
string llvmEnum = llvmEnum_;
329336
}
330337

331-
def NoUnwind : LlvmEnumAttributeTrait<"NoUnwind">;
332-
def WillReturn : LlvmEnumAttributeTrait<"WillReturn">;
333-
def NoReturn : LlvmEnumAttributeTrait<"NoReturn">;
334-
def NoRecurse : LlvmEnumAttributeTrait<"NoRecurse">;
335-
def NoSync : LlvmEnumAttributeTrait<"NoSync">;
336-
def NoFree : LlvmEnumAttributeTrait<"NoFree">;
337-
def MustProgress : LlvmEnumAttributeTrait<"MustProgress">;
338-
def NoCallback : LlvmEnumAttributeTrait<"NoCallback">;
339-
def NoDuplicate : LlvmEnumAttributeTrait<"NoDuplicate">;
340-
def NoBuiltin : LlvmEnumAttributeTrait<"NoBuiltin">;
341-
def Builtin : LlvmEnumAttributeTrait<"Builtin">;
342-
def InlineHint : LlvmEnumAttributeTrait<"InlineHint">;
343-
def AlwaysInline : LlvmEnumAttributeTrait<"AlwaysInline">;
344-
def Cold : LlvmEnumAttributeTrait<"Cold">;
345-
def Hot : LlvmEnumAttributeTrait<"Hot">;
346-
def Convergent : LlvmEnumAttributeTrait<"Convergent">;
347-
def Speculatable : LlvmEnumAttributeTrait<"Speculatable">;
338+
def NoUnwind : LlvmEnumAttributeTrait<"NoUnwind", [FnTrait]>;
339+
def WillReturn : LlvmEnumAttributeTrait<"WillReturn", [FnTrait]>;
340+
def NoReturn : LlvmEnumAttributeTrait<"NoReturn", [FnTrait]>;
341+
def NoRecurse : LlvmEnumAttributeTrait<"NoRecurse", [FnTrait]>;
342+
def NoSync : LlvmEnumAttributeTrait<"NoSync", [FnTrait]>;
343+
def NoFree : LlvmEnumAttributeTrait<"NoFree", [FnTrait]>;
344+
def MustProgress : LlvmEnumAttributeTrait<"MustProgress", [FnTrait]>;
345+
def NoCallback : LlvmEnumAttributeTrait<"NoCallback", [FnTrait]>;
346+
def NoDuplicate : LlvmEnumAttributeTrait<"NoDuplicate", [FnTrait]>;
347+
def NoBuiltin : LlvmEnumAttributeTrait<"NoBuiltin", [FnTrait]>;
348+
def Builtin : LlvmEnumAttributeTrait<"Builtin", [FnTrait]>;
349+
def InlineHint : LlvmEnumAttributeTrait<"InlineHint", [FnTrait]>;
350+
def AlwaysInline : LlvmEnumAttributeTrait<"AlwaysInline", [FnTrait]>;
351+
def Cold : LlvmEnumAttributeTrait<"Cold", [FnTrait]>;
352+
def Hot : LlvmEnumAttributeTrait<"Hot", [FnTrait]>;
353+
def Convergent : LlvmEnumAttributeTrait<"Convergent", [FnTrait]>;
354+
def Speculatable : LlvmEnumAttributeTrait<"Speculatable", [FnTrait]>;
355+
356+
def NoCapture : LlvmEnumAttributeTrait<"NoCapture", [ParamTrait]>;
357+
def ReadOnly : LlvmEnumAttributeTrait<"ReadOnly", [ParamTrait]>;
358+
359+
def NoUndef : LlvmEnumAttributeTrait<"NoUndef", [ParamTrait, RetTrait]>;
360+
def NonNull : LlvmEnumAttributeTrait<"NonNull", [ParamTrait, RetTrait]>;
348361

349362
/// Represent the LLVM `memory(...)` attribute as the OR (or union) of memory
350363
/// effects. An empty effects list means the operation does not access memory
@@ -358,7 +371,7 @@ def Speculatable : LlvmEnumAttributeTrait<"Speculatable">;
358371
/// Example: `Memory<[(ref), (mod ArgMem, InaccessibleMem)]>` means the
359372
/// operation may read from any kind of memory and write to argument and
360373
/// inaccessible memory.
361-
class Memory<list<dag> effects_> : Trait {
374+
class Memory<list<dag> effects_> : Trait<[FnTrait]> {
362375
list<dag> effects = effects_;
363376
}
364377

@@ -394,6 +407,8 @@ class OpClass<Dialect dialect_> : OpClassBase {
394407

395408
dag arguments = ?;
396409

410+
list<dag> value_traits = [];
411+
397412
string summary = ?;
398413
string description = ?;
399414
}
@@ -412,6 +427,8 @@ class Op<Dialect dialect_, string mnemonic_, list<Trait> traits_> {
412427
dag arguments = ?;
413428
dag results = ?;
414429

430+
list<dag> value_traits = [];
431+
415432
list<dag> verifier = [];
416433

417434
string summary = ?;

include/llvm-dialects/TableGen/Dialects.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#pragma once
2020

2121
#include <memory>
22+
#include <unordered_map>
2223

2324
#include "llvm-dialects/TableGen/Common.h"
2425
#include "llvm/ADT/ArrayRef.h"
@@ -75,7 +76,7 @@ class GenDialectsContext {
7576
void init(RecordKeeperTy &records,
7677
const llvm::DenseSet<llvm::StringRef> &dialects);
7778

78-
Trait *getTrait(RecordTy *traitRec);
79+
Trait *getTrait(RecordTy *traitRec, int idx = -1);
7980
Predicate *getPredicate(const llvm::Init *init, llvm::raw_ostream &errs);
8081
Attr *getAttr(RecordTy *record, llvm::raw_ostream &errs);
8182
OpClass *getOpClass(RecordTy *opClassRec);
@@ -97,6 +98,7 @@ class GenDialectsContext {
9798
const llvm::Init *m_any = nullptr;
9899
bool m_attrsComplete = false;
99100
llvm::DenseMap<RecordTy *, std::unique_ptr<Trait>> m_traits;
101+
llvm::DenseMap<RecordTy *, std::unordered_map<int, std::unique_ptr<Trait>>> m_valueTraits;
100102
llvm::DenseMap<const llvm::Init *, std::unique_ptr<Predicate>> m_predicates;
101103
llvm::DenseMap<RecordTy *, std::unique_ptr<Attr>> m_attrs;
102104
llvm::DenseMap<RecordTy *, std::unique_ptr<OpClass>> m_opClasses;

include/llvm-dialects/TableGen/Operations.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ class OperationBase {
6767
void emitArgumentAccessorDefinitions(llvm::raw_ostream &out,
6868
FmtContext &fmt) const;
6969

70+
void parseValueTraits(llvm::raw_ostream &errs, RecordTy *record,
71+
GenDialectsContext &context);
72+
73+
std::vector<Trait *> traits;
74+
7075
protected:
7176
bool init(llvm::raw_ostream &errs, GenDialectsContext &context,
7277
RecordTy *record);
@@ -105,7 +110,6 @@ class Operation : public OperationBase {
105110
std::string mnemonic;
106111
std::string summary;
107112
std::string description;
108-
std::vector<Trait *> traits;
109113

110114
std::vector<NamedValue> results;
111115

include/llvm-dialects/TableGen/Traits.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,19 @@ class Trait {
4040
enum class Kind : uint8_t {
4141
LlvmAttributeTrait_First,
4242
LlvmEnumAttributeTrait = LlvmAttributeTrait_First,
43+
LlvmEnumFnAttributeTrait,
44+
LlvmEnumRetAttributeTrait,
45+
LlvmEnumParamAttributeTrait,
4346
LlvmMemoryAttributeTrait,
4447
LlvmAttributeTrait_Last = LlvmMemoryAttributeTrait,
4548
};
4649

4750
static std::unique_ptr<Trait> fromRecord(GenDialectsContext *context,
48-
RecordTy *record);
51+
RecordTy *record, int idx = 0);
4952

5053
virtual ~Trait() = default;
5154

52-
virtual void init(GenDialectsContext *context, RecordTy *record);
55+
virtual void init(GenDialectsContext *context, RecordTy *record, int idx);
5356

5457
Kind getKind() const { return m_kind; }
5558
RecordTy *getRecord() const { return m_record; }

lib/TableGen/Dialects.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,14 @@ void GenDialect::finalize(raw_ostream &errs) {
7979
GenDialectsContext::GenDialectsContext() = default;
8080
GenDialectsContext::~GenDialectsContext() = default;
8181

82-
Trait *GenDialectsContext::getTrait(RecordTy *traitRec) {
82+
Trait *GenDialectsContext::getTrait(RecordTy *traitRec, int idx) {
8383
if (!traitRec->isSubClassOf("Trait"))
8484
report_fatal_error(Twine("Trying to use '") + traitRec->getName() +
8585
"' as a trait, but it is not a subclass of 'Trait'");
8686

87-
auto &result = m_traits[traitRec];
87+
auto &result = idx < 0 ? m_traits[traitRec] : m_valueTraits[traitRec][idx];
8888
if (!result)
89-
result = Trait::fromRecord(this, traitRec);
89+
result = Trait::fromRecord(this, traitRec, idx);
9090
return result.get();
9191
}
9292

lib/TableGen/GenDialect.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,11 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
353353
if (!dialect->attribute_lists_empty()) {
354354
FmtContextScope scope{fmt};
355355
fmt.addSubst("attrBuilder", "attrBuilder");
356+
fmt.addSubst("argAttrList", "argAttrList");
356357

357358
for (const auto &enumeratedTraits : enumerate(dialect->attribute_lists())) {
358359
out << tgfmt("{\n ::llvm::AttrBuilder $attrBuilder{context};\n", &fmt);
360+
out << tgfmt(" ::llvm::AttributeList $argAttrList;\n", &fmt);
359361

360362
for (const Trait *trait : enumeratedTraits.value()) {
361363
if (auto *llvmAttribute = dyn_cast<LlvmAttributeTrait>(trait)) {
@@ -365,8 +367,8 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
365367
}
366368
}
367369

368-
out << tgfmt("m_attributeLists[$0] = ::llvm::AttributeList::get(context, "
369-
"::llvm::AttributeList::FunctionIndex, $attrBuilder);\n}\n",
370+
out << tgfmt("m_attributeLists[$0] = "
371+
"$argAttrList.addFnAttributes(context, $attrBuilder);\n}\n",
370372
&fmt, enumeratedTraits.index());
371373
}
372374
}

lib/TableGen/Operations.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm-dialects/TableGen/Dialects.h"
2323
#include "llvm-dialects/TableGen/Format.h"
2424

25+
#include "llvm/Support/ErrorHandling.h"
2526
#include "llvm/TableGen/Record.h"
2627

2728
using namespace llvm;
@@ -90,6 +91,11 @@ bool OperationBase::init(raw_ostream &errs, GenDialectsContext &context,
9091

9192
m_arguments = std::move(*arguments);
9293

94+
if (m_superclass && m_superclass->traits.size() > 0)
95+
traits = m_superclass->traits;
96+
97+
parseValueTraits(errs, record, context);
98+
9399
// Don't allow any other arguments if the superclass already uses
94100
// variadic arguments, as the arguments will be appended to the arguments of
95101
// the superclass.
@@ -297,6 +303,47 @@ void OperationBase::emitArgumentAccessorDefinitions(llvm::raw_ostream &out,
297303
}
298304
}
299305

306+
void OperationBase::parseValueTraits(raw_ostream &errs, RecordTy *record,
307+
GenDialectsContext &context) {
308+
const DagInit *insDag = record->getValueAsDag("arguments");
309+
std::unordered_map<std::string, unsigned> nameToIndexMap;
310+
for (unsigned i = 0; i < insDag->getNumArgs(); ++i) {
311+
StringRef name = insDag->getArgNameStr(i);
312+
nameToIndexMap[name.str()] = i + 1;
313+
}
314+
315+
const RecordVal *outsVal = record->getValue("results");
316+
if (outsVal) {
317+
const DagInit *DI = cast<DagInit>(outsVal->getValue());
318+
if (DI->getNumArgs() > 0) {
319+
StringRef name = DI->getArgNameStr(0);
320+
nameToIndexMap[name.str()] = 0;
321+
}
322+
}
323+
324+
const ListInit *List = record->getValueAsListInit("value_traits");
325+
for (const Init *I : List->getValues()) {
326+
if (const DagInit *DI = dyn_cast<DagInit>(I)) {
327+
if (DI->getNumArgs() != 1) {
328+
errs << "value_traits " << *DI << " is missing argument name";
329+
return;
330+
}
331+
332+
StringRef name = DI->getArgNameStr(0);
333+
334+
if (const DefInit *Op = dyn_cast<DefInit>(DI->getOperator())) {
335+
traits.push_back(
336+
context.getTrait(Op->getDef(), nameToIndexMap[name.str()]));
337+
} else {
338+
errs << "value_traits " << *DI << " is not of form (Trait $arg)";
339+
return;
340+
}
341+
} else {
342+
report_fatal_error("value_traits was not a list of DAG's");
343+
}
344+
}
345+
}
346+
300347
std::unique_ptr<OpClass> OpClass::parse(raw_ostream &errs,
301348
GenDialectsContext &context,
302349
RecordTy *record) {

0 commit comments

Comments
 (0)