forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_constraints.py
86 lines (74 loc) · 3.23 KB
/
test_constraints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import pytest
import torch
from torch.distributions import biject_to, constraints, transform_to
from torch.testing._internal.common_cuda import TEST_CUDA
CONSTRAINTS = [
(constraints.real,),
(constraints.positive,),
(constraints.greater_than, [-10., -2, 0, 2, 10]),
(constraints.greater_than, 0),
(constraints.greater_than, 2),
(constraints.greater_than, -2),
(constraints.greater_than_eq, 0),
(constraints.greater_than_eq, 2),
(constraints.greater_than_eq, -2),
(constraints.less_than, [-10., -2, 0, 2, 10]),
(constraints.less_than, 0),
(constraints.less_than, 2),
(constraints.less_than, -2),
(constraints.unit_interval,),
(constraints.interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]),
(constraints.interval, -2, -1),
(constraints.interval, 1, 2),
(constraints.half_open_interval, [-4., -2, 0, 2, 4], [-3., 3, 1, 5, 5]),
(constraints.half_open_interval, -2, -1),
(constraints.half_open_interval, 1, 2),
(constraints.simplex,),
(constraints.lower_cholesky,),
]
def build_constraint(constraint_fn, args, is_cuda=False):
if not args:
return constraint_fn
t = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
return constraint_fn(*(t(x) if isinstance(x, list) else x for x in args))
@pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS])
@pytest.mark.parametrize('is_cuda', [False,
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
reason='CUDA not found.'))])
def test_biject_to(constraint_fn, args, is_cuda):
constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
try:
t = biject_to(constraint)
except NotImplementedError:
pytest.skip('`biject_to` not implemented.')
assert t.bijective, "biject_to({}) is not bijective".format(constraint)
x = torch.randn(5, 5, dtype=torch.double)
if is_cuda:
x = x.cuda()
y = t(x)
assert constraint.check(y).all(), '\n'.join([
"Failed to biject_to({})".format(constraint),
"x = {}".format(x),
"biject_to(...)(x) = {}".format(y),
])
x2 = t.inv(y)
assert torch.allclose(x, x2), "Error in biject_to({}) inverse".format(constraint)
j = t.log_abs_det_jacobian(x, y)
assert j.shape == x.shape[:x.dim() - t.event_dim]
@pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS])
@pytest.mark.parametrize('is_cuda', [False,
pytest.param(True, marks=pytest.mark.skipif(not TEST_CUDA,
reason='CUDA not found.'))])
def test_transform_to(constraint_fn, args, is_cuda):
constraint = build_constraint(constraint_fn, args, is_cuda=is_cuda)
t = transform_to(constraint)
x = torch.randn(5, 5, dtype=torch.double)
if is_cuda:
x = x.cuda()
y = t(x)
assert constraint.check(y).all(), "Failed to transform_to({})".format(constraint)
x2 = t.inv(y)
y2 = t(x2)
assert torch.allclose(y, y2), "Error in transform_to({}) pseudoinverse".format(constraint)
if __name__ == '__main__':
pytest.main([__file__])