Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/zh/install_setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@

``` sh
cd PaddleScience/
set PYTHONPATH=%cd%
set PYTHONPATH=%PYTHONPATH%;%CD%
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple # manually install requirements
```

Expand Down
6 changes: 6 additions & 0 deletions ppsci/loss/mtl/agda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from typing import ClassVar
from typing import List

import paddle
Expand All @@ -30,6 +31,10 @@ class AGDA(base.LossAggregator):

NOTE: This loss aggregator is only suitable for two-task learning and the first task loss must be PDE loss.

Attributes:
should_persist(bool): Whether to persist the loss aggregator when saving.
Those loss aggregators with parameters and/or buffers should be persisted.

Args:
model (nn.Layer): Training model.
M (int, optional): Smoothing period. Defaults to 100.
Expand All @@ -49,6 +54,7 @@ class AGDA(base.LossAggregator):
... bc_loss = paddle.sum((y2 - 2) ** 2)
... loss_aggregator({'pde_loss': pde_loss, 'bc_loss': bc_loss}).backward()
"""
should_persist: ClassVar[bool] = False

def __init__(self, model: nn.Layer, M: int = 100, gamma: float = 0.999) -> None:
super().__init__(model)
Expand Down
14 changes: 14 additions & 0 deletions ppsci/loss/mtl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import ClassVar
from typing import Dict
from typing import Union

Expand All @@ -27,10 +28,16 @@
class LossAggregator(nn.Layer):
"""Base class of loss aggregator mainly for multitask learning.

Attributes:
should_persist(bool): Whether to persist the loss aggregator when saving.
Those loss aggregators with parameters and/or buffers should be persisted.

Args:
model (nn.Layer): Training model.
"""

should_persist: ClassVar[bool] = False

def __init__(self, model: nn.Layer) -> None:
super().__init__()
self.model = model
Expand All @@ -52,3 +59,10 @@ def backward(self) -> None:
raise NotImplementedError(
f"'backward' should be implemented in subclass {self.__class__.__name__}"
)

def state_dict(self):
agg_state = super().state_dict()
model_state = self.model.state_dictq()
# remove model parameters from state dict for already in pdparams
agg_state = {k: v for k, v in agg_state.items() if k not in model_state}
return agg_state
6 changes: 6 additions & 0 deletions ppsci/loss/mtl/grad_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from typing import ClassVar
from typing import Dict
from typing import List

Expand Down Expand Up @@ -42,6 +43,10 @@ class GradNorm(base.LossAggregator):
\end{align*}
$$

Attributes:
should_persist(bool): Whether to persist the loss aggregator when saving.
Those loss aggregators with parameters and/or buffers should be persisted.

Args:
model (nn.Layer): Training model.
num_losses (int, optional): Number of losses. Defaults to 1.
Expand All @@ -63,6 +68,7 @@ class GradNorm(base.LossAggregator):
... loss2 = paddle.sum((y2 - 2) ** 2)
... loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
"""
should_persist: ClassVar[bool] = True
weight: paddle.Tensor

def __init__(
Expand Down
8 changes: 7 additions & 1 deletion ppsci/loss/mtl/ntk.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,21 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import ClassVar
from typing import List

import paddle
from paddle import nn

from ppsci.loss.mtl import base

if TYPE_CHECKING:
from paddle import nn


class NTK(base.LossAggregator):
should_persist: ClassVar[bool] = True

def __init__(
self,
model: nn.Layer,
Expand Down
6 changes: 6 additions & 0 deletions ppsci/loss/mtl/pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from typing import ClassVar
from typing import List

import numpy as np
Expand All @@ -31,6 +32,10 @@ class PCGrad(base.LossAggregator):

Code reference: [https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py](https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py)

Attributes:
should_persist(bool): Whether to persist the loss aggregator when saving.
Those loss aggregators with parameters and/or buffers should be persisted.

Args:
model (nn.Layer): Training model.

Expand All @@ -48,6 +53,7 @@ class PCGrad(base.LossAggregator):
... loss2 = paddle.sum((y2 - 2) ** 2)
... loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
"""
should_persist: ClassVar[bool] = False

def __init__(self, model: nn.Layer) -> None:
super().__init__(model)
Expand Down
6 changes: 6 additions & 0 deletions ppsci/loss/mtl/relobralo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from typing import ClassVar
from typing import Dict

import paddle
Expand All @@ -26,6 +27,10 @@ class Relobralo(nn.Layer):

[Multi-Objective Loss Balancing for Physics-Informed Deep Learning](https://arxiv.org/abs/2110.09813)

Attributes:
should_persist(bool): Whether to persist the loss aggregator when saving.
Those loss aggregators with parameters and/or buffers should be persisted.

Args:
num_losses (int): Number of losses.
alpha (float, optional): Ability for remembering past in paper. Defaults to 0.95.
Expand All @@ -49,6 +54,7 @@ class Relobralo(nn.Layer):
... loss2 = paddle.sum((y2 - 2) ** 2)
... loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
"""
should_persist: ClassVar[bool] = True

def __init__(
self,
Expand Down
7 changes: 7 additions & 0 deletions ppsci/loss/mtl/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
if TYPE_CHECKING:
import paddle

from typing import ClassVar

from ppsci.loss.mtl.base import LossAggregator


Expand All @@ -30,7 +32,12 @@ class Sum(LossAggregator):
$$
loss = \sum_i^N losses_i
$$

Attributes:
should_persist(bool): Whether to persist the loss aggregator when saving.
Those loss aggregators with parameters and/or buffers should be persisted.
"""
should_persist: ClassVar[bool] = False

def __init__(self) -> None:
self.step = 0
Expand Down
4 changes: 4 additions & 0 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def __init__(
self.scaler,
self.equation,
self.ema_model,
self.loss_aggregator,
)
if isinstance(loaded_metric, dict):
self.best_metric.update(loaded_metric)
Expand Down Expand Up @@ -567,6 +568,7 @@ def train(self) -> None:
self.output_dir,
"best_model",
self.equation,
aggregator=self.loss_aggregator,
)
logger.info(
f"[Eval][Epoch {epoch_id}]"
Expand Down Expand Up @@ -633,6 +635,7 @@ def train(self) -> None:
f"epoch_{epoch_id}",
self.equation,
ema_model=self.ema_model,
aggregator=self.loss_aggregator,
)

# save the latest model for convenient resume training
Expand All @@ -646,6 +649,7 @@ def train(self) -> None:
self.equation,
print_log=(epoch_id == start_epoch),
ema_model=self.ema_model,
aggregator=self.loss_aggregator,
)

def finetune(self, pretrained_model_path: str) -> None:
Expand Down
38 changes: 35 additions & 3 deletions ppsci/utils/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from paddle import optimizer

from ppsci import equation
from ppsci.loss import mtl
from ppsci.utils import ema


Expand All @@ -42,7 +43,10 @@


def _load_pretrain_from_path(
path: str, model: nn.Layer, equation: Optional[Dict[str, equation.PDE]] = None
path: str,
model: nn.Layer,
equation: Optional[Dict[str, equation.PDE]] = None,
loss_aggregator: Optional[mtl.LossAggregator] = None,
):
"""Load pretrained model from given path.

Expand Down Expand Up @@ -77,9 +81,26 @@ def _load_pretrain_from_path(
f"Finish loading pretrained equation parameters from: {path}.pdeqn"
)

if loss_aggregator is not None:
if not os.path.exists(f"{path}.pdagg"):
if loss_aggregator.should_persist:
logger.warning(
f"Given loss_aggregator({type(loss_aggregator)}) has persistable"
f"parameters or buffers, but {path}.pdagg not found."
)
else:
aggregator_dict = paddle.load(f"{path}.pdagg")
loss_aggregator.set_state_dict(aggregator_dict)
logger.message(
f"Finish loading pretrained equation parameters from: {path}.pdagg"
)


def load_pretrain(
model: nn.Layer, path: str, equation: Optional[Dict[str, equation.PDE]] = None
model: nn.Layer,
path: str,
equation: Optional[Dict[str, equation.PDE]] = None,
loss_aggregator: Optional[mtl.LossAggregator] = None,
):
"""
Load pretrained model from given path or url.
Expand Down Expand Up @@ -121,7 +142,7 @@ def is_url_accessible(url: str):
# remove ".pdparams" in suffix of path for convenient
if path.endswith(".pdparams"):
path = path[:-9]
_load_pretrain_from_path(path, model, equation)
_load_pretrain_from_path(path, model, equation, loss_aggregator)


def load_checkpoint(
Expand All @@ -131,6 +152,7 @@ def load_checkpoint(
grad_scaler: Optional[amp.GradScaler] = None,
equation: Optional[Dict[str, equation.PDE]] = None,
ema_model: Optional[ema.AveragedModel] = None,
aggregator: Optional[mtl.LossAggregator] = None,
) -> Dict[str, Any]:
"""Load from checkpoint.

Expand All @@ -141,6 +163,7 @@ def load_checkpoint(
grad_scaler (Optional[amp.GradScaler]): GradScaler for AMP. Defaults to None.
equation (Optional[Dict[str, equation.PDE]]): Equations. Defaults to None.
ema_model: Optional[ema.AveragedModel]: Average model. Defaults to None.
aggregator: Optional[mtl.LossAggregator]: Loss aggregator. Defaults to None.

Returns:
Dict[str, Any]: Loaded metric information.
Expand Down Expand Up @@ -189,6 +212,10 @@ def load_checkpoint(
avg_param_dict = paddle.load(f"{path}_ema.pdparams")
ema_model.set_state_dict(avg_param_dict)

if aggregator is not None:
aggregator_dict = paddle.load(f"{path}.pdagg")
aggregator.set_state_dict(aggregator_dict)

logger.message(f"Finish loading checkpoint from {path}")
return metric_dict

Expand All @@ -203,6 +230,7 @@ def save_checkpoint(
equation: Optional[Dict[str, equation.PDE]] = None,
print_log: bool = True,
ema_model: Optional[ema.AveragedModel] = None,
aggregator: Optional[mtl.LossAggregator] = None,
):
"""
Save checkpoint, including model params, optimizer params, metric information.
Expand All @@ -219,6 +247,7 @@ def save_checkpoint(
keeping log tidy without duplicate 'Finish saving checkpoint ...' log strings.
Defaults to True.
ema_model: Optional[ema.AveragedModel]: Average model. Defaults to None.
aggregator: Optional[mtl.LossAggregator]: Loss aggregator. Defaults to None.

Examples:
>>> import ppsci
Expand Down Expand Up @@ -258,6 +287,9 @@ def save_checkpoint(
if ema_model:
paddle.save(ema_model.state_dict(), f"{ckpt_path}_ema.pdparams")

if aggregator and aggregator.should_persist:
paddle.save(aggregator.state_dict(), f"{ckpt_path}.pdagg")

if print_log:
log_str = f"Finish saving checkpoint to: {ckpt_path}"
if prefix == "latest":
Expand Down