Skip to content

Support dim order in Java Tensor #11593

Open
@GregoryComer

Description

@GregoryComer

🚀 The feature, motivation and pitch

Our C++ tensor APIs support dim order, allowing for elements in memory to be laid out in an alternative order. This is commonly used to support channels-last image tensors, where channel values can be interleaved, rather than being laid out in the default channels first order. This would be useful to implement on Java tensors.

Goals for this task:

  • Add a dim order field to Tensor.java.
    • Dim order is represented as a list of integers, similar to sizes. It should look very similar to the shape accessor.
  • Add a constructor to Tensor.java taking dim order. Update the various fromBlob methods to add an overload taking dim order.
    • See [https://github.com/pytorch/executorch/blob/main/extension/tensor/tensor_ptr.h#L284] for an example of how this is done in the native tensor builder.
    • See the existing constructor here: Tensor.java.
  • Update newJTensorFromTensor to support tensors with non-default dim orders.
    • When a native tensor with non-default dim order is passed in, it should create a Java tensor with the same dim order.
  • Update JEValueToTensorImpl to support non-default dim orders.
    • When a Java tensor with non-default dim order is passed in, it should construct a native tensor with the same dim order.

Testing the change

  • We'll want to add a new test to ModuleInstrumentationTest.kt to cover inference with a channels last tensor.
    • We'll also need to generate a test model that takes channels-last inputs (see below).
    • This test should create an input tensor in channels-last format (using the newly created methods), run Add (or another model), and check the output values (including dim order). See TestMV2FP32 for a very similar test.

Generating a channels-last test model

We use a simple model (ModuleAdd.pte) for existing tests in this file. We'll want to create a similar model that is exported with channels last tensors. We'll want to do something similar for this. An example export script might look like this:

import torch
from executorch.exir import to_edge_transform_and_lower

class ModuleAddChannelsLast(torch.nn.Module):
  def forward(self, x, y):
    return x + y

inputs = (torch.randn(1, 3, 16, 16).to(memory_format=torch.channels_last), torch.randn(1, 3, 16, 16).to(memory_format=torch.channels_last))
ep = torch.export.export(ModuleAddChannelsLast(), inputs)
lowered = to_edge_transform_and_lower(ep).to_executorch()
with open("ModuleAddChannelsLast.pte", "wb") as f:
   f.write(lowered.buffer)

The existing export code for ModuleAdd.pte is here: link. We'll want to add another model, maybe called ModuleAddChannelsLast.

The test integration might be a little bit complicated. Feel free to tag myself or @kirklandsign with any questions.

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

cc @kirklandsign @cbilgin

Metadata

Metadata

Assignees

Labels

good first issueGood for newcomersmodule: androidIssues related to Android code, build, and execution

Projects

Status

No status

Status

Todo

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions