Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计

| 问题类型 | 案例名称 | 优化算法 | 模型类型 | 训练方式 | 数据集 | 参考资料 |
|-----|---------|-----|---------|----|---------|---------|
| 天气预报 | [Extformer-MoE 气象预报](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/extformer_moe.md) | 数据驱动 | FourCastNet | 监督学习 | [enso](https://tianchi.aliyun.com/dataset/98942) | - |
| 天气预报 | [FourCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/fourcastnet) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
| 天气预报 | [NowCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/nowcastnet) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) |
| 天气预报 | [GraphCast 气象预报](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/graphcast) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) |
Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@

| 问题类型 | 案例名称 | 优化算法 | 模型类型 | 训练方式 | 数据集 | 参考资料 |
|-----|---------|-----|---------|----|---------|---------|
| 天气预报 | [Extformer-MoE 气象预报](./zh/examples/extformer_moe.md) | 数据驱动 | FourCastNet | 监督学习 | [enso](https://tianchi.aliyun.com/dataset/98942) | - |
| 天气预报 | [FourCastNet 气象预报](./zh/examples/fourcastnet.md) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
| 天气预报 | [NowCastNet 气象预报](./zh/examples/nowcastnet.md) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) |
| 天气预报 | [GraphCast 气象预报](./zh/examples/graphcast.md) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) |
Expand Down
9 changes: 6 additions & 3 deletions docs/zh/examples/extformer_moe.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# Extformer-MoE

开始训练、评估前,请先下载,并对应修改 yaml 配置文件中的 FILE_PATH
!!! note

[ICAR-ENSO数据集](https://tianchi.aliyun.com/dataset/98942)
1. 开始训练、评估前,请先下载 [ICAR-ENSO数据集](https://tianchi.aliyun.com/dataset/98942),并对应修改 yaml 配置文件中的 `FILE_PATH` 为解压后的数据集路径。
2. 开始训练、评估前,请安装 `xarray` 和 `h5netcdf`:`pip install requirements.txt`
3. 若训练时显存不足,可指定 `MODEL.checkpoint_level` 为 `1` 或 `2`,此时使用 recompute 模式运行,以训练时间换取显存。

=== "模型训练命令"

``` sh
# ICAR-ENSO 数据预训练模型: Extformer-MoE
python extformer_moe_enso_train.py
# python extformer_moe_enso_train.py MODEL.checkpoint_level=1 # using recompute to run in device with small GPU memory
# python extformer_moe_enso_train.py MODEL.checkpoint_level=2 # using recompute to run in device with small GPU memory
```

=== "模型评估命令"
Expand Down Expand Up @@ -46,7 +50,6 @@ Earthformer,一种用于地球系统预测的时空转换器。为了更好地

Rank-N-Contrast(RNC)是一种表征学习方法,旨在学习一种回归感知的样本表征,该表征以连续标签空间中的距离为依据,对嵌入空间中的样本间距离进行排序,然后利用它来预测最终连续的标签。在地球系统极端预测问题中,RNC 可以对气象数据的表征进行规范,使其满足嵌入空间的连续性,和标签空间对齐,最终缓解极端事件的预测结果的过平滑问题。


## 2. 模型原理

### 2.1 Earthformer
Expand Down
2 changes: 1 addition & 1 deletion examples/extformer_moe/extformer_moe_enso_train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import enso_metric
import hydra
import paddle
from omegaconf import DictConfig
from omegaconf import OmegaConf
from paddle import nn

import examples.extformer_moe.enso_metric as enso_metric
import ppsci


Expand Down
2 changes: 2 additions & 0 deletions examples/extformer_moe/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
h5netcdf
xarray==2024.2.0
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ nav:
- 材料科学(AI for Material):
- hPINNs: zh/examples/hpinns.md
- 地球科学(AI for Earth Science):
- Extformer-MoE: zh/examples/extformer_moe.md
- FourCastNet: zh/examples/fourcastnet.md
- NowcastNet: zh/examples/nowcastnet.md
- DGMR: zh/examples/dgmr.md
Expand Down
2 changes: 1 addition & 1 deletion ppsci/arch/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, beta: float = 1.0):
super().__init__()
self.beta = self.create_parameter(
shape=[],
default_initializer=paddle.nn.initializer.Constant(beta),
default_initializer=nn.initializer.Constant(beta),
)

def forward(self, x):
Expand Down
28 changes: 14 additions & 14 deletions ppsci/arch/amgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,21 +238,21 @@ def faster_graph_connectivity(perm, edge_index, edge_weight, score, pos, N, norm
value_A = edge_weight.clone()

value_A = paddle.squeeze(value_A)
model_1 = paddle.nn.Sequential(
("l1", paddle.nn.Linear(128, 256)),
("act1", paddle.nn.ReLU()),
("l2", paddle.nn.Linear(256, 256)),
("act2", paddle.nn.ReLU()),
("l4", paddle.nn.Linear(256, 128)),
("act4", paddle.nn.ReLU()),
("l5", paddle.nn.Linear(128, 1)),
model_1 = nn.Sequential(
("l1", nn.Linear(128, 256)),
("act1", nn.ReLU()),
("l2", nn.Linear(256, 256)),
("act2", nn.ReLU()),
("l4", nn.Linear(256, 128)),
("act4", nn.ReLU()),
("l5", nn.Linear(128, 1)),
)
model_2 = paddle.nn.Sequential(
("l1", paddle.nn.Linear(1, 64)),
("act1", paddle.nn.ReLU()),
("l2", paddle.nn.Linear(64, 128)),
("act2", paddle.nn.ReLU()),
("l4", paddle.nn.Linear(128, 128)),
model_2 = nn.Sequential(
("l1", nn.Linear(1, 64)),
("act1", nn.ReLU()),
("l2", nn.Linear(64, 128)),
("act2", nn.ReLU()),
("l4", nn.Linear(128, 128)),
)

val_A = model_1(value_A)
Expand Down
56 changes: 26 additions & 30 deletions ppsci/arch/cuboid_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""A space-time Transformer with Cuboid Attention"""


class InitialEncoder(paddle.nn.Layer):
class InitialEncoder(nn.Layer):
def __init__(
self,
dim,
Expand All @@ -38,39 +38,35 @@ def __init__(
for i in range(num_conv_layers):
if i == 0:
conv_block.append(
paddle.nn.Conv2D(
nn.Conv2D(
kernel_size=(3, 3),
padding=(1, 1),
in_channels=dim,
out_channels=out_dim,
)
)
conv_block.append(
paddle.nn.GroupNorm(num_groups=16, num_channels=out_dim)
)
conv_block.append(nn.GroupNorm(num_groups=16, num_channels=out_dim))
conv_block.append(
act_mod.get_activation(activation)
if activation != "leaky_relu"
else nn.LeakyReLU(NEGATIVE_SLOPE)
)
else:
conv_block.append(
paddle.nn.Conv2D(
nn.Conv2D(
kernel_size=(3, 3),
padding=(1, 1),
in_channels=out_dim,
out_channels=out_dim,
)
)
conv_block.append(
paddle.nn.GroupNorm(num_groups=16, num_channels=out_dim)
)
conv_block.append(nn.GroupNorm(num_groups=16, num_channels=out_dim))
conv_block.append(
act_mod.get_activation(activation)
if activation != "leaky_relu"
else nn.LeakyReLU(NEGATIVE_SLOPE)
)
self.conv_block = paddle.nn.Sequential(*conv_block)
self.conv_block = nn.Sequential(*conv_block)
if isinstance(downsample_scale, int):
patch_merge_downsample = (1, downsample_scale, downsample_scale)
elif len(downsample_scale) == 2:
Expand Down Expand Up @@ -121,7 +117,7 @@ def forward(self, x):
return x


class FinalDecoder(paddle.nn.Layer):
class FinalDecoder(nn.Layer):
def __init__(
self,
target_thw: Tuple[int, ...],
Expand All @@ -142,20 +138,20 @@ def __init__(
conv_block = []
for i in range(num_conv_layers):
conv_block.append(
paddle.nn.Conv2D(
nn.Conv2D(
kernel_size=(3, 3),
padding=(1, 1),
in_channels=dim,
out_channels=dim,
)
)
conv_block.append(paddle.nn.GroupNorm(num_groups=16, num_channels=dim))
conv_block.append(nn.GroupNorm(num_groups=16, num_channels=dim))
conv_block.append(
act_mod.get_activation(activation)
if activation != "leaky_relu"
else nn.LeakyReLU(NEGATIVE_SLOPE)
)
self.conv_block = paddle.nn.Sequential(*conv_block)
self.conv_block = nn.Sequential(*conv_block)
self.upsample = cuboid_decoder.Upsample3DLayer(
dim=dim,
out_dim=dim,
Expand Down Expand Up @@ -196,7 +192,7 @@ def forward(self, x):
return x


class InitialStackPatchMergingEncoder(paddle.nn.Layer):
class InitialStackPatchMergingEncoder(nn.Layer):
def __init__(
self,
num_merge: int,
Expand All @@ -220,8 +216,8 @@ def __init__(
self.downsample_scale_list = downsample_scale_list[:num_merge]
self.num_conv_per_merge_list = num_conv_per_merge_list
self.num_group_list = [max(1, out_dim // 4) for out_dim in self.out_dim_list]
self.conv_block_list = paddle.nn.LayerList()
self.patch_merge_list = paddle.nn.LayerList()
self.conv_block_list = nn.LayerList()
self.patch_merge_list = nn.LayerList()
for i in range(num_merge):
if i == 0:
in_dim = in_dim
Expand All @@ -236,15 +232,15 @@ def __init__(
else:
conv_in_dim = out_dim
conv_block.append(
paddle.nn.Conv2D(
nn.Conv2D(
kernel_size=(3, 3),
padding=(1, 1),
in_channels=conv_in_dim,
out_channels=out_dim,
)
)
conv_block.append(
paddle.nn.GroupNorm(
nn.GroupNorm(
num_groups=self.num_group_list[i], num_channels=out_dim
)
)
Expand All @@ -253,7 +249,7 @@ def __init__(
if activation != "leaky_relu"
else nn.LeakyReLU(NEGATIVE_SLOPE)
)
conv_block = paddle.nn.Sequential(*conv_block)
conv_block = nn.Sequential(*conv_block)
self.conv_block_list.append(conv_block)
patch_merge = cuboid_encoder.PatchMerging3D(
dim=out_dim,
Expand Down Expand Up @@ -303,7 +299,7 @@ def forward(self, x):
return x


class FinalStackUpsamplingDecoder(paddle.nn.Layer):
class FinalStackUpsamplingDecoder(nn.Layer):
def __init__(
self,
target_shape_list: Tuple[Tuple[int, ...]],
Expand All @@ -326,8 +322,8 @@ def __init__(
self.in_dim = in_dim
self.num_conv_per_up_list = num_conv_per_up_list
self.num_group_list = [max(1, out_dim // 4) for out_dim in self.out_dim_list]
self.conv_block_list = paddle.nn.LayerList()
self.upsample_list = paddle.nn.LayerList()
self.conv_block_list = nn.LayerList()
self.upsample_list = nn.LayerList()
for i in range(self.num_upsample):
if i == 0:
in_dim = in_dim
Expand All @@ -349,15 +345,15 @@ def __init__(
else:
conv_in_dim = out_dim
conv_block.append(
paddle.nn.Conv2D(
nn.Conv2D(
kernel_size=(3, 3),
padding=(1, 1),
in_channels=conv_in_dim,
out_channels=out_dim,
)
)
conv_block.append(
paddle.nn.GroupNorm(
nn.GroupNorm(
num_groups=self.num_group_list[i], num_channels=out_dim
)
)
Expand All @@ -366,7 +362,7 @@ def __init__(
if activation != "leaky_relu"
else nn.LeakyReLU(NEGATIVE_SLOPE)
)
conv_block = paddle.nn.Sequential(*conv_block)
conv_block = nn.Sequential(*conv_block)
self.conv_block_list.append(conv_block)
self.reset_parameters()

Expand Down Expand Up @@ -686,7 +682,7 @@ def __init__(
embed_dim=base_units, typ=pos_embed_type, maxH=H_in, maxW=W_in, maxT=T_in
)
mem_shapes = self.encoder.get_mem_shapes()
self.z_proj = paddle.nn.Linear(
self.z_proj = nn.Linear(
in_features=mem_shapes[-1][-1], out_features=mem_shapes[-1][-1]
)
self.dec_pos_embed = cuboid_decoder.PosEmbed(
Expand Down Expand Up @@ -799,7 +795,7 @@ def get_initial_encoder_final_decoder(
new_input_shape = self.initial_encoder.patch_merge.get_out_shape(
self.input_shape
)
self.dec_final_proj = paddle.nn.Linear(
self.dec_final_proj = nn.Linear(
in_features=self.base_units, out_features=C_out
)
elif self.initial_downsample_type == "stack_conv":
Expand Down Expand Up @@ -839,7 +835,7 @@ def get_initial_encoder_final_decoder(
linear_init_mode=self.down_up_linear_init_mode,
norm_init_mode=self.norm_init_mode,
)
self.dec_final_proj = paddle.nn.Linear(
self.dec_final_proj = nn.Linear(
in_features=dec_target_shape_list[-1][-1], out_features=C_out
)
new_input_shape = self.initial_encoder.get_out_shape_list(self.input_shape)[
Expand Down Expand Up @@ -892,7 +888,7 @@ def get_initial_z(self, final_mem, T_out):
shape=[B, -1, -1, -1, -1]
)
elif self.z_init_method == "nearest_interp":
initial_z = paddle.nn.functional.interpolate(
initial_z = nn.functional.interpolate(
x=final_mem.transpose(perm=[0, 4, 1, 2, 3]),
size=(T_out, final_mem.shape[2], final_mem.shape[3]),
).transpose(perm=[0, 2, 3, 4, 1])
Expand Down
Loading