Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Report CycleGAN validation metrics correctly to wandb #2131

Open
wants to merge 87 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
1379966
report validation metrics correctly to wandb
mcgibbon Jan 5, 2023
3c750df
restore sample plotting to cyclegan training
mcgibbon Jan 5, 2023
fcae3c5
consolidate optimizer configs for generator and discriminator
mcgibbon Jan 5, 2023
797af2d
add explicit transfer to cpu numpy array before plotting
mcgibbon Jan 6, 2023
54261a9
add wandb sweep workflow, fix cyclegan saving on gs
mcgibbon Jan 11, 2023
9dcb1e8
add docstring for reload_path
mcgibbon Jan 11, 2023
f5a428b
fix _load_pytorch for models on cloud
mcgibbon Jan 12, 2023
090ec0b
commit current project data for cyclegan
mcgibbon Jan 12, 2023
daefd94
use consistent vmin/vmax within domains a and b for cyclegan plotting
Jan 12, 2023
f4825cf
add (untested) cross-plots for all channels in cyclegan
Jan 13, 2023
946bd2d
add ability to save checkpoint models to CycleGAN training
Jan 13, 2023
1a381a4
use three digits, not two, for epoch label
Jan 13, 2023
e62f047
updates to cyclegan projects folder
Jan 13, 2023
8c40e9b
Merge branch 'master' into feature/improve_wandb_cyclegan_reporting
mcgibbon Jan 13, 2023
d5e12f0
Merge branch 'master' into feature/improve_wandb_cyclegan_reporting
mcgibbon Jan 13, 2023
9678ab4
remove output folder from repo
mcgibbon Jan 13, 2023
a12557e
fix path joining for external urls when saving pytorch models
mcgibbon Jan 17, 2023
acb06ba
fix plotting transpose, fix path join when saving models
mcgibbon Jan 17, 2023
5ce73aa
fix generator for cyclegan to work for odd strided kernel size
mcgibbon Jan 18, 2023
c59bd99
reduce memory footprint of cyclegan test
mcgibbon Jan 18, 2023
e688e7f
commit project files to xfer systems
mcgibbon Jan 20, 2023
b30bbb9
add histogram plotting to cyclegan training
mcgibbon Jan 23, 2023
9c0c3b4
remove instance norm from first layer of discriminator in cyclegan, t…
mcgibbon Jan 24, 2023
c7ff935
use inline conv definition instead of convblock without norm
mcgibbon Jan 24, 2023
66de2b2
update tfdataset loader to work when sample dimension exists
mcgibbon Jan 30, 2023
4956f2a
update cyclegan project files
mcgibbon Feb 2, 2023
36fdb59
add mean bias plots to cyclegan inline diagnostics
mcgibbon Feb 3, 2023
b5959bd
update project files, fix validation data config bug
mcgibbon Feb 3, 2023
a61ba13
updating project files
mcgibbon Feb 13, 2023
c624051
revert online bias plots, add non-negativity loss
mcgibbon Feb 13, 2023
34f554a
add pattern bias metrics during training
mcgibbon Feb 14, 2023
67ceb48
initialize geographic bias with random small values
mcgibbon Feb 16, 2023
b064a30
update GeographicBias so init will actually be used
mcgibbon Feb 16, 2023
efc393c
Revert "update GeographicBias so init will actually be used"
mcgibbon Feb 16, 2023
07bb58a
Revert "initialize geographic bias with random small values"
mcgibbon Feb 16, 2023
06c50b7
add use_geographic_features option to cyclegan training
mcgibbon Feb 17, 2023
c5801eb
add online histogram aggregation to cyclegan training
mcgibbon Feb 23, 2023
b81bbd7
add plumbing to propagate time of samples into cyclegan training
mcgibbon Mar 1, 2023
cb2633b
fix bug where training data was used for validation losses
mcgibbon Mar 1, 2023
dd41851
Merge branch 'master' into feature/improve_wandb_cyclegan_reporting
mcgibbon Mar 2, 2023
b924e45
fix cyclegan stacking so it retains requested order, decode time to u…
mcgibbon Mar 2, 2023
c0fdf3d
update cyclegan to include time_x and time_y in geographic features
mcgibbon Mar 2, 2023
923d21c
delete autoencoder model so we don't have to refactor it, was used fo…
mcgibbon Mar 2, 2023
b81d687
fix training part of cyclegan test (reloadable still needs fixing)
mcgibbon Mar 2, 2023
5395b02
update cyclegan reloadable to work with time input
mcgibbon Mar 2, 2023
3c6b68f
fix fmr model by reverting to pre-cyclegan-changes
mcgibbon Mar 3, 2023
6ce1ae0
add percentile metrics, fix bug where normalization constants are cha…
mcgibbon Mar 9, 2023
a16d3b9
log checkpoint save path for training run
mcgibbon Mar 9, 2023
398d35f
multiply time feature by cos(lat) to remove polar discontinuities
mcgibbon Mar 13, 2023
cdaf607
remove deleted reload_path attribute from docstring in CycleGANNetwor…
mcgibbon Mar 13, 2023
3918f33
add use_geographic_embedded_bias option to generator and discriminato…
mcgibbon Mar 13, 2023
700ac4e
restore backwards compatibility for GeographicBias layer
mcgibbon Mar 13, 2023
1a741b8
restore backwards compatibility for GeographicFeatures layer
mcgibbon Mar 13, 2023
b07c782
Merge branch 'feature/improve_wandb_cyclegan_reporting' of github.com…
mcgibbon Mar 13, 2023
52f6f0e
include aggregator metrics for validation data
mcgibbon Mar 13, 2023
4ef65ff
remove wait on unit tests for pytorch image compilation
mcgibbon Mar 13, 2023
9f93c24
remove unnecessary FoldFirstDimension on GeographicBias
mcgibbon Mar 14, 2023
a0e3e15
fix userwarning from target shape not matching final batch shape
mcgibbon Mar 14, 2023
90343a3
reintroduce bug with training data used as validation data, for testi…
mcgibbon Mar 14, 2023
916d8e2
remove bug with training data used as validation data
mcgibbon Mar 14, 2023
8871235
add SchedulerConfig to cyclegan training
mcgibbon Mar 14, 2023
11104db
delte unused validation_batch_size configuration setting
mcgibbon Mar 14, 2023
a659aa8
update cyclegan project files
mcgibbon Mar 15, 2023
5c76788
Merge branch 'master' into feature/improve_wandb_cyclegan_reporting
mcgibbon Mar 15, 2023
868b24e
update project files
mcgibbon Mar 16, 2023
2758fd0
Merge branch 'master' into feature/improve_wandb_cyclegan_reporting
mcgibbon Mar 16, 2023
c765aba
update project files
mcgibbon Mar 17, 2023
bdeb021
update project files
mcgibbon Mar 20, 2023
9afd0b3
Merge branch 'master' into feature/improve_wandb_cyclegan_reporting
mcgibbon Mar 22, 2023
85a912d
update project files
mcgibbon Mar 23, 2023
37297df
Merge branch 'master' into feature/improve_wandb_cyclegan_reporting
mcgibbon Mar 28, 2023
1bb13ca
fix cyclegan when use_geographic_features=False
mcgibbon Mar 28, 2023
6fbed6e
fix fmr model breakage from main merge
mcgibbon Mar 28, 2023
1806d88
fix linting error in ramping.py
mcgibbon Mar 28, 2023
8104460
fix time handling for multi-climate data
mcgibbon Mar 29, 2023
9acfee6
update project files with aggregate processing
mcgibbon Apr 17, 2023
7014221
add 1e-5 and 1e-6 percentiles to metrics
mcgibbon Apr 18, 2023
9280a54
update project files, daily mean vals for histogram
mcgibbon Apr 18, 2023
3dade60
update project files, ramping uses new data
mcgibbon Apr 21, 2023
cf1099f
add disable_temporal_features option for cyclegan
mcgibbon May 16, 2023
9c4f5fb
update project files
mcgibbon May 17, 2023
e71114f
add option for whether to plot colorbar in UpdateablePColormesh
mcgibbon May 22, 2023
2a50db1
allow training model with perturbation as context
mcgibbon May 23, 2023
6c55fcd
import Protocol from typing instead of typing_extensions
mcgibbon May 23, 2023
3b2d231
update Reloadable to work for new with-perturbation-context model
mcgibbon May 26, 2023
26de524
update project files
mcgibbon Jun 22, 2023
5f5913e
Merge branch 'master' into feature/improve_wandb_cyclegan_reporting
mcgibbon Jun 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
exclude: "external/gcsfs/"
repos:
- repo: https://github.com/psf/black
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like black to run before flake8 so that it can auto-fix flake8 issues before flake8 runs.

rev: 19.10b0
hooks:
- id: black
additional_dependencies: ["click==8.0.4"]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
Expand All @@ -20,11 +25,6 @@ repos:
files: "__init__.py"
# ignore unused import error in __init__.py files
args: ["--ignore=F401,E203", --config, setup.cfg]
- repo: https://github.com/psf/black
rev: 19.10b0
hooks:
- id: black
additional_dependencies: ["click==8.0.4"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.770
hooks:
Expand Down
208 changes: 107 additions & 101 deletions external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Dict, List, Literal, Mapping, Tuple, Optional
import tensorflow as tf
from typing import List, Literal, Mapping, Tuple, Optional
from fv3fit._shared.scaler import StandardScaler
from .reloadable import CycleGAN, CycleGANModule
import torch
Expand All @@ -16,6 +15,8 @@
from fv3fit import wandb
import io
import PIL
import xarray as xr
from vcm.cubedsphere import to_cross

try:
import matplotlib.pyplot as plt
Expand All @@ -33,10 +34,8 @@ class CycleGANNetworkConfig:
Configuration for building and training a CycleGAN network.

Attributes:
generator_optimizer: configuration for the optimizer used to train the
generator
discriminator_optimizer: configuration for the optimizer used to train the
discriminator
optimizer: configuration for the optimizer used to train the
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging these was necessary so a wandb sweep can operate on the learning rate, there's no way to pair two hyperparameters for wandb sweeps.

generator and discriminator
generator: configuration for building the generator network
discriminator: configuration for building the discriminator network
identity_loss: loss function used to make the generator which outputs
Expand All @@ -51,12 +50,11 @@ class CycleGANNetworkConfig:
cycle_weight: weight of the cycle loss
generator_weight: weight of the generator's gan loss
discriminator_weight: weight of the discriminator gan loss
reload_path: path to a directory containing a saved CycleGAN model to use
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just a missing docstring entry.

as a starting point for training
"""

generator_optimizer: OptimizerConfig = dataclasses.field(
default_factory=lambda: OptimizerConfig("Adam")
)
discriminator_optimizer: OptimizerConfig = dataclasses.field(
optimizer: OptimizerConfig = dataclasses.field(
default_factory=lambda: OptimizerConfig("Adam")
)
generator: "GeneratorConfig" = dataclasses.field(
Expand Down Expand Up @@ -92,12 +90,12 @@ def build(
)
discriminator_a = self.discriminator.build(n_state, convolution=convolution)
discriminator_b = self.discriminator.build(n_state, convolution=convolution)
optimizer_generator = self.generator_optimizer.instance(
optimizer_generator = self.optimizer.instance(
itertools.chain(
generator_a_to_b.parameters(), generator_b_to_a.parameters()
)
)
optimizer_discriminator = self.discriminator_optimizer.instance(
optimizer_discriminator = self.optimizer.instance(
itertools.chain(discriminator_a.parameters(), discriminator_b.parameters())
)
model = CycleGANModule(
Expand Down Expand Up @@ -284,26 +282,26 @@ def __post_init__(self):
self._script_disc_a = None
self._script_disc_b = None

def _call_generator_a_to_b(self, input):
def _call_generator_a_to_b(self, input: torch.Tensor) -> torch.Tensor:
if self._script_gen_a_to_b is None:
self._script_gen_a_to_b = torch.jit.trace(
self.generator_a_to_b.forward, input
)
return self._script_gen_a_to_b(input)

def _call_generator_b_to_a(self, input):
def _call_generator_b_to_a(self, input: torch.Tensor) -> torch.Tensor:
if self._script_gen_b_to_a is None:
self._script_gen_b_to_a = torch.jit.trace(
self.generator_b_to_a.forward, input
)
return self._script_gen_b_to_a(input)

def _call_discriminator_a(self, input):
def _call_discriminator_a(self, input: torch.Tensor) -> torch.Tensor:
if self._script_disc_a is None:
self._script_disc_a = torch.jit.trace(self.discriminator_a.forward, input)
return self._script_disc_a(input)

def _call_discriminator_b(self, input):
def _call_discriminator_b(self, input: torch.Tensor) -> torch.Tensor:
if self._script_disc_b is None:
self._script_disc_b = torch.jit.trace(self.discriminator_b.forward, input)
return self._script_disc_b(input)
Expand All @@ -316,76 +314,8 @@ def _init_targets(self, shape: Tuple[int, ...]):
torch.Tensor(shape).fill_(0.0).to(DEVICE), requires_grad=False
)

def evaluate_on_dataset(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function was never actually used (it was called, but only on validation data, and I never provided validation datasets before now).

self, dataset: tf.data.Dataset, n_dims_keep: int = 3
) -> Dict[str, float]:
stats_real_a = StatsCollector(n_dims_keep)
stats_real_b = StatsCollector(n_dims_keep)
stats_gen_a = StatsCollector(n_dims_keep)
stats_gen_b = StatsCollector(n_dims_keep)
real_a: np.ndarray
real_b: np.ndarray
reported_plot = False
for real_a, real_b in dataset:
# for now there is no time-evolution-based loss, so we fold the time
# dimension into the sample dimension
real_a = real_a.reshape(
[real_a.shape[0] * real_a.shape[1]] + list(real_a.shape[2:])
)
real_b = real_b.reshape(
[real_b.shape[0] * real_b.shape[1]] + list(real_b.shape[2:])
)
stats_real_a.observe(real_a)
stats_real_b.observe(real_b)
gen_b: np.ndarray = self.generator_a_to_b(
torch.as_tensor(real_a).float().to(DEVICE)
).detach().cpu().numpy()
gen_a: np.ndarray = self.generator_b_to_a(
torch.as_tensor(real_b).float().to(DEVICE)
).detach().cpu().numpy()
stats_gen_a.observe(gen_a)
stats_gen_b.observe(gen_b)
if not reported_plot and plt is not None:
report = {}
for i_tile in range(6):
fig, ax = plt.subplots(2, 2, figsize=(8, 7))
im = ax[0, 0].pcolormesh(real_a[0, i_tile, 0, :, :])
plt.colorbar(im, ax=ax[0, 0])
ax[0, 0].set_title("a_real")
im = ax[1, 0].pcolormesh(real_b[0, i_tile, 0, :, :])
plt.colorbar(im, ax=ax[1, 0])
ax[1, 0].set_title("b_real")
im = ax[0, 1].pcolormesh(gen_b[0, i_tile, 0, :, :])
plt.colorbar(im, ax=ax[0, 1])
ax[0, 1].set_title("b_gen")
im = ax[1, 1].pcolormesh(gen_a[0, i_tile, 0, :, :])
plt.colorbar(im, ax=ax[1, 1])
ax[1, 1].set_title("a_gen")
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close()
buf.seek(0)
report[f"tile_{i_tile}_example"] = wandb.Image(
PIL.Image.open(buf), caption=f"Tile {i_tile} Example",
)
wandb.log(report)
reported_plot = True
metrics = {
"r2_mean_b_against_real_a": get_r2(stats_real_a.mean, stats_gen_b.mean),
"r2_mean_a": get_r2(stats_real_a.mean, stats_gen_a.mean),
"bias_mean_a": np.mean(stats_real_a.mean - stats_gen_a.mean),
"r2_mean_b": get_r2(stats_real_b.mean, stats_gen_b.mean),
"bias_mean_b": np.mean(stats_real_b.mean - stats_gen_b.mean),
"r2_std_a": get_r2(stats_real_a.std, stats_gen_a.std),
"bias_std_a": np.mean(stats_real_a.std - stats_gen_a.std),
"r2_std_b": get_r2(stats_real_b.std, stats_gen_b.std),
"bias_std_b": np.mean(stats_real_b.std - stats_gen_b.std),
}
return metrics

def train_on_batch(
self, real_a: torch.Tensor, real_b: torch.Tensor
self, real_a: torch.Tensor, real_b: torch.Tensor, training: bool = True
) -> Mapping[str, float]:
"""
Train the CycleGAN on a batch of data.
Expand All @@ -395,6 +325,8 @@ def train_on_batch(
[sample, time, tile, channel, y, x]
real_b: a batch of data from domain B, should have shape
[sample, time, tile, channel, y, x]
training: if True, the model will be trained, otherwise we will
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows getting training metrics on validation data.

only evaluate the loss.
"""
# for now there is no time-evolution-based loss, so we fold the time
# dimension into the sample dimension
Expand All @@ -412,10 +344,11 @@ def train_on_batch(

# Generators A2B and B2A ######

# don't update discriminators when training generators to fool them
set_requires_grad(
[self.discriminator_a, self.discriminator_b], requires_grad=False
)
if training:
# don't update discriminators when training generators to fool them
set_requires_grad(
[self.discriminator_a, self.discriminator_b], requires_grad=False
)

# Identity loss
# G_A2B(B) should equal B if real B is fed
Expand Down Expand Up @@ -448,16 +381,18 @@ def train_on_batch(

# Total loss
loss_g: torch.Tensor = (loss_identity + loss_gan + loss_cycle)
self.optimizer_generator.zero_grad()
loss_g.backward()
self.optimizer_generator.step()
if training:
self.optimizer_generator.zero_grad()
loss_g.backward()
self.optimizer_generator.step()

# Discriminators A and B ######

# do update discriminators when training them to identify samples
set_requires_grad(
[self.discriminator_a, self.discriminator_b], requires_grad=True
)
if training:
# do update discriminators when training them to identify samples
set_requires_grad(
[self.discriminator_a, self.discriminator_b], requires_grad=True
)

# Real loss
pred_real = self.discriminator_a(real_a)
Expand All @@ -466,7 +401,8 @@ def train_on_batch(
)

# Fake loss
fake_a = self.fake_a_buffer.query(fake_a)
if training:
fake_a = self.fake_a_buffer.query(fake_a)
pred_a_fake = self.discriminator_a(fake_a.detach())
loss_d_a_fake = (
self.gan_loss(pred_a_fake, self.target_fake) * self.discriminator_weight
Expand All @@ -479,7 +415,8 @@ def train_on_batch(
)

# Fake loss
fake_b = self.fake_b_buffer.query(fake_b)
if training:
fake_b = self.fake_b_buffer.query(fake_b)
pred_b_fake = self.discriminator_b(fake_b.detach())
loss_d_b_fake = (
self.gan_loss(pred_b_fake, self.target_fake) * self.discriminator_weight
Expand All @@ -490,9 +427,10 @@ def train_on_batch(
loss_d_b_real + loss_d_b_fake + loss_d_a_real + loss_d_a_fake
)

self.optimizer_discriminator.zero_grad()
loss_d.backward()
self.optimizer_discriminator.step()
if training:
self.optimizer_discriminator.zero_grad()
loss_d.backward()
self.optimizer_discriminator.step()

return {
"b_to_a_gan_loss": float(loss_gan_b_to_a),
Expand All @@ -504,8 +442,76 @@ def train_on_batch(
"generator_loss": float(loss_g),
"discriminator_loss": float(loss_d),
"train_loss": float(loss_g + loss_d),
"regularization_loss": float(loss_cycle + loss_identity),
}

def generate_plots(
self, real_a: torch.Tensor, real_b: torch.Tensor
) -> Mapping[str, wandb.Image]:
"""
Plot model output on the first sample of a given batch and return it as
a dictionary of wandb.Image objects.

Args:
real_a: a batch of data from domain A, should have shape
[sample, time, tile, channel, y, x]
real_b: a batch of data from domain B, should have shape
[sample, time, tile, channel, y, x]
"""
# for now there is no time-evolution-based loss, so we fold the time
# dimension into the sample dimension
real_a = real_a.reshape(
[real_a.shape[0] * real_a.shape[1]] + list(real_a.shape[2:])
)
real_b = real_b.reshape(
[real_b.shape[0] * real_b.shape[1]] + list(real_b.shape[2:])
)

# plot the first sample of the batch
with torch.no_grad():
fake_b = self._call_generator_a_to_b(real_a[:1, :])
fake_a = self._call_generator_b_to_a(real_b[:1, :])
real_a = real_a.cpu().numpy()
real_b = real_b.cpu().numpy()
fake_a = fake_a.cpu().numpy()
fake_b = fake_b.cpu().numpy()
report = {}
for i in range(real_a.shape[2]):
var_real_a = to_cross(
xr.DataArray(real_a[0, :, i, :, :], dims=["tile", "grid_yt", "grid_xt"])
)
var_real_b = to_cross(
xr.DataArray(real_b[0, :, i, :, :], dims=["tile", "grid_yt", "grid_xt"])
)
var_fake_a = to_cross(
xr.DataArray(fake_a[0, :, i, :, :], dims=["tile", "grid_yt", "grid_xt"])
)
var_fake_b = to_cross(
xr.DataArray(fake_b[0, :, i, :, :], dims=["tile", "grid_yt", "grid_xt"])
)
vmin_a = min(np.min(real_a[0, :, i, :, :]), np.min(fake_a[0, :, i, :, :]))
vmax_a = max(np.max(real_a[0, :, i, :, :]), np.max(fake_a[0, :, i, :, :]))
vmin_b = min(np.min(real_b[0, :, i, :, :]), np.min(fake_b[0, :, i, :, :]))
vmax_b = max(np.max(real_b[0, :, i, :, :]), np.max(fake_b[0, :, i, :, :]))
fig, ax = plt.subplots(2, 2, figsize=(8, 7))
var_real_a.plot(ax=ax[0, 0], vmin=vmin_a, vmax=vmax_a)
var_fake_b.plot(ax=ax[0, 1], vmin=vmin_b, vmax=vmax_b)
var_real_b.plot(ax=ax[1, 0], vmin=vmin_b, vmax=vmax_b)
var_fake_a.plot(ax=ax[1, 1], vmin=vmin_a, vmax=vmax_a)
ax[0, 0].set_title("real_a")
ax[0, 1].set_title("fake_b")
ax[1, 0].set_title("real_b")
ax[1, 1].set_title("fake_a")

buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close(fig)
buf.seek(0)
report[f"example_{i}"] = wandb.Image(
PIL.Image.open(buf), caption=f"Channel {i} Example",
)
return report


def set_requires_grad(nets: List[torch.nn.Module], requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Expand Down
28 changes: 28 additions & 0 deletions external/fv3fit/fv3fit/pytorch/cyclegan/reporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Mapping, Any


class Reporter:
"""
Helper class to combine reported metrics to be sent to wandb.
"""

def __init__(self):
self.metrics = {}

def log(self, kwargs: Mapping[str, Any]):
self.metrics.update(kwargs)

def clear(self):
self.metrics.clear()


class NullReporter(Reporter):
"""
Reporter that does nothing.
"""

def log(self, kwargs: Mapping[str, Any]):
pass

def clear(self):
pass
7 changes: 1 addition & 6 deletions external/fv3fit/fv3fit/pytorch/cyclegan/test_cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,8 @@ def test_cyclegan_runs_without_errors(tmpdir, conv_type: str, regtest):
generator=fv3fit.pytorch.GeneratorConfig(
n_convolutions=2, n_resnet=5, max_filters=128, kernel_size=3
),
generator_optimizer=fv3fit.pytorch.OptimizerConfig(
name="Adam", kwargs={"lr": 0.001}
),
optimizer=fv3fit.pytorch.OptimizerConfig(name="Adam", kwargs={"lr": 0.001}),
discriminator=fv3fit.pytorch.DiscriminatorConfig(kernel_size=3),
discriminator_optimizer=fv3fit.pytorch.OptimizerConfig(
name="Adam", kwargs={"lr": 0.001}
),
convolution_type=conv_type,
identity_weight=0.01,
cycle_weight=10.0,
Expand Down
Loading