-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dgmr #813
Dgmr #813
Changes from 15 commits
6ee9b31
5c637f5
bb6f197
b6f4f09
6f0ae70
7a09656
f1da4e3
c51f544
752719d
7660b62
ff619e4
2e05dae
63db897
0760db2
cb202a9
1d06768
a772306
7802eca
a5c524b
ac227a7
d81a68f
cf7645b
5f90e57
239b37a
4940b9f
fba8ca0
cba2e8a
f451fa3
01d2184
2f01f33
ee21c3d
cf6c8be
4a67425
086e670
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
- USCNN | ||
- NowcastNet | ||
- HEDeepONets | ||
- DGMR | ||
- ChipDeepONets | ||
- AutoEncoder | ||
show_root_heading: true | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,4 +23,5 @@ | |
- MeshCylinderDataset | ||
- RadarDataset | ||
- build_dataset | ||
- DGMRDataset | ||
show_root_heading: true |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
hydra: | ||
run: | ||
# dynamic output directory according to running time and override name | ||
dir: outputs_dgmr/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} | ||
job: | ||
name: ${mode} # name of logfile | ||
chdir: false # keep current working direcotry unchaned | ||
config: | ||
override_dirname: | ||
exclude_keys: | ||
- TRAIN.checkpoint_path | ||
- TRAIN.pretrained_model_path | ||
- EVAL.pretrained_model_path | ||
- mode | ||
- output_dir | ||
- log_freq | ||
sweep: | ||
# output directory for multirun | ||
dir: ${hydra.run.dir} | ||
subdir: ./ | ||
|
||
# general settings | ||
mode: eval # running mode: train/eval | ||
seed: 42 | ||
output_dir: ${hydra:run.dir} | ||
|
||
# dataset settings | ||
DATASET: | ||
input_keys: 'input_frames' | ||
label_keys: 'target_frames' | ||
split: validation # train or validation | ||
NUM_INPUT_FRAMES: 4 | ||
NUM_TARGET_FRAMES: 18 | ||
dataset_path: /workspace/workspace/skillful_nowcasting/openclimatefix/nimrod-uk-1km | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 使用相对路径 |
||
number: 1 | ||
|
||
# dataLoader settings | ||
DATALOADER: | ||
batch_size: 1 | ||
shuffle: False | ||
num_workers: 1 | ||
drop_last: True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dataloader配置放py文件,不需要写在这里 |
||
|
||
# model settings | ||
MODEL: | ||
input_keys: ['input_frames'] | ||
output_keys: ['future_images'] | ||
Comment on lines
+41
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 去掉中括号 |
||
forecast_steps: 18 | ||
input_channels: 1 | ||
output_shape: 256 | ||
gen_lr: 5e-05 | ||
disc_lr: 0.0002 | ||
visualize: False | ||
conv_type: 'standard' | ||
num_samples: 6 | ||
grid_lambda: 20.0 | ||
beta1: 0.0 | ||
beta2: 0.999 | ||
latent_channels: 768 | ||
context_channels: 384 | ||
generation_steps: 6 | ||
|
||
# evaluation settings | ||
EVAL: | ||
pretrained_model_path: openclimatefix/paddle/paddle_model.pdparams |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring格式不规范,重新用vscode插件生成下 |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,208 @@ | ||||||||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||||||||||
|
||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||
# you may not use this file except in compliance with the License. | ||||||||||
# You may obtain a copy of the License at | ||||||||||
|
||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||
|
||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||
# See the License for the specific language governing permissions and | ||||||||||
# limitations under the License. | ||||||||||
|
||||||||||
""" | ||||||||||
Reference: https://github.com/openclimatefix/skillful_nowcasting | ||||||||||
""" | ||||||||||
from os import path as osp | ||||||||||
|
||||||||||
import hydra | ||||||||||
import matplotlib.pyplot as plt | ||||||||||
import numpy as np | ||||||||||
import paddle | ||||||||||
from omegaconf import DictConfig | ||||||||||
|
||||||||||
import ppsci | ||||||||||
from ppsci.utils import logger | ||||||||||
|
||||||||||
|
||||||||||
def visualize( | ||||||||||
cfg: DictConfig, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 传一个 |
||||||||||
x: paddle.Tensor, | ||||||||||
y: paddle.Tensor, | ||||||||||
y_hat: paddle.Tensor, | ||||||||||
batch_idx: int, | ||||||||||
) -> None: | ||||||||||
images = x[0] | ||||||||||
future_images = y[0] | ||||||||||
generated_images = y_hat[0] | ||||||||||
fig, axes = plt.subplots(2, 2) | ||||||||||
for i, ax in enumerate(axes.flat): | ||||||||||
alpha = images[i][0].numpy() | ||||||||||
alpha[alpha < 1] = 0 | ||||||||||
alpha[alpha > 1] = 1 | ||||||||||
ax.imshow(images[i].transpose([1, 2, 0]).numpy(), alpha=alpha, cmap="viridis") | ||||||||||
ax.axis("off") | ||||||||||
plt.subplots_adjust(hspace=0.1, wspace=0.1) | ||||||||||
plt.savefig(osp.join(cfg.output_dir, "Input_Image_Stack_Frame.png")) | ||||||||||
fig, axes = plt.subplots(3, 3) | ||||||||||
for i, ax in enumerate(axes.flat): | ||||||||||
alpha = future_images[i][0].numpy() | ||||||||||
alpha[alpha < 1] = 0 | ||||||||||
alpha[alpha > 1] = 1 | ||||||||||
ax.imshow( | ||||||||||
future_images[i].transpose([1, 2, 0]).numpy(), alpha=alpha, cmap="viridis" | ||||||||||
) | ||||||||||
plt.subplots_adjust(hspace=0.1, wspace=0.1) | ||||||||||
plt.savefig(osp.join(cfg.output_dir, "Target_Image_Frame.png")) | ||||||||||
fig, axes = plt.subplots(3, 3) | ||||||||||
for i, ax in enumerate(axes.flat): | ||||||||||
alpha = generated_images[i][0].numpy() | ||||||||||
alpha[alpha < 1] = 0 | ||||||||||
alpha[alpha > 1] = 1 | ||||||||||
ax.imshow( | ||||||||||
generated_images[i].transpose([1, 2, 0]).numpy(), | ||||||||||
alpha=alpha, | ||||||||||
cmap="viridis", | ||||||||||
) | ||||||||||
ax.axis("off") | ||||||||||
plt.subplots_adjust(hspace=0.1, wspace=0.1) | ||||||||||
plt.savefig(osp.join(cfg.output_dir, "Generated_Image_Frame.png")) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logger打印下保存路径 |
||||||||||
|
||||||||||
|
||||||||||
def validation(solver, batch, batch_idx): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加上typehint和docstring |
||||||||||
images, future_images = batch | ||||||||||
images_value = list(images.values())[0] | ||||||||||
future_images_value = list(future_images.values())[0] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为什么不用字符串进行索引,而是默认取dict的第一个? |
||||||||||
# Two discriminator steps per generator step | ||||||||||
for _ in range(2): | ||||||||||
predictions = solver.predict(images) | ||||||||||
predictions_value = list(predictions.values())[0] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上, |
||||||||||
generated_sequence = paddle.concat(x=[images_value, predictions_value], axis=1) | ||||||||||
real_sequence = paddle.concat(x=[images_value, future_images_value], axis=1) | ||||||||||
concatenated_inputs = paddle.concat( | ||||||||||
x=[real_sequence, generated_sequence], axis=0 | ||||||||||
) | ||||||||||
concatenated_outputs = solver.model.discriminator(concatenated_inputs) | ||||||||||
score_real, score_generated = paddle.split( | ||||||||||
x=concatenated_outputs, | ||||||||||
num_or_sections=[real_sequence.shape[0], generated_sequence.shape[0]], | ||||||||||
axis=0, | ||||||||||
) | ||||||||||
score_real_spatial, score_real_temporal = paddle.split( | ||||||||||
x=score_real, num_or_sections=score_real.shape[1], axis=1 | ||||||||||
) | ||||||||||
score_generated_spatial, score_generated_temporal = paddle.split( | ||||||||||
x=score_generated, num_or_sections=score_generated.shape[1], axis=1 | ||||||||||
) | ||||||||||
discriminator_loss = loss_hinge_disc( | ||||||||||
score_generated_spatial, score_real_spatial | ||||||||||
) + loss_hinge_disc(score_generated_temporal, score_real_temporal) | ||||||||||
|
||||||||||
predictions_value = [list(solver.predict(images).values())[0] for _ in range(6)] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||||||||||
grid_cell_reg = grid_cell_regularizer( | ||||||||||
paddle.stack(x=predictions_value, axis=0), future_images_value | ||||||||||
) | ||||||||||
generated_sequence = [ | ||||||||||
paddle.concat(x=[images_value, x], axis=1) for x in predictions_value | ||||||||||
] | ||||||||||
real_sequence = paddle.concat(x=[images_value, future_images_value], axis=1) | ||||||||||
generated_scores = [] | ||||||||||
for g_seq in generated_sequence: | ||||||||||
concatenated_inputs = paddle.concat(x=[real_sequence, g_seq], axis=0) | ||||||||||
concatenated_outputs = solver.model.discriminator(concatenated_inputs) | ||||||||||
score_real, score_generated = paddle.split( | ||||||||||
x=concatenated_outputs, | ||||||||||
num_or_sections=[real_sequence.shape[0], g_seq.shape[0]], | ||||||||||
axis=0, | ||||||||||
) | ||||||||||
generated_scores.append(score_generated) | ||||||||||
generator_disc_loss = loss_hinge_gen(paddle.concat(x=generated_scores, axis=0)) | ||||||||||
generator_loss = generator_disc_loss + 20 * grid_cell_reg | ||||||||||
|
||||||||||
return discriminator_loss, generator_loss, grid_cell_reg | ||||||||||
|
||||||||||
|
||||||||||
def loss_hinge_disc(score_generated, score_real): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||
"""Discriminator hinge loss.""" | ||||||||||
l1 = paddle.nn.functional.relu(x=1.0 - score_real) | ||||||||||
loss = paddle.mean(x=l1) | ||||||||||
l2 = paddle.nn.functional.relu(x=1.0 + score_generated) | ||||||||||
loss += paddle.mean(x=l2) | ||||||||||
return loss | ||||||||||
|
||||||||||
|
||||||||||
def loss_hinge_gen(score_generated): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||
"""Generator hinge loss.""" | ||||||||||
loss = -paddle.mean(x=score_generated) | ||||||||||
return loss | ||||||||||
|
||||||||||
|
||||||||||
def grid_cell_regularizer(generated_samples, batch_targets): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||
"""Grid cell regularizer. | ||||||||||
|
||||||||||
Args: | ||||||||||
generated_samples: Tensor of size [n_samples, batch_size, 18, 256, 256, 1]. | ||||||||||
batch_targets: Tensor of size [batch_size, 18, 256, 256, 1]. | ||||||||||
|
||||||||||
Returns: | ||||||||||
loss: A tensor of shape [batch_size]. | ||||||||||
""" | ||||||||||
gen_mean = paddle.mean(x=generated_samples, axis=0) | ||||||||||
weights = paddle.clip(x=batch_targets, min=0.0, max=24.0) | ||||||||||
loss = paddle.mean(x=paddle.abs(x=gen_mean - batch_targets) * weights) | ||||||||||
return loss | ||||||||||
|
||||||||||
|
||||||||||
def train(cfg: DictConfig): | ||||||||||
print("Not supported.") | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
|
||||||||||
def evaluate(cfg: DictConfig): | ||||||||||
# set model | ||||||||||
model = ppsci.arch.DGMR(**cfg.MODEL) | ||||||||||
# load evaluate data | ||||||||||
dataset = ppsci.data.dataset.DGMRDataset(**cfg.DATASET) | ||||||||||
val_loader = paddle.io.DataLoader(dataset, batch_size=cfg.DATALOADER.batch_size) | ||||||||||
# initialize solver | ||||||||||
solver = ppsci.solver.Solver( | ||||||||||
model, | ||||||||||
pretrained_model_path=cfg.EVAL.pretrained_model_path, | ||||||||||
) | ||||||||||
solver.model.eval() | ||||||||||
|
||||||||||
# evaluate pretrained model | ||||||||||
d_loss = [] | ||||||||||
g_loss = [] | ||||||||||
grid_loss = [] | ||||||||||
for batch_idx, batch in enumerate(val_loader): | ||||||||||
with paddle.no_grad(): | ||||||||||
out_dict = validation(solver, batch, batch_idx) | ||||||||||
|
||||||||||
# visualize | ||||||||||
images = batch[0]["input_frames"] | ||||||||||
future_images = batch[1]["target_frames"] | ||||||||||
generated_images = solver.predict(batch[0])["future_images"] | ||||||||||
visualize(cfg, images, future_images, generated_images, batch_idx) | ||||||||||
|
||||||||||
d_loss.append(out_dict[0]) | ||||||||||
g_loss.append(out_dict[1]) | ||||||||||
grid_loss.append(out_dict[2]) | ||||||||||
logger.message(f"d_loss: {np.array(d_loss).mean()}") | ||||||||||
logger.message(f"g_loss: {np.array(g_loss).mean()}") | ||||||||||
logger.message(f"grid_loss: {np.array(grid_loss).mean()}") | ||||||||||
|
||||||||||
|
||||||||||
@hydra.main(version_base=None, config_path="./conf", config_name="dgmr.yaml") | ||||||||||
def main(cfg: DictConfig): | ||||||||||
if cfg.mode == "train": | ||||||||||
train(cfg) | ||||||||||
elif cfg.mode == "eval": | ||||||||||
evaluate(cfg) | ||||||||||
else: | ||||||||||
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") | ||||||||||
|
||||||||||
|
||||||||||
if __name__ == "__main__": | ||||||||||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考其他配置文件,加上callback字段