forked from swyoon/pytorch-energy-based-model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
88 lines (69 loc) · 3.7 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import math
import torch
def potential_fn(dataset):
"""
toy potention functions
Code borrowed from https://github.com/kamenbliznashki/normalizing_flows/blob/master/bnaf.py"""
w1 = lambda z: torch.sin(2 * math.pi * z[:,0] / 4)
w2 = lambda z: 3 * torch.exp(-0.5 * ((z[:,0] - 1)/0.6)**2)
w3 = lambda z: 3 * torch.sigmoid((z[:,0] - 1) / 0.3)
if dataset == 'u1':
return lambda z: 0.5 * ((torch.norm(z, p=2, dim=1) - 2) / 0.4)**2 - \
torch.log(torch.exp(-0.5*((z[:,0] - 2) / 0.6)**2) + \
torch.exp(-0.5*((z[:,0] + 2) / 0.6)**2) + 1e-10)
elif dataset == 'u2':
return lambda z: 0.5 * ((z[:,1] - w1(z)) / 0.4)**2
elif dataset == 'u3':
return lambda z: - torch.log(torch.exp(-0.5*((z[:,1] - w1(z))/0.35)**2) + \
torch.exp(-0.5*((z[:,1] - w1(z) + w2(z))/0.35)**2) + 1e-10)
elif dataset == 'u4':
return lambda z: - torch.log(torch.exp(-0.5*((z[:,1] - w1(z))/0.4)**2) + \
torch.exp(-0.5*((z[:,1] - w1(z) + w3(z))/0.35)**2) + 1e-10)
else:
raise RuntimeError('Invalid potential name to sample from.')
def sample_2d_data(dataset, n_samples):
"""generate samples from 2D toy distributions
Code borrowed from https://github.com/kamenbliznashki/normalizing_flows/blob/master/bnaf.py"""
z = torch.randn(n_samples, 2)
if dataset == '8gaussians':
scale = 4
sq2 = 1/math.sqrt(2)
centers = [(1,0), (-1,0), (0,1), (0,-1), (sq2,sq2), (-sq2,sq2), (sq2,-sq2), (-sq2,-sq2)]
centers = torch.tensor([(scale * x, scale * y) for x,y in centers])
return sq2 * (0.5 * z + centers[torch.randint(len(centers), size=(n_samples,))])
elif dataset == '2spirals':
n = torch.sqrt(torch.rand(n_samples // 2)) * 540 * (2 * math.pi) / 360
d1x = - torch.cos(n) * n + torch.rand(n_samples // 2) * 0.5
d1y = torch.sin(n) * n + torch.rand(n_samples // 2) * 0.5
x = torch.cat([torch.stack([ d1x, d1y], dim=1),
torch.stack([-d1x, -d1y], dim=1)], dim=0) / 3
return x + 0.1*z
elif dataset == 'checkerboard':
x1 = torch.rand(n_samples) * 4 - 2
x2_ = torch.rand(n_samples) - torch.randint(0, 2, (n_samples,), dtype=torch.float) * 2
x2 = x2_ + x1.floor() % 2
return torch.stack([x1, x2], dim=1) * 2
elif dataset == 'rings':
n_samples4 = n_samples3 = n_samples2 = n_samples // 4
n_samples1 = n_samples - n_samples4 - n_samples3 - n_samples2
# so as not to have the first point = last point, set endpoint=False in np; here shifted by one
linspace4 = torch.linspace(0, 2 * math.pi, n_samples4 + 1)[:-1]
linspace3 = torch.linspace(0, 2 * math.pi, n_samples3 + 1)[:-1]
linspace2 = torch.linspace(0, 2 * math.pi, n_samples2 + 1)[:-1]
linspace1 = torch.linspace(0, 2 * math.pi, n_samples1 + 1)[:-1]
circ4_x = torch.cos(linspace4)
circ4_y = torch.sin(linspace4)
circ3_x = torch.cos(linspace4) * 0.75
circ3_y = torch.sin(linspace3) * 0.75
circ2_x = torch.cos(linspace2) * 0.5
circ2_y = torch.sin(linspace2) * 0.5
circ1_x = torch.cos(linspace1) * 0.25
circ1_y = torch.sin(linspace1) * 0.25
x = torch.stack([torch.cat([circ4_x, circ3_x, circ2_x, circ1_x]),
torch.cat([circ4_y, circ3_y, circ2_y, circ1_y])], dim=1) * 3.0
# random sample
x = x[torch.randint(0, n_samples, size=(n_samples,))]
# Add noise
return x + torch.normal(mean=torch.zeros_like(x), std=0.08*torch.ones_like(x))
else:
raise RuntimeError('Invalid `dataset` to sample from.')