Skip to content

Commit

Permalink
# This is a combination of 22 commits.
Browse files Browse the repository at this point in the history
# This is the 1st commit message:

[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (llvm#2340)

[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.
# This is the commit message llvm#2:

update PyTorch version to 2.1.0.dev20230729 (llvm#2354)

- torch version: 2.1.0.dev20230729
 - torch commit hash: b638df0afb83572724032c824c64e481bb4499a0
 - torchvision version: 0.16.0.dev20230729

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
# This is the commit message llvm#3:

update PyTorch version to 2.1.0.dev20230730 (llvm#2356)

- torch version: 2.1.0.dev20230730
 - torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742
 - torchvision version: 0.16.0.dev20230730

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
# This is the commit message llvm#4:

update PyTorch version to 2.1.0.dev20230731 (llvm#2359)

- torch version: 2.1.0.dev20230731
 - torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7
 - torchvision version: 0.16.0.dev20230731

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
# This is the commit message llvm#5:

LTC->MLIR Debug Info support (llvm#1922)

* LTC->MLIR Debug Info support

* SW-95317 Propagate Lazy->Jit->MLIR scope name.

* Enhance location information based on op names

Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.

* Update locations logic; updated debug-info.py test

* Use {scope}/{op_name} format to track names by default

---------

Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
# This is the commit message llvm#6:

build: update llvm tag to 4189584

Summary of changes:
- Update tags
  llvm: 4189584
  mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>

# This is the commit message llvm#7:

update PyTorch version to 2.1.0.dev20230802 (llvm#2366)

- torch version: 2.1.0.dev20230802
 - torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e
 - torchvision version: 0.16.0.dev20230802

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
# This is the commit message llvm#8:

Change Python version from 3.10 to 3.11 in installation instructions (llvm#2370)


# This is the commit message llvm#9:

Add CITATION file (llvm#2371)


# This is the commit message llvm#10:

Add packaging as an install dependency (llvm#2369)

Needed by `torch_mlir._version`. Resolves llvm#2368.
# This is the commit message llvm#11:

[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (llvm#2358)

* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op
# This is the commit message llvm#12:

update PyTorch version to 2.1.0.dev20230803 (llvm#2372)

- torch version: 2.1.0.dev20230803
 - torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e
 - torchvision version: 0.16.0.dev20230803

Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
# This is the commit message llvm#13:

Prevent failed stable CI job from cancelling nightly jobs (llvm#2373)

The CI jobs that use stable PyTorch are currently not required to pass
in order for a patch to get merged in `main`. This commit makes sure
that if a CI job for stable PyTorch fails, it does not cancel the
other required jobs.
# This is the commit message llvm#14:

[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (llvm#2355)


# This is the commit message llvm#15:

update

# This is the commit message llvm#16:

update xfail sets

# This is the commit message llvm#17:

update xfail_sets

# This is the commit message llvm#18:

update

# This is the commit message llvm#19:

fix xfail_sets

# This is the commit message llvm#20:

update:

# This is the commit message llvm#21:

update

# This is the commit message llvm#22:

update:
  • Loading branch information
Vremold authored and JianzheXiao committed Aug 4, 2023
1 parent 0109bf7 commit f85ea8c
Show file tree
Hide file tree
Showing 28 changed files with 764 additions and 35 deletions.
1 change: 1 addition & 0 deletions .github/workflows/buildAndTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ concurrency:
# macos - arm64 - llvm in-tree - pytorch binary - build only # cross compile, can't test arm64
jobs:
build-test:
continue-on-error: ${{ matrix.torch-version == 'stable' }}
strategy:
fail-fast: true
matrix:
Expand Down
19 changes: 19 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
cff-version: 1.2.0
title: Torch-MLIR
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- name: LLVM
repository-code: 'https://github.com/llvm/torch-mlir'
abstract: >-
The Torch-MLIR project aims to provide first class support
from the PyTorch ecosystem to the MLIR ecosystem.
keywords:
- Compiler
- PyTorch
- MLIR
license:
- Apache-2.0 with LLVM Exceptions
- BSD
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ We have few paths to lower down to the Torch MLIR Dialect.

## Install torch-mlir snapshot

At the time of writing, we release pre-built snapshot of torch-mlir for Python 3.10 on Linux and macOS.
At the time of writing, we release pre-built snapshot of torch-mlir for Python 3.11 on Linux and macOS.

If you have Python 3.10, the following commands initialize a virtual environment.
If you have Python 3.11, the following commands initialize a virtual environment.
```shell
python3.10 -m venv mlir_venv
python3.11 -m venv mlir_venv
source mlir_venv/bin/activate
```

Or, if you want to switch over multiple versions of Python using conda, you can create a conda environment with Python 3.10.
Or, if you want to switch over multiple versions of Python using conda, you can create a conda environment with Python 3.11.
```shell
conda create -n torch-mlir python=3.10
conda create -n torch-mlir python=3.11
conda activate torch-mlir
python -m pip install --upgrade pip
```
Expand Down
3 changes: 2 additions & 1 deletion e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
LINALG_XFAIL_SET,
MAKE_FX_TOSA_PASS_SET,
STABLEHLO_PASS_SET,
STABLEHLO_CRASHING_SET,
TOSA_PASS_SET,
LTC_XFAIL_SET,
TORCHDYNAMO_XFAIL_SET,
Expand Down Expand Up @@ -101,7 +102,7 @@ def main():
elif args.config == "stablehlo":
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
crashing_set = set()
crashing_set = STABLEHLO_CRASHING_SET
elif args.config == "native_torch":
config = NativeTorchTestConfig()
xfail_set = set()
Expand Down
55 changes: 55 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,34 @@
}

STABLEHLO_PASS_SET = {
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"AddIntModule_basic",
"AtenIntBoolOpModule_basic",
"AtenIntTensorByteDtypeModule_basic",
"AtenIntTensorCharDtypeModule_basic",
"BoolFloatFalseModule_basic",
"BoolFloatTrueModule_basic",
"BoolIntFalseModule_basic",
"BoolIntTrueModule_basic",
"CeilFloatModule_basic",
"DivFloatModule_basic",
"DivIntModule_basic",
"EqIntModule_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
"GtFloatIntModule_basic",
"GtIntModule_basic",
"MulIntModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"SqrtIntModule_basic",
"SubFloatModule_basic",
"SubIntModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToIntZeroRank_basic",
"TensorToFloatZeroRank_basic",
"AliasModule_basic",
"TensorIntModule_basic",
"AllBoolFalseModule_basic",
Expand Down Expand Up @@ -352,6 +380,7 @@
"MaskedFillScalarIntValueStaticModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AddSizeIntModule_basic",
"AddSizeIntNegDimModule_basic",
"ArangeDtypeFloatModule_basic",
Expand Down Expand Up @@ -753,6 +782,7 @@
"ReshapeExpandModule_basic",
"RollModule_basic",
"TestMultipleTensorReturn_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"BaddbmmStaticModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
Expand Down Expand Up @@ -826,9 +856,23 @@
"TupleModule_basic",
}

STABLEHLO_CRASHING_SET = {
# These e2e tests crash because currently mlir-hlo's shape-component-analysis
# only support exact one index in tensor::ExtractOp when it's related with
# some tensors' shape. REF:
# https://github.com/tensorflow/mlir-hlo/blob/master/mhlo/analysis/shape_component_analysis.cc#L586
# FIXME if upstream mlir-hlo fix this.
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",

"Aten_EmbeddingBagExample_basic"
}

# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"AliasModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
Expand Down Expand Up @@ -1142,6 +1186,11 @@
"TupleModule_basic",
"NumpyTRank0Module_basic",
"Permute0RankModule_basic",
"Add_Module_basic",
"SoftmaxIntModule_basic",
"SoftmaxIntNegDimModule_basic",
"_LogSoftmaxModule_basic",
"_SoftmaxModule_basic",
}

MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
Expand All @@ -1150,6 +1199,8 @@
"SliceWholeTensorModule_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
}) - {
### Test failing in make_fx_tosa but not in tosa

Expand All @@ -1172,6 +1223,8 @@
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
# RuntimeError: The size of tensor a (7) must match the size of tensor b (3) at non-singleton dimension 1
"Add_Module_basic",
}

if torch_version_for_comparison() < version.parse("2.1.0.dev"):
Expand All @@ -1190,6 +1243,8 @@
"_ConvolutionDeprecated2DBenchmarkModule_basic",
"_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddIntModule_basic",
Expand Down
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 7726 files
2 changes: 1 addition & 1 deletion externals/mlir-hlo
Submodule mlir-hlo updated 145 files
100 changes: 100 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5250,6 +5250,55 @@ def Torch_AtenScatter_ValueOp : Torch_Op<"aten.scatter_.value", [
}];
}

def Torch_AtenMaskedScatterOp : Torch_Op<"aten.masked_scatter", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$mask,
AnyTorchTensorType:$source
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedScatterOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedScatterOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::masked_scatter_ : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$mask,
AnyTorchTensorType:$source
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedScatter_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedScatter_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand All @@ -5274,6 +5323,30 @@ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
}];
}

def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAdaptiveAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenAdaptiveAvgPool1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenTopkOp : Torch_Op<"aten.topk", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -7697,6 +7770,33 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
}];
}

<<<<<<< HEAD
def Torch_AtenTileOp : Torch_Op<"aten.tile", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::tile : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dims
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenTileOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenTileOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

=======
>>>>>>> ad24482e (update)
def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
AllowsTypeRefinement,
ReadOnly
Expand Down
55 changes: 55 additions & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,51 @@ class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {

} // namespace

namespace {
// Casts a tensor of exactly one element to an elemental type.
// Many codes borrowed from
// `lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp`
template <typename AtenOpT>
class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputType =
adaptor.getA().getType().template dyn_cast<RankedTensorType>();
if (!inputType)

op.emitError("only Tensor types supported in StableHLO");
auto outType =
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType());
Location loc = op.getLoc();
Value input = adaptor.getA();
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
int64_t inputRank = inputSizes.size();
Type inputDtype =
op.getA().getType().template cast<BaseTensorType>().getDtype();

Value constantOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
for (int64_t i = 0; i < inputRank; i++)
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);

Value constantZero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
SmallVector<Value> indices(inputRank, constantZero);
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result,
resultType, inputDtype));
return success();
}
};
} // namespace

// The binary broadcast patterns
namespace {
template <typename AtenOpT, typename ChloOpT>
Expand Down Expand Up @@ -1662,6 +1707,16 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN

#define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenOp>>(typeConverter, \
context)

INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp);
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp);
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenBoolTensorOp);
#undef INSERT_TENSOR_TO_SCALAR_PATTERN

#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context)
Expand Down
Loading

0 comments on commit f85ea8c

Please sign in to comment.