Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move handling of integer signedness to the backend conversions #2597

Merged
merged 1 commit into from
Nov 29, 2023

Conversation

ramiro050
Copy link
Collaborator

The function getTypeForScalarType currently takes an argument to specify the signedness of integer types. This is leakage of backend specific requirements into the torch dialect world. Because getTypeForScalarType is a utility function for the torch dialect, it should only produce types that match the sign conventions used by PyTorch (regular integers are signed and unsigned integers are unsigned).

This commit removes the signedness argument from
getTypeForScalarType, and moves the backend specific handling of integer types to the backend code.

The function `getTypeForScalarType` currently takes an argument to
specify the signedness of integer types. This is leakage of backend
specific requirements into the torch dialect world. Because
`getTypeForScalarType` is a utility function for the torch dialect, it
should only produce types that match the sign conventions used by
PyTorch (regular integers are signed and unsigned integers are
unsigned).

This commit removes the signedness argument from
`getTypeForScalarType`, and moves the backend specific handling of
integer types to the backend code.
@ramiro050 ramiro050 requested a review from AmosLewis November 29, 2023 03:00
@ramiro050
Copy link
Collaborator Author

With this patch, we get the following transformation:

func.func private @forward(%arg0: !torch.vtensor<[20,100,35,45],f32>) -> !torch.vtensor<[20,100,35,45],f32> {
  %int0 = torch.constant.int 0
  %0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
  %int0_0 = torch.constant.int 0
  %int0_1 = torch.constant.int 0
  %cpu = torch.constant.device "cpu"
  %none = torch.constant.none
  %none_2 = torch.constant.none
  %1 = torch.aten.empty.memory_format %0, %int0_0, %int0_1, %cpu, %none, %none_2 : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.none, !torch.none -> !torch.vtensor<[0],ui8>
  return %arg0 : !torch.vtensor<[20,100,35,45],f32>
}

// -----// IR Dump After ConvertTorchToLinalg (convert-torch-to-linalg) //----- //
func.func private @forward(%arg0: !torch.vtensor<[20,100,35,45],f32>) -> !torch.vtensor<[20,100,35,45],f32> {
  %int0 = torch.constant.int 0
  %0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
  %int0_0 = torch.constant.int 0
  %int0_1 = torch.constant.int 0
  %cpu = torch.constant.device "cpu"
  %none = torch.constant.none
  %none_2 = torch.constant.none
  %1 = torch_c.to_i64 %int0
  %2 = arith.index_cast %1 : i64 to index
  %3 = tensor.empty(%2) : tensor<?xi8>
  %cast = tensor.cast %3 : tensor<?xi8> to tensor<0xi8>
  return %arg0 : !torch.vtensor<[20,100,35,45],f32>
}

Copy link
Collaborator

@AmosLewis AmosLewis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I also test the iree and shark-turbine where the issue first found, everything looks good now. nod-ai/SHARK-ModelDev#110

@ramiro050 ramiro050 merged commit e568f7e into llvm:main Nov 29, 2023
5 checks passed
@ramiro050 ramiro050 deleted the fix-signedness branch November 29, 2023 17:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants