Closed
Description
Unless I am understanding stacking incorrectly, there seems to be an issue with stacking Transform3D.
To reproduce:
import torch
from pytorch3d.transforms import Transform3D
transform3 = Transform3d().rotate(torch.stack([torch.eye(3)]*3)).translate(torch.zeros(3,3))
transform1 = Transform3d()
transform4 = transform1.stack(transform3)
print(len(transform3))
print(len(transform1))
print(len(transform4))
transform4.transform_points(torch.zeros(4,5,3))
The last line errors with:
ValueError: Expected batch dim for bmm to be equal or 1; got torch.Size([4, 5, 4]), torch.Size([2, 4, 4])