-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
equilibrium_median.py
39 lines (32 loc) · 1.1 KB
/
equilibrium_median.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
r"""Replicates the experiment from `"Deep Graph Infomax"
<https://arxiv.org/abs/1809.10341>`_ to try and teach `EquilibriumAggregation`
to learn to take the median of a set of numbers.
This example converges slowly to being able to predict the
median similar to what is observed in the paper.
"""
import numpy as np
import torch
from torch_geometric.nn import EquilibriumAggregation
input_size = 100
steps = 10000000
embedding_size = 10
eval_each = 1000
model = EquilibriumAggregation(1, 10, [256, 256], 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
norm = torch.distributions.normal.Normal(0.5, 0.4)
gamma = torch.distributions.gamma.Gamma(0.2, 0.5)
uniform = torch.distributions.uniform.Uniform(0, 1)
total_loss = 0
n_loss = 0
for i in range(1, steps + 1):
optimizer.zero_grad()
dist = np.random.choice([norm, gamma, uniform])
x = dist.sample((input_size, 1))
y = model(x)
loss = (y - x.median()).norm(2) / input_size
loss.backward()
optimizer.step()
total_loss += loss
n_loss += 1
if i % eval_each == 0:
print(f"Epoch: {i}, Loss {total_loss / n_loss:.6f}")