Skip to content

order of dimensions in advanced indexing does not match PyTorch #2137

Open
@t-vi

Description

@t-vi
import thunder
import torch

def fn(a, idx):
    return a[idx]

a = torch.randn(5, 5, 7, requires_grad=True)
idx = (None, [1, 2])

print(fn(a, idx).shape)    # torch.Size([1, 2, 5, 7])

jfn = thunder.jit(fn)

print(jfn(a, idx).shape)  # torch.Size([2, 1, 5, 7])

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions