-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Report CycleGAN validation metrics correctly to wandb #2131
base: master
Are you sure you want to change the base?
Changes from 12 commits
1379966
3c750df
fcae3c5
797af2d
54261a9
9dcb1e8
f5a428b
090ec0b
daefd94
f4825cf
946bd2d
1a381a4
e62f047
8c40e9b
d5e12f0
9678ab4
a12557e
acb06ba
5ce73aa
c59bd99
e688e7f
b30bbb9
9c0c3b4
c7ff935
66de2b2
4956f2a
36fdb59
b5959bd
a61ba13
c624051
34f554a
67ceb48
b064a30
efc393c
07bb58a
06c50b7
c5801eb
b81bbd7
cb2633b
dd41851
b924e45
c0fdf3d
923d21c
b81d687
5395b02
3c6b68f
6ce1ae0
a16d3b9
398d35f
cdaf607
3918f33
700ac4e
1a741b8
b07c782
52f6f0e
4ef65ff
9f93c24
a0e3e15
90343a3
916d8e2
8871235
11104db
a659aa8
5c76788
868b24e
2758fd0
c765aba
bdeb021
9afd0b3
85a912d
37297df
1bb13ca
6fbed6e
1806d88
8104460
9acfee6
7014221
9280a54
3dade60
cf1099f
9c4f5fb
e71114f
2a50db1
6c55fcd
3b2d231
26de524
5f5913e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
from typing import Dict, List, Literal, Mapping, Tuple, Optional | ||
import tensorflow as tf | ||
from typing import List, Literal, Mapping, Tuple, Optional | ||
from fv3fit._shared.scaler import StandardScaler | ||
from .reloadable import CycleGAN, CycleGANModule | ||
import torch | ||
|
@@ -16,6 +15,8 @@ | |
from fv3fit import wandb | ||
import io | ||
import PIL | ||
import xarray as xr | ||
from vcm.cubedsphere import to_cross | ||
|
||
try: | ||
import matplotlib.pyplot as plt | ||
|
@@ -33,10 +34,8 @@ class CycleGANNetworkConfig: | |
Configuration for building and training a CycleGAN network. | ||
|
||
Attributes: | ||
generator_optimizer: configuration for the optimizer used to train the | ||
generator | ||
discriminator_optimizer: configuration for the optimizer used to train the | ||
discriminator | ||
optimizer: configuration for the optimizer used to train the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Merging these was necessary so a wandb sweep can operate on the learning rate, there's no way to pair two hyperparameters for wandb sweeps. |
||
generator and discriminator | ||
generator: configuration for building the generator network | ||
discriminator: configuration for building the discriminator network | ||
identity_loss: loss function used to make the generator which outputs | ||
|
@@ -51,12 +50,11 @@ class CycleGANNetworkConfig: | |
cycle_weight: weight of the cycle loss | ||
generator_weight: weight of the generator's gan loss | ||
discriminator_weight: weight of the discriminator gan loss | ||
reload_path: path to a directory containing a saved CycleGAN model to use | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was just a missing docstring entry. |
||
as a starting point for training | ||
""" | ||
|
||
generator_optimizer: OptimizerConfig = dataclasses.field( | ||
default_factory=lambda: OptimizerConfig("Adam") | ||
) | ||
discriminator_optimizer: OptimizerConfig = dataclasses.field( | ||
optimizer: OptimizerConfig = dataclasses.field( | ||
default_factory=lambda: OptimizerConfig("Adam") | ||
) | ||
generator: "GeneratorConfig" = dataclasses.field( | ||
|
@@ -92,12 +90,12 @@ def build( | |
) | ||
discriminator_a = self.discriminator.build(n_state, convolution=convolution) | ||
discriminator_b = self.discriminator.build(n_state, convolution=convolution) | ||
optimizer_generator = self.generator_optimizer.instance( | ||
optimizer_generator = self.optimizer.instance( | ||
itertools.chain( | ||
generator_a_to_b.parameters(), generator_b_to_a.parameters() | ||
) | ||
) | ||
optimizer_discriminator = self.discriminator_optimizer.instance( | ||
optimizer_discriminator = self.optimizer.instance( | ||
itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()) | ||
) | ||
model = CycleGANModule( | ||
|
@@ -284,26 +282,26 @@ def __post_init__(self): | |
self._script_disc_a = None | ||
self._script_disc_b = None | ||
|
||
def _call_generator_a_to_b(self, input): | ||
def _call_generator_a_to_b(self, input: torch.Tensor) -> torch.Tensor: | ||
if self._script_gen_a_to_b is None: | ||
self._script_gen_a_to_b = torch.jit.trace( | ||
self.generator_a_to_b.forward, input | ||
) | ||
return self._script_gen_a_to_b(input) | ||
|
||
def _call_generator_b_to_a(self, input): | ||
def _call_generator_b_to_a(self, input: torch.Tensor) -> torch.Tensor: | ||
if self._script_gen_b_to_a is None: | ||
self._script_gen_b_to_a = torch.jit.trace( | ||
self.generator_b_to_a.forward, input | ||
) | ||
return self._script_gen_b_to_a(input) | ||
|
||
def _call_discriminator_a(self, input): | ||
def _call_discriminator_a(self, input: torch.Tensor) -> torch.Tensor: | ||
if self._script_disc_a is None: | ||
self._script_disc_a = torch.jit.trace(self.discriminator_a.forward, input) | ||
return self._script_disc_a(input) | ||
|
||
def _call_discriminator_b(self, input): | ||
def _call_discriminator_b(self, input: torch.Tensor) -> torch.Tensor: | ||
if self._script_disc_b is None: | ||
self._script_disc_b = torch.jit.trace(self.discriminator_b.forward, input) | ||
return self._script_disc_b(input) | ||
|
@@ -316,76 +314,8 @@ def _init_targets(self, shape: Tuple[int, ...]): | |
torch.Tensor(shape).fill_(0.0).to(DEVICE), requires_grad=False | ||
) | ||
|
||
def evaluate_on_dataset( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function was never actually used (it was called, but only on validation data, and I never provided validation datasets before now). |
||
self, dataset: tf.data.Dataset, n_dims_keep: int = 3 | ||
) -> Dict[str, float]: | ||
stats_real_a = StatsCollector(n_dims_keep) | ||
stats_real_b = StatsCollector(n_dims_keep) | ||
stats_gen_a = StatsCollector(n_dims_keep) | ||
stats_gen_b = StatsCollector(n_dims_keep) | ||
real_a: np.ndarray | ||
real_b: np.ndarray | ||
reported_plot = False | ||
for real_a, real_b in dataset: | ||
# for now there is no time-evolution-based loss, so we fold the time | ||
# dimension into the sample dimension | ||
real_a = real_a.reshape( | ||
[real_a.shape[0] * real_a.shape[1]] + list(real_a.shape[2:]) | ||
) | ||
real_b = real_b.reshape( | ||
[real_b.shape[0] * real_b.shape[1]] + list(real_b.shape[2:]) | ||
) | ||
stats_real_a.observe(real_a) | ||
stats_real_b.observe(real_b) | ||
gen_b: np.ndarray = self.generator_a_to_b( | ||
torch.as_tensor(real_a).float().to(DEVICE) | ||
).detach().cpu().numpy() | ||
gen_a: np.ndarray = self.generator_b_to_a( | ||
torch.as_tensor(real_b).float().to(DEVICE) | ||
).detach().cpu().numpy() | ||
stats_gen_a.observe(gen_a) | ||
stats_gen_b.observe(gen_b) | ||
if not reported_plot and plt is not None: | ||
report = {} | ||
for i_tile in range(6): | ||
fig, ax = plt.subplots(2, 2, figsize=(8, 7)) | ||
im = ax[0, 0].pcolormesh(real_a[0, i_tile, 0, :, :]) | ||
plt.colorbar(im, ax=ax[0, 0]) | ||
ax[0, 0].set_title("a_real") | ||
im = ax[1, 0].pcolormesh(real_b[0, i_tile, 0, :, :]) | ||
plt.colorbar(im, ax=ax[1, 0]) | ||
ax[1, 0].set_title("b_real") | ||
im = ax[0, 1].pcolormesh(gen_b[0, i_tile, 0, :, :]) | ||
plt.colorbar(im, ax=ax[0, 1]) | ||
ax[0, 1].set_title("b_gen") | ||
im = ax[1, 1].pcolormesh(gen_a[0, i_tile, 0, :, :]) | ||
plt.colorbar(im, ax=ax[1, 1]) | ||
ax[1, 1].set_title("a_gen") | ||
plt.tight_layout() | ||
buf = io.BytesIO() | ||
plt.savefig(buf, format="png") | ||
plt.close() | ||
buf.seek(0) | ||
report[f"tile_{i_tile}_example"] = wandb.Image( | ||
PIL.Image.open(buf), caption=f"Tile {i_tile} Example", | ||
) | ||
wandb.log(report) | ||
reported_plot = True | ||
metrics = { | ||
"r2_mean_b_against_real_a": get_r2(stats_real_a.mean, stats_gen_b.mean), | ||
"r2_mean_a": get_r2(stats_real_a.mean, stats_gen_a.mean), | ||
"bias_mean_a": np.mean(stats_real_a.mean - stats_gen_a.mean), | ||
"r2_mean_b": get_r2(stats_real_b.mean, stats_gen_b.mean), | ||
"bias_mean_b": np.mean(stats_real_b.mean - stats_gen_b.mean), | ||
"r2_std_a": get_r2(stats_real_a.std, stats_gen_a.std), | ||
"bias_std_a": np.mean(stats_real_a.std - stats_gen_a.std), | ||
"r2_std_b": get_r2(stats_real_b.std, stats_gen_b.std), | ||
"bias_std_b": np.mean(stats_real_b.std - stats_gen_b.std), | ||
} | ||
return metrics | ||
|
||
def train_on_batch( | ||
self, real_a: torch.Tensor, real_b: torch.Tensor | ||
self, real_a: torch.Tensor, real_b: torch.Tensor, training: bool = True | ||
) -> Mapping[str, float]: | ||
""" | ||
Train the CycleGAN on a batch of data. | ||
|
@@ -395,6 +325,8 @@ def train_on_batch( | |
[sample, time, tile, channel, y, x] | ||
real_b: a batch of data from domain B, should have shape | ||
[sample, time, tile, channel, y, x] | ||
training: if True, the model will be trained, otherwise we will | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This allows getting training metrics on validation data. |
||
only evaluate the loss. | ||
""" | ||
# for now there is no time-evolution-based loss, so we fold the time | ||
# dimension into the sample dimension | ||
|
@@ -412,10 +344,11 @@ def train_on_batch( | |
|
||
# Generators A2B and B2A ###### | ||
|
||
# don't update discriminators when training generators to fool them | ||
set_requires_grad( | ||
[self.discriminator_a, self.discriminator_b], requires_grad=False | ||
) | ||
if training: | ||
# don't update discriminators when training generators to fool them | ||
set_requires_grad( | ||
[self.discriminator_a, self.discriminator_b], requires_grad=False | ||
) | ||
|
||
# Identity loss | ||
# G_A2B(B) should equal B if real B is fed | ||
|
@@ -448,16 +381,18 @@ def train_on_batch( | |
|
||
# Total loss | ||
loss_g: torch.Tensor = (loss_identity + loss_gan + loss_cycle) | ||
self.optimizer_generator.zero_grad() | ||
loss_g.backward() | ||
self.optimizer_generator.step() | ||
if training: | ||
self.optimizer_generator.zero_grad() | ||
loss_g.backward() | ||
self.optimizer_generator.step() | ||
|
||
# Discriminators A and B ###### | ||
|
||
# do update discriminators when training them to identify samples | ||
set_requires_grad( | ||
[self.discriminator_a, self.discriminator_b], requires_grad=True | ||
) | ||
if training: | ||
# do update discriminators when training them to identify samples | ||
set_requires_grad( | ||
[self.discriminator_a, self.discriminator_b], requires_grad=True | ||
) | ||
|
||
# Real loss | ||
pred_real = self.discriminator_a(real_a) | ||
|
@@ -466,7 +401,8 @@ def train_on_batch( | |
) | ||
|
||
# Fake loss | ||
fake_a = self.fake_a_buffer.query(fake_a) | ||
if training: | ||
fake_a = self.fake_a_buffer.query(fake_a) | ||
pred_a_fake = self.discriminator_a(fake_a.detach()) | ||
loss_d_a_fake = ( | ||
self.gan_loss(pred_a_fake, self.target_fake) * self.discriminator_weight | ||
|
@@ -479,7 +415,8 @@ def train_on_batch( | |
) | ||
|
||
# Fake loss | ||
fake_b = self.fake_b_buffer.query(fake_b) | ||
if training: | ||
fake_b = self.fake_b_buffer.query(fake_b) | ||
pred_b_fake = self.discriminator_b(fake_b.detach()) | ||
loss_d_b_fake = ( | ||
self.gan_loss(pred_b_fake, self.target_fake) * self.discriminator_weight | ||
|
@@ -490,9 +427,10 @@ def train_on_batch( | |
loss_d_b_real + loss_d_b_fake + loss_d_a_real + loss_d_a_fake | ||
) | ||
|
||
self.optimizer_discriminator.zero_grad() | ||
loss_d.backward() | ||
self.optimizer_discriminator.step() | ||
if training: | ||
self.optimizer_discriminator.zero_grad() | ||
loss_d.backward() | ||
self.optimizer_discriminator.step() | ||
|
||
return { | ||
"b_to_a_gan_loss": float(loss_gan_b_to_a), | ||
|
@@ -504,8 +442,76 @@ def train_on_batch( | |
"generator_loss": float(loss_g), | ||
"discriminator_loss": float(loss_d), | ||
"train_loss": float(loss_g + loss_d), | ||
"regularization_loss": float(loss_cycle + loss_identity), | ||
} | ||
|
||
def generate_plots( | ||
self, real_a: torch.Tensor, real_b: torch.Tensor | ||
) -> Mapping[str, wandb.Image]: | ||
""" | ||
Plot model output on the first sample of a given batch and return it as | ||
a dictionary of wandb.Image objects. | ||
|
||
Args: | ||
real_a: a batch of data from domain A, should have shape | ||
[sample, time, tile, channel, y, x] | ||
real_b: a batch of data from domain B, should have shape | ||
[sample, time, tile, channel, y, x] | ||
""" | ||
# for now there is no time-evolution-based loss, so we fold the time | ||
# dimension into the sample dimension | ||
real_a = real_a.reshape( | ||
[real_a.shape[0] * real_a.shape[1]] + list(real_a.shape[2:]) | ||
) | ||
real_b = real_b.reshape( | ||
[real_b.shape[0] * real_b.shape[1]] + list(real_b.shape[2:]) | ||
) | ||
|
||
# plot the first sample of the batch | ||
with torch.no_grad(): | ||
fake_b = self._call_generator_a_to_b(real_a[:1, :]) | ||
fake_a = self._call_generator_b_to_a(real_b[:1, :]) | ||
real_a = real_a.cpu().numpy() | ||
real_b = real_b.cpu().numpy() | ||
fake_a = fake_a.cpu().numpy() | ||
fake_b = fake_b.cpu().numpy() | ||
report = {} | ||
for i in range(real_a.shape[2]): | ||
var_real_a = to_cross( | ||
xr.DataArray(real_a[0, :, i, :, :], dims=["tile", "grid_yt", "grid_xt"]) | ||
) | ||
var_real_b = to_cross( | ||
xr.DataArray(real_b[0, :, i, :, :], dims=["tile", "grid_yt", "grid_xt"]) | ||
) | ||
var_fake_a = to_cross( | ||
xr.DataArray(fake_a[0, :, i, :, :], dims=["tile", "grid_yt", "grid_xt"]) | ||
) | ||
var_fake_b = to_cross( | ||
xr.DataArray(fake_b[0, :, i, :, :], dims=["tile", "grid_yt", "grid_xt"]) | ||
) | ||
vmin_a = min(np.min(real_a[0, :, i, :, :]), np.min(fake_a[0, :, i, :, :])) | ||
vmax_a = max(np.max(real_a[0, :, i, :, :]), np.max(fake_a[0, :, i, :, :])) | ||
vmin_b = min(np.min(real_b[0, :, i, :, :]), np.min(fake_b[0, :, i, :, :])) | ||
vmax_b = max(np.max(real_b[0, :, i, :, :]), np.max(fake_b[0, :, i, :, :])) | ||
fig, ax = plt.subplots(2, 2, figsize=(8, 7)) | ||
var_real_a.plot(ax=ax[0, 0], vmin=vmin_a, vmax=vmax_a) | ||
var_fake_b.plot(ax=ax[0, 1], vmin=vmin_b, vmax=vmax_b) | ||
var_real_b.plot(ax=ax[1, 0], vmin=vmin_b, vmax=vmax_b) | ||
var_fake_a.plot(ax=ax[1, 1], vmin=vmin_a, vmax=vmax_a) | ||
ax[0, 0].set_title("real_a") | ||
ax[0, 1].set_title("fake_b") | ||
ax[1, 0].set_title("real_b") | ||
ax[1, 1].set_title("fake_a") | ||
|
||
buf = io.BytesIO() | ||
plt.savefig(buf, format="png") | ||
plt.close(fig) | ||
buf.seek(0) | ||
report[f"example_{i}"] = wandb.Image( | ||
PIL.Image.open(buf), caption=f"Channel {i} Example", | ||
) | ||
return report | ||
|
||
|
||
def set_requires_grad(nets: List[torch.nn.Module], requires_grad=False): | ||
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Mapping, Any | ||
|
||
|
||
class Reporter: | ||
""" | ||
Helper class to combine reported metrics to be sent to wandb. | ||
""" | ||
|
||
def __init__(self): | ||
self.metrics = {} | ||
|
||
def log(self, kwargs: Mapping[str, Any]): | ||
self.metrics.update(kwargs) | ||
|
||
def clear(self): | ||
self.metrics.clear() | ||
|
||
|
||
class NullReporter(Reporter): | ||
""" | ||
Reporter that does nothing. | ||
""" | ||
|
||
def log(self, kwargs: Mapping[str, Any]): | ||
pass | ||
|
||
def clear(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like black to run before flake8 so that it can auto-fix flake8 issues before flake8 runs.