Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions config/evaluate/diffusion_overfitting.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#optional: if commented out all is taken care of by the default settings
# NB. global options apply to all run_ids
#global_plotting_options:
# image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" ..
# dpi_val : 300
# ERA5:
# marker_size: 2
# scale_marker_size: 1
# marker: "o"
# # alpha: 0.5
# 2t:
# vmin: 250
# vmax: 300
# 10u:
# vmin: -40
# vmax: 40

evaluation:
metrics : ["rmse", "mae"]
regions: ["global", "nhem"]
summary_plots : true
summary_dir: "./plots/"
plot_ensemble: "members" #supported: false, "std", "minmax", "members"
plot_score_maps: false #plot scores on a 2D maps. it slows down score computation
print_summary: false #print out score values on screen. it can be verbose
log_scale: false
add_grid: false
score_cards: false
num_processes: 0 #options: int, "auto", 0 means no parallelism (default)

run_ids :
vpqg83lz:
label: "overfit model vpqg83lz"
results_base_dir : "./results/"
mini_epoch: 0
rank: 0
streams:
ERA5:
channels: ["2t", "10u"] #, "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ]
evaluation:
forecast_step: "all"
sample: "all"
ensemble: "all" #supported: "all", "mean", [0,1,2]
plotting:
sample: [1]
forecast_step: [1]
ensemble: "all" #supported: "all", "mean", [0,1,2]
plot_maps: true
plot_histograms: true
plot_animations: true


23 changes: 16 additions & 7 deletions src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,22 +427,27 @@ def _prepare_logging(
]

# assert len(targets_rt) == len(preds) and len(preds) == len(self.cf.streams)
fsteps = len(targets_rt)

preds = preds[0]
preds_all: list[list[list[NDArray]]] = [
[[] for _ in self.cf.streams] for _ in range(fsteps)
[[] for _ in self.cf.streams] for _ in forecast_range
]
targets_all: list[list[list[NDArray]]] = [
[[] for _ in self.cf.streams] for _ in range(fsteps)
[[] for _ in self.cf.streams] for _ in forecast_range
]
targets_lens: list[list[list[int]]] = [[[] for _ in self.cf.streams] for _ in range(fsteps)]
targets_lens: list[list[list[int]]] = [[[] for _ in self.cf.streams] for _ in forecast_range]

# TODO: iterate over batches here in future, and change loop order to batch, stream, fstep
for fstep in range(len(targets_rt)):
print(f'forecast range is {forecast_range}')
print(f'range of targets_rt is {range(len(targets_rt))}')
for fstep in forecast_range:
print(f'Processing forecast step {fstep}...')
print(len(preds.physical))
if len(preds.physical[fstep]) == 0:
continue

for i_strm, target in enumerate(targets_rt[fstep]):
pred = preds[fstep][i_strm]
pred = preds.physical[fstep][i_strm]
idxs_inv = idxs_inv_rt[fstep][i_strm]

if not (target.shape[0] > 0 and pred.shape[0] > 0):
Expand Down Expand Up @@ -838,7 +843,11 @@ def validate(self, mini_epoch):
# log output
if bidx < cf.log_validation:
# TODO: Move _prepare_logging into write_validation by passing streams_data
streams_data: list[list[StreamData]] = batch[0]
streams_data: list[list[StreamData]] = batch[0][0]
for i in streams_data:
print(f"unpacking {i}")
for f in i:
print(f)
(
preds_all,
targets_all,
Expand Down