forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_transforms.py
476 lines (412 loc) · 19.1 KB
/
test_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
# Owner(s): ["module: distributions"]
from numbers import Number
import pytest
import torch
from torch.autograd.functional import jacobian
from torch.distributions import Dirichlet, Independent, Normal, TransformedDistribution, constraints
from torch.distributions.transforms import (AbsTransform, AffineTransform, ComposeTransform,
CorrCholeskyTransform, CumulativeDistributionTransform,
ExpTransform, IndependentTransform,
LowerCholeskyTransform, PowerTransform,
ReshapeTransform, SigmoidTransform, TanhTransform,
SoftmaxTransform, SoftplusTransform, StickBreakingTransform,
identity_transform, Transform, _InverseTransform)
from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix
def get_transforms(cache_size):
transforms = [
AbsTransform(cache_size=cache_size),
ExpTransform(cache_size=cache_size),
PowerTransform(exponent=2,
cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.).normal_(),
cache_size=cache_size),
PowerTransform(exponent=torch.tensor(5.).normal_(),
cache_size=cache_size),
SigmoidTransform(cache_size=cache_size),
TanhTransform(cache_size=cache_size),
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(torch.randn(5),
torch.randn(5),
cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
SoftmaxTransform(cache_size=cache_size),
SoftplusTransform(cache_size=cache_size),
StickBreakingTransform(cache_size=cache_size),
LowerCholeskyTransform(cache_size=cache_size),
CorrCholeskyTransform(cache_size=cache_size),
ComposeTransform([
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
]),
ComposeTransform([
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
ExpTransform(cache_size=cache_size),
]),
ComposeTransform([
AffineTransform(0, 1, cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
AffineTransform(1, -2, cache_size=cache_size),
AffineTransform(torch.randn(4, 5),
torch.randn(4, 5),
cache_size=cache_size),
]),
ReshapeTransform((4, 5), (2, 5, 2)),
IndependentTransform(
AffineTransform(torch.randn(5),
torch.randn(5),
cache_size=cache_size),
1),
CumulativeDistributionTransform(Normal(0, 1)),
]
transforms += [t.inv for t in transforms]
return transforms
def reshape_transform(transform, shape):
# Needed to squash batch dims for testing jacobian
if isinstance(transform, AffineTransform):
if isinstance(transform.loc, Number):
return transform
try:
return AffineTransform(transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size)
except RuntimeError:
return AffineTransform(transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size)
if isinstance(transform, ComposeTransform):
reshaped_parts = []
for p in transform.parts:
reshaped_parts.append(reshape_transform(p, shape))
return ComposeTransform(reshaped_parts, cache_size=transform._cache_size)
if isinstance(transform.inv, AffineTransform):
return reshape_transform(transform.inv, shape).inv
if isinstance(transform.inv, ComposeTransform):
return reshape_transform(transform.inv, shape).inv
return transform
# Generate pytest ids
def transform_id(x):
assert isinstance(x, Transform)
name = f'Inv({type(x._inv).__name__})' if isinstance(x, _InverseTransform) else f'{type(x).__name__}'
return f'{name}(cache_size={x._cache_size})'
def generate_data(transform):
torch.manual_seed(1)
while isinstance(transform, IndependentTransform):
transform = transform.base_transform
if isinstance(transform, ReshapeTransform):
return torch.randn(transform.in_shape)
if isinstance(transform.inv, ReshapeTransform):
return torch.randn(transform.inv.out_shape)
domain = transform.domain
while (isinstance(domain, constraints.independent) and
domain is not constraints.real_vector):
domain = domain.base_constraint
codomain = transform.codomain
x = torch.empty(4, 5)
if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky:
x = torch.empty(6, 6)
x = x.normal_()
return x
elif domain is constraints.real:
return x.normal_()
elif domain is constraints.real_vector:
# For corr_cholesky the last dim in the vector
# must be of size (dim * dim) // 2
x = torch.empty(3, 6)
x = x.normal_()
return x
elif domain is constraints.positive:
return x.normal_().exp()
elif domain is constraints.unit_interval:
return x.uniform_()
elif isinstance(domain, constraints.interval):
x = x.uniform_()
x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound)
return x
elif domain is constraints.simplex:
x = x.normal_().exp()
x /= x.sum(-1, True)
return x
elif domain is constraints.corr_cholesky:
x = torch.empty(4, 5, 5)
x = x.normal_().tril()
x /= x.norm(dim=-1, keepdim=True)
x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs())
return x
raise ValueError('Unsupported domain: {}'.format(domain))
TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1)
TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0)
ALL_TRANSFORMS = TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform]
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
def test_inv_inv(transform, ids=transform_id):
assert transform.inv.inv is transform
@pytest.mark.parametrize('x', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
@pytest.mark.parametrize('y', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_equality(x, y):
if x is y:
assert x == y
else:
assert x != y
assert identity_transform == identity_transform.inv
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
def test_with_cache(transform):
if transform._cache_size == 0:
transform = transform.with_cache(1)
assert transform._cache_size == 1
x = generate_data(transform).requires_grad_()
try:
y = transform(x)
except NotImplementedError:
pytest.skip('Not implemented.')
y2 = transform(x)
assert y2 is y
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
@pytest.mark.parametrize('test_cached', [True, False])
def test_forward_inverse(transform, test_cached):
x = generate_data(transform).requires_grad_()
try:
y = transform(x)
except NotImplementedError:
pytest.skip('Not implemented.')
assert y.shape == transform.forward_shape(x.shape)
if test_cached:
x2 = transform.inv(y) # should be implemented at least by caching
else:
try:
x2 = transform.inv(y.clone()) # bypass cache
except NotImplementedError:
pytest.skip('Not implemented.')
assert x2.shape == transform.inverse_shape(y.shape)
y2 = transform(x2)
if transform.bijective:
# verify function inverse
assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), '\n'.join([
'{} t.inv(t(-)) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
])
else:
# verify weaker function pseudo-inverse
assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), '\n'.join([
'{} t(t.inv(t(-))) error'.format(transform),
'x = {}'.format(x),
'y = t(x) = {}'.format(y),
'x2 = t.inv(y) = {}'.format(x2),
'y2 = t(x2) = {}'.format(y2),
])
def test_compose_transform_shapes():
transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()
assert transform0.event_dim == 0
assert transform1.event_dim == 1
assert transform2.event_dim == 2
assert ComposeTransform([transform0, transform1]).event_dim == 1
assert ComposeTransform([transform0, transform2]).event_dim == 2
assert ComposeTransform([transform1, transform2]).event_dim == 2
transform0 = ExpTransform()
transform1 = SoftmaxTransform()
transform2 = LowerCholeskyTransform()
base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4))
base_dist1 = Dirichlet(torch.ones(4, 4))
base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))
@pytest.mark.parametrize('batch_shape, event_shape, dist', [
((4, 4), (), base_dist0),
((4,), (4,), base_dist1),
((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
((3, 4, 4), (), base_dist2),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
])
def test_transformed_distribution_shapes(batch_shape, event_shape, dist):
assert dist.batch_shape == batch_shape
assert dist.event_shape == event_shape
x = dist.rsample()
try:
dist.log_prob(x) # this should not crash
except NotImplementedError:
pytest.skip('Not implemented.')
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_fwd(transform):
x = generate_data(transform).requires_grad_()
def f(x):
return transform(x)
try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
pytest.skip('Not implemented.')
# check on different inputs
x = generate_data(transform).requires_grad_()
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_inv(transform):
y = generate_data(transform.inv).requires_grad_()
def f(y):
return transform.inv(y)
try:
traced_f = torch.jit.trace(f, (y,))
except NotImplementedError:
pytest.skip('Not implemented.')
# check on different inputs
y = generate_data(transform.inv).requires_grad_()
assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize('transform', TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
def test_jit_jacobian(transform):
x = generate_data(transform).requires_grad_()
def f(x):
y = transform(x)
return transform.log_abs_det_jacobian(x, y)
try:
traced_f = torch.jit.trace(f, (x,))
except NotImplementedError:
pytest.skip('Not implemented.')
# check on different inputs
x = generate_data(transform).requires_grad_()
assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
@pytest.mark.parametrize('transform', ALL_TRANSFORMS, ids=transform_id)
def test_jacobian(transform):
x = generate_data(transform)
try:
y = transform(x)
actual = transform.log_abs_det_jacobian(x, y)
except NotImplementedError:
pytest.skip('Not implemented.')
# Test shape
target_shape = x.shape[:x.dim() - transform.domain.event_dim]
assert actual.shape == target_shape
# Expand if required
transform = reshape_transform(transform, x.shape)
ndims = len(x.shape)
event_dim = ndims - transform.domain.event_dim
x_ = x.view((-1,) + x.shape[event_dim:])
n = x_.shape[0]
# Reshape to squash batch dims to a single batch dim
transform = reshape_transform(transform, x_.shape)
# 1. Transforms with unit jacobian
if isinstance(transform, ReshapeTransform) or isinstance(transform.inv, ReshapeTransform):
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
# 2. Transforms with 0 off-diagonal elements
elif transform.domain.event_dim == 0:
jac = jacobian(transform, x_)
# assert off-diagonal elements are zero
assert torch.allclose(jac, jac.diagonal().diag_embed())
expected = jac.diagonal().abs().log().reshape(x.shape)
# 3. Transforms with non-0 off-diagonal elements
else:
if isinstance(transform, CorrCholeskyTransform):
jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
elif isinstance(transform.inv, CorrCholeskyTransform):
jac = jacobian(lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
tril_matrix_to_vec(x_, diag=-1))
elif isinstance(transform, StickBreakingTransform):
jac = jacobian(lambda x: transform(x)[..., :-1], x_)
else:
jac = jacobian(transform, x_)
# Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims)
# However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims)
# after reshaping the event dims (see above) to give a batched square matrix whose determinant
# can be computed.
gather_idx_shape = list(jac.shape)
gather_idx_shape[-2] = 1
gather_idxs = torch.arange(n).reshape((n,) + (1,) * (len(jac.shape) - 1)).expand(gather_idx_shape)
jac = jac.gather(-2, gather_idxs).squeeze(-2)
out_ndims = jac.shape[-2]
jac = jac[..., :out_ndims] # Remove extra zero-valued dims (for inverse stick-breaking).
expected = torch.slogdet(jac).logabsdet
assert torch.allclose(actual, expected, atol=1e-5)
@pytest.mark.parametrize("event_dims",
[(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)],
ids=str)
def test_compose_affine(event_dims):
transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == max(event_dims)
assert transform.domain.event_dim == max(event_dims)
base_dist = Normal(0, 1)
if transform.domain.event_dim:
base_dist = base_dist.expand((1,) * transform.domain.event_dim)
dist = TransformedDistribution(base_dist, transform.parts)
assert dist.support.event_dim == max(event_dims)
base_dist = Dirichlet(torch.ones(5))
if transform.domain.event_dim > 1:
base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
dist = TransformedDistribution(base_dist, transforms)
assert dist.support.event_dim == max(1, max(event_dims))
@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
def test_compose_reshape(batch_shape):
transforms = [ReshapeTransform((), ()),
ReshapeTransform((2,), (1, 2)),
ReshapeTransform((3, 1, 2), (6,)),
ReshapeTransform((6,), (2, 3))]
transform = ComposeTransform(transforms)
assert transform.codomain.event_dim == 2
assert transform.domain.event_dim == 2
data = torch.randn(batch_shape + (3, 2))
assert transform(data).shape == batch_shape + (2, 3)
dist = TransformedDistribution(Normal(data, 1), transforms)
assert dist.batch_shape == batch_shape
assert dist.event_shape == (2, 3)
assert dist.support.event_dim == 2
@pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str)
@pytest.mark.parametrize("transform_dim", [0, 1, 2])
@pytest.mark.parametrize("base_batch_dim", [0, 1, 2])
@pytest.mark.parametrize("base_event_dim", [0, 1, 2])
@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3])
def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim,
num_transforms, sample_shape):
shape = torch.Size([2, 3, 4, 5])
base_dist = Normal(0, 1)
base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:])
if base_event_dim:
base_dist = Independent(base_dist, base_event_dim)
transforms = [AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1),
ReshapeTransform((4, 5), (20,)),
ReshapeTransform((3, 20), (6, 10))]
transforms = transforms[:num_transforms]
transform = ComposeTransform(transforms)
# Check validation in .__init__().
if base_batch_dim + base_event_dim < transform.domain.event_dim:
with pytest.raises(ValueError):
TransformedDistribution(base_dist, transforms)
return
d = TransformedDistribution(base_dist, transforms)
# Check sampling is sufficiently expanded.
x = d.sample(sample_shape)
assert x.shape == sample_shape + d.batch_shape + d.event_shape
num_unique = len(set(x.reshape(-1).tolist()))
assert num_unique >= 0.9 * x.numel()
# Check log_prob shape on full samples.
log_prob = d.log_prob(x)
assert log_prob.shape == sample_shape + d.batch_shape
# Check log_prob shape on partial samples.
y = x
while y.dim() > len(d.event_shape):
y = y[0]
log_prob = d.log_prob(y)
assert log_prob.shape == d.batch_shape
if __name__ == '__main__':
pytest.main([__file__])