Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New example] Add nls-mb example #838

Merged
merged 11 commits into from
Apr 18, 2024

Conversation

xusuyong
Copy link
Contributor

@xusuyong xusuyong commented Apr 8, 2024

PR types

Others

PR changes

Others

Describe

增加光孤子和光学怪波的案例

Copy link

paddle-bot bot commented Apr 8, 2024

Thanks for your contribution!

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辛苦suyong研发两个新的光学方程案例👍
有几处代码前段时间做了优化,可以按照review,进一步简化一下代码

- mode
- output_dir
- log_freq
sweep:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sweep上方可以加上callbacks字段,这个字段可以自动进行随机种子和文件夹创建操作

- log_freq
callbacks:
init_callback:
_target_: ppsci.utils.callbacks.InitCallback
sweep:

# model settings
MODEL:
input_keys: ["x", "t"]
output_keys: ["Eu","Ev","pu","pv","eta"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List字符之间加上空格

iters_per_epoch: 1
lbfgs:
iters_per_epoch: ${TRAIN.iters_per_epoch}
output_dir: ./outputs_NLS-MB_soliton_L-BFGS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里使用引用语法感觉会更好?:output_dir: ${output_dir}LBFGS

from ppsci.equation.pde import base


class NLSMB(base.PDE):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否能加上对应的docstring呢(公式使用katex,语法与latex基本相同,可以本地mkdocs 渲染自测)?这样文档中可以渲染出公式
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs/zh/api/equation.md里也添加一下NLSMB,这样文档能检索到

Comment on lines 130 to 134
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": False,
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop_last和shuffle都是False的话,"sampler"字段可以删除,

sampler_cfg = cfg.pop("sampler", None)
if sampler_cfg is not None:
batch_sampler_cls = sampler_cfg.pop("name")
if batch_sampler_cls == "BatchSampler":
if world_size > 1:
batch_sampler_cls = "DistributedBatchSampler"
logger.warning(
f"Automatically use 'DistributedBatchSampler' instead of "
f"'BatchSampler' when world_size({world_size}) > 1."
)
sampler_cfg["batch_size"] = cfg["batch_size"]
batch_sampler = getattr(io, batch_sampler_cls)(_dataset, **sampler_cfg)
else:
batch_sampler_cls = "BatchSampler"
if world_size > 1:
batch_sampler_cls = "DistributedBatchSampler"
logger.warning(
f"Automatically use 'DistributedBatchSampler' instead of "
f"'BatchSampler' when world_size({world_size}) > 1."
)
batch_sampler = getattr(io, batch_sampler_cls)(
_dataset,
batch_size=cfg["batch_size"],
shuffle=False,
drop_last=False,
)
logger.message(
"'shuffle' and 'drop_last' are both set to False in default as sampler config is not specified."
)

Comment on lines 289 to 293
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(cfg.seed)
# initialize logger
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上可以删除

Comment on lines 48 to 53
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(cfg.seed)

# initialize logger
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,可以删除

Comment on lines 108 to 112
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": False,
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,可以删除

Comment on lines 15 to 17
- output_dir
- log_freq
sweep:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,可以加上callbacks字段

@xusuyong xusuyong changed the title Add nls-mb example [New example] Add nls-mb example Apr 8, 2024
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辛苦再补充一下导出推理代码和文档内的执行脚本,以及结果展示

```

## 5. 结果展示

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -0,0 +1,196 @@
# NLS-MB

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 349 to 379
plt.figure(figsize=(10, 10))
plt.subplot(3, 3, 1)
plt.title("E_ref")
plt.tricontourf(x, t, E_ref[:, 0], levels=256, cmap="jet")
plt.subplot(3, 3, 2)
plt.title("E_pred")
plt.tricontourf(x, t, E_pred[:, 0], levels=256, cmap="jet")
plt.subplot(3, 3, 3)
plt.title("E_diff")
plt.tricontourf(x, t, np.abs(E_ref[:, 0] - E_pred[:, 0]), levels=256, cmap="jet")
plt.subplot(3, 3, 4)
plt.title("p_ref")
plt.tricontourf(x, t, p_ref[:, 0], levels=256, cmap="jet")
plt.subplot(3, 3, 5)
plt.title("p_pred")
plt.tricontourf(x, t, p_pred[:, 0], levels=256, cmap="jet")
plt.subplot(3, 3, 6)
plt.title("p_diff")
plt.tricontourf(x, t, np.abs(p_ref[:, 0] - p_pred[:, 0]), levels=256, cmap="jet")
plt.subplot(3, 3, 7)
plt.title("eta_ref")
plt.tricontourf(x, t, eta_ref[:, 0], levels=256, cmap="jet")
plt.subplot(3, 3, 8)
plt.title("eta_pred")
plt.tricontourf(x, t, eta_pred[:, 0], levels=256, cmap="jet")
plt.subplot(3, 3, 9)
plt.title("eta_diff")
plt.tricontourf(
x, t, np.abs(eta_ref[:, 0] - eta_pred[:, 0]), levels=256, cmap="jet"
)
plt.savefig(osp.join(cfg.output_dir, "pred_optical_rogue_wave.png"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. plot相关代码可以提取变成一个函数,给train、evaluate、inference使用,类似:
    def plot(
    t_star: np.ndarray,
    x_star: np.ndarray,
    u_ref: np.ndarray,
    u_pred: np.ndarray,
    output_dir: str,
    ):
    fig = plt.figure(figsize=(18, 5))
    TT, XX = np.meshgrid(t_star, x_star, indexing="ij")
    u_ref = u_ref.reshape([len(t_star), len(x_star)])
    plt.subplot(1, 3, 1)
    plt.pcolor(TT, XX, u_ref, cmap="jet")
    plt.colorbar()
    plt.xlabel("t")
    plt.ylabel("x")
    plt.title("Exact")
    plt.tight_layout()
    plt.subplot(1, 3, 2)
    plt.pcolor(TT, XX, u_pred, cmap="jet")
    plt.colorbar()
    plt.xlabel("t")
    plt.ylabel("x")
    plt.title("Predicted")
    plt.tight_layout()
    plt.subplot(1, 3, 3)
    plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet")
    plt.colorbar()
    plt.xlabel("t")
    plt.ylabel("x")
    plt.title("Absolute error")
    plt.tight_layout()
    fig_path = osp.join(output_dir, "ac.png")
    print(f"Saving figure to {fig_path}")
    fig.savefig(fig_path, bbox_inches="tight", dpi=400)
    plt.close()
  2. plot完毕后,建议调用一下plt.close()

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit fd3a7f9 into PaddlePaddle:develop Apr 18, 2024
3 of 4 checks passed
@xusuyong xusuyong deleted the add_NLS-MB_example branch April 22, 2024 08:29
huohuohuohuohuo123 pushed a commit to huohuohuohuohuo123/PaddleScience that referenced this pull request Aug 12, 2024
* add NLS-MB example

* fix

* fix

* fix

* fix

* modify

* modify

* modify

* fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants