Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
aedbfc1
allow empty optimizer when saving checkpoint
HydrogenSulfate Apr 9, 2024
e97886d
add model averaging module
HydrogenSulfate Apr 9, 2024
381c702
fix return dtype inconsistency with global dtype
HydrogenSulfate Apr 9, 2024
cf8b742
use python func instead of sympy function for pow(u,3) get a bit poor…
HydrogenSulfate Apr 10, 2024
d6f4b2b
refine AllenCahn docstring
HydrogenSulfate Apr 10, 2024
39bb505
support save and load for average model module
HydrogenSulfate Apr 10, 2024
265b7bb
add 3 ema unitests
HydrogenSulfate Apr 10, 2024
f8ce9f8
update 2023 to 2024
HydrogenSulfate Apr 10, 2024
5f61f8a
add ema config pydantic scheme
HydrogenSulfate Apr 10, 2024
13ea6c1
add avg_range for SWA
HydrogenSulfate Apr 10, 2024
c1a44fe
update field_validator for swa and ema
HydrogenSulfate Apr 10, 2024
e5fca6a
support period embedding for MLP
HydrogenSulfate Apr 10, 2024
5c88550
Keep non-float data when reading file
HydrogenSulfate Apr 10, 2024
4ce144f
Merge branch 'add_allen_cahn' into add_period_layer
HydrogenSulfate Apr 10, 2024
1be0892
update ema and save_load, printer and eval, solver module code
HydrogenSulfate Apr 10, 2024
1e77b75
add allen_cahn example
HydrogenSulfate Apr 11, 2024
21a7278
refine code
HydrogenSulfate Apr 11, 2024
6a0b36a
save buffer and non-grad required params in ema
HydrogenSulfate Apr 11, 2024
5d2de5f
add unitest for ema with buffer
HydrogenSulfate Apr 11, 2024
a0f7c33
Merge branch 'develop' into add_allen_cahn_example
HydrogenSulfate Apr 11, 2024
c9e0133
fix epoch_ema saving
HydrogenSulfate Apr 11, 2024
9d3063f
add unitest for ema state_dict
HydrogenSulfate Apr 11, 2024
186648b
refine allen_cahn_plain.py
HydrogenSulfate Apr 11, 2024
aa23fcc
fix string to floating conversion in reader.py
HydrogenSulfate Apr 11, 2024
3f56fcf
fix string to floating conversion in reader.py
HydrogenSulfate Apr 11, 2024
ac52573
remove print code in solver
HydrogenSulfate Apr 11, 2024
b4c1cbf
Update allen_cahn_plain.py
HydrogenSulfate Apr 11, 2024
7a4c960
Update misc.py
HydrogenSulfate Apr 11, 2024
c4c7f14
Merge branch 'develop' into add_allen_cahn_example
zhiminzhang0830 Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deploy/python_infer/pinn_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def predict(

# inference by batch
for batch_id in range(1, batch_num + 1):
if batch_id % self.log_freq == 0 or batch_id == batch_num:
if batch_id == 1 or batch_id % self.log_freq == 0 or batch_id == batch_num:
logger.info(f"Predicting batch {batch_id}/{batch_num}")

# prepare batch input dict
Expand Down
1 change: 1 addition & 0 deletions docs/zh/api/data/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- ChipHeatDataset
- CSVDataset
- IterableCSVDataset
- ContinuousNamedArrayDataset
- ERA5Dataset
- ERA5SampledDataset
- IterableMatDataset
Expand Down
298 changes: 298 additions & 0 deletions examples/allen_cahn/allen_cahn_plain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
"""
Reference: https://docs.nvidia.com/deeplearning/modulus/modulus-v2209/user_guide/intermediate/adding_stl_files.html
"""

from os import path as osp

import hydra
import numpy as np
import paddle
import scipy.io as sio
from matplotlib import pyplot as plt
from omegaconf import DictConfig

import ppsci
from ppsci.utils import misc

dtype = paddle.get_default_dtype()


def plot(
t_star: np.ndarray,
x_star: np.ndarray,
u_ref: np.ndarray,
u_pred: np.ndarray,
output_dir: str,
):
fig = plt.figure(figsize=(18, 5))
TT, XX = np.meshgrid(t_star, x_star, indexing="ij")
u_ref = u_ref.reshape([len(t_star), len(x_star)])

plt.subplot(1, 3, 1)
plt.pcolor(TT, XX, u_ref, cmap="jet")
plt.colorbar()
plt.xlabel("t")
plt.ylabel("x")
plt.title("Exact")
plt.tight_layout()

plt.subplot(1, 3, 2)
plt.pcolor(TT, XX, u_pred, cmap="jet")
plt.colorbar()
plt.xlabel("t")
plt.ylabel("x")
plt.title("Predicted")
plt.tight_layout()

plt.subplot(1, 3, 3)
plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet")
plt.colorbar()
plt.xlabel("t")
plt.ylabel("x")
plt.title("Absolute error")
plt.tight_layout()

fig_path = osp.join(output_dir, "ac.png")
fig.savefig(fig_path, bbox_inches="tight", dpi=400)
plt.close()


def train(cfg: DictConfig):
# set model
model = ppsci.arch.MLP(**cfg.MODEL)

# set equation
equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)}

data = sio.loadmat(cfg.DATA_PATH)
u_ref = data["usol"].astype(dtype) # (nt, nx)
t_star = data["t"].flatten().astype(dtype) # [nt, ]
x_star = data["x"].flatten().astype(dtype) # [nx, ]

u0 = u_ref[0, :] # [nx, ]

t0 = t_star[0] # float
t1 = t_star[-1] # float

x0 = x_star[0] # float
x1 = x_star[-1] # float

# set constraint
def gen_input_batch():
tx = np.random.uniform(
[t0, x0],
[t1, x1],
(cfg.TRAIN.batch_size, 2),
).astype(dtype)
return {
"t": tx[:, 0:1],
"x": tx[:, 1:2],
}

def gen_label_batch(input_batch):
return {"allen_cahn": np.zeros([cfg.TRAIN.batch_size, 1], dtype)}

pde_constraint = ppsci.constraint.SupervisedConstraint(
{
"dataset": {
"name": "ContinuousNamedArrayDataset",
"input": gen_input_batch,
"label": gen_label_batch,
},
},
output_expr=equation["AllenCahn"].equations,
loss=ppsci.loss.MSELoss(),
name="PDE",
)

ic_input = {"t": np.full([len(x_star), 1], t0), "x": x_star.reshape([-1, 1])}
ic_label = {"u": u0.reshape([-1, 1])}
ic = ppsci.constraint.SupervisedConstraint(
{
"dataset": {
"name": "IterableNamedArrayDataset",
"input": ic_input,
"label": ic_label,
},
},
output_expr={"u": lambda out: out["u"]},
loss=ppsci.loss.MSELoss("mean"),
name="IC",
)
# wrap constraints together
constraint = {
pde_constraint.name: pde_constraint,
ic.name: ic,
}

# set optimizer
lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay(
**cfg.TRAIN.lr_scheduler
)()
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

# set validator
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
eval_label = {"u": u_ref.reshape([-1, 1])}
u_validator = ppsci.validate.SupervisedValidator(
{
"dataset": {
"name": "NamedArrayDataset",
"input": eval_data,
"label": eval_label,
},
"batch_size": cfg.EVAL.batch_size,
},
ppsci.loss.MSELoss("mean"),
{"u": lambda out: out["u"]},
metric={"L2Rel": ppsci.metric.L2Rel()},
name="u_validator",
)
validator = {u_validator.name: u_validator}

# initialize solver
solver = ppsci.solver.Solver(
model,
constraint,
cfg.output_dir,
optimizer,
lr_scheduler,
cfg.TRAIN.epochs,
cfg.TRAIN.iters_per_epoch,
save_freq=cfg.TRAIN.save_freq,
log_freq=cfg.log_freq,
eval_during_train=True,
eval_freq=cfg.TRAIN.eval_freq,
seed=cfg.seed,
equation=equation,
validator=validator,
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
checkpoint_path=cfg.TRAIN.checkpoint_path,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
use_tbd=True,
cfg=cfg,
)
# train model
solver.train()
# evaluate after finished training
solver.eval()
# visualize prediction after finished training
u_pred = solver.predict(
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
)["u"]
u_pred = u_pred.reshape([len(t_star), len(x_star)])

# plot
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)


def evaluate(cfg: DictConfig):
# set model
model = ppsci.arch.MLP(**cfg.MODEL)

data = sio.loadmat(cfg.DATA_PATH)
u_ref = data["usol"].astype(dtype) # (nt, nx)
t_star = data["t"].flatten().astype(dtype) # [nt, ]
x_star = data["x"].flatten().astype(dtype) # [nx, ]

# set validator
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
eval_label = {"u": u_ref.reshape([-1, 1])}
u_validator = ppsci.validate.SupervisedValidator(
{
"dataset": {
"name": "NamedArrayDataset",
"input": eval_data,
"label": eval_label,
},
"batch_size": cfg.EVAL.batch_size,
},
ppsci.loss.MSELoss("mean"),
{"u": lambda out: out["u"]},
metric={"L2Rel": ppsci.metric.L2Rel()},
name="u_validator",
)
validator = {u_validator.name: u_validator}

# initialize solver
solver = ppsci.solver.Solver(
model,
output_dir=cfg.output_dir,
log_freq=cfg.log_freq,
validator=validator,
pretrained_model_path=cfg.EVAL.pretrained_model_path,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)

# evaluate after finished training
solver.eval()
# visualize prediction after finished training
u_pred = solver.predict(
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
)["u"]
u_pred = u_pred.reshape([len(t_star), len(x_star)])

# plot
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)


def export(cfg: DictConfig):
# set model
model = ppsci.arch.MLP(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
# export model
from paddle.static import InputSpec

input_spec = [
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
]
solver.export(input_spec, cfg.INFER.export_path, with_onnx=False)


def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

predictor = pinn_predictor.PINNPredictor(cfg)
data = sio.loadmat(cfg.DATA_PATH)
u_ref = data["usol"].astype(dtype) # (nt, nx)
t_star = data["t"].flatten().astype(dtype) # [nt, ]
x_star = data["x"].flatten().astype(dtype) # [nx, ]
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)

input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
# mapping data to cfg.INFER.output_keys

plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)


@hydra.main(version_base=None, config_path="./conf", config_name="allen_cahn.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
main()
Loading