Skip to content

Commit 9af1fa5

Browse files
committed
aaa
1 parent 9f1470e commit 9af1fa5

File tree

6 files changed

+812
-101
lines changed

6 files changed

+812
-101
lines changed

examples/cylinder/3d_unsteady_continuous/load_lbm_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@
159159
import copy
160160
from pyevtk.hl import pointsToVTK
161161

162-
dir_ic = '/workspace/hesensen/PaddleScience_dev_3d/examples/cylinder/3d_unsteady_continuous/data/ic_data/'
162+
dir_ic = '/workspace/hesensen/PaddleScience_cqp/examples/cylinder/3d_unsteady_continuous/data/ic_data/'
163163
# dir_sp = '/home/aistudio/work/data/supervised_data/'
164-
dir_sp_mode1 = '/workspace/hesensen/PaddleScience_dev_3d/examples/cylinder/3d_unsteady_continuous/data/new_sp_data/'
164+
dir_sp_mode1 = '/workspace/hesensen/PaddleScience_cqp/examples/cylinder/3d_unsteady_continuous/data/new_sp_data/'
165165

166166
#file_pattern = 'point_70.000000_9.000000_27.000000.dat'
167167

examples/cylinder/3d_unsteady_continuous/test_3d_new.py

Lines changed: 206 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,14 @@
196196
# limitations under the License.
197197

198198
import copy
199+
from abc import abstractmethod
200+
from typing import Union
199201

200202
import numpy as np
201203
import paddle
202204
import paddle.distributed as dist
203205
import paddlescience as psci
206+
from paddle.optimizer import lr
204207
from pyevtk.hl import pointsToVTK
205208

206209
import sample_boundary_training_data as sample_data
@@ -214,7 +217,7 @@
214217
# time arraep
215218
ic_t = 200000
216219
t_start = 200050
217-
t_end = 200250
220+
t_end = 200100
218221
t_step = 50
219222
time_num = int((t_end - t_start) / t_step) + 1
220223
time_tmp = np.linspace(t_start - ic_t, t_end - ic_t, time_num, endpoint=True)
@@ -229,76 +232,76 @@
229232
cylinder_wgt = 5.0
230233
top_wgt = 2.0
231234
bottom_wgt = 2.0
232-
eq_wgt= 10.0
235+
eq_wgt= 2.0
233236
ic_wgt = 5.0
234237
sup_wgt = 10.0
235238

236239
# initial value
237240
txyz_uvwpe_ic = load_ic_data(ic_t)
238-
init_t = txyz_uvwpe_ic[:, 0]
239-
init_x = txyz_uvwpe_ic[:, 1]
240-
init_y = txyz_uvwpe_ic[:, 2]
241-
init_z = txyz_uvwpe_ic[:, 3]
242-
init_u = txyz_uvwpe_ic[:, 4]
243-
init_v = txyz_uvwpe_ic[:, 5]
244-
init_w = txyz_uvwpe_ic[:, 6]
245-
init_p = txyz_uvwpe_ic[:, 7]
241+
init_t = txyz_uvwpe_ic[:, 0]; print(f"init_t={init_t.shape} {init_t.mean().item():.10f}")
242+
init_x = txyz_uvwpe_ic[:, 1]; print(f"init_x={init_x.shape} {init_x.mean().item():.10f}")
243+
init_y = txyz_uvwpe_ic[:, 2]; print(f"init_y={init_y.shape} {init_y.mean().item():.10f}")
244+
init_z = txyz_uvwpe_ic[:, 3]; print(f"init_z={init_z.shape} {init_z.mean().item():.10f}")
245+
init_u = txyz_uvwpe_ic[:, 4]; print(f"init_u={init_u.shape} {init_u.mean().item():.10f}")
246+
init_v = txyz_uvwpe_ic[:, 5]; print(f"init_v={init_v.shape} {init_v.mean().item():.10f}")
247+
init_w = txyz_uvwpe_ic[:, 6]; print(f"init_w={init_w.shape} {init_w.mean().item():.10f}")
248+
init_p = txyz_uvwpe_ic[:, 7]; print(f"init_p={init_p.shape} {init_p.mean().item():.10f}")
246249

247250
# num of supervised points
248251
n_sup = 2000
249252

250253
# supervised data
251254
txyz_uvwpe_s = load_supervised_data(t_start, t_end, t_step, ic_t, n_sup)
252-
sup_t = txyz_uvwpe_s[:, 0]
253-
sup_x = txyz_uvwpe_s[:, 1]
254-
sup_y = txyz_uvwpe_s[:, 2]
255-
sup_z = txyz_uvwpe_s[:, 3]
256-
sup_u = txyz_uvwpe_s[:, 4]
257-
sup_v = txyz_uvwpe_s[:, 5]
258-
sup_w = txyz_uvwpe_s[:, 6]
259-
sup_p = txyz_uvwpe_s[:, 7]
255+
sup_t = txyz_uvwpe_s[:, 0]; print(f"sup_t={sup_t.shape} {sup_t.mean().item():.10f}")
256+
sup_x = txyz_uvwpe_s[:, 1]; print(f"sup_x={sup_x.shape} {sup_x.mean().item():.10f}")
257+
sup_y = txyz_uvwpe_s[:, 2]; print(f"sup_y={sup_y.shape} {sup_y.mean().item():.10f}")
258+
sup_z = txyz_uvwpe_s[:, 3]; print(f"sup_z={sup_z.shape} {sup_z.mean().item():.10f}")
259+
sup_u = txyz_uvwpe_s[:, 4]; print(f"sup_u={sup_u.shape} {sup_u.mean().item():.10f}")
260+
sup_v = txyz_uvwpe_s[:, 5]; print(f"sup_v={sup_v.shape} {sup_v.mean().item():.10f}")
261+
sup_w = txyz_uvwpe_s[:, 6]; print(f"sup_w={sup_w.shape} {sup_w.mean().item():.10f}")
262+
sup_p = txyz_uvwpe_s[:, 7]; print(f"sup_p={sup_p.shape} {sup_p.mean().item():.10f}")
260263

261264
# num points to sample per GPU
262-
# num_points = 30000
263-
num_points = 15000
265+
num_points = 30000
266+
# num_points = 15000
264267
# discretize node by geo
265268
inlet_txyz, outlet_txyz, top_txyz, bottom_txyz, cylinder_txyz, interior_txyz = sample_data.sample_data(t_step=time_num, nr_points=num_points)
266269

267270
# interior nodes discre
268-
i_t = interior_txyz[:, 0]
269-
i_x = interior_txyz[:, 1]
270-
i_y = interior_txyz[:, 2]
271-
i_z = interior_txyz[:, 3]
271+
i_t = interior_txyz[:, 0]; print(f"i_t={i_t.shape} {i_t.mean().item():.10f}")
272+
i_x = interior_txyz[:, 1]; print(f"i_x={i_x.shape} {i_x.mean().item():.10f}")
273+
i_y = interior_txyz[:, 2]; print(f"i_y={i_y.shape} {i_y.mean().item():.10f}")
274+
i_z = interior_txyz[:, 3]; print(f"i_z={i_z.shape} {i_z.mean().item():.10f}")
272275

273276
# bc inlet nodes discre
274-
b_inlet_t = inlet_txyz[:, 0]
275-
b_inlet_x = inlet_txyz[:, 1]
276-
b_inlet_y = inlet_txyz[:, 2]
277-
b_inlet_z = inlet_txyz[:, 3]
277+
b_inlet_t = inlet_txyz[:, 0]; print(f"b_inlet_t={b_inlet_t.shape} {b_inlet_t.mean().item():.10f}")
278+
b_inlet_x = inlet_txyz[:, 1]; print(f"b_inlet_x={b_inlet_x.shape} {b_inlet_x.mean().item():.10f}")
279+
b_inlet_y = inlet_txyz[:, 2]; print(f"b_inlet_y={b_inlet_y.shape} {b_inlet_y.mean().item():.10f}")
280+
b_inlet_z = inlet_txyz[:, 3]; print(f"b_inlet_z={b_inlet_z.shape} {b_inlet_z.mean().item():.10f}")
278281

279282
# bc outlet nodes discre
280-
b_outlet_t = outlet_txyz[:, 0]
281-
b_outlet_x = outlet_txyz[:, 1]
282-
b_outlet_y = outlet_txyz[:, 2]
283-
b_outlet_z = outlet_txyz[:, 3]
283+
b_outlet_t = outlet_txyz[:, 0]; print(f"b_outlet_t={b_outlet_t.shape} {b_outlet_t.mean().item():.10f}")
284+
b_outlet_x = outlet_txyz[:, 1]; print(f"b_outlet_x={b_outlet_x.shape} {b_outlet_x.mean().item():.10f}")
285+
b_outlet_y = outlet_txyz[:, 2]; print(f"b_outlet_y={b_outlet_y.shape} {b_outlet_y.mean().item():.10f}")
286+
b_outlet_z = outlet_txyz[:, 3]; print(f"b_outlet_z={b_outlet_z.shape} {b_outlet_z.mean().item():.10f}")
284287

285288
# bc cylinder nodes discre
286-
b_cylinder_t = cylinder_txyz[:, 0]
287-
b_cylinder_x = cylinder_txyz[:, 1]
288-
b_cylinder_y = cylinder_txyz[:, 2]
289-
b_cylinder_z = cylinder_txyz[:, 3]
289+
b_cylinder_t = cylinder_txyz[:, 0]; print(f"b_cylinder_t={b_cylinder_t.shape} {b_cylinder_t.mean().item():.10f}")
290+
b_cylinder_x = cylinder_txyz[:, 1]; print(f"b_cylinder_x={b_cylinder_x.shape} {b_cylinder_x.mean().item():.10f}")
291+
b_cylinder_y = cylinder_txyz[:, 2]; print(f"b_cylinder_y={b_cylinder_y.shape} {b_cylinder_y.mean().item():.10f}")
292+
b_cylinder_z = cylinder_txyz[:, 3]; print(f"b_cylinder_z={b_cylinder_z.shape} {b_cylinder_z.mean().item():.10f}")
290293

291294
# bc-top nodes discre
292-
b_top_t = top_txyz[:, 0] # value = [1, 2, 3, 4, 5]
293-
b_top_x = top_txyz[:, 1]
294-
b_top_y = top_txyz[:, 2]
295-
b_top_z = top_txyz[:, 3]
296-
295+
b_top_t = top_txyz[:, 0]; print(f"b_top_t={b_top_t.shape} {b_top_t.mean().item():.10f}") # value = [1, 2, 3, 4, 5]
296+
b_top_x = top_txyz[:, 1]; print(f"b_top_x={b_top_x.shape} {b_top_x.mean().item():.10f}")
297+
b_top_y = top_txyz[:, 2]; print(f"b_top_y={b_top_y.shape} {b_top_y.mean().item():.10f}")
298+
b_top_z = top_txyz[:, 3]; print(f"b_top_z={b_top_z.shape} {b_top_z.mean().item():.10f}")
299+
297300
# bc-bottom nodes discre
298-
b_bottom_t = bottom_txyz[:, 0] # value = [1, 2, 3, 4, 5]
299-
b_bottom_x = bottom_txyz[:, 1]
300-
b_bottom_y = bottom_txyz[:, 2]
301-
b_bottom_z = bottom_txyz[:, 3]
301+
b_bottom_t = bottom_txyz[:, 0]; print(f"b_bottom_t={b_bottom_t.shape} {b_bottom_t.mean().item():.10f}") # value = [1, 2, 3, 4, 5]
302+
b_bottom_x = bottom_txyz[:, 1]; print(f"b_bottom_x={b_bottom_x.shape} {b_bottom_x.mean().item():.10f}")
303+
b_bottom_y = bottom_txyz[:, 2]; print(f"b_bottom_y={b_bottom_y.shape} {b_bottom_y.mean().item():.10f}")
304+
b_bottom_z = bottom_txyz[:, 3]; print(f"b_bottom_z={b_bottom_z.shape} {b_bottom_z.mean().item():.10f}")
302305

303306
# bc & interior nodes for nn
304307
inputeq = np.stack((i_t, i_x, i_y, i_z), axis=1)
@@ -316,17 +319,17 @@
316319
pde = psci.pde.NavierStokes(nu=0.0205, rho=1.0, dim=3, time_dependent=True)
317320

318321
# set bounday condition
319-
bc_inlet_u = psci.bc.Dirichlet("u", rhs=1)
320-
bc_inlet_v = psci.bc.Dirichlet("v", rhs=0)
321-
bc_inlet_w = psci.bc.Dirichlet("w", rhs=0)
322-
bc_cylinder_u = psci.bc.Dirichlet("u", rhs=0)
323-
bc_cylinder_v = psci.bc.Dirichlet("v", rhs=0)
324-
bc_cylinder_w = psci.bc.Dirichlet("w", rhs=0)
325-
bc_outlet_p = psci.bc.Dirichlet("p", rhs=0)
326-
bc_top_u = psci.bc.Dirichlet("u", rhs=0)
327-
bc_top_v = psci.bc.Dirichlet("v", rhs=0)
328-
bc_bottom_u = psci.bc.Dirichlet("u", rhs=0)
329-
bc_bottom_v = psci.bc.Dirichlet("v", rhs=0)
322+
bc_inlet_u = psci.bc.Dirichlet("u", rhs=1.0)
323+
bc_inlet_v = psci.bc.Dirichlet("v", rhs=0.0)
324+
bc_inlet_w = psci.bc.Dirichlet("w", rhs=0.0)
325+
bc_cylinder_u = psci.bc.Dirichlet("u", rhs=0.0)
326+
bc_cylinder_v = psci.bc.Dirichlet("v", rhs=0.0)
327+
bc_cylinder_w = psci.bc.Dirichlet("w", rhs=0.0)
328+
bc_outlet_p = psci.bc.Dirichlet("p", rhs=0.0)
329+
bc_top_u = psci.bc.Dirichlet("u", rhs=0.0)
330+
bc_top_v = psci.bc.Dirichlet("v", rhs=0.0)
331+
bc_bottom_u = psci.bc.Dirichlet("u", rhs=0.0)
332+
bc_bottom_v = psci.bc.Dirichlet("v", rhs=0.0)
330333

331334
# add bounday and boundary condition
332335
pde.set_bc("inlet", bc_inlet_u, bc_inlet_v, bc_inlet_w)
@@ -345,8 +348,6 @@
345348
# Network
346349
net = psci.network.FCNet(
347350
num_ins=4, num_outs=4, num_layers=6, hidden_size=50, activation="tanh")
348-
# net.initialize("checkpoint/static_model_params_10000.pdparams")
349-
# net.initialize("checkpoint/dynamic_net_params_100000.pdparams")
350351

351352
outeq = net(inputeq)
352353
outbc1 = net(inputbc1)
@@ -391,25 +392,161 @@
391392
algo = psci.algorithm.PINNs(net=net, loss=loss)
392393

393394
# Optimizer
394-
opt = psci.optimizer.Adam(learning_rate=0.01, parameters=net.parameters())
395+
class LRBase(object):
396+
"""Base class for custom learning rates
397+
398+
Args:
399+
epochs (int): total epoch(s)
400+
step_each_epoch (int): number of iterations within an epoch
401+
learning_rate (float): learning rate
402+
warmup_epoch (int): number of warmup epoch(s)
403+
warmup_start_lr (float): start learning rate within warmup
404+
last_epoch (int): last epoch
405+
by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
406+
verbose (bool): If True, prints a message to stdout for each update. Defaults to False
407+
"""
408+
409+
def __init__(self,
410+
epochs: int,
411+
step_each_epoch: int,
412+
learning_rate: float,
413+
warmup_epoch: int,
414+
warmup_start_lr: float,
415+
last_epoch: int,
416+
by_epoch: bool,
417+
verbose: bool=False) -> None:
418+
"""Initialize and record the necessary parameters
419+
"""
420+
super(LRBase, self).__init__()
421+
if warmup_epoch >= epochs:
422+
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
423+
print(msg)
424+
warmup_epoch = epochs
425+
self.epochs = epochs
426+
self.step_each_epoch = step_each_epoch
427+
self.learning_rate = learning_rate
428+
self.warmup_epoch = warmup_epoch
429+
self.warmup_steps = self.warmup_epoch if by_epoch else round(
430+
self.warmup_epoch * self.step_each_epoch)
431+
self.warmup_start_lr = warmup_start_lr
432+
self.last_epoch = last_epoch
433+
self.by_epoch = by_epoch
434+
self.verbose = verbose
435+
436+
@abstractmethod
437+
def __call__(self, *kargs, **kwargs) -> lr.LRScheduler:
438+
"""generate an learning rate scheduler
439+
440+
Returns:
441+
lr.LinearWarmup: learning rate scheduler
442+
"""
443+
pass
444+
445+
def linear_warmup(
446+
self,
447+
learning_rate: Union[float, lr.LRScheduler]) -> lr.LinearWarmup:
448+
"""Add an Linear Warmup before learning_rate
449+
450+
Args:
451+
learning_rate (Union[float, lr.LRScheduler]): original learning rate without warmup
452+
453+
Returns:
454+
lr.LinearWarmup: learning rate scheduler with warmup
455+
"""
456+
warmup_lr = lr.LinearWarmup(
457+
learning_rate=learning_rate,
458+
warmup_steps=self.warmup_steps,
459+
start_lr=self.warmup_start_lr,
460+
end_lr=self.learning_rate,
461+
last_epoch=self.last_epoch,
462+
verbose=self.verbose)
463+
return warmup_lr
464+
465+
466+
class Constant(lr.LRScheduler):
467+
"""Constant learning rate Class implementation
468+
469+
Args:
470+
learning_rate (float): The initial learning rate
471+
last_epoch (int, optional): The index of last epoch. Default: -1.
472+
"""
473+
474+
def __init__(self, learning_rate, last_epoch=-1, **kwargs):
475+
self.learning_rate = learning_rate
476+
self.last_epoch = last_epoch
477+
super(Constant, self).__init__()
478+
479+
def get_lr(self) -> float:
480+
"""always return the same learning rate
481+
"""
482+
return self.learning_rate
483+
484+
485+
class Cosine(LRBase):
486+
"""Cosine learning rate decay
487+
488+
``lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)``
489+
490+
Args:
491+
epochs (int): total epoch(s)
492+
step_each_epoch (int): number of iterations within an epoch
493+
learning_rate (float): learning rate
494+
eta_min (float, optional): Minimum learning rate. Defaults to 0.0.
495+
warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
496+
warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
497+
last_epoch (int, optional): last epoch. Defaults to -1.
498+
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
499+
"""
500+
501+
def __init__(self,
502+
epochs,
503+
step_each_epoch,
504+
learning_rate,
505+
eta_min=0.0,
506+
warmup_epoch=0,
507+
warmup_start_lr=0.0,
508+
last_epoch=-1,
509+
by_epoch=False,
510+
**kwargs):
511+
super(Cosine, self).__init__(epochs, step_each_epoch, learning_rate,
512+
warmup_epoch, warmup_start_lr, last_epoch,
513+
by_epoch)
514+
self.T_max = (self.epochs - self.warmup_epoch) * self.step_each_epoch
515+
self.eta_min = eta_min
516+
if self.by_epoch:
517+
self.T_max = self.epochs - self.warmup_epoch
518+
519+
def __call__(self):
520+
learning_rate = lr.CosineAnnealingDecay(
521+
learning_rate=self.learning_rate,
522+
T_max=self.T_max,
523+
eta_min=self.eta_min,
524+
last_epoch=self.last_epoch) if self.T_max > 0 else Constant(
525+
self.learning_rate)
526+
527+
if self.warmup_steps > 0:
528+
learning_rate = self.linear_warmup(learning_rate)
529+
530+
setattr(learning_rate, "by_epoch", self.by_epoch)
531+
return learning_rate
532+
533+
num_epoch = 100000
534+
_lr = Cosine(num_epoch, 1, 0.001, warmup_epoch=5000, by_epoch=True)()
535+
# _lr = 0.001
536+
opt = psci.optimizer.Adam(learning_rate=_lr, parameters=net.parameters())
395537

396538
# Solver
397539
solver = psci.solver.Solver(pde=pde, algo=algo, opt=opt)
398540

399541
# Solve
400-
solution = solver.solve(num_epoch=100000)
401-
# solution = solver.predict()
402-
# print(type(solution), len(solution), solution[0].shape)
403-
# exit()
404-
405-
# solution = [
406-
# np.stack([init_u, init_v, init_w, init_p], axis=1)
407-
# ]
542+
solution = solver.solve(num_epoch=num_epoch)
543+
544+
# print shape of every subset points' output
408545
for idx, si in enumerate(solution):
409546
print(f"solution[{idx}].shape = {si.shape}")
410547

548+
# only coord at start time is needed
411549
n = int(i_x.shape[0] / len(time_array))
412-
# n = 1
413550
i_x = i_x.astype("float32")
414551
i_y = i_y.astype("float32")
415552
i_z = i_z.astype("float32")
@@ -420,6 +557,5 @@
420557
i_z = i_z * 320
421558

422559
cord = np.stack((i_x[0:n], i_y[0:n], i_z[0:n]), axis=1)
423-
# cord = np.stack((i_x, i_y, i_z), axis=1)
424560
# psci.visu.__save_vtk_raw(cordinate=cord, data=solution[0][-n::])
425-
psci.visu.save_vtk_cord(filename="./vtk/output_2023_1_12", time_array=time_array, cord=cord, data=solution)
561+
psci.visu.save_vtk_cord(filename="./vtk/output_2023_1_13_new", time_array=time_array, cord=cord, data=solution)

0 commit comments

Comments
 (0)