Skip to content

Commit 04b63ac

Browse files
authored
[tosa] Change VariableOp to align with spec (#142240)
This fixes Tosa VariableOp to align with spec 1.0 - add var_shape attribute to store shape of variable type - change type attribute to store element type of variable type - add a builder so previous construction calls still work - fix up level check of rank to be on variable type instead of initial value which is optional - add level check of size for variable type - add lit tests for variable op's without initial values - add lit test for variable op with fixed rank but unknown dimension - add invalid lit test for variable op with unranked type Signed-off-by: Tai Ly <tai.ly@arm.com>
1 parent 4d42c8e commit 04b63ac

File tree

11 files changed

+266
-69
lines changed

11 files changed

+266
-69
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,16 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder<
197197
input, paddings);
198198
}]>;
199199

200+
// This builder is called on the TOSA variable operator with a variable type
201+
// and optional initial value. The builder will extract var_shape and element type
202+
// attributes from variable type.
203+
def Tosa_VariableOpBuilder : OpBuilder<
204+
(ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value),
205+
[{
206+
buildVariableOp($_builder, $_state, name, variable_type, initial_value);
207+
}]>;
208+
209+
200210
// Wrapper over base I32EnumAttr to set common fields.
201211
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
202212
: I32EnumAttr<name, description, cases> {

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,14 @@ class PatternRewriter;
4444

4545
namespace tosa {
4646

47-
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
48-
Attribute &attr);
49-
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
50-
Attribute attr);
47+
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser,
48+
DenseElementsAttr &varShapeAttr,
49+
TypeAttr &typeAttr,
50+
Attribute &initialValueAttr);
51+
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op,
52+
DenseElementsAttr varShapeAttr,
53+
TypeAttr typeAttr,
54+
Attribute initialValueAttr);
5155

5256
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
5357

@@ -172,6 +176,9 @@ std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
172176
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
173177
int32_t val = 0);
174178

179+
// returns type of variable op
180+
RankedTensorType getVariableType(VariableOp variableOp);
181+
175182
} // namespace tosa
176183
} // namespace mlir
177184

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
9292

9393
let arguments = (ins
9494
SymbolNameAttr:$name,
95+
IndexElementsAttr:$var_shape,
9596
TypeAttr:$type,
9697
OptionalAttr<AnyAttr>:$initial_value
9798
);
@@ -101,12 +102,16 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> {
101102
Extension<[Tosa_EXT_VARIABLE]>,
102103
];
103104

105+
let hasCustomAssemblyFormat = 1;
106+
104107
let assemblyFormat = [{
105108
$name
106109
attr-dict
107-
custom<TypeOrAttr>($type, $initial_value)
110+
custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
108111
}];
109112

113+
let builders = [Tosa_VariableOpBuilder];
114+
110115
let hasVerifier = 1;
111116
}
112117

mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
2626

2727
LogicalResult matchAndRewrite(tosa::VariableOp op,
2828
PatternRewriter &rewriter) const final {
29+
auto variableType = tosa::getVariableType(op);
2930
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
30-
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
31+
op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
3132
op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
3233
newVariable.setPrivate();
3334
rewriter.replaceOp(op, newVariable);

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 105 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,24 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131131
return {&getBodyGraph()};
132132
}
133133

134+
//===----------------------------------------------------------------------===//
135+
// TOSA variable operator support.
136+
//===----------------------------------------------------------------------===//
137+
138+
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
139+
return to_vector(llvm::map_range(shape, [](int64_t dim) {
140+
return dim == -1 ? ShapedType::kDynamic : dim;
141+
}));
142+
}
143+
144+
// returns type of variable op
145+
RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
146+
Type elementType = variableOp.getType();
147+
DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
148+
auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
149+
return RankedTensorType::get(shape, elementType);
150+
}
151+
134152
//===----------------------------------------------------------------------===//
135153
// Tosa dialect initialization.
136154
//===----------------------------------------------------------------------===//
@@ -177,42 +195,80 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
177195
// Parsers and printers
178196
//===----------------------------------------------------------------------===//
179197

180-
ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
181-
Attribute &attr) {
198+
namespace {
199+
200+
ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
201+
DenseElementsAttr &varShapeAttr,
202+
TypeAttr &typeAttr) {
203+
if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204+
if (!shapedType.hasRank())
205+
return parser.emitError(parser.getCurrentLocation())
206+
<< "expected ranked type";
207+
208+
auto elementType = shapedType.getElementType();
209+
typeAttr = TypeAttr::get(elementType);
210+
ArrayRef<int64_t> shape = shapedType.getShape();
211+
Builder builder(parser.getContext());
212+
varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
213+
return success();
214+
}
215+
return parser.emitError(parser.getCurrentLocation())
216+
<< "expected shaped type";
217+
}
218+
219+
} // namespace
220+
221+
// parses the optional initial value or type for a tosa variable
222+
// with initial value:
223+
// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
224+
//
225+
// without initial value:
226+
// tosa.variable @name : tensor<1x8xf32>
227+
ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue(
228+
OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
229+
Attribute &initialValueAttr) {
182230
if (succeeded(parser.parseOptionalEqual())) {
183-
if (failed(parser.parseAttribute(attr))) {
231+
if (failed(parser.parseAttribute(initialValueAttr))) {
184232
return parser.emitError(parser.getCurrentLocation())
185233
<< "expected attribute";
186234
}
187-
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
188-
typeAttr = TypeAttr::get(typedAttr.getType());
235+
if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236+
return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
237+
typeAttr);
189238
}
190-
return success();
239+
return parser.emitError(parser.getCurrentLocation())
240+
<< "expected Typed attr";
191241
}
192242

193-
Type type;
194-
if (failed(parser.parseColonType(type))) {
195-
return parser.emitError(parser.getCurrentLocation()) << "expected type";
243+
initialValueAttr = nullptr;
244+
Type parsedType;
245+
if (failed(parser.parseColonType(parsedType))) {
246+
return parser.emitError(parser.getCurrentLocation())
247+
<< "expected type after colon";
196248
}
197-
typeAttr = TypeAttr::get(type);
198-
199-
return success();
249+
return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
200250
}
201251

202-
void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
203-
Attribute attr) {
252+
void mlir::tosa::printVariableOpTypeOrInitialValue(
253+
OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
254+
TypeAttr typeAttr, Attribute initialValueAttr) {
204255
bool needsSpace = false;
205-
auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
206-
if (!typedAttr || typedAttr.getType() != type.getValue()) {
256+
if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257+
auto shape =
258+
convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
259+
Type elementType = typeAttr.getValue();
260+
RankedTensorType tensorType =
261+
RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
262+
auto tensorTypeAttr = TypeAttr::get(tensorType);
207263
p << ": ";
208-
p.printAttribute(type);
264+
p.printAttribute(tensorTypeAttr);
209265
needsSpace = true; // subsequent attr value needs a space separator
210266
}
211-
if (attr) {
267+
if (initialValueAttr) {
212268
if (needsSpace)
213269
p << ' ';
214270
p << "= ";
215-
p.printAttribute(attr);
271+
p.printAttribute(initialValueAttr);
216272
}
217273
}
218274

@@ -657,8 +713,9 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
657713
<< symName << "' has not been declared by 'tosa.variable'";
658714

659715
// Verify type and shape
660-
Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
661-
if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
716+
auto variableType = getVariableType(varOp.value());
717+
if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
718+
"the input tensor")
662719
.failed())
663720
return failure();
664721

@@ -1103,6 +1160,33 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
11031160
result.types.push_back(outputType);
11041161
}
11051162

1163+
static void buildVariableOp(OpBuilder &builder, OperationState &result,
1164+
StringRef name, Type variableType,
1165+
Attribute initialValue) {
1166+
const Location loc{result.location};
1167+
auto nameAttr = builder.getStringAttr(name);
1168+
1169+
auto shapedType = dyn_cast<ShapedType>(variableType);
1170+
if (!shapedType) {
1171+
(void)emitError(loc, "variable type must be a shaped type");
1172+
return;
1173+
}
1174+
if (!shapedType.hasRank()) {
1175+
(void)emitError(loc, "variable type must be a ranked type");
1176+
return;
1177+
}
1178+
1179+
auto elementType = shapedType.getElementType();
1180+
auto elementTypeAttr = TypeAttr::get(elementType);
1181+
ArrayRef<int64_t> shape = shapedType.getShape();
1182+
auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1183+
1184+
result.addAttribute("name", nameAttr);
1185+
result.addAttribute("var_shape", varShapeAttr);
1186+
result.addAttribute("type", elementTypeAttr);
1187+
result.addAttribute("initial_value", initialValue);
1188+
}
1189+
11061190
//===----------------------------------------------------------------------===//
11071191
// TOSA Operator Return Type Inference.
11081192
//===----------------------------------------------------------------------===//
@@ -1676,12 +1760,6 @@ LogicalResult tosa::PadOp::verify() {
16761760
return success();
16771761
}
16781762

1679-
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
1680-
return to_vector(llvm::map_range(shape, [](int64_t dim) {
1681-
return dim == -1 ? ShapedType::kDynamic : dim;
1682-
}));
1683-
}
1684-
16851763
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
16861764
MLIRContext *context, ::std::optional<Location> location,
16871765
SliceOp::Adaptor adaptor,

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,8 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
215215

216216
template <>
217217
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
218-
::mlir::Attribute attr = op.getInitialValueAttr();
219-
if (attr == nullptr)
220-
return failure();
221-
222-
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
223-
addType(getElementTypeOrSelf(typedAttr));
224-
return success();
225-
}
226-
return failure();
218+
addType(op.getType());
219+
return success();
227220
}
228221

229222
template <>

0 commit comments

Comments
 (0)