-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
1,132 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,5 +20,6 @@ | |
- VtuDataset | ||
- MeshAirfoilDataset | ||
- MeshCylinderDataset | ||
- RadarDataset | ||
- build_dataset | ||
show_root_heading: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# NowcastNet | ||
|
||
=== "模型训练命令" | ||
|
||
暂无 | ||
|
||
=== "模型评估命令" | ||
|
||
``` sh | ||
# linux | ||
wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/nowcastnet/nowcastnet.zip | ||
# windows | ||
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/nowcastnet/nowcastnet.zip --output nowcastnet.zip | ||
unzip nowcastnet.zip -d datasets/ | ||
python nowcastnet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/nowcastnet/nowcastnet_pretrained.pdparams | ||
``` | ||
|
||
## 1. 背景简介 | ||
|
||
近年来,深度学习方法已被应用于天气预报,尤其是雷达观测的降水预报。这些方法利用大量雷达复合观测数据来训练神经网络模型,以端到端的方式进行训练,无需明确参考降水过程的物理定律。 | ||
这里复现了一个针对极端降水的非线性短临预报模型——NowcastNet,该模型将物理演变方案和条件学习法统一到一个神经网络框架中,实现了端到端的优化。 | ||
|
||
## 2. 模型原理 | ||
|
||
本章节仅对 NowcastNet 的模型原理进行简单地介绍,详细的理论推导请阅读 [Skilful nowcasting of extreme precipitation with NowcastNet](https://www.nature.com/articles/s41586-023-06184-4#Abs1)。 | ||
|
||
模型的总体结构如图所示: | ||
|
||
<figure markdown> | ||
![nowcastnet-arch](nowcastnet/nowcastnet.png){ loading=lazy style="margin:0 auto"} | ||
<figcaption>NowcastNet 网络模型</figcaption> | ||
</figure> | ||
|
||
模型使用预训练权重推理,接下来将介绍模型的推理过程。 | ||
|
||
## 3. 模型构建 | ||
|
||
在该案例中,用 PaddleScience 代码表示如下: | ||
|
||
``` py linenums="24" title="examples/nowcastnet/nowcastnet.py" | ||
--8<-- | ||
examples/nowcastnet/nowcastnet.py:24:36 | ||
--8<-- | ||
``` | ||
|
||
``` yaml linenums="35" title="examples/nowcastnet/conf/nowcastnet.yaml" | ||
--8<-- | ||
examples/nowcastnet/conf/nowcastnet.yaml:35:53 | ||
--8<-- | ||
``` | ||
|
||
其中,`input_keys` 和 `output_keys` 分别代表网络模型输入、输出变量的名称。 | ||
|
||
## 4 模型评估可视化 | ||
|
||
完成上述设置之后,将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`: | ||
|
||
``` py linenums="57" title="examples/nowcastnet/nowcastnet.py" | ||
--8<-- | ||
examples/nowcastnet/nowcastnet.py:57:61 | ||
--8<-- | ||
``` | ||
|
||
然后构建 VisualizerRadar 生成图片结果: | ||
|
||
``` py linenums="69" title="examples/nowcastnet/nowcastnet.py" | ||
--8<-- | ||
examples/nowcastnet/nowcastnet.py:69:82 | ||
--8<-- | ||
|
||
## 5. 完整代码 | ||
|
||
``` py linenums="1" title="examples/nowcastnet/nowcastnet.py" | ||
--8<-- | ||
examples/nowcastnet/nowcastnet.py | ||
--8<-- | ||
``` | ||
|
||
## 6. 结果展示 | ||
|
||
下图展示了模型的预测结果和真值结果。 | ||
|
||
<figure markdown> | ||
![result](nowcastnet/pd.gif){ loading=lazy style="margin:0 auto;"} | ||
<figcaption>模型预测结果</figcaption> | ||
</figure> | ||
|
||
<figure markdown> | ||
![result](nowcastnet/gt.gif){ loading=lazy style="margin:0 auto;"} | ||
<figcaption>模型真值结果</figcaption> | ||
</figure> |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
hydra: | ||
run: | ||
# dynamic output directory according to running time and override name | ||
dir: outputs_nowcastnet/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} | ||
job: | ||
name: ${mode} # name of logfile | ||
chdir: false # keep current working direcotry unchaned | ||
config: | ||
override_dirname: | ||
exclude_keys: | ||
- TRAIN.checkpoint_path | ||
- TRAIN.pretrained_model_path | ||
- EVAL.pretrained_model_path | ||
- mode | ||
- output_dir | ||
- log_freq | ||
sweep: | ||
# output directory for multirun | ||
dir: ${hydra.run.dir} | ||
subdir: ./ | ||
|
||
# general settings | ||
mode: eval # running mode: train/eval | ||
seed: 42 | ||
output_dir: ${hydra:run.dir} | ||
NORMAL_DATASET_PATH: datasets/mrms/figure | ||
LARGE_DATASET_PATH: datasets/mrms/large_figure | ||
|
||
# set working condition | ||
CASE_TYPE: normal # normal/large | ||
NUM_SAVE_SAMPLES: 10 | ||
CPU_WORKER: 0 | ||
|
||
# model settings | ||
MODEL: | ||
normal: | ||
input_keys: ["input"] | ||
output_keys: ["output"] | ||
input_length: 9 | ||
total_length: 29 | ||
image_width: 512 | ||
image_height: 512 | ||
image_ch: 2 | ||
ngf: 32 | ||
large: | ||
input_keys: ["input"] | ||
output_keys: ["output"] | ||
input_length: 9 | ||
total_length: 29 | ||
image_width: 1024 | ||
image_height: 1024 | ||
image_ch: 2 | ||
ngf: 32 | ||
|
||
# evaluation settings | ||
EVAL: | ||
pretrained_model_path: checkpoints/paddle_mrms_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Reference: https://codeocean.com/capsule/3935105/tree/v1 | ||
""" | ||
from os import path as osp | ||
|
||
import hydra | ||
import paddle | ||
from omegaconf import DictConfig | ||
|
||
import ppsci | ||
from ppsci.utils import logger | ||
|
||
|
||
def train(cfg: DictConfig): | ||
print("Not supported.") | ||
|
||
|
||
def evaluate(cfg: DictConfig): | ||
# set random seed for reproducibility | ||
ppsci.utils.misc.set_random_seed(cfg.seed) | ||
# initialize logger | ||
logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info") | ||
|
||
if cfg.CASE_TYPE == "large": | ||
dataset_path = cfg.LARGE_DATASET_PATH | ||
model_cfg = cfg.MODEL.large | ||
output_dir = osp.join(cfg.output_dir, "large") | ||
elif cfg.CASE_TYPE == "normal": | ||
dataset_path = cfg.NORMAL_DATASET_PATH | ||
model_cfg = cfg.MODEL.normal | ||
output_dir = osp.join(cfg.output_dir, "normal") | ||
else: | ||
raise ValueError( | ||
f"cfg.CASE_TYPE should in ['normal', 'large'], but got '{cfg.mode}'" | ||
) | ||
model = ppsci.arch.NowcastNet(**model_cfg) | ||
|
||
input_keys = ("radar_frames",) | ||
dataset_param = { | ||
"input_keys": input_keys, | ||
"label_keys": (), | ||
"image_width": model_cfg.image_width, | ||
"image_height": model_cfg.image_height, | ||
"total_length": model_cfg.total_length, | ||
"dataset_path": dataset_path, | ||
"data_type": paddle.get_default_dtype(), | ||
} | ||
test_data_loader = paddle.io.DataLoader( | ||
ppsci.data.dataset.RadarDataset(**dataset_param), | ||
batch_size=1, | ||
shuffle=False, | ||
num_workers=cfg.CPU_WORKER, | ||
drop_last=True, | ||
) | ||
|
||
# initialize solver | ||
solver = ppsci.solver.Solver( | ||
model, | ||
output_dir=output_dir, | ||
pretrained_model_path=cfg.EVAL.pretrained_model_path, | ||
) | ||
|
||
for batch_id, test_ims in enumerate(test_data_loader): | ||
test_ims = test_ims[0][input_keys[0]].numpy() | ||
frames_tensor = paddle.to_tensor( | ||
data=test_ims, dtype=paddle.get_default_dtype() | ||
) | ||
if batch_id <= cfg.NUM_SAVE_SAMPLES: | ||
visualizer = { | ||
"v_nowcastnet": ppsci.visualize.VisualizerRadar( | ||
{"input": frames_tensor}, | ||
{ | ||
"output": lambda out: out["output"], | ||
}, | ||
prefix="v_nowcastnet", | ||
case_type=cfg.CASE_TYPE, | ||
total_length=model_cfg.total_length, | ||
) | ||
} | ||
solver.visualizer = visualizer | ||
# visualize prediction | ||
solver.visualize(batch_id) | ||
|
||
|
||
@hydra.main(version_base=None, config_path="./conf", config_name="nowcastnet.yaml") | ||
def main(cfg: DictConfig): | ||
if cfg.mode == "train": | ||
train(cfg) | ||
elif cfg.mode == "eval": | ||
evaluate(cfg) | ||
else: | ||
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.