Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions examples/poisson/poisson2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
import paddlescience as psci
from paddlescience.optimizer.lr import Cosine

# Geometry
npoints = 10201
seed_num = 42
sampler_method = 'uniform'
# Network
epochs = 20000
num_layers = 5
hidden_size = 20
activation = 'tanh'
# Optimizer
learning_rate = 0.001
# Post-processing
solution_filename = 'output_laplace2d'
vtk_filename = 'output_laplace2d'
checkpoint_path = 'checkpoints'

paddle.seed(seed_num)
np.random.seed(seed_num)


def replicate_t(t_array, data):
"""
replicate_t
"""
full_data = None
t_len = data.shape[0]
for time in t_array:
t_extended = np.array(
[time] * t_len, dtype="float32").reshape((-1, 1)) # [N, 1]
t_data = np.concatenate(
(t_extended, data), axis=1) # [N, xyz]->[N, txyz]
if full_data is None:
full_data = t_data
else:
full_data = np.concatenate((full_data, t_data)) # [N*t_step,txyz]

return full_data


# time
start_time = 0.1
end_time = 1.5
time_step = 0.1
time_num = int((end_time - start_time + 0.5 * time_step) / time_step) + 1
time_tmp = np.linspace(start_time, end_time, time_num, endpoint=True)
time_array = time_tmp
print(f"time_num = {time_num}, time_array = {time_array}")

# set geometry and boundary
# geo = psci.geometry.Rectangular(origin=(0.0, 0.0), extent=(1.0, 1.0))
geo = psci.neo_geometry.Disk((0.0, 0.0), 1.0)
geo.add_sample_config("interior", 10000)
geo.add_sample_config("boundary", 1000)

points_dict = geo.fetch_batch_data()
geo_disc = geo
geo_disc.interior = points_dict["interior"]
geo_disc.boundary = {
"around":
(points_dict["boundary"], geo.boundary_normal(points_dict["boundary"])),
}
geo_disc.user = None
geo_disc.normal = {"around": None}

# Poisson
pde = psci.pde.Poisson(dim=2, alpha=0.1, rhs=1.0, weight=1.0) # weight ?

# define boundary condition dT/dn = q
bc_around = psci.bc.Neumann('T', rhs=1.0)

# set bounday condition
pde.set_bc("around", bc_around)

# define initial condition T
ic_T = psci.ic.IC('T', rhs=0.0)

# set initial condition T
pde.set_ic(ic_T)

# Network
# TODO: remove num_ins and num_outs
net = psci.network.FCNet(
num_ins=3,
num_outs=1,
num_layers=num_layers,
hidden_size=hidden_size,
activation=activation)
# net.initialize("./checkpoint/dynamic_net_params_20000.pdparams")

# eq loss
cords_interior = geo_disc.interior
num_cords = cords_interior.shape[0]
print("num_cords = ", num_cords)
inputeq = replicate_t(time_array, cords_interior)
outeq = net(inputeq)
losseq = psci.loss.EqLoss(pde.equations[0], netout=outeq)

# ic loss
inputic = replicate_t([0.0], cords_interior)
print("inputic.shape: ", inputic.shape)
outic = net(inputic)
lossic = psci.loss.IcLoss(netout=outic[:, :1])

# bc loss
inputbc = geo_disc.boundary["around"][0]
inputbc_n = geo_disc.boundary["around"][1]
inputbc = replicate_t(time_array, inputbc)
inputbc_n = replicate_t(time_array, inputbc_n)
outbc = net((inputbc, inputbc_n))
lossbc = psci.loss.BcLoss("around", netout=outbc)

# total loss
loss = losseq + 10.0 * lossic + 10.0 * lossbc

# Algorithm
algo = psci.algorithm.PINNs(net=net, loss=loss)

# Optimizer
learning_rate = Cosine(
epochs, 1, learning_rate, warmup_epoch=int(epochs * 0.05), by_epoch=True)()
opt = psci.optimizer.Adam(
learning_rate=learning_rate, parameters=net.parameters())

# Solver
solver = psci.solver.Solver(pde=pde, algo=algo, opt=opt)
solution = solver.solve(num_epoch=epochs)
# solution = solver.predict()
for i in range(len(solution)):
print(f"solution[{i}]={solution[i].shape}")

# Save result to vtk
for i in range(time_num):
psci.visu.__save_vtk_raw(
filename=f"./vtk/disk_poisson2d_output_time{i}",
cordinate=geo_disc.interior,
data=solution[0][i * num_cords:(i + 1) * num_cords])
140 changes: 140 additions & 0 deletions examples/poisson/poisson3d_robin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np
import paddle
import paddlescience as psci
from paddlescience.optimizer.lr import Cosine
from paddlescience.utils import replicate_t

# Geometry
npoints = 10201
seed_num = 42
sampler_method = "uniform"
# Network
epochs = 20000
num_layers = 5
hidden_size = 20
activation = "tanh"
# Optimizer
learning_rate = 0.001
# Post-processing
solution_filename = "output_laplace3d"
vtk_filename = "output_laplace3d"
checkpoint_path = "checkpoints"

paddle.seed(seed_num)
np.random.seed(seed_num)

# time
start_time = 0.1
end_time = 1.5
time_step = 0.1
time_num = int((end_time - start_time + 0.5 * time_step) / time_step) + 1
time_tmp = np.linspace(start_time, end_time, time_num, endpoint=True)
time_array = time_tmp
print(f"time_num = {time_num}, time_array = {time_array}")

# set geometry and boundary
geo = psci.neo_geometry.Sphere((0.0, 0.0, 0.0), 1.0)
geo.add_sample_config("interior", 40000)
geo.add_sample_config("boundary", 4000)

points_dict = geo.fetch_batch_data()
geo_disc = geo
geo_disc.interior = points_dict["interior"]
geo_disc.boundary = {
"around":
(points_dict["boundary"], geo.boundary_normal(points_dict["boundary"])),
}
geo_disc.user = None
geo_disc.normal = {"around": None, }

# Poisson
pde = psci.pde.Poisson(dim=3, alpha=0.1, rhs=1.0, weight=1.0)

# define boundary condition dT/dn+hT-hT_amb=0
bc_around = psci.bc.Robin("T", a=1.0, b=1.0, rhs=1.0)

# set bounday condition
pde.set_bc("around", bc_around)

# define initial condition T
ic_T = psci.ic.IC("T", rhs=0.0)

# set initial condition T
pde.set_ic(ic_T)

# Network
# TODO: remove num_ins and num_outs
net = psci.network.FCNet(
num_ins=4,
num_outs=1,
num_layers=num_layers,
hidden_size=hidden_size,
activation=activation)
# net.initialize("./checkpoint/dynamic_net_params_20000.pdparams")

# eq loss
cords_interior = geo_disc.interior
num_cords = cords_interior.shape[0]
print("num_cords = ", num_cords)
inputeq = psci.utils.replicate_t(time_array, cords_interior)
outeq = net(inputeq)
losseq = psci.loss.EqLoss(pde.equations[0], netout=outeq)

# ic loss
inputic = psci.utils.replicate_t([0.0], cords_interior)
print("inputic.shape: ", inputic.shape)
outic = net(inputic)
lossic = psci.loss.IcLoss(netout=outic[:, :1])

# bc loss
inputbc = geo_disc.boundary["around"][0]
inputbc_n = geo_disc.boundary["around"][1]
inputbc = psci.utils.replicate_t(time_array, inputbc)
inputbc_n = psci.utils.replicate_t(time_array, inputbc_n)
outbc = net((inputbc, inputbc_n))
lossbc = psci.loss.BcLoss("around", netout=outbc)

# total loss
loss = losseq + 10.0 * lossic + 10.0 * lossbc

# Algorithm
algo = psci.algorithm.PINNs(net=net, loss=loss)

# Optimizer
learning_rate = Cosine(
epochs, 1, learning_rate, warmup_epoch=int(epochs * 0.05), by_epoch=True)()
opt = psci.optimizer.Adam(
learning_rate=learning_rate, parameters=net.parameters())

# Solver
solver = psci.solver.Solver(pde=pde, algo=algo, opt=opt)
solution = solver.solve(num_epoch=epochs)
# solution = solver.predict()

for i in range(len(solution)):
print(f"solution[{i}]={solution[i].shape}")

# Save result to vtk
dirname = "vtk_3d_robin_bs10240_PRver"
os.makedirs(f"./{dirname}", exist_ok=True)
for i in range(time_num):
psci.visu.__save_vtk_raw(
filename=f"./{dirname}/disk_poisson3d_output_time{i}",
cordinate=geo_disc.interior,
data=solution[0][i * num_cords:(i + 1) * num_cords])
3 changes: 2 additions & 1 deletion paddlescience/algorithm/algorithm_pinns.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,6 @@ def compute(self, params, *inputs_labels, ninputs, inputs_attr, nlabels,
# interior points: compute eq_loss
for name_i, input_attr in inputs_attr["interior"].items():
input = inputs[n]

# print("int: ", len(input))
# print(input[0:5, :])

Expand Down Expand Up @@ -702,6 +701,8 @@ def __sqrt(self, x):
return paddle.sqrt(x)

def __padding_array(self, nprocs, array):
if nprocs == 1:
return array
npad = (nprocs - len(array) % nprocs) % nprocs # pad npad elements
if array.ndim == 2:
datapad = array[-1, :].reshape((-1, array[-1, :].shape[0]))
Expand Down
20 changes: 15 additions & 5 deletions paddlescience/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,13 @@ def __init__(self, name, rhs=0.0, weight=1.0):
self.rhs = rhs

def to_formula(self, indvar):
n = sympy.Symbol('n')
u = sympy.Function(self.name)(*indvar)
self.formula = sympy.Derivative(u, n)
self.formula = 0
for indv in indvar:
if indv.name == "t":
continue
self.formula += sympy.Derivative(
u, indv) * sympy.Symbol(f"n_{indv.name}")

def discretize(self, indvar):
bc_disc = copy.deepcopy(self)
Expand All @@ -112,15 +116,21 @@ class Robin(BC):
>>> bc2 = psci.bc.Robin("u", rhs=lambda x, y: 0.0)
"""

def __init__(self, name, rhs=0.0, weight=1.0):
def __init__(self, name, a, b, rhs=0.0, weight=1.0):
super(Robin, self).__init__(name, weight)
self.category = "Robin"
self.a = a
self.b = b
self.rhs = rhs

def to_formula(self, indvar):
n = sympy.Symbol('n')
u = sympy.Function(self.name)(*indvar)
self.formula = u + sympy.Derivative(u, n)
self.formula = self.a * u
for indv in indvar:
if indv.name == "t":
continue
self.formula += sympy.Derivative(u, indv) * \
sympy.Symbol(f"n_{indv.name}") * self.b

def discretize(self, indvar):
bc_disc = copy.deepcopy(self)
Expand Down
6 changes: 5 additions & 1 deletion paddlescience/loss/loss_L2.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,18 @@ def bc_loss(self,
cmploss = CompFormula(pde, net)

# compute outs, jacobian, hessian
cmploss.compute_outs_der(input, bs,
cmploss.compute_outs_der(input[0]
if isinstance(input,
(tuple, list)) else input, bs,
params) # TODO: dirichlet not need der

loss = 0.0
for i in range(len(pde.bc[name_b])):
# TODO: hard code bs

normal_b = labels_attr["bc"][name_b][i]["normal"]
if isinstance(input, (tuple, list)):
normal_b = input[1]
if type(normal_b) == LabelInt:
normal = labels[normal_b]
else:
Expand Down
Loading