Skip to content

Conversation

@Yang-Changhui
Copy link
Contributor

@Yang-Changhui Yang-Changhui commented Mar 21, 2024

PR types

New features

PR changes

Others

Describe

增加EarthFormer模型

@paddle-bot
Copy link

paddle-bot bot commented Mar 21, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Mar 21, 2024

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
1 out of 2 committers have signed the CLA.

✅ zhiminzhang0830
❌ Yang-Changhui


Yang-Changhui seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@paddle-bot
Copy link

paddle-bot bot commented Mar 21, 2024

❌ The PR is not created using PR's template. You can refer to this Demo.
Please use PR's template, it helps save our maintainers' time so that more developers get helped.

@HydrogenSulfate
Copy link
Collaborator

得先签署一下cla协议,
image

@Yang-Changhui
Copy link
Contributor Author

这边显示已经同意啦
image


USE_SAMPLED_DATA: false
# set train and evaluate data path
FILE_PATH: /home/aistudio/data/data260191/enso_round1_train_20210201
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以改成相对路径

Comment on lines 154 to 157




Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多余空行删除

max_batch_size: 16
num_cpu_threads: 4
batch_size: 1
data_path: /home/aistudio/data/data260191/enso_round1_train_20210201/SODA_train.nc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

相对路径

batch_size: 16

INFER:
pretrained_model_path: /home/aistudio/best_model.pdparams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

相对路径

from functools import lru_cache
from typing import Tuple
from collections import OrderedDict
from .cuboid_transformer_patterns import CuboidSelfAttentionPatterns, CuboidCrossAttentionPatterns
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用绝对路径导入,ppsci.xxx

Comment on lines 22 to 65
def get_activation(act, inplace=False, **kwargs):
"""

Parameters
----------
act
Name of the activation
inplace
Whether to perform inplace activation

Returns
-------
activation_layer
The activation
"""
if act is None:
return lambda x: x
if isinstance(act, str):
if act == 'leaky':
negative_slope = kwargs.get('negative_slope', 0.1)
return paddle.nn.LeakyReLU(negative_slope=negative_slope)
elif act == 'identity':
return paddle.nn.Identity()
elif act == 'elu':
return paddle.nn.ELU()
elif act == 'gelu':
return paddle.nn.GELU()
elif act == 'relu':
return paddle.nn.ReLU()
elif act == 'sigmoid':
return paddle.nn.Sigmoid()
elif act == 'tanh':
return paddle.nn.Tanh()
elif act == 'softrelu' or act == 'softplus':
return paddle.nn.Softplus()
elif act == 'softsign':
return paddle.nn.Softsign()
else:
raise NotImplementedError(
'act="{}" is not supported. Try to include it if you can find that in '
'https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/Overview_cn.html'
.format(act))
else:
return act
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activation.py里应该有类似函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数是有的,但是其中使用的LeakyReLU中的negative_slope参数做了修改,还有激活函数不全,这种情况怎么做修改呢



class RMSNorm(paddle.nn.Layer):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除空行

Comment on lines 91 to 106
out_4 = paddle.create_parameter(
shape=init_data.shape,
dtype=init_data.dtype,
default_initializer=nn.initializer.Assign(init_data))
out_4.stop_gradient = not True
self.scale = out_4
self.add_parameter(name='scale', parameter=self.scale)
if self.bias:
init_data = paddle.zeros(d)
out_5 = paddle.create_parameter(
shape=init_data.shape,
dtype=init_data.dtype,
default_initializer=nn.initializer.Assign(init_data))
out_5.stop_gradient = not True
self.offset = out_5
self.add_parameter(name='offset', parameter=self.offset)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

转换后的变量名改一下,不要使用意义不明的命名

Comment on lines 313 to 317





Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除多余空行

@HydrogenSulfate
Copy link
Collaborator

代码commit前请安装pre-commit

- HEDeepONets
- ChipDeepONets
- AutoEncoder
- CuboidTransformerModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CuboidTransformerModel建议改为CuboidTransformer

input_shape: [12, 24, 48, 1]
target_shape: [14, 24, 48, 1]
base_units: 64
# block_units: null
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_units这个参数是需要的吗,不需要的话可以删除?

Comment on lines 80 to 88
num_blocks = len(cfg.MODEL.afno["enc_depth"])
if isinstance(cfg.MODEL["self_pattern"], str):
enc_attn_patterns = [cfg.MODEL["self_pattern"]] * num_blocks

if isinstance(cfg.MODEL["cross_self_pattern"], str):
dec_self_attn_patterns = [cfg.MODEL["cross_self_pattern"]] * num_blocks

if isinstance(cfg.MODEL["cross_pattern"], str):
dec_cross_attn_patterns = [cfg.MODEL["cross_pattern"]] * num_blocks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段逻辑只受到num_blocks的控制,是否能给CuboidTransformer增加一个num_blocks的参数,然后这段配置生成逻辑放到CuboidTransformer.__init__里去呢?我觉得这样的做法更合理

可以参考hidden_size的写法:
image


# # init optimizer and lr scheduler
lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把这行移动到eta_min=xxx的上面也可以把,不需要转dict再update了?

Comment on lines 100 to 103
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
"weight_decay": cfg.TRAIN.wd,
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否有必要为非decay_parameters设置wd呢?因为Optimizer的实例化中有一个同样的设置:
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问,weight_decay这个参数的初始化中是0.01,cfg.TRAIN.wd为1e-5,Optimizer的实例化是按照那个初始化呢

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问,weight_decay这个参数的初始化中是0.01,cfg.TRAIN.wd为1e-5,Optimizer的实例化是按照那个初始化呢

看了下官网的,可以先这么写吧,问题不大。

Copy link
Contributor Author

@Yang-Changhui Yang-Changhui Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我尝试把这个参数去掉,重新训练了一下,对结果没什么影响,感觉可以去掉

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我尝试把这个参数去掉,重新训练了一下,对结果没什么影响,感觉可以去掉

确认了下外部的wd是全局指定,并且可以被params参数内指定的wd覆盖,
image

image

而learning_rate是缩放倍率的关系,即最终的lr是全局lr乘以params字典内指定的lr

所以删掉跟全局wd重复的字段应该是没问题。

上述行为跟torch应该是一致的。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不建议引入registry机制,会增加代码阅读难度和调试难度,尤其是可能会影响python报错栈

Comment on lines 25 to 69
def get_activation(act, inplace=False, **kwargs):
"""
Parameters
----------
act
Name of the activation
inplace
Whether to perform inplace activation

Returns
-------
activation_layer
The activation
"""
if act is None:
return lambda x: x
if isinstance(act, str):
if act == "leaky":
negative_slope = kwargs.get("negative_slope", 0.1)
return paddle.nn.LeakyReLU(negative_slope=negative_slope)
elif act == "identity":
return paddle.nn.Identity()
elif act == "elu":
return paddle.nn.ELU()
elif act == "gelu":
return paddle.nn.GELU()
elif act == "relu":
return paddle.nn.ReLU()
elif act == "sigmoid":
return paddle.nn.Sigmoid()
elif act == "tanh":
return paddle.nn.Tanh()
elif act == "softrelu" or act == "softplus":
return paddle.nn.Softplus()
elif act == "softsign":
return paddle.nn.Softsign()
else:
raise NotImplementedError(
'act="{}" is not supported. Try to include it if you can find that in '
"https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/Overview_cn.html".format(
act
)
)
else:
return act
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议完善下ppsci/arch/activation.py,尽量复用已有代码

whether use bias term for RMSNorm, disabled by
default because RMSNorm doesn't enforce re-centering invariance.
"""
super(RMSNorm, self).__init__()
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为super().__init__()

self.scale = paddle.create_parameter(
shape=init_data.shape,
dtype=init_data.dtype,
default_initializer=nn.initializer.Assign(init_data),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以直接使用nn.initializer.Constant(1.0)

Comment on lines 93 to 109
init_data = paddle.ones(d)
self.scale = paddle.create_parameter(
shape=init_data.shape,
dtype=init_data.dtype,
default_initializer=nn.initializer.Assign(init_data),
)
self.scale.stop_gradient = not True
self.add_parameter(name="scale", parameter=self.scale)
if self.bias:
init_data = paddle.zeros(d)
self.offset = paddle.create_parameter(
shape=init_data.shape,
dtype=init_data.dtype,
default_initializer=nn.initializer.Assign(init_data),
)
self.offset.stop_gradient = not True
self.add_parameter(name="offset", parameter=self.offset)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以直接使用nn.initializer.Constant(value)代替Assign初始化,减少代码量

import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.distributed.fleet.utils import recompute
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用fleet.utils.recompute的方式调用

Comment on lines 100 to 103
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
"weight_decay": cfg.TRAIN.wd,
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问,weight_decay这个参数的初始化中是0.01,cfg.TRAIN.wd为1e-5,Optimizer的实例化是按照那个初始化呢

看了下官网的,可以先这么写吧,问题不大。


input_spec = [
{
key: InputSpec([1, 12, 24, 48, 1], "float32", name=key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

第一个维度是否可以改为None?

#816 (comment) 因为reshape操作,设置为None会导致维度冲突

哦哦好的,那先写成1好了

optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
"weight_decay": cfg.TRAIN.wd,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,可以删除

},
]

# # init optimizer and lr scheduler
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除多余井号#

assert self.target_nino.shape[1] == self.out_len - NINO_WINDOW_T + 1
return self.sst, self.target_nino

def GetDataShape(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

类方法使用小写

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于入参的合法性检查,建议使用if+raise XXXError判断,assert断言用于内部产生的变量判断是否符合预期,两种方法使用场景是不同的

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_parameter_names是一个高度定制化并不是很通用的函数,不建议加到utils里去

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该函数可以与调用该函数的代码放在一起,但不建议单独建一个文件,不太有必要。

@@ -0,0 +1,261 @@
import enso_metric
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本地python模块的import顺序应该是在最下方跟ppsci一起才对,为什么会跑到上面呢?

import hydra
import numpy as np
import paddle
import xarray as xr
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于这类不通用只为某几个案例使用的模块导入行为,可以分两种情况考虑

  1. 不会被ppsci模块导入,如train.py,可以使用以下语句代替,能提示用户如何安装该模块,保证用户只有在运行该文件时才报错。
try:
    import xxx
except ModuleNotFoundError:
    raise ModuleNotFoundError(
        "Please install xxx with `pip install xxx`."
    )
  1. 对于会被ppsci模块导入的子模块,如enso_dataset.py,作如下修改,保证用户只有在使用该类时才报错
impot 
try:
    import xxx
except ModuleNotFoundError:
    pass # 这样不影响正常用户使用

Class XXXDataset:
    def __init__(....):
        super().__init__()
        if importlib.util.find_spec("xxx") is None:
            raise ModuleNotFoundError(
                "To use RadarDataset, please install 'xxx' via: `pip install "
                "xxx` first."
            ) # 在实例化该Dataset的时候检查并提示安装命令

否则随着代码变多,项目的不必要依赖会越来越多,这显然是不合理的

@HydrogenSulfate
Copy link
Collaborator

建议重新进行一次完整的pre-commit,code-style-check流水线挂了

cross_last_n_frames (int, optional): The cross_last_n_frames of decoder. Defaults to None.
qkv_bias (bool, optional): Whether to enable bias in calculating qkv attention. Defaults to False.
qk_scale (float, optional): Whether to enable scale factor when calculating the attention. Defaults to None.
attn_drop (float, optional): The attention dropout.. Defaults to 0.0.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除dropout后多余的标点符号

qkv_bias (bool, optional): Whether to enable bias in calculating qkv attention. Defaults to False.
qk_scale (float, optional): Whether to enable scale factor when calculating the attention. Defaults to None.
attn_drop (float, optional): The attention dropout.. Defaults to 0.0.
proj_drop (float, optional): The projrction dropout.. Defaults to 0.0.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除dropout后多余的标点符号

max_temporal_relative (int, optional): The max temporal. Defaults to 50.
norm_layer (str, optional): The normalization layer. Defaults to "layer_norm".
use_global_vector (bool, optional): Whether to use the global vector or not. Defaults to True.
separate_global_qkv (bool, optional): Whether to use different network to calc q_global, k_global, v_global. . Defaults to False.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除v_global后多余的标点符号

normalization: str = "layer_norm",
layer_norm_eps: float = 1e-05,
pre_norm: bool = False,
linear_init_mode="0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

跟上面一样,完善数据类型,其他关键的类也需要完善下

linear_init_mode: str="0", 
norm_init_mode: str="0",

def compute_enso_score(
y_pred, y_true, acc_weight: Optional[Union[str, np.ndarray, paddle.Tensor]] = None
):
"""_summary_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完善 summary

Args:
y_pred (paddle.Tensor): predict data
y_true (paddle.Tensor): true data
acc_weight (Optional[Union[str, np.ndarray, paddle.Tensor]], optional): _description_. Defaults to None.use default acc_weight specified at https://tianchi.aliyun.com/competition/entrance/531871/information
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完善 description

Returns:
acc:
rmse:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns 去掉吧



def sst_to_nino(sst: paddle.Tensor, normalize_sst: bool = True, detach: bool = True):
"""_summary_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完善summary

Args:
sst (paddle.Tensor): Shape = (N, T, H, W)
normalize_sst (bool, optional): Defaults to True.
detach (bool, optional): Defaults to True.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加description

Raises:
NotImplementedError: _description_
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除此部分

dtype=init_data.dtype,
default_initializer=nn.initializer.Constant(1.0),
)
self.scale.stop_gradient = not True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not True 改为 False,其他类似代码也替换下

from ppsci.data.dataset.csv_dataset import CSVDataset
from ppsci.data.dataset.csv_dataset import IterableCSVDataset
from ppsci.data.dataset.cylinder_dataset import MeshCylinderDataset
from ppsci.data.dataset.dgmr_dataset import DGMRDataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要删除此行代码

@Yang-Changhui
Copy link
Contributor Author

@HydrogenSulfate 你好,我现在有一个简单的前向,initencoder,相同的网络,相同输入,相同权重,使用reprod_log检查,误差为0.00067,这是最小测试单元,
paddle_torch_forward_error:https://aistudio.baidu.com/projectdetail/7781433?sUid=650866&shared=1&ts=1713940796780 ,麻烦您帮忙检查一下,谢谢

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants