Skip to content

Conversation

AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Apr 4, 2023

#1979

Here is the file for test:
test_maked_fill.py,
test_masked_fill_torchscript_0327_transformers4.26.0.mlir

module attributes {torch.debug_module_name = "_lambda"} {
  func.func private @__torch__.torch.fx.graph_module._lambda.__code_getter(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">) -> !torch.str {
    %2 = torch.prim.GetAttr %arg0["_code"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.str
    return %2 : !torch.str
  }
  func.func private @__torch__.torch.fx.graph_module._lambda.forward(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,15],si64>}, %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[1,4],si64>}) -> !torch.tensor {
    %false = torch.constant.bool false
    %cpu = torch.constant.device "cpu"
    %int1 = torch.constant.int 1
    %int4 = torch.constant.int 4
    %int0 = torch.constant.int 0
    %2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %3 = torch.aten.new_zeros %arg2, %2, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
    %4 = torch.prim.GetAttr %arg0["_tensor_constant0"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor
    %5 = torch.aten.lift_fresh_copy %4 : !torch.tensor -> !torch.tensor
    %6 = torch.aten.select.int %3, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
    %7 = torch.aten.fill_.Tensor %6, %5 : !torch.tensor, !torch.tensor -> !torch.tensor
    return %3 : !torch.tensor
  }
  torch.class_type @__torch__.torch.fx.graph_module._lambda {
    torch.attr private "_tensor_constant0" : !torch.tensor
    torch.attr private "training" : !torch.bool
    torch.attr private "_is_full_backward_hook" : !torch.optional<bool>
    torch.attr private "_code" : !torch.str
    torch.method private "__code_getter", @__torch__.torch.fx.graph_module._lambda.__code_getter
    torch.method "forward", @__torch__.torch.fx.graph_module._lambda.forward
  }
  %0 = torch.tensor.literal(dense<0> : tensor<si64>) : !torch.tensor<[],si64>
  %true = torch.constant.bool true
  %none = torch.constant.none
  %str = torch.constant.str "\0A\0A\0Adef forward(self, arg0_1, arg1_1):\0A    new_zeros = torch.ops.aten.new_zeros(arg1_1, [1, 4], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False);  arg1_1 = None\0A    _tensor_constant0 = self._tensor_constant0\0A    lift_fresh_copy = torch.ops.aten.lift_fresh_copy(_tensor_constant0);  _tensor_constant0 = None\0A    select = torch.ops.aten.select(new_zeros, 1, 0)\0A    fill_ = torch.ops.aten.fill_(select, lift_fresh_copy);  select = lift_fresh_copy = None\0A    return new_zeros\0A    "
  %1 = torch.nn_module {
    torch.slot "_tensor_constant0", %0 : !torch.tensor<[],si64>
    torch.slot "training", %true : !torch.bool
    torch.slot "_is_full_backward_hook", %none : !torch.none
    torch.slot "_code", %str : !torch.str
  } : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">
}

--->

module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %false = torch.constant.bool false
    %int4 = torch.constant.int 4
    %none = torch.constant.none
    %0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
    %cpu = torch.constant.device "cpu"
    %1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.zeros %1, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
    %3 = torch.aten.clone %0, %none : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64>
    %4 = torch.aten.slice.Tensor %2, %int1, %int0, %int1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1],si64>
    %5 = torch.aten.squeeze.dim %4, %int1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1],si64>
    %6 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
    %7 = torch.prim.ListConstruct %6 : (!torch.vtensor<[],si64>) -> !torch.list<optional<vtensor>>
    %8 = torch.aten._index_put_impl %2, %7, %3, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
    return %8 : !torch.vtensor<[1,4],si64>
  }
}

@AmosLewis AmosLewis changed the title [MLIR] Fold aten select and copy pattern [MLIR] Fold aten select and fill_ pattern Apr 6, 2023
@AmosLewis AmosLewis marked this pull request as ready for review April 6, 2023 20:08
Copy link
Collaborator

@gpetters94 gpetters94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@AmosLewis AmosLewis merged commit 4df1d8a into llvm:main Apr 7, 2023
@AmosLewis AmosLewis deleted the slicefill branch April 7, 2023 15:12
gpetters94 pushed a commit to gpetters94/mlir-npcomp that referenced this pull request May 10, 2023
gpetters94 pushed a commit to gpetters94/mlir-npcomp that referenced this pull request Jul 7, 2023
gpetters94 pushed a commit to gpetters94/mlir-npcomp that referenced this pull request Jul 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants