Skip to content

Commit 014218a

Browse files
zlynnalijialin03
andauthored
GCNGalerkin+PINO论文复现 (#566)
* add submodule as a git submodule * pre commit * change the md file * md file * review * delete useless file * delete useless file * change the files according to the review * change graphGalerkin * change graphGalerkin * graphGalerkin * update1 --------- Co-authored-by: lijialin03 <lijialin03@baidu.com>
1 parent 935e08b commit 014218a

19 files changed

+3080
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "pycamotk"]
2+
path = pycamotk
3+
url = https://github.com/zlynna/pycamotk.git
Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
import sys
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import paddle
6+
from scipy.io import loadmat
7+
8+
sys.path.insert(0, "pycamotk")
9+
from pyCaMOtk.create_dbc_strct import create_dbc_strct
10+
from pyCaMOtk.create_fem_resjac import create_fem_resjac
11+
from pyCaMOtk.create_femsp_cg import create_femsp_cg
12+
from pyCaMOtk.create_mesh_hcube import mesh_hcube
13+
from pyCaMOtk.geom_mltdim import Hypercube
14+
from pyCaMOtk.geom_mltdim import Simplex
15+
from pyCaMOtk.LinearElasticityHandCode import *
16+
from pyCaMOtk.mesh import Mesh
17+
from pyCaMOtk.mesh import get_gdof_from_bndtag
18+
from pyCaMOtk.solve_fem import solve_fem
19+
from pyCaMOtk.visualize_fem import visualize_fem
20+
21+
sys.path.insert(0, "source")
22+
import setup_prob_eqn_handcode
23+
import TensorFEMCore
24+
from GCNNModel import LinearElasticityNet2D
25+
from GCNNModel import e2vcg2connectivity
26+
from TensorFEMCore import Double
27+
from TensorFEMCore import ReshapeFix
28+
from TensorFEMCore import solve_fem_GCNN
29+
30+
sys.path.insert(0, "utils")
31+
from utils import Data
32+
33+
paddle.seed(0)
34+
35+
36+
class LinearElasticity:
37+
def __init__(self) -> None:
38+
# GCNN model
39+
self.model = LinearElasticityNet2D()
40+
41+
def train(
42+
self,
43+
Ufem,
44+
ndof,
45+
xcg,
46+
connectivity,
47+
LossF,
48+
tol,
49+
maxit,
50+
dbc,
51+
ndim,
52+
nnode,
53+
etype,
54+
e2vcg,
55+
e2bnd,
56+
):
57+
ii = 0
58+
Graph = []
59+
Ue = Double(Ufem.flatten().reshape(ndof, 1))
60+
fcn_id = Double(np.asarray([ii]))
61+
Ue_aug = paddle.concat((fcn_id, Ue), axis=0)
62+
xcg_gcnn = np.zeros((2, 2 * xcg.shape[1]))
63+
for i in range(xcg.shape[1]):
64+
xcg_gcnn[:, 2 * i] = xcg[:, i]
65+
xcg_gcnn[:, 2 * i + 1] = xcg[:, i]
66+
Uin = Double(xcg_gcnn.T)
67+
graph = Data(x=Uin, y=Ue_aug, edge_index=connectivity)
68+
Graph.append(graph)
69+
DataList = [[Graph[0]]]
70+
TrainDataloader = DataList
71+
[self.model, info] = solve_fem_GCNN(
72+
TrainDataloader, LossF, self.model, tol, maxit
73+
)
74+
np.save("modelCircleDet.npy", info)
75+
solution = self.model(Graph[0].to("cuda"))
76+
solution = ReshapeFix(paddle.clone(solution), [len(solution.flatten()), 1], "C")
77+
solution[dbc.dbc_idx] = Double(dbc.dbc_val.reshape([len(dbc.dbc_val), 1]))
78+
solution = solution.detach().cpu().numpy()
79+
xcg_defGCNN = xcg + np.reshape(solution, [ndim, nnode], order="F")
80+
msh_defGCNN = Mesh(etype, xcg_defGCNN, e2vcg, e2bnd, ndim)
81+
uabsGCNN = np.sqrt(
82+
solution[[i for i in range(ndof) if i % 2 == 0]] ** 2
83+
+ solution[[i for i in range(ndof) if i % 2 != 0]] ** 2
84+
)
85+
return msh_defGCNN, uabsGCNN
86+
87+
def plot_hard_way(self, msh_defGCNN, uabsGCNN, e2vcg, msh_def, uabs):
88+
fig = plt.figure()
89+
ax1 = plt.subplot(1, 2, 1)
90+
visualize_fem(
91+
ax1, msh_defGCNN, uabsGCNN[e2vcg], {"plot_elem": False, "nref": 1}, []
92+
)
93+
ax1.set_title("GCNN solution")
94+
ax2 = plt.subplot(1, 2, 2)
95+
visualize_fem(ax2, msh_def, uabs[e2vcg], {"plot_elem": False, "nref": 1}, [])
96+
ax2.set_title("FEM solution")
97+
fig.tight_layout(pad=3.0)
98+
plt.savefig("GCNN.pdf", bbox_inches="tight")
99+
100+
def plot_square(self, msh_defGCNN, uabsGCNN, e2vcg, msh_def, uabs):
101+
plt.figure()
102+
ax1 = plt.subplot(1, 1, 1)
103+
_, cbar1 = visualize_fem(
104+
ax1, msh_defGCNN, uabsGCNN[e2vcg], {"plot_elem": False, "nref": 4}, []
105+
)
106+
ax1.axis("off")
107+
cbar1.remove()
108+
plt.margins(0, 0)
109+
plt.savefig(
110+
"gcnn_2dlinearelasticity_square.png",
111+
bbox_inches="tight",
112+
pad_inches=0,
113+
dpi=800,
114+
)
115+
116+
plt.figure()
117+
ax2 = plt.subplot(1, 1, 1)
118+
_, cbar2 = visualize_fem(
119+
ax2, msh_def, uabs[e2vcg], {"plot_elem": False, "nref": 4}, []
120+
)
121+
ax2.axis("off")
122+
cbar2.remove()
123+
plt.margins(0, 0)
124+
plt.savefig(
125+
"fem_2dlinearelasticity_square.png",
126+
bbox_inches="tight",
127+
pad_inches=0,
128+
dpi=800,
129+
)
130+
131+
def hard_way(self):
132+
# FEM
133+
etype = "simplex"
134+
ndim = 2
135+
dat = loadmat("./msh/cylshk0a-simp-nref0p1.mat")
136+
xcg = dat["xcg"] / 10
137+
e2vcg = dat["e2vcg"] - 1
138+
e2bnd = dat["e2bnd"] - 1
139+
msh = Mesh(etype, xcg, e2vcg, e2bnd, ndim)
140+
xcg = msh.xcg
141+
e2vcg = msh.e2vcg
142+
e2bnd = msh.e2bnd
143+
porder = msh.porder
144+
[ndim, nnode] = xcg.shape
145+
nvar = ndim
146+
ndof = nnode * nvar
147+
148+
lam = lambda x, el: 1
149+
mu = lambda x, el: 1
150+
f = lambda x, el: np.zeros([ndim, 1])
151+
bnd2nbc = [0.0, 1.0, 2.0, 3.0, 4.0]
152+
tb = lambda x, n, bnd, el, fc: np.asarray([[2], [0]]) * (
153+
bnd == 2 or bnd == 2.0 or (bnd - 2) ** 2 < 1e-8
154+
) + np.asarray([[0], [0]])
155+
prob = setup_linelast_base_handcode(ndim, lam, mu, f, tb, bnd2nbc)
156+
# Create finite element space
157+
femsp = create_femsp_cg(prob, msh, porder, e2vcg, porder, e2vcg)
158+
ldof2gdof = femsp.ldof2gdof_var.ldof2gdof
159+
geo = Simplex(ndim, porder)
160+
f2v = geo.f2n
161+
dbc_idx = get_gdof_from_bndtag(
162+
[i for i in range(ndim)], [0], nvar, ldof2gdof, e2bnd, f2v
163+
)
164+
dbc_idx.sort()
165+
dbc_idx = np.asarray(dbc_idx)
166+
dbc_val = 0 * dbc_idx
167+
dbc = create_dbc_strct(ndof, dbc_idx, dbc_val)
168+
femsp.dbc = dbc
169+
tol = 1.0e-8
170+
maxit = 100000
171+
[Ufem, info] = solve_fem(
172+
"cg",
173+
msh.transfdatacontiguous,
174+
femsp.elem,
175+
femsp.elem_data,
176+
femsp.ldof2gdof_eqn.ldof2gdof,
177+
femsp.ldof2gdof_var.ldof2gdof,
178+
msh.e2e,
179+
femsp.spmat,
180+
dbc,
181+
None,
182+
tol,
183+
maxit,
184+
)
185+
186+
xcg_def = xcg + np.reshape(Ufem, [ndim, nnode], order="F")
187+
msh_def = Mesh(etype, xcg_def, e2vcg, e2bnd, ndim)
188+
uabs = np.sqrt(
189+
Ufem[[i for i in range(ndof) if i % 2 == 0]] ** 2
190+
+ Ufem[[i for i in range(ndof) if i % 2 != 0]] ** 2
191+
)
192+
fig = plt.figure()
193+
ax1 = plt.subplot(1, 1, 1)
194+
visualize_fem(ax1, msh_def, uabs[e2vcg], {"plot_elem": False, "nref": 1}, [])
195+
ax1.set_title("FEM solution")
196+
fig.tight_layout(pad=3.0)
197+
198+
idx_xcg = [
199+
i
200+
for i in range(xcg.shape[1])
201+
if 2 * i not in dbc_idx and 2 * i + 1 not in dbc_idx
202+
]
203+
204+
obsidx = np.asarray([5, 11, 26, 32, 38]) # max is 9
205+
206+
idx_whole = []
207+
for i in obsidx:
208+
idx_whole.append(2 * i)
209+
idx_whole.append(2 * i + 1)
210+
obsxcg = msh_def.xcg[:, obsidx]
211+
ax1.plot(obsxcg[0, :], obsxcg[1, :], "o")
212+
213+
dbc_idx_new = np.hstack((dbc_idx, idx_whole))
214+
dbc_val_new = Ufem[dbc_idx_new]
215+
dbc = create_dbc_strct(msh.xcg.shape[1] * nvar, dbc_idx_new, dbc_val_new)
216+
217+
Src_new = self.model.source
218+
K_new = paddle.to_tensor([[0], [0]], dtype="float32").reshape((2,))
219+
parsfuncI = lambda x: paddle.concat((Src_new[0:1], Src_new[1:2], K_new), axis=0)
220+
# GCNN
221+
connectivity = e2vcg2connectivity(e2vcg, "ele")
222+
prob = setup_prob_eqn_handcode.setup_linelast_base_handcode(
223+
ndim, lam, mu, f, tb, bnd2nbc
224+
)
225+
femsp_gcnn = create_femsp_cg(prob, msh, porder, e2vcg, porder, e2vcg, dbc)
226+
LossF = []
227+
fcn = lambda u_: TensorFEMCore.create_fem_resjac(
228+
"cg",
229+
u_,
230+
msh.transfdatacontiguous,
231+
femsp_gcnn.elem,
232+
femsp_gcnn.elem_data,
233+
femsp_gcnn.ldof2gdof_eqn.ldof2gdof,
234+
femsp_gcnn.ldof2gdof_var.ldof2gdof,
235+
msh.e2e,
236+
femsp_gcnn.spmat,
237+
dbc,
238+
[i for i in range(ndof) if i not in dbc_idx],
239+
parsfuncI,
240+
None,
241+
)
242+
LossF.append(fcn)
243+
msh_defGCNN, uabsGCNN = self.train(
244+
Ufem,
245+
ndof,
246+
xcg,
247+
connectivity,
248+
LossF,
249+
tol,
250+
maxit,
251+
dbc,
252+
ndim,
253+
nnode,
254+
etype,
255+
e2vcg,
256+
e2bnd,
257+
)
258+
self.plot_hard_way(msh_defGCNN, uabsGCNN, e2vcg, msh_def, uabs)
259+
260+
def main_square(self):
261+
# FEM
262+
nvar = 2
263+
etype = "hcube"
264+
lims = np.asarray([[0, 1], [0, 1]])
265+
nel = [2, 2]
266+
porder = 2
267+
nf = 4
268+
msh = mesh_hcube(etype, lims, nel, porder).getmsh()
269+
xcg = msh.xcg
270+
e2vcg = msh.e2vcg
271+
e2bnd = msh.e2bnd
272+
porder = msh.porder
273+
[ndim, nnode] = xcg.shape
274+
nvar = ndim
275+
ndof = nnode * nvar
276+
277+
lam = lambda x, el: 1
278+
mu = lambda x, el: 1
279+
f = lambda x, el: np.zeros([ndim, 1])
280+
bnd2nbc = np.asarray([0, 1, 2, 3])
281+
tb = lambda x, n, bnd, el, fc: np.asarray([[0.5], [0]]) * (
282+
(bnd - 2) ** 2 < 1e-8
283+
) + np.asarray([[0], [0]])
284+
prob = setup_linelast_base_handcode(ndim, lam, mu, f, tb, bnd2nbc)
285+
# Create finite element space
286+
femsp = create_femsp_cg(prob, msh, porder, e2vcg, porder, e2vcg)
287+
ldof2gdof = femsp.ldof2gdof_var.ldof2gdof
288+
geo = Hypercube(ndim, porder)
289+
f2v = geo.f2n
290+
dbc_idx = get_gdof_from_bndtag(
291+
[i for i in range(ndim)], [0], nvar, ldof2gdof, e2bnd, f2v
292+
)
293+
dbc_idx.sort()
294+
dbc_idx = np.asarray(dbc_idx)
295+
dbc_val = 0 * dbc_idx
296+
dbc = create_dbc_strct(ndof, dbc_idx, dbc_val)
297+
femsp.dbc = dbc
298+
tol = 1.0e-8
299+
maxit = 4500
300+
301+
[Ufem, info] = solve_fem(
302+
"cg",
303+
msh.transfdatacontiguous,
304+
femsp.elem,
305+
femsp.elem_data,
306+
femsp.ldof2gdof_eqn.ldof2gdof,
307+
femsp.ldof2gdof_var.ldof2gdof,
308+
msh.e2e,
309+
femsp.spmat,
310+
dbc,
311+
None,
312+
tol,
313+
maxit,
314+
)
315+
316+
xcg_def = xcg + np.reshape(Ufem, [ndim, nnode], order="F")
317+
msh_def = Mesh(etype, xcg_def, e2vcg, e2bnd, ndim)
318+
uabs = np.sqrt(
319+
Ufem[[i for i in range(ndof) if i % 2 == 0]] ** 2
320+
+ Ufem[[i for i in range(ndof) if i % 2 != 0]] ** 2
321+
)
322+
# GCNN
323+
connectivity = e2vcg2connectivity(e2vcg, "ele")
324+
prob = setup_prob_eqn_handcode.setup_linelast_base_handcode(
325+
ndim, lam, mu, f, tb, bnd2nbc
326+
)
327+
femsp_gcnn = create_femsp_cg(prob, msh, porder, e2vcg, porder, e2vcg, dbc)
328+
LossF = []
329+
fcn = lambda u_: TensorFEMCore.create_fem_resjac(
330+
"cg",
331+
u_,
332+
msh.transfdatacontiguous,
333+
femsp_gcnn.elem,
334+
femsp_gcnn.elem_data,
335+
femsp_gcnn.ldof2gdof_eqn.ldof2gdof,
336+
femsp_gcnn.ldof2gdof_var.ldof2gdof,
337+
msh.e2e,
338+
femsp_gcnn.spmat,
339+
dbc,
340+
)
341+
fcn_fem = lambda u_: create_fem_resjac(
342+
"cg",
343+
u_,
344+
msh.transfdatacontiguous,
345+
femsp.elem,
346+
femsp.elem_data,
347+
femsp.ldof2gdof_eqn.ldof2gdof,
348+
femsp.ldof2gdof_var.ldof2gdof,
349+
msh.e2e,
350+
femsp.spmat,
351+
dbc,
352+
)
353+
LossF.append(fcn)
354+
msh_defGCNN, uabsGCNN = self.train(
355+
Ufem,
356+
ndof,
357+
xcg,
358+
connectivity,
359+
LossF,
360+
tol,
361+
maxit,
362+
dbc,
363+
ndim,
364+
nnode,
365+
etype,
366+
e2vcg,
367+
e2bnd,
368+
)
369+
self.plot_square(msh_defGCNN, uabsGCNN, e2vcg, msh_def, uabs)
370+
371+
372+
if __name__ == "__main__":
373+
le_obj = LinearElasticity()
374+
le_obj.hard_way()
375+
le_obj.main_square()

0 commit comments

Comments
 (0)