Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions examples/cylinder/2d_unsteady_continuous/2d_unstready_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddlescience as psci
import numpy as np
import paddle

import loading_cfd_data
import numpy as np

paddle.seed(1)
np.random.seed(1)

# time array
time_tmp = np.linspace(0, 50, 50, endpoint=True).astype(int)
time_array = np.random.choice(time_tmp, 11)
time_array.sort()

time_array = np.array([1, 2, 3])

# loading data from files
dr = loading_cfd_data.DataLoader(path='./datasets/')
# interior data
i_t, i_x, i_y = dr.loading_train_inside_domain_data(time_array)
# boundary inlet and circle
b_inlet_u, b_inlet_v, b_inlet_t, b_inlet_x, b_inlet_y = dr.loading_boundary_data(
time_array)
# boundary outlet
b_outlet_p, b_outlet_t, b_outlet_x, b_outlet_y = dr.loading_outlet_data(
time_array)
# initial data
init_p, init_u, init_v, init_t, init_x, init_y = dr.loading_initial_data([1])
# supervised data
sup_p, sup_u, sup_v, sup_t, sup_x, sup_y = dr.loading_supervised_data(
time_array)

inputeq = np.stack((i_t, i_x, i_y), axis=1)
inputbc1 = np.stack((b_inlet_t, b_inlet_x, b_inlet_y), axis=1)
inputbc2 = np.stack((b_outlet_t, b_outlet_x, b_outlet_y), axis=1)
inputic = np.stack((init_t, init_x, init_y), axis=1)
inputsup = np.stack((sup_t, sup_x, sup_y), axis=1)
refsup = np.stack((sup_p, sup_u, sup_v), axis=1)

# N-S
pde = psci.pde.NavierStokes(nu=0.02, rho=1.0, dim=2, time_dependent=True)

# set bounday condition
bc_inlet_u = psci.bc.Dirichlet('u', rhs=b_inlet_u)
bc_inlet_v = psci.bc.Dirichlet('v', rhs=b_inlet_v)
bc_outlet_p = psci.bc.Dirichlet('p', rhs=b_outlet_p)

# add bounday and boundary condition
pde.set_bc("inlet", bc_inlet_u, bc_inlet_v)
pde.set_bc("outlet", bc_outlet_p)

# add initial condition
ic_u = psci.ic.IC('u', rhs=init_u)
ic_v = psci.ic.IC('v', rhs=init_v)
ic_p = psci.ic.IC('p', rhs=init_p)
pde.set_ic(ic_u, ic_v, ic_p)

# Network
net = psci.network.FCNet(
num_ins=3, num_outs=3, num_layers=6, hidden_size=50, activation='tanh')
net.initialize(path='./checkpoint/pretrained_net_params')

outeq = net(inputeq)
outbc1 = net(inputbc1)
outbc2 = net(inputbc2)
outic = net(inputic)
outsup = net(inputsup)

# eq loss
losseq1 = psci.loss.EqLoss(pde.equations[0], netout=outeq)
losseq2 = psci.loss.EqLoss(pde.equations[1], netout=outeq)
losseq3 = psci.loss.EqLoss(pde.equations[2], netout=outeq)
# bc loss
lossbc1 = psci.loss.BcLoss("inlet", netout=outbc1)
lossbc2 = psci.loss.BcLoss("outlet", netout=outbc2)
# ic loss
lossic = psci.loss.IcLoss(netout=outic)
# supervise loss
losssup = psci.loss.DataLoss(netout=outsup, ref=refsup)

# total loss
loss = losseq1 + losseq2 + losseq3 + 10.0 * lossbc1 + lossbc2 + 10.0 * lossic + 10.0 * losssup

# Algorithm
algo = psci.algorithm.PINNs(net=net, loss=loss)

# Optimizer
opt = psci.optimizer.Adam(learning_rate=0.001, parameters=net.parameters())

# Solver
solver = psci.solver.Solver(pde=pde, algo=algo, opt=opt)

# Solve
solution = solver.solve(num_epoch=1)

for i in solution:
print(i.shape)

n = int(i_x.shape[0] / len(time_array))

i_x = i_x.astype("float32")
i_y = i_y.astype("float32")

cord = np.stack((i_x[0:n], i_y[0:n]), axis=1)
psci.visu.__save_vtk_raw(cordinate=cord, data=solution[0][-n::])

exit()

psci.visu.save_vtk(
time_array=time_array, geo_disc=pde_disc.geometry, data=solution)
104 changes: 31 additions & 73 deletions examples/cylinder/2d_unsteady_continuous/loading_cfd_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,7 @@ def replicate_time_list(self, time_list, domain_shape, spatial_data):

return replicated_t, spatial_data

def loading_train_inside_domain_data(self,
time_list,
flatten=False,
dtype='float32'):
def loading_train_inside_domain_data(self, time_list):
# load train_domain points
# domain_train.csv, title is p,U:0,U:1,U:2,Points:0,Points:1,Points:2
filename = 'domain_train.csv'
Expand All @@ -113,16 +110,10 @@ def loading_train_inside_domain_data(self,
x = domain_data[:, 4].reshape((-1, 1))
y = domain_data[:, 5].reshape((-1, 1))
t, xy = self.replicate_time_list(time_list, x.shape[0], [x, y])
t = t.astype(dtype)
xy[0] = xy[0].astype(dtype)
xy[1] = xy[1].astype(dtype)
print("residual data shape:", t.shape[0])
if flatten == True:
return t.flatten(), xy[0].flatten(), xy[1].flatten()
else:
return t, xy[0], xy[1]
return t.flatten(), xy[0].flatten(), xy[1].flatten()

def loading_outlet_data(self, time_list, flatten=False, dtype='float32'):
def loading_outlet_data(self, time_list):
filename = 'domain_outlet.csv'
path = self.path

Expand All @@ -138,17 +129,8 @@ def loading_outlet_data(self, time_list, flatten=False, dtype='float32'):
print("outlet data shape:", outlet_data.shape[0])
t, pxy = self.replicate_time_list(time_list, outlet_data.shape[0],
[p, x, y])

pxy[0] = pxy[0].astype(dtype)
pxy[1] = pxy[1].astype(dtype)
pxy[2] = pxy[2].astype(dtype)
t = t.astype(dtype)

if flatten == True:
return pxy[0].flatten(), t.flatten(), pxy[1].flatten(), pxy[
2].flatten()
else:
return pxy[0], t, pxy[1], pxy[2]
return pxy[0].flatten(), t.flatten(), pxy[1].flatten(), pxy[2].flatten(
)

def loading_inlet_data(self, time_list, path):
# title is p,U:0,U:1,U:2,Points:0,Points:1,Points:2
Expand All @@ -171,32 +153,18 @@ def loading_side_data(self, time_list, path):
# u, v, x, y
return self.loading_data(time_list, path, filename)

def loading_boundary_data(self,
time_list,
num_random=None,
flatten=False,
dtype='float32'):
def loading_boundary_data(self, time_list, num_random=None):
inlet_bc = self.loading_inlet_data(time_list, self.path)
# side_bc = self.loading_side_data(time_list, self.path)
cylinder_bc = self.loading_cylinder_data(time_list, self.path)

u = np.concatenate((inlet_bc[0], cylinder_bc[0])).astype(dtype)
v = np.concatenate((inlet_bc[1], cylinder_bc[1])).astype(dtype)
x = np.concatenate((inlet_bc[2], cylinder_bc[2])).astype(dtype)
y = np.concatenate((inlet_bc[3], cylinder_bc[3])).astype(dtype)
u = np.concatenate((inlet_bc[0], cylinder_bc[0]))
v = np.concatenate((inlet_bc[1], cylinder_bc[1]))
x = np.concatenate((inlet_bc[2], cylinder_bc[2]))
y = np.concatenate((inlet_bc[3], cylinder_bc[3]))
t, uvxy = self.replicate_time_list(time_list, u.shape[0], [u, v, x, y])

t = t.astype(dtype)
uvxy[0] = uvxy[0].astype(dtype)
uvxy[1] = uvxy[1].astype(dtype)
uvxy[2] = uvxy[2].astype(dtype)
uvxy[3] = uvxy[3].astype(dtype)

if flatten == True:
return uvxy[0].flatten(), uvxy[1].flatten(), t.flatten(), uvxy[
2].flatten(), uvxy[3].flatten()
else:
return uvxy[0], uvxy[1], t, uvxy[2], uvxy[3]
return uvxy[0].flatten(), uvxy[1].flatten(), t.flatten(), uvxy[
2].flatten(), uvxy[3].flatten()

def loading_data(self, time_list, path, filename, num_random=None):
# boudnary datra: cylinder/inlet/side
Expand All @@ -221,10 +189,7 @@ def loading_data(self, time_list, path, filename, num_random=None):
print("boundary data shape:", boundary_data.shape[0])
return u, v, x, y

def loading_supervised_data(self,
time_list,
flatten=False,
dtype='float32'):
def loading_supervised_data(self, time_list):
path = self.path

supervised_data = None
Expand All @@ -250,20 +215,16 @@ def loading_supervised_data(self,

print("supervised data shape:", full_supervised_data.shape[0])
# p, u, v, t, x, y
p = full_supervised_data[:, 1].reshape((-1, 1)).astype(dtype)
u = full_supervised_data[:, 2].reshape((-1, 1)).astype(dtype)
v = full_supervised_data[:, 3].reshape((-1, 1)).astype(dtype)
t = full_supervised_data[:, 0].reshape((-1, 1)).astype(dtype)
x = full_supervised_data[:, 6].reshape((-1, 1)).astype(dtype)
y = full_supervised_data[:, 7].reshape((-1, 1)).astype(dtype)

if flatten == True:
return p.flatten(), u.flatten(), v.flatten(), t.flatten(
), x.flatten(), y.flatten()
else:
return p, u, v, t, x, y

def loading_initial_data(self, time_list, flatten=False, dtype='float32'):
p = full_supervised_data[:, 1].reshape((-1, 1))
u = full_supervised_data[:, 2].reshape((-1, 1))
v = full_supervised_data[:, 3].reshape((-1, 1))
t = full_supervised_data[:, 0].reshape((-1, 1))
x = full_supervised_data[:, 6].reshape((-1, 1))
y = full_supervised_data[:, 7].reshape((-1, 1))
return p.flatten(), u.flatten(), v.flatten(), t.flatten(), x.flatten(
), y.flatten()

def loading_initial_data(self, time_list):
# "p","U:0","U:1","U:2","vtkOriginalPointIds","Points:0","Points:1","Points:2"
path = self.path

Expand All @@ -278,14 +239,11 @@ def loading_initial_data(self, time_list, flatten=False, dtype='float32'):

print("initial data shape:", initial_data.shape[0])
# p, u, v, t, x, y
p = initial_t_data[:, 1].reshape((-1, 1)).astype(dtype)
u = initial_t_data[:, 2].reshape((-1, 1)).astype(dtype)
v = initial_t_data[:, 3].reshape((-1, 1)).astype(dtype)
t = initial_t_data[:, 0].reshape((-1, 1)).astype(dtype)
x = initial_t_data[:, 5].reshape((-1, 1)).astype(dtype)
y = initial_t_data[:, 6].reshape((-1, 1)).astype(dtype)
if flatten == True:
return p.flatten(), u.flatten(), v.flatten(), t.flatten(
), x.flatten(), y.flatten()
else:
return p, u, v, t, x, y
p = initial_t_data[:, 1].reshape((-1, 1))
u = initial_t_data[:, 2].reshape((-1, 1))
v = initial_t_data[:, 3].reshape((-1, 1))
t = initial_t_data[:, 0].reshape((-1, 1))
x = initial_t_data[:, 5].reshape((-1, 1))
y = initial_t_data[:, 6].reshape((-1, 1))
return p.flatten(), u.flatten(), v.flatten(), t.flatten(), x.flatten(
), y.flatten()
Loading