Skip to content

Commit a1e7720

Browse files
authored
Metrics (#51)
* Add physcial metrics * Delete resolution folder * Add inference model notebook * Add metrics notebooks * Update notebooks * Yapf * Update docs * Bump version * Modify nets visualize method api * Rename metrics functions * Log images in trainer * Nested functions * Change visualize signature * Add physical figures logginh * Debug physical metrics * turn of logging * Fix logging * Upd notebook * Add space characteristic * Fix dims space * Refactor * Close the right figure * Upd get time prediction * Upd get time_values * Upd * Add multiprocessing to get_time_values
1 parent 077de14 commit a1e7720

20 files changed

+349
-146
lines changed

generation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.12.2'
1+
__version__ = '0.12.3'

generation/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
"d_lr": 1e-4,
4141
"epochs_num": 1000,
4242
"batch_size": 32,
43-
"log_each": 1,
43+
"log_each": 5,
4444
"decay_epoch": 0,
4545
"save_each": 2,
4646
"device": "cuda:1",

generation/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from generation.metrics.physical import get_physical_figs

generation/metrics/physical.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
import tqdm
4+
import multiprocessing as mp
5+
6+
from generation.dataset.data_utils import postprocess_signal
7+
8+
_BINS_NUM = 20
9+
_PROCESSES_NUM = 24
10+
11+
12+
def get_energy_values(signals):
13+
"""
14+
Returns energy characteristic values for a given set of signals.
15+
Energy characteristic is a ratio of a signal energy to an amplitude
16+
:param signals: signals np array of shape [detectors_num, signals_num, signal_size]
17+
:returns: energy characteristic values
18+
"""
19+
amplitudes = np.max(signals, axis=2)
20+
energies = np.sum(signals, axis=2)
21+
ratios = energies / amplitudes
22+
return ratios
23+
24+
25+
def _calculate_centre_mass(amplitudes):
26+
"""
27+
Returns centre mass for a given amplitudes array
28+
:param amplitudes: np array with shape [detectors_num, signals_num]
29+
:returns: mass centres array with shape [signals_num]
30+
"""
31+
coords = np.array([[-1, 1], [0, 1], [1, 1], \
32+
[-1, 0], [0, 0], [1, 0], \
33+
[-1, -1], [0, -1], [1, -1]])
34+
return coords.T @ amplitudes
35+
36+
37+
def get_space_values(signals):
38+
"""
39+
Returns space characteristic values for a given set of signals.
40+
Space characteristic is a centre mass of detector coordinates,
41+
where weights are corresping amplitudes
42+
:param signals: signals np array of shape [detectors_num, signals_num, signal_size]
43+
:returns: space characteristic values
44+
"""
45+
amplitudes = np.max(signals, axis=2)
46+
mass_centres = _calculate_centre_mass(amplitudes)
47+
return mass_centres
48+
49+
50+
def _get_space_fig(real_mass_centres, fake_mass_centres):
51+
"""
52+
Returns a figure with real and fake mass centres distributions
53+
:param real_mass_centres: np array with shape [detectors_num, 2]
54+
:param fake_mass_centres: np array with shape [detectors_num, 2]
55+
:returns: figure with distributions
56+
"""
57+
fig, ax = plt.subplots(1, 1, figsize=(3, 3))
58+
ax.scatter(real_mass_centres[0, :], real_mass_centres[1, :])
59+
ax.scatter(fake_mass_centres[0, :], fake_mass_centres[1, :])
60+
ax.legend(['Real', 'Fake'])
61+
return fig
62+
63+
64+
def _get_ref_time_pred(signal):
65+
half_amplitude = np.min(signal) + (np.max(signal) - np.min(signal)) / 2
66+
for idx, cur_amplitude in enumerate(signal):
67+
if cur_amplitude > half_amplitude:
68+
return idx
69+
70+
71+
def get_time_values(signals):
72+
"""
73+
Returns time characteristic values for a given set of signals.
74+
Time characteristic is a signal reference time.
75+
:param signals: signals np array of shape [detectors_num, signals_num, signal_size]
76+
:returns: time characteristic values
77+
"""
78+
postprocessed_signals = []
79+
with mp.Pool(_PROCESSES_NUM) as pool:
80+
for detector_signals in signals:
81+
postprocessed_detector_signals = pool.map(postprocess_signal, detector_signals)
82+
postprocessed_signals.append(postprocessed_detector_signals)
83+
time_values = [[_get_ref_time_pred(signal) for signal in detector_signals] for detector_signals in postprocessed_signals]
84+
time_values = np.array(time_values)
85+
return time_values
86+
87+
88+
def _get_energy_time_fig(real_values, fake_values, bins_num=_BINS_NUM):
89+
fig, ax = plt.subplots(3, 3, figsize=(10, 10))
90+
for i in range(9): # TODO: (@whiteRa2bit, 2021-01-05) Replace with config constant
91+
real_detector_values = real_values[i]
92+
fake_detector_values = fake_values[i]
93+
bins = np.histogram(np.hstack((real_detector_values, fake_detector_values)), bins=bins_num)[1]
94+
ax[i // 3][i % 3].hist(real_detector_values, bins=bins, alpha=0.6)
95+
ax[i // 3][i % 3].hist(fake_detector_values, bins=bins, alpha=0.6)
96+
ax[i // 3][i % 3].legend(["Real", "Fake"])
97+
return fig
98+
99+
100+
def _transform_signals(signals_tensor):
101+
"""
102+
Transforms torch signals tensor to np array and reshapes it
103+
:param signals_tensor: torch tensor with shape [batch_size, detectors_num, x_dim]
104+
:returns: np array with shape [detectors_num, batch_size, x_dim]
105+
"""
106+
signals_array = signals_tensor.cpu().detach().numpy()
107+
signals_array = np.transpose(signals_array, (1, 0, 2))
108+
return signals_array
109+
110+
111+
def get_physical_figs(real_signals_tensor, fake_signals_tensor):
112+
real_signals = _transform_signals(real_signals_tensor)
113+
fake_signals = _transform_signals(fake_signals_tensor)
114+
115+
real_energy_values = get_energy_values(real_signals)
116+
fake_energy_values = get_energy_values(fake_signals)
117+
energy_fig = _get_energy_time_fig(real_energy_values, fake_energy_values)
118+
119+
real_time_values = get_time_values(real_signals)
120+
fake_time_values = get_time_values(fake_signals)
121+
time_fig = _get_energy_time_fig(real_time_values, fake_time_values)
122+
123+
real_space_values = get_space_values(real_signals)
124+
fake_space_values = get_space_values(fake_signals)
125+
space_fig = _get_space_fig(real_space_values, fake_space_values)
126+
127+
return energy_fig, time_fig, space_fig

generation/nets/abstract_net.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def forward(self, x, debug=False):
1616

1717
@staticmethod
1818
@abstractmethod
19-
def visualize(generated_sample, real_sample):
19+
def get_rel_fake_fig(real_sample, fake_sample):
2020
raise NotImplementedError
2121

2222

generation/nets/amplitudes_net.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,18 @@ def forward(self, x, debug=False):
2424
return torch.clamp(x, 0, 1)
2525

2626
@staticmethod
27-
def visualize(generated_sample, real_sample):
28-
generated_sample = generated_sample.cpu().data
27+
def get_rel_fake_fig(real_sample, fake_sample):
2928
real_sample = real_sample.cpu().data
29+
fake_sample = fake_sample.cpu().data
3030

31-
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
32-
ax[0].set_title("Generated")
33-
ax[0].plot(generated_sample)
34-
ax[1].set_title("Real")
35-
ax[1].plot(real_sample)
36-
wandb.log({"generated_real": fig})
3731
plt.clf()
32+
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
33+
ax[0].set_title("Real")
34+
ax[0].plot(real_sample)
35+
ax[1].set_title("Fake")
36+
ax[1].plot(fake_sample)
37+
return fig
38+
3839

3940

4041
class Discriminator(AbstractDiscriminator):

generation/nets/images_net.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,17 @@ def _debug():
5454
return x
5555

5656
@staticmethod
57-
def visualize(generated_sample, real_sample):
58-
generated_sample = generated_sample.cpu().data
57+
def get_rel_fake_fig(real_sample, fake_sample):
5958
real_sample = real_sample.cpu().data
59+
fake_sample = fake_sample.cpu().data
6060

61-
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
62-
ax[0].set_title("Generated")
63-
ax[0].imshow(generated_sample)
64-
ax[1].set_title("Real")
65-
ax[1].imshow(real_sample)
66-
wandb.log({"generated_real": fig})
6761
plt.clf()
62+
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
63+
ax[0].set_title("Real")
64+
ax[0].imshow(real_sample)
65+
ax[1].set_title("Fake")
66+
ax[1].imshow(fake_sample)
67+
return fig
6868

6969

7070
class Discriminator(AbstractDiscriminator):

generation/nets/shapes_net.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,18 @@ def _debug():
5454
return torch.sigmoid(x.squeeze(1))
5555

5656
@staticmethod
57-
def visualize(generated_sample, real_sample):
58-
generated_sample = generated_sample.cpu().data
57+
def get_rel_fake_fig(real_sample, fake_sample):
5958
real_sample = real_sample.cpu().data
59+
fake_sample = fake_sample.cpu().data
6060

61-
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
62-
ax[0].set_title("Generated")
63-
ax[0].plot(generated_sample)
64-
ax[1].set_title("Real")
65-
ax[1].plot(real_sample)
66-
wandb.log({"generated_real": fig})
6761
plt.clf()
62+
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
63+
ax[0].set_title("Real")
64+
ax[0].plot(real_sample)
65+
ax[1].set_title("Fake")
66+
ax[1].plot(fake_sample)
67+
return fig
68+
6869

6970

7071
class Discriminator(AbstractDiscriminator):

generation/nets/signals_net.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,15 @@ def _debug():
6767
return torch.tanh(x)
6868

6969
@staticmethod
70-
def visualize(generated_sample, real_sample):
71-
def get_figure(sample):
72-
fig, ax = plt.subplots(3, 3, figsize=(10, 10))
73-
for i in range(9):
74-
ax[i // 3][i % 3].plot(sample[i])
75-
return fig
76-
77-
generated_sample = generated_sample.cpu().data
70+
def get_rel_fake_fig(real_sample, fake_sample):
7871
real_sample = real_sample.cpu().data
79-
fig_gen = get_figure(generated_sample)
80-
fig_real = get_figure(real_sample)
81-
wandb.log({"generated": fig_gen, "real": fig_real})
82-
plt.clf()
72+
fake_sample = fake_sample.cpu().data
73+
74+
fig, ax = plt.subplots(3, 6, figsize=(10, 20))
75+
for i in range(9): # TODO: (@whiteRa2bit, 2021-01-05) Replace with config constant
76+
ax[i // 3][i % 3].plot(real_sample[i])
77+
ax[i // 3][3 + i % 3].plot(fake_sample[i])
78+
return fig
8379

8480

8581
class Discriminator(AbstractDiscriminator):

generation/training/gan_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def run_train(self, dataset):
5454

5555
if epoch % self.config['log_each'] == 0:
5656
wandb.log({"D loss": d_loss.cpu().data, "G loss": g_loss.cpu().data}, step=epoch)
57-
self.generator.visualize(g_sample[0], X[0])
57+
self.generator.get_rel_fake_fig(X[0], g_sample[0])
5858
if epoch % self.config['save_each'] == 0:
5959
self._save_checkpoint(self.generator, f"generator_{epoch}")
6060
self._save_checkpoint(self.discriminator, f"discriminator_{epoch}")

0 commit comments

Comments
 (0)