Skip to content

Commit

Permalink
[Torch-MLIR] support for asr (alibaba#929)
Browse files Browse the repository at this point in the history
* support for asr
* erase traced shape information when shape analysis failed
  • Loading branch information
zzpmiracle authored Jan 4, 2023
1 parent cfed001 commit 3d791ca
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 34 deletions.
64 changes: 37 additions & 27 deletions pytorch_blade/pytorch_blade/compiler/jit/torch/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ void PropertyPropBase::propagateBlock(Block* block, bool insert_expands) {
for (Node* node : block->nodes()) {
try {
propagateNode(node, insert_expands);
} catch (propagation_error& e) {
setUnshapedType(node);
} catch (std::exception& e) {
ErrorReport errMsg(node->sourceRange());
errMsg << ExceptionMessage(e)
Expand Down Expand Up @@ -158,7 +156,14 @@ void PropertyPropBase::processLoop(Node* node) {
}

void PropertyPropBase::setUnshapedType(Value* o) {
o->setType(unshapedType(o->type()));
auto type = o->type();
TypePtr withDimOrUnshapedType;
if (TensorTypePtr tt = type->cast<TensorType>()) {
withDimOrUnshapedType = tt->withDim(tt->sizes().size());
} else {
withDimOrUnshapedType = unshapedType(type);
}
o->setType(withDimOrUnshapedType);
}

void PropertyPropBase::setUnshapedType(Node* node) {
Expand All @@ -171,10 +176,6 @@ namespace prim {
using namespace ::c10::prim;
}

#define SHAPE_ASSERT(cond) \
if (!(cond)) \
throw propagation_error()

namespace {

bool isValidArgumentForRunning(Value* v) {
Expand Down Expand Up @@ -603,7 +604,8 @@ class ShapePropagator : public PropertyPropBase {
case aten::FloatImplicit:
case aten::IntImplicit:
case aten::size:
return; // correct num type is already set
case prim::device:
return; // correct type is already set
case aten::item:
case aten::ScalarImplicit: {
if (auto dtype = getDType(*node->input()->type())) {
Expand Down Expand Up @@ -768,14 +770,7 @@ class ShapePropagator : public PropertyPropBase {
<< node->schema();
}

if (DoesntRefineOutputs(node)) {
return;
}

if (PropagateShapeOnNodeByRunningIt(node)) {
return;
}

// shape anaysis failed, erase traced shape only
return setUnshapedType(node);
}

Expand Down Expand Up @@ -915,7 +910,6 @@ class ShapePropagator : public PropertyPropBase {
"aten::feature_dropout(Tensor input, float p, bool train) -> Tensor",
"aten::hardshrink(Tensor self, Scalar lambd) -> Tensor",
"aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
"aten::glu(Tensor self, int dim) -> Tensor",
"aten::inverse(Tensor self) -> Tensor",
"aten::group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, bool cudnn_enabled) -> Tensor",
"aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor",
Expand Down Expand Up @@ -1316,6 +1310,7 @@ class ShapePropagator : public PropertyPropBase {
// - First input should be the only tensor input
static const register_formula_for aten_to_dtype{
{"aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor",
"aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
#if PYTORCH_MAJOR_VERSION == 1 && PYTORCH_MINOR_VERSION >= 8
"aten::to.dtype_layout(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"
#endif
Expand All @@ -1340,7 +1335,8 @@ class ShapePropagator : public PropertyPropBase {
// Additionally:
// - First input should be the only tensor input
static const register_formula_for aten_to_device{
{"aten::to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"},
{"aten::to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor",
"aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"},
[](Node* node) -> type_vec_t {
at::optional<IValue> maybe_device_option = node->get(attr::device);
if (auto type = node->input(0)->type()->cast<TensorType>()) {
Expand Down Expand Up @@ -1752,22 +1748,14 @@ class ShapePropagator : public PropertyPropBase {

static const auto factory_with_ndim = [](Node* node,
int dim) -> type_vec_t {
at::optional<IValue> maybe_layout_option = node->get(attr::layout);
if (!maybe_layout_option)
return {};

at::optional<IValue> maybe_device_option = node->get(attr::device);
if (!maybe_device_option)
return {};
auto device =
(maybe_device_option->isNone() ? at::kCPU
: maybe_device_option->toDevice());

at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
if (!maybe_dtype_option)
return {};
auto dtype =
(maybe_dtype_option->isNone() ? at::kDouble
(maybe_dtype_option->isNone() ? at::kFloat
: maybe_dtype_option->toScalarType());

return {TensorType::create(
Expand Down Expand Up @@ -2319,6 +2307,28 @@ class ShapePropagator : public PropertyPropBase {
node->output()->setType(type->withDim(0));
return true;
}
} else if (node->matches("aten::glu(Tensor self, int dim) -> Tensor")) {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
auto sizesOptional = type->symbolic_sizes().sizes();
auto dimOptional = node->get<int64_t>(attr::dim);
if (!(sizesOptional && dimOptional))
return false;

std::vector<c10::ShapeSymbol> new_sizes = sizesOptional.value();
int64_t input_rank = new_sizes.size();
int64_t dim =
at::maybe_wrap_dim(dimOptional.value(), input_rank, false);

if (new_sizes[dim].is_static()) {
new_sizes[dim] =
ShapeSymbol::fromStaticSize(new_sizes[dim].static_size() / 2);
} else {
// set default to dynamic
new_sizes[dim] = ShapeSymbol::newSymbol();
}
node->outputs()[0]->setType(type->withSymbolicShapes(new_sizes));
}
return true;
} else if (
node->matches(
#if PYTORCH_VERSION_GE(1, 8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"aten::neg",
// TODO(disc): need to lower mhlo::RngOp to mhlo_disc::UniformOp
//"aten::native_dropout",
"aten::ones",
"aten::ones_like",
"aten::permute",
"aten::pow",
"aten::relu",
Expand Down Expand Up @@ -136,6 +138,7 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"aten::unsqueeze",
"aten::view",
"aten::view_as",
"aten::zeros"
"aten::zeros_like",
"prim::Constant",
"prim::ListConstruct",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,83 @@ LogicalResult ConvertAtenOp<OperatorOp>::matchAndRewrite(

rewriter.replaceOpWithNewOp<AtenMulScalarOp>(op, outTy, addVal, floatScale);
return success();
} else if ("aten.conv1d" == name) {
// "aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride,
// int[] padding, int[] dilation, int groups) -> Tensor"
Value emptyList = rewriter.create<PrimListConstructOp>(
op.getLoc(),
Torch::ListType::get(Torch::IntType::get(op.getContext())),
SmallVector<Value>());
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op,
op->getResultTypes(),
op.getOperand(0),
op.getOperand(1),
op.getOperand(2),
op.getOperand(3),
op.getOperand(4),
op.getOperand(5),
cstFalse,
emptyList,
op.getOperand(6));
return success();
} else if ("aten.view_as" == name) {
auto sizeListType =
Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value sizeList = rewriter.create<AtenSizeOp>(
op.getLoc(), sizeListType, op.getOperand(1));
rewriter.replaceOpWithNewOp<AtenViewOp>(
op, op->getResultTypes(), op.getOperand(0), sizeList);
return success();
} else if ("aten.glu" == name) {
auto inputTy = op.getOperand(0).getType().cast<BaseTensorType>();
int64_t inputRank = inputTy.getSizes().size();
Value dim = op.getOperand(1);
int64_t dimInt;
if (!matchPattern(dim, m_TorchConstantInt(&dimInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dim must be a constant");
dimInt = toPositiveDim(dimInt, inputRank);
Value size =
rewriter.create<AtenSizeIntOp>(op.getLoc(), op.getOperand(0), dim);
Value constTwo = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(2));
Value constNone = rewriter.create<ConstantNoneOp>(loc);
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
ArrayRef<int64_t> inputShape = inputTy.getSizes();
SmallVector<int64_t> sliceShape{inputShape.begin(), inputShape.end()};
sliceShape[dimInt] = ShapedType::kDynamicSize;

Type sliceTy = inputTy.getWithSizesAndDtype(
llvm::makeArrayRef(sliceShape), inputTy.getDtype());
SmallVector<int64_t> empty;
Value halfSize =
rewriter.create<AtenFloordivIntOp>(op.getLoc(), size, constTwo);
Value a = rewriter.create<AtenSliceTensorOp>(
op.getLoc(),
sliceTy,
op.getOperand(0),
dim,
constZero,
halfSize,
constOne);
Value b = rewriter.create<AtenSliceTensorOp>(
op.getLoc(),
sliceTy,
op.getOperand(0),
dim,
halfSize,
constNone,
constOne);
Value sigmoidB = rewriter.create<AtenSigmoidOp>(op.getLoc(), sliceTy, b);
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(
op, op->getResultTypes(), a, sigmoidB);
return success();
}

return failure();
}

Expand Down Expand Up @@ -519,7 +594,9 @@ class DiscDecomposeComplexOpsPass
"aten.narrow.Tensor",
"aten.selu",
"aten.selu_",
};
"aten.conv1d",
"aten.view_as",
"aten.glu"};

if (illegalSet.find(op.name().str()) != illegalSet.end()) {
return false;
Expand Down
1 change: 0 additions & 1 deletion pytorch_blade/tests/disc/ops/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def test_hardtanh_static_shape(self):
def test_hardtanh_dynamic_shape(self):
self._test_activation(torch.nn.Hardtanh(), torch.nn.functional.hardtanh, [([-1, -1, -1, -1], torch.float)])

@skipIfEnableTorchMlir()
def test_glu(self):
self._test_activation(torch.nn.GLU(), torch.nn.functional.glu, [([2, 4, 16, 16], torch.float)])

Expand Down
52 changes: 50 additions & 2 deletions pytorch_blade/tests/disc/ops/test_conv_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,62 @@ def _test_conv(self, conv_func, inp_test_data=None):
test_data = (test_data.to(self.device),)
self._test_cvt_to_disc(conv_func, test_data)

def test_conv1d(self):
def test_conv1d_aten_convolution(self):
# traced to aten::_convolution
conv = torch.nn.Conv1d(16, 33, 3, stride=2, padding=2)
self._test_conv(conv, torch.randn([20, 16, 60], device=self.device))

def test_conv2d(self):

def test_conv2d_aten_convolution(self):
# traced to aten::_convolution
conv = torch.nn.Conv2d(16, 33, (3, 4), stride=2, padding=[2, 1], dilation=2)
self._test_conv(conv)

def test_conv1d(self):

# traced to aten::conv1d
@torch.jit.script
def cuda_conv_func(x):
weights = torch.ones([16, 16, 1], device="cuda:0")
bias = torch.ones(16, device="cuda:0")
out_y = torch.nn.functional.conv1d(x, weights, bias)
return out_y

@torch.jit.script
def cpu_conv_func(x):
weights = torch.ones([16, 16, 1], device="cpu")
bias = torch.ones(16, device="cpu")
out_y = torch.nn.functional.conv1d(x, weights, bias)
return out_y
# some errors for unfixed device in aten::ones
inputs = torch.randn([20, 16, 100], device=self.device)
if self.device == torch.device('cuda'):
self._test_conv(cuda_conv_func, inputs)
else:
self._test_conv(cpu_conv_func, inputs)

def test_conv2d(self):

# traced to aten::conv2d
@torch.jit.script
def cuda_conv_func(x):
weights = torch.ones([33, 16, 1, 1], device="cuda:0")
bias = torch.ones(33, device="cuda:0")
out_y = torch.nn.functional.conv2d(x, weights, bias)
return out_y

@torch.jit.script
def cpu_conv_func(x):
weights = torch.ones([33, 16, 1, 1], device="cpu")
bias = torch.ones(33, device="cpu")
out_y = torch.nn.functional.conv2d(x, weights, bias)
return out_y

if self.device == torch.device('cuda'):
self._test_conv(cuda_conv_func)
else:
self._test_conv(cpu_conv_func)

@unittest.skipIf(torch_blade.version.cuda_available, "disc-gpu not support 3d conv yet.")
def test_conv3d(self):
conv = torch.nn.Conv3d(16, 33, (3, 4, 5), stride=[2, 1, 3], padding=2)
Expand Down
1 change: 0 additions & 1 deletion pytorch_blade/tests/disc/ops/test_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def reshape_as(x, y):
annotations = [([-1,-1,-1,-1], dtype), ([-1,-1,-1], dtype)]
self._test_disc(reshape_as, annotations, test_data)

@skipIfEnableTorchMlir()
def test_view_as(self):
@torch.jit.script
def view_as(x, y):
Expand Down
26 changes: 26 additions & 0 deletions pytorch_blade/tests/torchscript/basics.graph
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,29 @@ graph(%p1 : Float(*, *, *, device=cuda:0)):
// CHECK: Float(*, *, *, device=cuda:0), %{{[a-z.0-9]+}} : Long(*, *, *, device=cuda:0) = aten::topk
%r2: Tensor, %idx2 : Tensor = aten::topk(%p1, %11, %1, %true, %false)
return (%r1, %idx1, %r2, %idx2)

// CHECK-LABEL: graph
graph(%p1 : Float(2, 4, 16, 16, device=cuda:0)):
%1 : int = prim::Constant[value=-1]()
// CHECK: Float(2, 4, 16, 8, device=cuda:0) = aten::glu(%p1, %1)
%2 : Tensor = aten::glu(%p1, %1)
return (%2)

// CHECK-LABEL: graph
graph(%p1 : Float(*, *, *, device=cuda:0)):
%1 : int = prim::Constant[value=-1]()
// CHECK: Float(*, *, *, device=cuda:0) = aten::glu(%p1, %1)
%2 : Tensor = aten::glu(%p1, %1)
return (%2)

// analysis fail, erase shape information
// CHECK-LABEL: graph
graph(%p1 : Float(*, *, *, device=cuda:0),
%p2 : Float(*, *, *, device=cuda:0),
%p3 : Float(*, *, *, device=cuda:0),
%p4 : Float(*, *, *, device=cuda:0),
%p5 : Float(*, *, *, device=cuda:0),
%p6 : Float(*, *, *, device=cuda:0)):
// CHECK: Float(*, *, *, device=cuda:0) = aten::gru_cell(%p1, %p2, %p3, %p4, %p5, %p6)
%1 : Float(32, 32, 10, device=cuda:0) = aten::gru_cell(%p1, %p2, %p3, %p4, %p5, %p6)
return (%1)
19 changes: 19 additions & 0 deletions pytorch_blade/tests/torchscript/since_1_10.graph
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,23 @@ graph(%p1 : Float(20, 30, 40, device=cpu)):
%4 : NoneType = prim::Constant()
// CHECK: Float(20, 30, 40, device=cuda:0) = aten::to(%p1, %1, %2, %3, %3, %4)
%3 : Tensor = aten::to(%p1, %1, %2, %3, %3, %4)
return (%3)

// aten::to.prim_Device
// CHECK-LABEL: graph
graph(%p1 : Float(*, *, *, device=cuda:0)):
%1 : Device = prim::Constant[value="cuda:1"]()
%2 : NoneType = prim::Constant()
%3 : bool = prim::Constant[value=0]()
// CHECK: Float(*, *, *, device=cuda:1) = aten::to(%p1, %1, %2, %3, %3)
%5 : Tensor = aten::to(%p1, %1, %2, %3, %3)
return (%2)

// aten::to.prim_dtype
// CHECK-LABEL: graph
graph(%p1 : Float(*, *, *, device=cuda:0)):
%1 : int = prim::Constant[value=5]()
%2 : bool = prim::Constant[value=0]()
// CHECK: Half(*, *, *, device=cuda:0) = aten::to(%p1, %1, %2, %2)
%3 : Tensor = aten::to(%p1, %1, %2, %2)
return (%3)
2 changes: 1 addition & 1 deletion pytorch_blade/torch_blade/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def annotate_args(self):
def annotate_args(self, val):
assert isinstance(val, list), "annotate_args should be list, got {}".format(type(val))
for i, v in enumerate(val):
assert isinstance(v, tuple), "annotate_args[{}] should be list, got{}".format(i, type(v))
assert isinstance(v, tuple), "annotate_args[{}] should be tuple, got{}".format(i, type(v))
self._annotate_args = val

@property
Expand Down

0 comments on commit 3d791ca

Please sign in to comment.