-
Notifications
You must be signed in to change notification settings - Fork 184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Replace python funcion with sympy expression #507
Changes from 1 commit
2d95737
f8bb795
b31a785
d9d30d1
6877901
2ccc9f5
bb9d71c
8783ec9
6f376b6
cfa853f
62e9855
fee4553
1b7642e
1538f0d
36e48e5
d32fd84
4e61fa5
4baf466
ba2a5c2
f057eec
ffd6b27
0809194
b0d1df1
2fa8a8b
b2a4509
13d6ff6
ffd3a93
256f31a
3a870d9
08e92d3
c2373ca
051dcd6
2eacfea
c7eea1b
615fefb
e096ea2
cb16107
899a2d2
640c080
16aae52
9b46410
f179ec1
2e37bca
6ecb31b
0a88a08
18a1638
8ccf858
978cdfc
a319e99
be67fb4
eecc5e9
cb7f777
59bc990
c0228f0
408e378
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
from __future__ import annotations | ||
|
||
from typing import Optional | ||
from typing import Tuple | ||
|
||
import sympy as sp | ||
|
||
|
@@ -61,8 +62,10 @@ def __init__( | |
rho: float = 1, | ||
dim: int = 3, | ||
time: bool = False, | ||
detach_keys: Optional[Tuple[str, ...]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 更新docstring There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
): | ||
super().__init__() | ||
self.detach_keys = detach_keys | ||
self.dim = dim | ||
self.time = time | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,10 +15,10 @@ | |
from __future__ import annotations | ||
|
||
from typing import Callable | ||
from typing import Optional | ||
from typing import Tuple | ||
from typing import Union | ||
|
||
import sympy as sp | ||
|
||
from ppsci.equation.pde import base | ||
|
||
|
||
|
@@ -63,19 +63,37 @@ class NavierStokes(base.PDE): | |
>>> pde = ppsci.equation.NavierStokes(0.1, 1.0, 3, False) | ||
""" | ||
|
||
def __init__(self, nu: Union[float, Callable], rho: float, dim: int, time: bool): | ||
def __init__( | ||
self, | ||
nu: Union[float, Callable], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后面代码中 nu 允许类型为 str, 类型提示需要修改下。其他变量也检查下有没有类似情况 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
rho: float, | ||
dim: int, | ||
time: bool, | ||
detach_keys: Optional[Tuple[str, ...]] = None, | ||
): | ||
super().__init__() | ||
self.detach_keys = detach_keys | ||
t, x, y, z = self.create_symbols("t x y z") | ||
invars = (x, y) | ||
if time: | ||
invars = (t,) + invars | ||
if dim == 3: | ||
invars += (z,) | ||
|
||
self.nu = nu | ||
self.rho = rho | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这两行代码中用的的 nu, rho 在后续代码中有修改,看是否需要把这两行放到修改后面 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
self.dim = dim | ||
self.time = time | ||
|
||
t, x, y, z = self.create_symbols("t x y z") | ||
u, v, w, p = self.create_symbols("u v w p") | ||
if self.dim == 2: | ||
w = sp.Number(0) | ||
if not time: | ||
t = sp.Number(0) | ||
if isinstance(nu, str): | ||
nu = self.create_function(nu, invars) | ||
if isinstance(rho, str): | ||
rho = self.create_function(rho, invars) | ||
|
||
u = self.create_function("u", invars) | ||
v = self.create_function("v", invars) | ||
w = self.create_function("w", invars) | ||
p = self.create_function("p", invars) | ||
|
||
continuity = u.diff(x) + v.diff(y) + w.diff(z) | ||
momentum_x = ( | ||
|
@@ -105,4 +123,5 @@ def __init__(self, nu: Union[float, Callable], rho: float, dim: int, time: bool) | |
self.add_equation("continuity", continuity) | ||
self.add_equation("momentum_x", momentum_x) | ||
self.add_equation("momentum_y", momentum_y) | ||
self.add_equation("momentum_z", momentum_z) | ||
if self.dim == 3: | ||
self.add_equation("momentum_z", momentum_z) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
更新docstring, 检查类型提示
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改