Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dgmr #813

Merged
merged 34 commits into from
Mar 29, 2024
Merged

Dgmr #813

merged 34 commits into from
Mar 29, 2024

Conversation

liaoxin2
Copy link
Contributor

PR types

One of dgmr eval code

PR changes

One of APIs

Describe

dgmr eval code

Copy link

paddle-bot bot commented Mar 20, 2024

Thanks for your contribution!

- mode
- output_dir
- log_freq
sweep:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考其他配置文件,加上callback字段

split: validation # train or validation
NUM_INPUT_FRAMES: 4
NUM_TARGET_FRAMES: 18
dataset_path: /workspace/workspace/skillful_nowcasting/openclimatefix/nimrod-uk-1km
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用相对路径

Comment on lines 38 to 42
DATALOADER:
batch_size: 1
shuffle: False
num_workers: 1
drop_last: True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataloader配置放py文件,不需要写在这里

Comment on lines +46 to +47
input_keys: ['input_frames']
output_keys: ['future_images']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉中括号



def train(cfg: DictConfig):
print("Not supported.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("Not supported.")
raise NotImplementedError(
"Training of DGMR is not supported now."
)

Comment on lines 948 to 955
"""

Args:
x: tensor on the correct device, to move over the latent distribution

Returns:

"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删多余空行

def __init__(
self, input_channels: int, output_channels: int, ratio_kq=8, ratio_v=8
):
super(AttentionLayer, self).__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super(XXX, self).__init__()的写法全部改为super().__init__()

Comment on lines 50 to 51
NUM_INPUT_FRAMES: int = 4,
NUM_TARGET_FRAMES: int = 18,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数变量名小写

Comment on lines 59 to 60
self.NUM_INPUT_FRAMES = NUM_INPUT_FRAMES
self.NUM_TARGET_FRAMES = NUM_TARGET_FRAMES
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

小写

def __len__(self):
return self.number

def __getitem__(self, item):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

item==>idx

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所有paddle.nn.xxx的调用方式改为nn.xxx

Comment on lines 407 to 415
"""
Spatial discriminator from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf

Args:
input_channels: Number of input channels per timestep
num_timesteps: Number of timesteps to use, in the paper 8/18 timesteps were chosen
num_layers: Number of intermediate DBlock layers to use
conv_type: Type of 2d convolutions to use, see satflow/models/utils.py for options
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__init__方法的 docstring 移动到SpatialDiscriminator下方

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class的__init__方法的 docstring 移动到class的下方

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring格式不规范,重新用vscode插件生成下

Comment on lines 12 to 15
raise ModuleNotFoundError(
"Please install einops with 'pip install einops'"
" before exporting DGMR model."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要抛出异常,改为pass即可

Comment on lines 321 to 329
"""
Temporal Discriminator from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf

Args:
input_channels: Number of channels per timestep
crop_size: Size of the crop, in the paper half the width of the input images
num_layers: Number of intermediate DBlock layers to use
conv_type: Type of 2d convolutions to use, see satflow/models/utils.py for options
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python类的__init__的docstring,统一放到类下方,而不是__init__下方


def forward(self, x: paddle.Tensor) -> paddle.Tensor:
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除空行


import numpy as np
import paddle
from datasets import load_dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要从模块中直接导入方法,而是通过导入模块再调用,改为像einops一样用try导入 import datasets

self.num_input_frames = num_input_frames
self.num_target_frames = num_target_frames
self.number = number
self.reader = load_dataset(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

datasets.load_dataset

Comment on lines 3 to 13
=== "模型训练命令"

``` sh
python dgmr.py
```

=== "模型评估命令"

``` sh
python dgmr.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/dgmr/dgmr_pretrained.pdparams
```
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

训练和验证命令补充数据集下载的脚本,否则用户还需要手动去找数据集怎么下载
image

)
plt.subplots_adjust(hspace=0.1, wspace=0.1)
plt.savefig(osp.join(output_dir, f"Generated_Image_Frame_{batch_idx}.png"))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

末尾加一行plt.close()

cfg: DictConfig,
solver: ppsci.solver.Solver,
batch: tuple,
batch_idx: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数没用到,删除

def validation(
cfg: DictConfig,
solver: ppsci.solver.Solver,
batch: tuple,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 使用Tuple而不是tuple
  2. 写清楚Tuple里的内容

Args:
cfg (DictConfig): Configuration object.
solver (ppsci.solver.Solver): Solver object containing the model and related components.
batch (tuple): Input batch consisting of images and corresponding future images.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同batch修改

try:
row = next(self.iter_reader)
except Exception:
rng = default_rng()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里得固定下seed,不然会有随机


Examples:
>>> import ppsci
>>> dataset = ppsci.data.dataset.DGMRDataset(("input", ), ("output", ))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

45行末尾加上: # doctest: +SKIP

Comment on lines 44 to 45
>>> output_dict = model(input_dict)
>>> print(output_dict["output"].shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

末尾加上: # doctest: +SKIP

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit f4ca04a into PaddlePaddle:develop Mar 29, 2024
3 of 4 checks passed
huohuohuohuohuo123 pushed a commit to huohuohuohuohuo123/PaddleScience that referenced this pull request Aug 12, 2024
* add bubble datafile test=develop

* add bubble code test=develop

* add bubble code test=develop

* add bubble code test=develop

* add bubble code

* add bubble data

* add bubble code

* add some code for chip

* add some code for DGMR

* add some code for DGMR

* add some code for DGMR

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code

* add some dgmr code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants