-
Notifications
You must be signed in to change notification settings - Fork 184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[New example] Add nls-mb example #838
[New example] Add nls-mb example #838
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
辛苦suyong研发两个新的光学方程案例👍
有几处代码前段时间做了优化,可以按照review,进一步简化一下代码
- mode | ||
- output_dir | ||
- log_freq | ||
sweep: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sweep上方可以加上callbacks
字段,这个字段可以自动进行随机种子和文件夹创建操作
PaddleScience/examples/bracket/conf/bracket.yaml
Lines 16 to 20 in 07b440f
- log_freq | |
callbacks: | |
init_callback: | |
_target_: ppsci.utils.callbacks.InitCallback | |
sweep: |
# model settings | ||
MODEL: | ||
input_keys: ["x", "t"] | ||
output_keys: ["Eu","Ev","pu","pv","eta"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
List字符之间加上空格
iters_per_epoch: 1 | ||
lbfgs: | ||
iters_per_epoch: ${TRAIN.iters_per_epoch} | ||
output_dir: ./outputs_NLS-MB_soliton_L-BFGS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里使用引用语法感觉会更好?:output_dir: ${output_dir}LBFGS
from ppsci.equation.pde import base | ||
|
||
|
||
class NLSMB(base.PDE): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否能加上对应的docstring呢(公式使用katex,语法与latex基本相同,可以本地mkdocs 渲染自测)?这样文档中可以渲染出公式
ppsci/equation/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docs/zh/api/equation.md
里也添加一下NLSMB,这样文档能检索到
"sampler": { | ||
"name": "BatchSampler", | ||
"drop_last": False, | ||
"shuffle": False, | ||
}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
drop_last和shuffle都是False的话,"sampler"字段可以删除,
PaddleScience/ppsci/data/__init__.py
Lines 70 to 100 in 07b440f
sampler_cfg = cfg.pop("sampler", None) | |
if sampler_cfg is not None: | |
batch_sampler_cls = sampler_cfg.pop("name") | |
if batch_sampler_cls == "BatchSampler": | |
if world_size > 1: | |
batch_sampler_cls = "DistributedBatchSampler" | |
logger.warning( | |
f"Automatically use 'DistributedBatchSampler' instead of " | |
f"'BatchSampler' when world_size({world_size}) > 1." | |
) | |
sampler_cfg["batch_size"] = cfg["batch_size"] | |
batch_sampler = getattr(io, batch_sampler_cls)(_dataset, **sampler_cfg) | |
else: | |
batch_sampler_cls = "BatchSampler" | |
if world_size > 1: | |
batch_sampler_cls = "DistributedBatchSampler" | |
logger.warning( | |
f"Automatically use 'DistributedBatchSampler' instead of " | |
f"'BatchSampler' when world_size({world_size}) > 1." | |
) | |
batch_sampler = getattr(io, batch_sampler_cls)( | |
_dataset, | |
batch_size=cfg["batch_size"], | |
shuffle=False, | |
drop_last=False, | |
) | |
logger.message( | |
"'shuffle' and 'drop_last' are both set to False in default as sampler config is not specified." | |
) |
# set random seed for reproducibility | ||
ppsci.utils.misc.set_random_seed(cfg.seed) | ||
# initialize logger | ||
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上可以删除
# set random seed for reproducibility | ||
ppsci.utils.misc.set_random_seed(cfg.seed) | ||
|
||
# initialize logger | ||
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,可以删除
"sampler": { | ||
"name": "BatchSampler", | ||
"drop_last": False, | ||
"shuffle": False, | ||
}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,可以删除
- output_dir | ||
- log_freq | ||
sweep: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,可以加上callbacks字段
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
辛苦再补充一下导出推理代码和文档内的执行脚本,以及结果展示
``` | ||
|
||
## 5. 结果展示 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -0,0 +1,196 @@ | |||
# NLS-MB | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plt.figure(figsize=(10, 10)) | ||
plt.subplot(3, 3, 1) | ||
plt.title("E_ref") | ||
plt.tricontourf(x, t, E_ref[:, 0], levels=256, cmap="jet") | ||
plt.subplot(3, 3, 2) | ||
plt.title("E_pred") | ||
plt.tricontourf(x, t, E_pred[:, 0], levels=256, cmap="jet") | ||
plt.subplot(3, 3, 3) | ||
plt.title("E_diff") | ||
plt.tricontourf(x, t, np.abs(E_ref[:, 0] - E_pred[:, 0]), levels=256, cmap="jet") | ||
plt.subplot(3, 3, 4) | ||
plt.title("p_ref") | ||
plt.tricontourf(x, t, p_ref[:, 0], levels=256, cmap="jet") | ||
plt.subplot(3, 3, 5) | ||
plt.title("p_pred") | ||
plt.tricontourf(x, t, p_pred[:, 0], levels=256, cmap="jet") | ||
plt.subplot(3, 3, 6) | ||
plt.title("p_diff") | ||
plt.tricontourf(x, t, np.abs(p_ref[:, 0] - p_pred[:, 0]), levels=256, cmap="jet") | ||
plt.subplot(3, 3, 7) | ||
plt.title("eta_ref") | ||
plt.tricontourf(x, t, eta_ref[:, 0], levels=256, cmap="jet") | ||
plt.subplot(3, 3, 8) | ||
plt.title("eta_pred") | ||
plt.tricontourf(x, t, eta_pred[:, 0], levels=256, cmap="jet") | ||
plt.subplot(3, 3, 9) | ||
plt.title("eta_diff") | ||
plt.tricontourf( | ||
x, t, np.abs(eta_ref[:, 0] - eta_pred[:, 0]), levels=256, cmap="jet" | ||
) | ||
plt.savefig(osp.join(cfg.output_dir, "pred_optical_rogue_wave.png")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- plot相关代码可以提取变成一个函数,给train、evaluate、inference使用,类似:
PaddleScience/examples/allen_cahn/allen_cahn_plain.py
Lines 20 to 58 in 8d7dadd
def plot( t_star: np.ndarray, x_star: np.ndarray, u_ref: np.ndarray, u_pred: np.ndarray, output_dir: str, ): fig = plt.figure(figsize=(18, 5)) TT, XX = np.meshgrid(t_star, x_star, indexing="ij") u_ref = u_ref.reshape([len(t_star), len(x_star)]) plt.subplot(1, 3, 1) plt.pcolor(TT, XX, u_ref, cmap="jet") plt.colorbar() plt.xlabel("t") plt.ylabel("x") plt.title("Exact") plt.tight_layout() plt.subplot(1, 3, 2) plt.pcolor(TT, XX, u_pred, cmap="jet") plt.colorbar() plt.xlabel("t") plt.ylabel("x") plt.title("Predicted") plt.tight_layout() plt.subplot(1, 3, 3) plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet") plt.colorbar() plt.xlabel("t") plt.ylabel("x") plt.title("Absolute error") plt.tight_layout() fig_path = osp.join(output_dir, "ac.png") print(f"Saving figure to {fig_path}") fig.savefig(fig_path, bbox_inches="tight", dpi=400) plt.close() - plot完毕后,建议调用一下plt.close()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* add NLS-MB example * fix * fix * fix * fix * modify * modify * modify * fix
PR types
Others
PR changes
Others
Describe
增加光孤子和光学怪波的案例