Skip to content

Commit bbfadc0

Browse files
refine arch code
1 parent c108931 commit bbfadc0

File tree

3 files changed

+63
-34
lines changed

3 files changed

+63
-34
lines changed

examples/spinn/helmholtz3d.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,24 @@ def _helmholtz3d_exact_u(a1, a2, a3, x, y, z):
4141

4242

4343
def _helmholtz3d_source_term(a1, a2, a3, x, y, z, lda=1.0):
44-
u_gt = _helmholtz3d_exact_u(a1, a2, a3, x, y, z)
44+
u_gt = _helmholtz3d_exact_u(a1, a2, a3, x, y, z)[..., None]
4545
uxx = -((a1 * np.pi) ** 2) * u_gt
4646
uyy = -((a2 * np.pi) ** 2) * u_gt
4747
uzz = -((a3 * np.pi) ** 2) * u_gt
4848
return uxx + uyy + uzz + lda * u_gt
4949

5050

5151
def generate_train_helmholtz3d(a1, a2, a3, nc):
52-
xc = np.random.uniform(-1.0, 1.0, [nc]).astype(dtype)
53-
yc = np.random.uniform(-1.0, 1.0, [nc]).astype(dtype)
54-
zc = np.random.uniform(-1.0, 1.0, [nc]).astype(dtype)
52+
xc = np.random.uniform(-1.0, 1.0, [nc, 1]).astype(dtype)
53+
yc = np.random.uniform(-1.0, 1.0, [nc, 1]).astype(dtype)
54+
zc = np.random.uniform(-1.0, 1.0, [nc, 1]).astype(dtype)
5555
# source term
5656
xcm, ycm, zcm = np.meshgrid(xc, yc, zc, indexing="ij")
5757
uc = _helmholtz3d_source_term(a1, a2, a3, xcm, ycm, zcm).astype(dtype)
5858
# boundary (hard-coded)
5959
xb = [
60-
np.asarray([1.0], dtype=dtype),
61-
np.asarray([-1.0], dtype=dtype),
60+
np.asarray([[1.0]], dtype=dtype),
61+
np.asarray([[-1.0]], dtype=dtype),
6262
xc,
6363
xc,
6464
xc,
@@ -67,8 +67,8 @@ def generate_train_helmholtz3d(a1, a2, a3, nc):
6767
yb = [
6868
yc,
6969
yc,
70-
np.asarray([1.0], dtype=dtype),
71-
np.asarray([-1.0], dtype=dtype),
70+
np.asarray([[1.0]], dtype=dtype),
71+
np.asarray([[-1.0]], dtype=dtype),
7272
yc,
7373
yc,
7474
]
@@ -77,8 +77,8 @@ def generate_train_helmholtz3d(a1, a2, a3, nc):
7777
zc,
7878
zc,
7979
zc,
80-
np.asarray([1.0], dtype=dtype),
81-
np.asarray([-1.0], dtype=dtype),
80+
np.asarray([[1.0]], dtype=dtype),
81+
np.asarray([[-1.0]], dtype=dtype),
8282
]
8383
return xc, yc, zc, uc, xb, yb, zb
8484

@@ -88,7 +88,7 @@ def generate_test_helmholtz3d(a1, a2, a3, nc_test):
8888
y = np.linspace(-1.0, 1.0, nc_test, dtype=dtype)
8989
z = np.linspace(-1.0, 1.0, nc_test, dtype=dtype)
9090
xm, ym, zm = np.meshgrid(x, y, z, indexing="ij")
91-
u_gt = _helmholtz3d_exact_u(a1, a2, a3, xm, ym, zm).astype(dtype)
91+
u_gt = _helmholtz3d_exact_u(a1, a2, a3, xm, ym, zm).astype(dtype)[..., None]
9292
x = x.reshape(-1, 1)
9393
y = y.reshape(-1, 1)
9494
z = z.reshape(-1, 1)
@@ -120,7 +120,7 @@ def _gen(self):
120120
self.xc = xc
121121
self.yc = yc
122122
self.zc = zc
123-
self.uc = uc[..., np.newaxis]
123+
self.uc = uc
124124

125125
def __call__(self):
126126
self.iter += 1

ppsci/arch/spinn.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22

33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -15,10 +15,12 @@
1515
from __future__ import annotations
1616

1717
from typing import Dict
18+
from typing import List
1819
from typing import Optional
1920
from typing import Tuple
2021
from typing import Union
2122

23+
import paddle
2224
import paddle.nn as nn
2325

2426
from ppsci.arch import base
@@ -27,12 +29,12 @@
2729

2830

2931
class SPINN(base.Arch):
30-
"""
31-
SPINN: Sparse Interaction Neural Network
32+
"""Separable Physics-Informed Neural Networks.
33+
3234
Args:
3335
input_keys (Tuple[str, ...]): Keys of input variables.
3436
output_keys (Tuple[str, ...]): Keys of output variables.
35-
r (int): Number of features for each output.
37+
r (int): Number of features for each output dimension.
3638
num_layers (int): Number of layers.
3739
hidden_size (Union[int, Tuple[int, ...]]): Size of hidden layer.
3840
activation (str, optional): Name of activation function.
@@ -46,17 +48,19 @@ class SPINN(base.Arch):
4648
>>> from ppsci.arch import SPINN
4749
>>> model = SPINN(
4850
... input_keys=('x', 'y', 'z'),
49-
... output_keys=('u',),
50-
... r=316,
51+
... output_keys=('u', 'v'),
52+
... r=32,
5153
... num_layers=4,
5254
... hidden_size=32,
5355
... )
54-
>>> input_dict = {"x": paddle.rand([10]),
55-
... "y": paddle.rand([10]),
56-
... "z": paddle.rand([10])}
56+
>>> input_dict = {"x": paddle.rand([3, 1]),
57+
... "y": paddle.rand([4, 1]),
58+
... "z": paddle.rand([5, 1])}
5759
>>> output_dict = model(input_dict)
5860
>>> print(output_dict["u"].shape)
59-
[10, 10, 10, 1]
61+
[3, 4, 5, 1]
62+
>>> print(output_dict["v"].shape)
63+
[3, 4, 5, 1]
6064
"""
6165

6266
def __init__(
@@ -106,18 +110,45 @@ def _init_weights(self):
106110
initializer.glorot_normal_(m.weight)
107111
initializer.zeros_(m.bias)
108112

109-
def forward_tensor(self, x, y, z):
113+
def _tensor_contraction(self, x: paddle.Tensor, y: paddle.Tensor) -> paddle.Tensor:
114+
"""Tensor contraction between two tensors along the last channel.
115+
116+
Args:
117+
x (Tensor): Input tensor with shape [*N, C].
118+
y (Tensor): Input tensor with shape [*M, C]
119+
120+
Returns:
121+
Tensor: Output tensor with shape [*N, *M, C].
122+
"""
123+
x_ndim = x.ndim
124+
y_ndim = y.ndim
125+
out_dim = x_ndim + y_ndim - 1
126+
127+
# Align the dimensions of x and y to out_dim
128+
if x_ndim < out_dim:
129+
# Add singleton dimensions to x at the end of dimensions
130+
x = x.unsqueeze([-2] * (out_dim - x_ndim))
131+
if y_ndim < out_dim:
132+
# Add singleton dimensions to y at the begin of dimensions
133+
y = y.unsqueeze([0] * (out_dim - y_ndim))
134+
135+
# Multiply x and y with implicit broadcasting
136+
out = x * y
137+
138+
return out
139+
140+
def forward_tensor(self, x, y, z) -> List[paddle.Tensor]:
110141
# forward each dim branch
111142
feature_f = []
112143
for i, input_var in enumerate((x, y, z)):
113-
input_i = {self.input_keys[i]: input_var.unsqueeze(1)}
144+
input_i = {self.input_keys[i]: input_var}
114145
output_f_i = self.branch_nets[i](input_i)
115146
feature_f.append(output_f_i["f"]) # [B, r*output_dim]
116147

117-
# dot product and sum over all branch outputs and
118148
output = []
119149
for i, key in enumerate(self.output_keys):
120150
st, ed = i * self.r, (i + 1) * self.r
151+
# do tensor contraction and sum over all branch outputs
121152
if ed - st == self.r:
122153
output_i = feature_f[0]
123154
else:
@@ -128,20 +159,18 @@ def forward_tensor(self, x, y, z):
128159
output_ii = feature_f[j]
129160
else:
130161
output_ii = feature_f[j][:, st:ed]
131-
if j != len(self.input_keys) - 1:
132-
output_i = output_i.unsqueeze(1) * output_ii.unsqueeze(0)
133-
else:
134-
output_i = (
135-
output_i.unsqueeze(2) * output_ii.unsqueeze(0).unsqueeze(0)
136-
).sum(axis=-1, keepdim=True)
162+
output_i = self._tensor_contraction(output_i, output_ii)
163+
164+
output_i = output_i.sum(-1, keepdim=True)
137165
output.append(output_i)
138166

139-
return output[-1]
167+
return output
140168

141169
def forward(self, x):
142170
if self._input_transform is not None:
143171
x = self._input_transform(x)
144-
output = [self.forward_tensor(x["x"], x["y"], x["z"])]
172+
173+
output = self.forward_tensor(*[x[key] for key in self.input_keys])
145174

146175
output = {key: output[i] for i, key in enumerate(self.output_keys)}
147176

ppsci/equation/pde/helmholtz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def hvp_revrev(f: Callable, primals: Tuple[paddle.Tensor, ...]) -> paddle.Tensor
3838
# TODO: Merge this option into ppsci.autodiff.ad
3939
g = lambda primals: paddle.incubate.autograd.jvp(f, primals)[1]
4040
tangents_out = paddle.incubate.autograd.jvp(g, primals)[1]
41-
return tangents_out
41+
return tangents_out[0]
4242

4343

4444
class Helmholtz(base.PDE):

0 commit comments

Comments
 (0)