Skip to content

Commit 20e457c

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
[pytorch3d[ padded to packed function in struct utils
Summary: Added a padded to packed utils function which takes either split sizes or a padding value to remove padded elements from a tensor. Reviewed By: gkioxari Differential Revision: D20454238 fbshipit-source-id: 180b807ff44c74c4ee9d5c1ac3b5c4a9b4be57c7
1 parent 4d3c886 commit 20e457c

File tree

2 files changed

+165
-2
lines changed

2 files changed

+165
-2
lines changed

pytorch3d/structures/utils.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,16 @@ def padded_to_list(
6666
6767
Args:
6868
x: tensor
69-
split_size: the shape of the final tensor to be returned (of length N).
69+
split_size: list, tuple or int defining the number of items for each tensor
70+
in the output list.
71+
72+
Returns:
73+
x_list: a list of tensors
7074
"""
7175
if x.ndim != 3:
7276
raise ValueError("Supports only 3-dimensional input tensors")
7377
x_list = list(x.unbind(0))
78+
7479
if split_size is None:
7580
return x_list
7681

@@ -141,9 +146,81 @@ def packed_to_list(x: torch.Tensor, split_size: Union[list, int]):
141146
142147
Args:
143148
x: tensor
144-
split_size: list or int defining the number of items for each split
149+
split_size: list, tuple or int defining the number of items for each tensor
150+
in the output list.
145151
146152
Returns:
147153
x_list: A list of Tensors
148154
"""
149155
return x.split(split_size, dim=0)
156+
157+
158+
def padded_to_packed(
159+
x: torch.Tensor,
160+
split_size: Union[list, tuple, None] = None,
161+
pad_value: Union[float, int, None] = None,
162+
):
163+
r"""
164+
Transforms a padded tensor of shape (N, M, K) into a packed tensor
165+
of shape:
166+
- (sum(Mi), K) where (Mi, K) are the dimensions of
167+
each of the tensors in the batch and Mi is specified by split_size(i)
168+
- (N*M, K) if split_size is None
169+
170+
Support only for 3-dimensional input tensor and 1-dimensional split size.
171+
172+
Args:
173+
x: tensor
174+
split_size: list, tuple or int defining the number of items for each tensor
175+
in the output list.
176+
pad_value: optional value to use to filter the padded values in the input
177+
tensor.
178+
179+
Only one of split_size or pad_value should be provided, or both can be None.
180+
181+
Returns:
182+
x_packed: a packed tensor.
183+
"""
184+
if x.ndim != 3:
185+
raise ValueError("Supports only 3-dimensional input tensors")
186+
187+
N, M, D = x.shape
188+
189+
if split_size is not None and pad_value is not None:
190+
raise ValueError(
191+
"Only one of split_size or pad_value should be provided."
192+
)
193+
194+
x_packed = x.view(-1, D) # flatten padded
195+
196+
if pad_value is None and split_size is None:
197+
return x_packed
198+
199+
# Convert to packed using pad value
200+
if pad_value is not None:
201+
mask = x_packed.ne(pad_value).any(-1)
202+
x_packed = x_packed[mask]
203+
return x_packed
204+
205+
# Convert to packed using split sizes
206+
N = len(split_size)
207+
if x.shape[0] != N:
208+
raise ValueError(
209+
"Split size must be of same length as inputs first dimension"
210+
)
211+
212+
if not all(isinstance(i, int) for i in split_size):
213+
raise ValueError(
214+
"Support only 1-dimensional unbinded tensor. \
215+
Split size for more dimensions provided"
216+
)
217+
218+
padded_to_packed_idx = torch.cat(
219+
[
220+
torch.arange(v, dtype=torch.int64, device=x.device) + i * M
221+
for (i, v) in enumerate(split_size)
222+
],
223+
dim=0,
224+
)
225+
226+
return x_packed[padded_to_packed_idx]

tests/test_struct_utils.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,92 @@ def test_padded_to_list(self):
9797
split_size = torch.randint(1, K, size=(N,)).tolist()
9898
struct_utils.padded_to_list(x, split_size)
9999

100+
def test_padded_to_packed(self):
101+
device = torch.device("cuda:0")
102+
N = 5
103+
K = 20
104+
ndim = 2
105+
dims = [K] * ndim
106+
x = torch.rand([N] + dims, device=device)
107+
108+
# Case 1: no split_size or pad_value provided
109+
# Check output is just the flattened input.
110+
x_packed = struct_utils.padded_to_packed(x)
111+
self.assertTrue(x_packed.shape == (x.shape[0] * x.shape[1], x.shape[2]))
112+
self.assertClose(x_packed, x.reshape(-1, K))
113+
114+
# Case 2: pad_value is provided.
115+
# Check each section of the packed tensor matches the
116+
# corresponding unpadded elements of the padded tensor.
117+
# Check that only rows where all the values are padded
118+
# are removed in the conversion to packed.
119+
pad_value = -1
120+
x_list = []
121+
split_size = []
122+
for _ in range(N):
123+
dim = torch.randint(K, size=(1,)).item()
124+
# Add some random values in the input which are the same as the pad_value.
125+
# These should not be filtered out.
126+
x_list.append(
127+
torch.randint(
128+
low=pad_value, high=10, size=(dim, K), device=device
129+
)
130+
)
131+
split_size.append(dim)
132+
x_padded = struct_utils.list_to_padded(x_list, pad_value=pad_value)
133+
x_packed = struct_utils.padded_to_packed(x_padded, pad_value=pad_value)
134+
curr = 0
135+
for i in range(N):
136+
self.assertClose(
137+
x_packed[curr : curr + split_size[i], ...], x_list[i]
138+
)
139+
self.assertClose(torch.cat(x_list), x_packed)
140+
curr += split_size[i]
141+
142+
# Case 3: split_size is provided.
143+
# Check each section of the packed tensor matches the corresponding
144+
# unpadded elements.
145+
x_packed = struct_utils.padded_to_packed(
146+
x_padded, split_size=split_size
147+
)
148+
curr = 0
149+
for i in range(N):
150+
self.assertClose(
151+
x_packed[curr : curr + split_size[i], ...], x_list[i]
152+
)
153+
self.assertClose(torch.cat(x_list), x_packed)
154+
curr += split_size[i]
155+
156+
# Case 4: split_size of the wrong shape is provided.
157+
# Raise an error.
158+
split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0)
159+
with self.assertRaisesRegex(ValueError, "1-dimensional"):
160+
x_packed = struct_utils.padded_to_packed(
161+
x_padded, split_size=split_size
162+
)
163+
164+
split_size = torch.randint(1, K, size=(2 * N,)).view(N * 2).tolist()
165+
with self.assertRaisesRegex(
166+
ValueError, "same length as inputs first dimension"
167+
):
168+
x_packed = struct_utils.padded_to_packed(
169+
x_padded, split_size=split_size
170+
)
171+
172+
# Case 5: both pad_value and split_size are provided.
173+
# Raise an error.
174+
with self.assertRaisesRegex(ValueError, "Only one of"):
175+
x_packed = struct_utils.padded_to_packed(
176+
x_padded, split_size=split_size, pad_value=-1
177+
)
178+
179+
# Case 6: Input has more than 3 dims.
180+
# Raise an error.
181+
with self.assertRaisesRegex(ValueError, "Supports only"):
182+
x = torch.rand((N, K, K, K, K), device=device)
183+
split_size = torch.randint(1, K, size=(N,)).tolist()
184+
struct_utils.padded_to_list(x, split_size)
185+
100186
def test_list_to_packed(self):
101187
device = torch.device("cuda:0")
102188
N = 5

0 commit comments

Comments
 (0)