Skip to content

Conversation

AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Mar 24, 2023

#1953

SUCCESS test_slicecopy.py

test_slicecopyt_masked_fill.py

t5_small_torchscript_test2.mlir

➜  t5small git:(main) ✗ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints,torch.aten.native_layer_norm,torch.aten.linear})' ./t5_small_torchscript_test2.mlir
module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %false = torch.constant.bool false
    %int4 = torch.constant.int 4
    %none = torch.constant.none
    %int-1 = torch.constant.int -1
    %int-100 = torch.constant.int -100
    %int9223372036854775807 = torch.constant.int 9223372036854775807
    %cpu = torch.constant.device "cpu"
    %0 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.zeros %0, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
    %2 = torch.aten.slice.Tensor %arg1, %int1, %int0, %int-1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],si64>
    %3 = torch.aten.clone %2, %none : !torch.vtensor<[1,3],si64>, !torch.none -> !torch.vtensor<[1,3],si64>
    %4 = torch.aten.slice.Tensor %1, %int1, %int1, %int9223372036854775807, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],si64>
    %5 = torch.aten.arange.start_step %int1, %int4, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
    %6 = torch.prim.ListConstruct %none, %5 : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
    %7 = torch.aten._index_put_impl %1, %6, %3, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[1,3],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
    %8 = torch.aten.slice.Tensor %7, %int1, %int0, %int1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1],si64>
    %9 = torch.aten.squeeze.dim %8, %int1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1],si64>
    %10 = torch.aten.eq.Scalar %7, %int-100 : !torch.vtensor<[1,4],si64>, !torch.int -> !torch.vtensor<[1,4],i1>
    %11 = torch.prim.ListConstruct  : () -> !torch.list<int>
    %12 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
    %13 = torch.aten.broadcast_to %12, %11 : !torch.vtensor<[],si64>, !torch.list<int> -> !torch.vtensor<[],si64>
    %14 = torch.aten.where.self %10, %13, %7 : !torch.vtensor<[1,4],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[1,4],si64> -> !torch.vtensor<[1,4],si64>
    return %14 : !torch.vtensor<[1,4],si64>
  }
}

@AmosLewis AmosLewis requested a review from gpetters94 March 24, 2023 20:18
@AmosLewis AmosLewis force-pushed the int64_max branch 2 times, most recently from a241ec4 to 160cdeb Compare March 25, 2023 01:58
@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Mar 25, 2023

Taken from @gpetters94 patch https://github.com/gpetters94/mlir-npcomp/tree/intmax
Looks like it fixes the int64_max, but go back to the original slice and copy issue
https://gist.github.com/AmosLewis/1826326e9f85480da9f13191cb4b86f7
%2 = torch.tensor_static_info_cast %1 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
The shape info disappears.

@AmosLewis
Copy link
Collaborator Author

This patch already fixes the int64_max issue. The new shape issues are from the masked_fill_ op:
%144 = torch.aten.masked_fill_.Scalar %134, %143, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor
which is from python assigned a value to a slice:
x_new[..., 0] = 0

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Mar 27, 2023

I tried to only use the
x_new[..., 0] = 0 as a model. But got

ts_g.graph: 
graph(%self : __torch__.torch.fx.graph_module._lambda,
      %arg0_1 : Tensor,
      %arg1_1.1 : Tensor):
  %11 : bool = prim::Constant[value=0]() # <eval_with_key>.2:5:144
  %37 : Device = prim::Constant[value="cpu"]()
  %4 : int = prim::Constant[value=1]() # <eval_with_key>.2:5:50
  %5 : int = prim::Constant[value=4]() # <eval_with_key>.2:5:53
  %19 : int = prim::Constant[value=0]() # <eval_with_key>.2:8:49
  %6 : int[] = prim::ListConstruct(%4, %5)
  %new_zeros.1 : Tensor = aten::new_zeros(%arg1_1.1, %6, %5, %19, %37, %11) # <eval_with_key>.2:5:16
  %_tensor_constant0.1 : Tensor = prim::GetAttr[name="_tensor_constant0"](%self)
  %lift_fresh_copy.1 : Tensor = aten::lift_fresh_copy(%_tensor_constant0.1) # <eval_with_key>.2:7:22
  %select.1 : Tensor = aten::select(%new_zeros.1, %4, %19) # <eval_with_key>.2:8:13
  %fill_ : Tensor = aten::fill_(%select.1, %lift_fresh_copy.1) # <eval_with_key>.2:9:12
  return (%new_zeros.1)

Traceback (most recent call last):
  File "/home/chi/src/ubuntu20/shark/SHARK/tank/pytorch/t5small/test.py", line 101, in <module>
    module = torch_mlir.compile(
  File "/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 358, in compile
    raise Exception(f"""
Exception: 
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
### Importer C++ Exception:
required keyword attribute 'split' is undefined
### Importer Diagnostics:

FIXED. Upgrade my torch and torch-vision version.

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Mar 27, 2023

Success test_slicecopy.py

Copy link
Collaborator

@ramiro050 ramiro050 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are multiple edge cases here. All should be e2e tested to avoid off-by-1 errors.

@AmosLewis
Copy link
Collaborator Author

#2005

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Apr 18, 2023

Add a TODO for the general clamp way. #2005 (comment)

@AmosLewis AmosLewis requested a review from ramiro050 April 18, 2023 18:01
@ramiro050
Copy link
Collaborator

I suggest we just add the fragile INT64_MAX and add a TODO there for later fixing. Otherwise, I will just be stuck here.

Sure, it does require quite a few ops. Can we keep the same structure as the other patch? In particular, we should have a helper function clampDimToValidRange that checks if the value is equal to INT_MAX or INT_MIN and converts to 0 or dimsize, respectively. The helper function can then be used on both start and end, since both values require the same clamping.

This will make it easier in the future to add the full support for clamping, since now all that is needed is to improve clampDimToValidRange.

@ramiro050
Copy link
Collaborator

Also, this PR should have e2e tests testing the new functionality.

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Apr 19, 2023

@gpetters94

new slice copy e2e test for my case.

class SliceCopy2DStaticModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([1, 4], torch.int64, True),
        ([1, 3], torch.int64, True),
    ])
    def forward(self, x, y):
        xslice = torch.ops.aten.slice(x, 1, 1, 4, 1)
        xslice.copy_(y)
        return x


@register_test_case(module_factory=lambda: SliceCopy2DStaticModule())
def SliceCopy2DStaticModule_basic(module, tu: TestUtils):
    module.forward(tu.randint(1, 4, high=4), tu.randint(1, 3, high=1))

# ==============================================================================

@AmosLewis
Copy link
Collaborator Author

Fixed by a46b5c6

@AmosLewis AmosLewis closed this Jun 27, 2023
@AmosLewis AmosLewis deleted the int64_max branch January 19, 2024 19:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants