@@ -131,6 +131,24 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131
131
return {&getBodyGraph ()};
132
132
}
133
133
134
+ // ===----------------------------------------------------------------------===//
135
+ // TOSA variable operator support.
136
+ // ===----------------------------------------------------------------------===//
137
+
138
+ static SmallVector<int64_t > convertToMlirShape (ArrayRef<int64_t > shape) {
139
+ return to_vector (llvm::map_range (shape, [](int64_t dim) {
140
+ return dim == -1 ? ShapedType::kDynamic : dim;
141
+ }));
142
+ }
143
+
144
+ // returns type of variable op
145
+ RankedTensorType mlir::tosa::getVariableType (tosa::VariableOp variableOp) {
146
+ Type elementType = variableOp.getType ();
147
+ DenseIntElementsAttr varShapeAttr = variableOp.getVarShape ();
148
+ auto shape = convertToMlirShape (to_vector (varShapeAttr.getValues <int64_t >()));
149
+ return RankedTensorType::get (shape, elementType);
150
+ }
151
+
134
152
// ===----------------------------------------------------------------------===//
135
153
// Tosa dialect initialization.
136
154
// ===----------------------------------------------------------------------===//
@@ -177,42 +195,80 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
177
195
// Parsers and printers
178
196
// ===----------------------------------------------------------------------===//
179
197
180
- ParseResult mlir::tosa::parseTypeOrAttr (OpAsmParser &parser, TypeAttr &typeAttr,
181
- Attribute &attr) {
198
+ namespace {
199
+
200
+ ParseResult getShapeAndElementType (OpAsmParser &parser, Type parsedType,
201
+ DenseElementsAttr &varShapeAttr,
202
+ TypeAttr &typeAttr) {
203
+ if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204
+ if (!shapedType.hasRank ())
205
+ return parser.emitError (parser.getCurrentLocation ())
206
+ << " expected ranked type" ;
207
+
208
+ auto elementType = shapedType.getElementType ();
209
+ typeAttr = TypeAttr::get (elementType);
210
+ ArrayRef<int64_t > shape = shapedType.getShape ();
211
+ Builder builder (parser.getContext ());
212
+ varShapeAttr = builder.getIndexTensorAttr (convertFromMlirShape (shape));
213
+ return success ();
214
+ }
215
+ return parser.emitError (parser.getCurrentLocation ())
216
+ << " expected shaped type" ;
217
+ }
218
+
219
+ } // namespace
220
+
221
+ // parses the optional initial value or type for a tosa variable
222
+ // with initial value:
223
+ // tosa.variable @name = dense<0.0> : tensor<1x8xf32>
224
+ //
225
+ // without initial value:
226
+ // tosa.variable @name : tensor<1x8xf32>
227
+ ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue (
228
+ OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
229
+ Attribute &initialValueAttr) {
182
230
if (succeeded (parser.parseOptionalEqual ())) {
183
- if (failed (parser.parseAttribute (attr ))) {
231
+ if (failed (parser.parseAttribute (initialValueAttr ))) {
184
232
return parser.emitError (parser.getCurrentLocation ())
185
233
<< " expected attribute" ;
186
234
}
187
- if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
188
- typeAttr = TypeAttr::get (typedAttr.getType ());
235
+ if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236
+ return getShapeAndElementType (parser, typedAttr.getType (), varShapeAttr,
237
+ typeAttr);
189
238
}
190
- return success ();
239
+ return parser.emitError (parser.getCurrentLocation ())
240
+ << " expected Typed attr" ;
191
241
}
192
242
193
- Type type;
194
- if (failed (parser.parseColonType (type))) {
195
- return parser.emitError (parser.getCurrentLocation ()) << " expected type" ;
243
+ initialValueAttr = nullptr ;
244
+ Type parsedType;
245
+ if (failed (parser.parseColonType (parsedType))) {
246
+ return parser.emitError (parser.getCurrentLocation ())
247
+ << " expected type after colon" ;
196
248
}
197
- typeAttr = TypeAttr::get (type);
198
-
199
- return success ();
249
+ return getShapeAndElementType (parser, parsedType, varShapeAttr, typeAttr);
200
250
}
201
251
202
- void mlir::tosa::printTypeOrAttr (OpAsmPrinter &p, Operation *op, TypeAttr type,
203
- Attribute attr) {
252
+ void mlir::tosa::printVariableOpTypeOrInitialValue (
253
+ OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
254
+ TypeAttr typeAttr, Attribute initialValueAttr) {
204
255
bool needsSpace = false ;
205
- auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
206
- if (!typedAttr || typedAttr.getType () != type.getValue ()) {
256
+ if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257
+ auto shape =
258
+ convertToMlirShape (to_vector (varShapeAttr.getValues <int64_t >()));
259
+ Type elementType = typeAttr.getValue ();
260
+ RankedTensorType tensorType =
261
+ RankedTensorType::get (ArrayRef<int64_t >(shape), elementType);
262
+ auto tensorTypeAttr = TypeAttr::get (tensorType);
207
263
p << " : " ;
208
- p.printAttribute (type );
264
+ p.printAttribute (tensorTypeAttr );
209
265
needsSpace = true ; // subsequent attr value needs a space separator
210
266
}
211
- if (attr ) {
267
+ if (initialValueAttr ) {
212
268
if (needsSpace)
213
269
p << ' ' ;
214
270
p << " = " ;
215
- p.printAttribute (attr );
271
+ p.printAttribute (initialValueAttr );
216
272
}
217
273
}
218
274
@@ -657,8 +713,9 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
657
713
<< symName << " ' has not been declared by 'tosa.variable'" ;
658
714
659
715
// Verify type and shape
660
- Type varType = cast<tosa::VariableOp>(varOp.value ()).getType ();
661
- if (errorIfTypeOrShapeMismatch (op, type, name, varType, " the input tensor" )
716
+ auto variableType = getVariableType (varOp.value ());
717
+ if (errorIfTypeOrShapeMismatch (op, type, name, variableType,
718
+ " the input tensor" )
662
719
.failed ())
663
720
return failure ();
664
721
@@ -1103,6 +1160,33 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1103
1160
result.types .push_back (outputType);
1104
1161
}
1105
1162
1163
+ static void buildVariableOp (OpBuilder &builder, OperationState &result,
1164
+ StringRef name, Type variableType,
1165
+ Attribute initialValue) {
1166
+ const Location loc{result.location };
1167
+ auto nameAttr = builder.getStringAttr (name);
1168
+
1169
+ auto shapedType = dyn_cast<ShapedType>(variableType);
1170
+ if (!shapedType) {
1171
+ (void )emitError (loc, " variable type must be a shaped type" );
1172
+ return ;
1173
+ }
1174
+ if (!shapedType.hasRank ()) {
1175
+ (void )emitError (loc, " variable type must be a ranked type" );
1176
+ return ;
1177
+ }
1178
+
1179
+ auto elementType = shapedType.getElementType ();
1180
+ auto elementTypeAttr = TypeAttr::get (elementType);
1181
+ ArrayRef<int64_t > shape = shapedType.getShape ();
1182
+ auto varShapeAttr = builder.getIndexTensorAttr (convertFromMlirShape (shape));
1183
+
1184
+ result.addAttribute (" name" , nameAttr);
1185
+ result.addAttribute (" var_shape" , varShapeAttr);
1186
+ result.addAttribute (" type" , elementTypeAttr);
1187
+ result.addAttribute (" initial_value" , initialValue);
1188
+ }
1189
+
1106
1190
// ===----------------------------------------------------------------------===//
1107
1191
// TOSA Operator Return Type Inference.
1108
1192
// ===----------------------------------------------------------------------===//
@@ -1676,12 +1760,6 @@ LogicalResult tosa::PadOp::verify() {
1676
1760
return success ();
1677
1761
}
1678
1762
1679
- static SmallVector<int64_t > convertToMlirShape (ArrayRef<int64_t > shape) {
1680
- return to_vector (llvm::map_range (shape, [](int64_t dim) {
1681
- return dim == -1 ? ShapedType::kDynamic : dim;
1682
- }));
1683
- }
1684
-
1685
1763
LogicalResult tosa::SliceOp::inferReturnTypeComponents (
1686
1764
MLIRContext *context, ::std::optional<Location> location,
1687
1765
SliceOp::Adaptor adaptor,
0 commit comments