Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dgmr #813

Merged
merged 34 commits into from
Mar 29, 2024
Merged

Dgmr #813

Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6ee9b31
add bubble datafile test=develop
liaoxin2 Aug 8, 2023
5c637f5
add bubble code test=develop
liaoxin2 Aug 8, 2023
bb6f197
add bubble code test=develop
liaoxin2 Aug 8, 2023
b6f4f09
add bubble code test=develop
liaoxin2 Aug 8, 2023
6f0ae70
add bubble code
liaoxin2 Aug 9, 2023
7a09656
add bubble data
liaoxin2 Aug 9, 2023
f1da4e3
add bubble code
liaoxin2 Aug 9, 2023
c51f544
add some code for chip
liaoxin2 Mar 18, 2024
752719d
add some code for DGMR
liaoxin2 Mar 18, 2024
7660b62
add some code for DGMR
liaoxin2 Mar 20, 2024
ff619e4
add some code for DGMR
liaoxin2 Mar 20, 2024
2e05dae
add some code for DGMR
liaoxin2 Mar 20, 2024
63db897
add some dgmr code
liaoxin2 Mar 21, 2024
0760db2
add some dgmr code
liaoxin2 Mar 21, 2024
cb202a9
add some dgmr code
liaoxin2 Mar 21, 2024
1d06768
add some dgmr code
liaoxin2 Mar 25, 2024
a772306
add some dgmr code
liaoxin2 Mar 25, 2024
7802eca
add some dgmr code
liaoxin2 Mar 26, 2024
a5c524b
add some dgmr code
liaoxin2 Mar 26, 2024
ac227a7
add some dgmr code
liaoxin2 Mar 26, 2024
d81a68f
add some dgmr code
liaoxin2 Mar 26, 2024
cf7645b
add some dgmr code
liaoxin2 Mar 27, 2024
5f90e57
add some dgmr code
liaoxin2 Mar 27, 2024
239b37a
add some dgmr code
liaoxin2 Mar 27, 2024
4940b9f
add some dgmr code
liaoxin2 Mar 27, 2024
fba8ca0
add some dgmr code
liaoxin2 Mar 27, 2024
cba2e8a
add some dgmr code
liaoxin2 Mar 27, 2024
f451fa3
add some dgmr code
liaoxin2 Mar 27, 2024
01d2184
add some dgmr code
liaoxin2 Mar 27, 2024
2f01f33
add some dgmr code
liaoxin2 Mar 28, 2024
ee21c3d
add some dgmr code
liaoxin2 Mar 28, 2024
cf6c8be
add some dgmr code
liaoxin2 Mar 29, 2024
4a67425
add some dgmr code
liaoxin2 Mar 29, 2024
086e670
add some dgmr code
liaoxin2 Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/zh/api/arch.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
- USCNN
- NowcastNet
- HEDeepONets
- DGMR
- ChipDeepONets
- AutoEncoder
show_root_heading: true
Expand Down
1 change: 1 addition & 0 deletions docs/zh/api/data/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
- MeshCylinderDataset
- RadarDataset
- build_dataset
- DGMRDataset
show_root_heading: true
65 changes: 65 additions & 0 deletions examples/dgmr/conf/dgmr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
hydra:
run:
# dynamic output directory according to running time and override name
dir: outputs_dgmr/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode} # name of logfile
chdir: false # keep current working direcotry unchaned
config:
override_dirname:
exclude_keys:
- TRAIN.checkpoint_path
- TRAIN.pretrained_model_path
- EVAL.pretrained_model_path
- mode
- output_dir
- log_freq
sweep:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考其他配置文件,加上callback字段

# output directory for multirun
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: eval # running mode: train/eval
seed: 42
output_dir: ${hydra:run.dir}

# dataset settings
DATASET:
input_keys: 'input_frames'
label_keys: 'target_frames'
split: validation # train or validation
NUM_INPUT_FRAMES: 4
NUM_TARGET_FRAMES: 18
dataset_path: /workspace/workspace/skillful_nowcasting/openclimatefix/nimrod-uk-1km
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用相对路径

number: 1

# dataLoader settings
DATALOADER:
batch_size: 1
shuffle: False
num_workers: 1
drop_last: True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataloader配置放py文件,不需要写在这里


# model settings
MODEL:
input_keys: ['input_frames']
output_keys: ['future_images']
Comment on lines +41 to +42
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉中括号

forecast_steps: 18
input_channels: 1
output_shape: 256
gen_lr: 5e-05
disc_lr: 0.0002
visualize: False
conv_type: 'standard'
num_samples: 6
grid_lambda: 20.0
beta1: 0.0
beta2: 0.999
latent_channels: 768
context_channels: 384
generation_steps: 6

# evaluation settings
EVAL:
pretrained_model_path: openclimatefix/paddle/paddle_model.pdparams
208 changes: 208 additions & 0 deletions examples/dgmr/dgmr.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring格式不规范,重新用vscode插件生成下

Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Reference: https://github.com/openclimatefix/skillful_nowcasting
"""
from os import path as osp

import hydra
import matplotlib.pyplot as plt
import numpy as np
import paddle
from omegaconf import DictConfig

import ppsci
from ppsci.utils import logger


def visualize(
cfg: DictConfig,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

传一个output_dir就行了, 没必要把cf传进来

x: paddle.Tensor,
y: paddle.Tensor,
y_hat: paddle.Tensor,
batch_idx: int,
) -> None:
images = x[0]
future_images = y[0]
generated_images = y_hat[0]
fig, axes = plt.subplots(2, 2)
for i, ax in enumerate(axes.flat):
alpha = images[i][0].numpy()
alpha[alpha < 1] = 0
alpha[alpha > 1] = 1
ax.imshow(images[i].transpose([1, 2, 0]).numpy(), alpha=alpha, cmap="viridis")
ax.axis("off")
plt.subplots_adjust(hspace=0.1, wspace=0.1)
plt.savefig(osp.join(cfg.output_dir, "Input_Image_Stack_Frame.png"))
fig, axes = plt.subplots(3, 3)
for i, ax in enumerate(axes.flat):
alpha = future_images[i][0].numpy()
alpha[alpha < 1] = 0
alpha[alpha > 1] = 1
ax.imshow(
future_images[i].transpose([1, 2, 0]).numpy(), alpha=alpha, cmap="viridis"
)
plt.subplots_adjust(hspace=0.1, wspace=0.1)
plt.savefig(osp.join(cfg.output_dir, "Target_Image_Frame.png"))
fig, axes = plt.subplots(3, 3)
for i, ax in enumerate(axes.flat):
alpha = generated_images[i][0].numpy()
alpha[alpha < 1] = 0
alpha[alpha > 1] = 1
ax.imshow(
generated_images[i].transpose([1, 2, 0]).numpy(),
alpha=alpha,
cmap="viridis",
)
ax.axis("off")
plt.subplots_adjust(hspace=0.1, wspace=0.1)
plt.savefig(osp.join(cfg.output_dir, "Generated_Image_Frame.png"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger打印下保存路径



def validation(solver, batch, batch_idx):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加上typehint和docstring

images, future_images = batch
images_value = list(images.values())[0]
future_images_value = list(future_images.values())[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么不用字符串进行索引,而是默认取dict的第一个?

# Two discriminator steps per generator step
for _ in range(2):
predictions = solver.predict(images)
predictions_value = list(predictions.values())[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,

generated_sequence = paddle.concat(x=[images_value, predictions_value], axis=1)
real_sequence = paddle.concat(x=[images_value, future_images_value], axis=1)
concatenated_inputs = paddle.concat(
x=[real_sequence, generated_sequence], axis=0
)
concatenated_outputs = solver.model.discriminator(concatenated_inputs)
score_real, score_generated = paddle.split(
x=concatenated_outputs,
num_or_sections=[real_sequence.shape[0], generated_sequence.shape[0]],
axis=0,
)
score_real_spatial, score_real_temporal = paddle.split(
x=score_real, num_or_sections=score_real.shape[1], axis=1
)
score_generated_spatial, score_generated_temporal = paddle.split(
x=score_generated, num_or_sections=score_generated.shape[1], axis=1
)
discriminator_loss = loss_hinge_disc(
score_generated_spatial, score_real_spatial
) + loss_hinge_disc(score_generated_temporal, score_real_temporal)

predictions_value = [list(solver.predict(images).values())[0] for _ in range(6)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

grid_cell_reg = grid_cell_regularizer(
paddle.stack(x=predictions_value, axis=0), future_images_value
)
generated_sequence = [
paddle.concat(x=[images_value, x], axis=1) for x in predictions_value
]
real_sequence = paddle.concat(x=[images_value, future_images_value], axis=1)
generated_scores = []
for g_seq in generated_sequence:
concatenated_inputs = paddle.concat(x=[real_sequence, g_seq], axis=0)
concatenated_outputs = solver.model.discriminator(concatenated_inputs)
score_real, score_generated = paddle.split(
x=concatenated_outputs,
num_or_sections=[real_sequence.shape[0], g_seq.shape[0]],
axis=0,
)
generated_scores.append(score_generated)
generator_disc_loss = loss_hinge_gen(paddle.concat(x=generated_scores, axis=0))
generator_loss = generator_disc_loss + 20 * grid_cell_reg

return discriminator_loss, generator_loss, grid_cell_reg


def loss_hinge_disc(score_generated, score_real):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss_hinge_disc==>_loss_hinge_disc

"""Discriminator hinge loss."""
l1 = paddle.nn.functional.relu(x=1.0 - score_real)
loss = paddle.mean(x=l1)
l2 = paddle.nn.functional.relu(x=1.0 + score_generated)
loss += paddle.mean(x=l2)
return loss


def loss_hinge_gen(score_generated):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss_hinge_gen ==> _loss_hinge_gen

"""Generator hinge loss."""
loss = -paddle.mean(x=score_generated)
return loss


def grid_cell_regularizer(generated_samples, batch_targets):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grid_cell_regularizer==>_grid_cell_regularizer

"""Grid cell regularizer.

Args:
generated_samples: Tensor of size [n_samples, batch_size, 18, 256, 256, 1].
batch_targets: Tensor of size [batch_size, 18, 256, 256, 1].

Returns:
loss: A tensor of shape [batch_size].
"""
gen_mean = paddle.mean(x=generated_samples, axis=0)
weights = paddle.clip(x=batch_targets, min=0.0, max=24.0)
loss = paddle.mean(x=paddle.abs(x=gen_mean - batch_targets) * weights)
return loss


def train(cfg: DictConfig):
print("Not supported.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("Not supported.")
raise NotImplementedError(
"Training of DGMR is not supported now."
)



def evaluate(cfg: DictConfig):
# set model
model = ppsci.arch.DGMR(**cfg.MODEL)
# load evaluate data
dataset = ppsci.data.dataset.DGMRDataset(**cfg.DATASET)
val_loader = paddle.io.DataLoader(dataset, batch_size=cfg.DATALOADER.batch_size)
# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.EVAL.pretrained_model_path,
)
solver.model.eval()

# evaluate pretrained model
d_loss = []
g_loss = []
grid_loss = []
for batch_idx, batch in enumerate(val_loader):
with paddle.no_grad():
out_dict = validation(solver, batch, batch_idx)

# visualize
images = batch[0]["input_frames"]
future_images = batch[1]["target_frames"]
generated_images = solver.predict(batch[0])["future_images"]
visualize(cfg, images, future_images, generated_images, batch_idx)

d_loss.append(out_dict[0])
g_loss.append(out_dict[1])
grid_loss.append(out_dict[2])
logger.message(f"d_loss: {np.array(d_loss).mean()}")
logger.message(f"g_loss: {np.array(g_loss).mean()}")
logger.message(f"grid_loss: {np.array(grid_loss).mean()}")


@hydra.main(version_base=None, config_path="./conf", config_name="dgmr.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions ppsci/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ppsci.arch.he_deeponets import HEDeepONets # isort:skip
from ppsci.arch.chip_deeponets import ChipDeepONets # isort:skip
from ppsci.arch.cfdgcn import CFDGCN # isort:skip
from ppsci.arch.dgmr import DGMR # isort:skip
from ppsci.arch.vae import AutoEncoder # isort:skip
from ppsci.utils import logger # isort:skip

Expand Down Expand Up @@ -67,6 +68,7 @@
"USCNN",
"HEDeepONets",
"ChipDeepONets",
"DGMR",
"AutoEncoder",
"build_model",
"CFDGCN",
Expand Down
Loading