Skip to content

[Fea] Support detach_keys argument for all PDE #889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 73 additions & 15 deletions ppsci/equation/pde/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Union

import paddle
import sympy
import sympy as sp
from paddle import nn

DETACH_FUNC_NAME = "detach"
Expand All @@ -33,7 +33,7 @@ class PDE:

def __init__(self):
super().__init__()
self.equations = {}
self.equations: Dict[str, Union[Callable, sp.Basic]] = {}
# for PDE which has learnable parameter(s)
self.learnable_parameters = nn.ParameterList()

Expand All @@ -42,7 +42,7 @@ def __init__(self):
@staticmethod
def create_symbols(
symbol_str: str,
) -> Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]:
) -> Union[sp.Symbol, Tuple[sp.Symbol, ...]]:
"""create symbolic variables.

Args:
Expand All @@ -61,11 +61,9 @@ def create_symbols(
>>> print(symbols_xyz)
(x, y, z)
"""
return sympy.symbols(symbol_str)
return sp.symbols(symbol_str)

def create_function(
self, name: str, invars: Tuple[sympy.Symbol, ...]
) -> sympy.Function:
def create_function(self, name: str, invars: Tuple[sp.Symbol, ...]) -> sp.Function:
"""Create named function depending on given invars.

Args:
Expand All @@ -86,14 +84,73 @@ def create_function(
>>> print(f)
f(x, y, z)
"""
expr = sympy.Function(name)(*invars)
expr = sp.Function(name)(*invars)

# wrap `expression(...)` to `detach(expression(...))`
# if name of expression is in given detach_keys
if self.detach_keys and name in self.detach_keys:
expr = sympy.Function(DETACH_FUNC_NAME)(expr)
return expr

def _apply_detach(self):
"""
Wrap detached sub_expr into detach(sub_expr) to prevent gradient
back-propagation, only for those items speicified in self.detach_keys.

NOTE: This function is expected to be called after self.equations is ready in PDE.__init__.

Examples:
>>> import ppsci
>>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False)
>>> print(ns)
NavierStokes
continuity: Derivative(u(x, y), x) + Derivative(v(x, y), y)
momentum_x: u(x, y)*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
momentum_y: u(x, y)*Derivative(v(x, y), x) + v(x, y)*Derivative(v(x, y), y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
>>> detach_keys = ("u", "v__y")
>>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False, detach_keys=detach_keys)
>>> print(ns)
NavierStokes
continuity: detach(Derivative(v(x, y), y)) + Derivative(u(x, y), x)
momentum_x: detach(u(x, y))*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
momentum_y: detach(u(x, y))*Derivative(v(x, y), x) + detach(Derivative(v(x, y), y))*v(x, y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
"""
if self.detach_keys is None:
return

from copy import deepcopy

from sympy.core.traversal import postorder_traversal

from ppsci.utils.symbolic import _cvt_to_key

for name, expr in self.equations.items():
if not isinstance(expr, sp.Basic):
continue
# only process sympy expression
expr_ = deepcopy(expr)
for item in postorder_traversal(expr):
if _cvt_to_key(item) in self.detach_keys:
# inplace all related sub_expr into detach(sub_expr)
expr_ = expr_.replace(item, sp.Function(DETACH_FUNC_NAME)(item))

# remove all detach wrapper for more-than-once wrapped items to prevent duplicated wrapping
expr_ = expr_.replace(
sp.Function(DETACH_FUNC_NAME)(
sp.Function(DETACH_FUNC_NAME)(item)
),
sp.Function(DETACH_FUNC_NAME)(item),
)

# remove unccessary detach wrapping for the first arg of Derivative
for item_ in list(postorder_traversal(expr_)):
if isinstance(item_, sp.Derivative):
if item_.args[0].name == DETACH_FUNC_NAME:
expr_ = expr_.replace(
item_,
sp.Derivative(
item_.args[0].args[0], *item_.args[1:]
),
)

self.equations[name] = expr_

def add_equation(self, name: str, equation: Callable):
"""Add an equation.

Expand All @@ -110,7 +167,8 @@ def add_equation(self, name: str, equation: Callable):
>>> equation = sympy.diff(u, x) + sympy.diff(u, y)
>>> pde.add_equation('linear_pde', equation)
>>> print(pde)
PDE, linear_pde: 2*x + 2*y
PDE
linear_pde: 2*x + 2*y
"""
self.equations.update({name: equation})

Expand Down Expand Up @@ -181,7 +239,7 @@ def set_state_dict(
return self.learnable_parameters.set_state_dict(state_dict)

def __str__(self):
return ", ".join(
return "\n".join(
[self.__class__.__name__]
+ [f"{name}: {eq}" for name, eq in self.equations.items()]
+ [f" {name}: {eq}" for name, eq in self.equations.items()]
)
2 changes: 2 additions & 0 deletions ppsci/equation/pde/biharmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,5 @@ def __init__(
biharmonic += u.diff(invar_i, 2).diff(invar_j, 2)

self.add_equation("biharmonic", biharmonic)

self._apply_detach()
2 changes: 2 additions & 0 deletions ppsci/equation/pde/heat_exchanger.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,5 @@ def __init__(
self.add_equation("heat_boundary", heat_boundary)
self.add_equation("cold_boundary", cold_boundary)
self.add_equation("wall", wall)

self._apply_detach()
2 changes: 2 additions & 0 deletions ppsci/equation/pde/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None):
laplace += u.diff(invar, 2)

self.add_equation("laplace", laplace)

self._apply_detach()
2 changes: 2 additions & 0 deletions ppsci/equation/pde/linear_elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,5 @@ def __init__(
self.add_equation("traction_y", traction_y)
if self.dim == 3:
self.add_equation("traction_z", traction_z)

self._apply_detach()
2 changes: 2 additions & 0 deletions ppsci/equation/pde/navier_stokes.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,5 @@ def __init__(
self.add_equation("momentum_y", momentum_y)
if self.dim == 3:
self.add_equation("momentum_z", momentum_z)

self._apply_detach()
2 changes: 2 additions & 0 deletions ppsci/equation/pde/nls_m_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,5 @@ def __init__(
self.add_equation("Maxwell_1", Maxwell_1)
self.add_equation("Maxwell_2", Maxwell_2)
self.add_equation("Bloch", Bloch)

self._apply_detach()
2 changes: 2 additions & 0 deletions ppsci/equation/pde/normal_dot_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,5 @@ def __init__(
normal_dot_vec += normal * vec

self.add_equation("normal_dot_vec", normal_dot_vec)

self._apply_detach()
2 changes: 2 additions & 0 deletions ppsci/equation/pde/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None):
poisson += p.diff(invar, 2)

self.add_equation("poisson", poisson)

self._apply_detach()
2 changes: 2 additions & 0 deletions ppsci/equation/pde/viv.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,5 @@ def __init__(self, rho: float, k1: float, k2: float):
k2 = self.create_symbols(self.k2.name)
f = self.rho * eta.diff(t_f, 2) + sp.exp(k1) * eta.diff(t_f) + sp.exp(k2) * eta
self.add_equation("f", f)

self._apply_detach()
20 changes: 13 additions & 7 deletions ppsci/utils/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

__all__ = [
"lambdify",
"_cvt_to_key",
]


Expand Down Expand Up @@ -116,14 +117,18 @@ def _cvt_to_key(expr: sp.Basic) -> str:
Returns:
str: Converted string key.
"""
if isinstance(expr, sp.Function) and str(expr.func) == equation.DETACH_FUNC_NAME:
return f"{_cvt_to_key(expr.args[0])}_{equation.DETACH_FUNC_NAME}"

if isinstance(expr, (sp.Symbol, sp.core.function.UndefinedFunction, sp.Function)):
# use name of custom function(e.g. "f") instead of itself(e.g. "f(x, y)")
# for simplicity.
if hasattr(expr, "name"):
# use name of custom function instead of itself.
return expr.name
else:
return str(expr)
elif isinstance(expr, sp.Derivative):
# convert Derivative(u(x,y),(x,2),(y,2)) to "u__x__x__y__y"
# convert "Derivative(u(x,y),(x,2),(y,2))" to "u__x__x__y__y"
expr_str = expr.args[0].name
for symbol, order in expr.args[1:]:
expr_str += f"__{symbol}" * order
Expand Down Expand Up @@ -813,12 +818,13 @@ def _expr_to_callable_nodes(
else:
callable_nodes.append(OperatorNode(node))
elif isinstance(node, sp.Function):
if node.name == equation.DETACH_FUNC_NAME:
if str(node.func) == equation.DETACH_FUNC_NAME:
callable_nodes.append(DetachNode(node))
logger.debug(f"Detected detach node {node}")
else:
match_index = None
for j, model in enumerate(models):
if str(node.func.name) in model.output_keys:
if str(node.func) in model.output_keys:
callable_nodes.append(
LayerNode(
node,
Expand All @@ -828,13 +834,13 @@ def _expr_to_callable_nodes(
if match_index is not None:
raise ValueError(
f"Name of function: '{node}' should be unique along given"
f" models, but got same output_key: '{node.func.name}' "
f" models, but got same output_key: '{str(node.func)}' "
f"in given models[{match_index}] and models[{j}]."
)
match_index = j
# NOTE: Skip 'sdf' function, which should be already generated in
# given data_dict
if match_index is None and node.name != "sdf":
if match_index is None and str(node.func) != "sdf":
raise ValueError(
f"Node {node} can not match any model in given model(s)."
)
Expand Down Expand Up @@ -925,7 +931,7 @@ def _expr_to_callable_nodes(
logger.debug(
f"Fused {len(candidate_pos)} derivatives nodes: "
f"{[callable_nodes_group[i][j].expr for i, j in candidate_pos]} into"
f" fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])"
f" {len(fused_node_seq)} fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])"
)

# mark merged node
Expand Down
Loading