Skip to content

Commit

Permalink
[mlir][tosa] Added more shape inference for tosa ops
Browse files Browse the repository at this point in the history
Added shape inference for:
- scatter
- gather
- transpose
- slice
- pad
- concat
- reduction operations

Also updated reshape for more aggressive shape inference.

Differential Revision: https://reviews.llvm.org/D105383
  • Loading branch information
rsuderman committed Jul 12, 2021
1 parent a95f56f commit 5a4e776
Show file tree
Hide file tree
Showing 4 changed files with 1,144 additions and 297 deletions.
89 changes: 70 additions & 19 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
//===----------------------------------------------------------------------===//
// Operator: argmax
//===----------------------------------------------------------------------===//
def Tosa_ArgMaxOp : Tosa_Op<"argmax", [NoSideEffect]> {
def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Perform argmax on the input.";

let description = [{
Expand Down Expand Up @@ -173,7 +176,10 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: fully_connected
//===----------------------------------------------------------------------===//
def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [NoSideEffect]> {
def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Fully Connected operator";

let description = [{
Expand All @@ -199,7 +205,10 @@ def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: matmul
//===----------------------------------------------------------------------===//
def Tosa_MatMulOp : Tosa_Op<"matmul", [NoSideEffect]> {
def Tosa_MatMulOp : Tosa_Op<"matmul", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Matrix multiplication with bias";

let description = [{
Expand Down Expand Up @@ -589,8 +598,9 @@ def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [
// Operator: logical_right_shift
//===----------------------------------------------------------------------===//
def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, ["inferReturnTypeComponents"]>, ResultsBroadcastableShape,
NoSideEffect]> {
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
ResultsBroadcastableShape, NoSideEffect]> {
let summary = "Elementwise Logical Right Shift";

let description = [{
Expand Down Expand Up @@ -783,7 +793,10 @@ def Tosa_SubOp : Tosa_Op<"sub", [
//===----------------------------------------------------------------------===//
// Operator: table
//===----------------------------------------------------------------------===//
def Tosa_TableOp : Tosa_Op<"table", [NoSideEffect]> {
def Tosa_TableOp : Tosa_Op<"table", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Table lookup op";

let description = [{
Expand Down Expand Up @@ -1178,7 +1191,10 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
//===----------------------------------------------------------------------===//
// Operator: reduce_all
//===----------------------------------------------------------------------===//
def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [NoSideEffect]> {
def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Reduce All operator";

let description = [{
Expand All @@ -1198,7 +1214,10 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: reduce_any
//===----------------------------------------------------------------------===//
def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [NoSideEffect]> {
def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Reduce Any operator";

let description = [{
Expand All @@ -1218,7 +1237,10 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: reduce_max
//===----------------------------------------------------------------------===//
def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [NoSideEffect]> {
def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Reduce Max operator";

let description = [{
Expand All @@ -1238,7 +1260,10 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: reduce_min
//===----------------------------------------------------------------------===//
def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [NoSideEffect]> {
def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Reduce Min operator";

let description = [{
Expand All @@ -1258,7 +1283,10 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: reduce_prod
//===----------------------------------------------------------------------===//
def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [NoSideEffect]> {
def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Reduce Prod operator";

let description = [{
Expand All @@ -1278,7 +1306,10 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: reduce_sum
//===----------------------------------------------------------------------===//
def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [NoSideEffect]> {
def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Reduce Sum operator";

let description = [{
Expand All @@ -1303,7 +1334,10 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: concat
//===----------------------------------------------------------------------===//
def Tosa_ConcatOp : Tosa_Op<"concat", [NoSideEffect]> {
def Tosa_ConcatOp : Tosa_Op<"concat", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Concatenates tensors along one dimension.";

let description = [{
Expand All @@ -1324,7 +1358,10 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: pad
//===----------------------------------------------------------------------===//
def Tosa_PadOp : Tosa_Op<"pad", [NoSideEffect]> {
def Tosa_PadOp : Tosa_Op<"pad", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Pads a tensor with zeros.";

let description = [{
Expand Down Expand Up @@ -1396,7 +1433,9 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
//===----------------------------------------------------------------------===//
// Operator: slice
//===----------------------------------------------------------------------===//
def Tosa_SliceOp: Tosa_Op<"slice", [NoSideEffect]> {
def Tosa_SliceOp: Tosa_Op<"slice", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>, NoSideEffect]> {
let summary = "Slice operator";

let description = [{
Expand All @@ -1419,7 +1458,10 @@ def Tosa_SliceOp: Tosa_Op<"slice", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: tile
//===----------------------------------------------------------------------===//
def Tosa_TileOp: Tosa_Op<"tile", [NoSideEffect]> {
def Tosa_TileOp: Tosa_Op<"tile", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Tile operator";

let description = [{
Expand All @@ -1438,7 +1480,10 @@ def Tosa_TileOp: Tosa_Op<"tile", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: transpose
//===----------------------------------------------------------------------===//
def Tosa_TransposeOp : Tosa_Op<"transpose", [NoSideEffect]> {
def Tosa_TransposeOp : Tosa_Op<"transpose", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Transpose operator";

let description = [{
Expand All @@ -1463,7 +1508,10 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: gather
//===----------------------------------------------------------------------===//
def Tosa_GatherOp : Tosa_Op<"gather", [NoSideEffect]> {
def Tosa_GatherOp : Tosa_Op<"gather", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Gather operation,";

let description = [{
Expand All @@ -1484,7 +1532,10 @@ def Tosa_GatherOp : Tosa_Op<"gather", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// Operator: scatter
//===----------------------------------------------------------------------===//
def Tosa_ScatterOp : Tosa_Op<"scatter", [NoSideEffect]> {
def Tosa_ScatterOp : Tosa_Op<"scatter", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NoSideEffect]> {
let summary = "Scatter operation,";

let description = [{
Expand Down
Loading

0 comments on commit 5a4e776

Please sign in to comment.