Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/llvm/torch-mlir into zhekun…
Browse files Browse the repository at this point in the history
…z/hlo_index
  • Loading branch information
zhekunz2 committed May 22, 2023
2 parents b5de7bd + 588bdc1 commit 2db5b79
Show file tree
Hide file tree
Showing 29 changed files with 1,092 additions and 102 deletions.
1 change: 1 addition & 0 deletions .github/workflows/RollPyTorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ jobs:
- torchvision version: ${{ env.PTVISION_RELEASE }}
committer: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
title: update PyTorch version to ${{ env.PT_RELEASE }}
token: ${{ secrets.ROLLPYTORCH_TOKEN0 }}
reviewers: ashay, powderluv, vivekkhandelwal1
3 changes: 3 additions & 0 deletions build_tools/autogen_ltc_backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ blacklist:
- index_put # Error: TODO not sure if there are other valid types to handle here
- index_put_ # Error: TODO not sure if there are other valid types to handle here

# Ops with list of tensors output
- unbind.int

# Additional ops which autogen is supported for but don't compile yet
- _convolution
- detach
Expand Down
35 changes: 34 additions & 1 deletion e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarIntModule_basic",
Expand Down Expand Up @@ -258,6 +259,10 @@
"AtenComplexImagModule_basic",
"AtenComplexRealModule_basic",
# END tests failing due to: complex floating point ops

# ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
}

TORCHDYNAMO_CRASHING_SET = {
Expand Down Expand Up @@ -317,6 +322,12 @@
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"BatchMlpLayerModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"ResNet18StaticModule_basic",
"BmmModule_basic",
"BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic",
Expand Down Expand Up @@ -580,6 +591,21 @@
"MmModule_basic",
"MmModule_chained",
"MaxPool2dStaticModule_basic",
"EmptyModule_contiguous",
"EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory",
"EmptyModule_int",
"EmptyModule_float",
"NewEmptyModuleDefaultDtype_basic",
"NewEmptyModuleFalsePinMemory_basic",
"NewEmptyModuleFloat2D_basic",
"NewEmptyModuleFloat3D_basic",
"NewEmptyModuleInt2D_basic",
"NewEmptyModuleInt3D_basic",
"NewEmptyModuleLayoutIntDtype_basic",
"NewEmptyModuleNonDefaultFloatDtype_basic",
"NewEmptyModuleNonDefaultIntDtype_basic",
"NewEmptyStridedModuleDefaultDtype_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
Expand Down Expand Up @@ -702,6 +728,8 @@
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"AtenComplex64Module_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
}

# Write the TOSA set as a "passing" set as it is very early in development
Expand Down Expand Up @@ -807,6 +835,7 @@
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"FlattenStaticModule_basic",
"FlattenRank0Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
Expand Down Expand Up @@ -980,6 +1009,8 @@
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"DetachModule_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
"TensorsConcatStaticModule_basic",
"TensorsConcatNegativeDimStaticModule_basic",
"AtenComplex64Module_basic",
Expand Down Expand Up @@ -1161,5 +1192,7 @@
"VarMeanDimBiasedModule_basic",
"AtenComplexImagModule_basic",
"AtenComplexRealModule_basic",
"AtenComplexViewModule_basic"
"AtenComplexViewModule_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
}
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,90 @@ def TMTensor_SortOp : TMTensor_Op<"sort",
}];
}

def TMTensor_AttentionOp : TMTensor_Op<"attention",
[DeclareOpInterfaceMethods<TMTensorInterface,
["payloadUsesValueFromOperand"]>,
DeclareOpInterfaceMethods<ScalarLoopOpInterface,
["generateScalarImplementation"]>]> {
let summary = "Attention operator";
let description = [{
This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes
the attention. Each of the inputs has shape BxNxd where B is the
of the batch dimension, N is the sequence length and d is head dimension.
Typically N >>> d. Mathematically, the attention is defined as
matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually,
this operator also performs scaling, masking and dropout, but we leave
that out of the current implementation.
}];

let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs
);

let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs)>
];

let results = (outs Variadic<AnyRankedTensor>:$result);
let assemblyFormat = [{
attr-dict
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
}];

let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{
Value getQuery() {
return getInputOperand(0)->get();
}
Value getKey() {
return getInputOperand(1)->get();
}
Value getValue() {
return getInputOperand(2)->get();
}
Value getOutput() {
return getOutputOperand(0)->get();
}
ShapedType getQueryType() {
return getQuery().getType().cast<ShapedType>();
}
ShapedType getKeyType() {
return getKey().getType().cast<ShapedType>();
}
ShapedType getValueType() {
return getValue().getType().cast<ShapedType>();
}
ShapedType getOutputType() {
return getOutput().getType().cast<ShapedType>();
}
int64_t getQueryRank() {
return getQueryType().getRank();
}
int64_t getKeyRank() {
return getKeyType().getRank();
}
int64_t getValueRank() {
return getValueType().getRank();
}
int64_t getOutputRank() {
return getOutputType().getRank();
}
int64_t getIterationDomainRank() {
return 2;
};
// Method to implement for specifying output range for
// DestinationStyleOpInterface
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
std::pair<unsigned, unsigned> outputsIndexAndLength =
getODSOperandIndexAndLength(1);
return std::make_pair<int64_t, int64_t>(
outputsIndexAndLength.first,
outputsIndexAndLength.first + outputsIndexAndLength.second);
}
}];
}

//===----------------------------------------------------------------------===//
// Pure ops
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 2db5b79

Please sign in to comment.