Skip to content

Commit f82085d

Browse files
committed
fix issues + speedup test
1 parent 05f7d2f commit f82085d

15 files changed

Lines changed: 87 additions & 64 deletions

pina/solver/__init__.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"CompetitivePINN",
1212
"SelfAdaptivePINN",
1313
"RBAPINN",
14+
"SupervisedSolverInterface",
1415
"SupervisedSolver",
1516
"ReducedOrderModelSolver",
1617
"DeepEnsembleSolverInterface",
@@ -20,7 +21,23 @@
2021
]
2122

2223
from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface
23-
from .physics_informed_solver import *
24-
from .supervised_solver import *
25-
from .ensemble_solver import *
24+
from .physics_informed_solver import (
25+
PINNInterface,
26+
PINN,
27+
GradientPINN,
28+
CausalPINN,
29+
CompetitivePINN,
30+
SelfAdaptivePINN,
31+
RBAPINN,
32+
)
33+
from .supervised_solver import (
34+
SupervisedSolverInterface,
35+
SupervisedSolver,
36+
ReducedOrderModelSolver,
37+
)
38+
from .ensemble_solver import (
39+
DeepEnsembleSolverInterface,
40+
DeepEnsembleSupervisedSolver,
41+
DeepEnsemblePINN,
42+
)
2643
from .garom import GAROM

pina/solver/ensemble_solver/ensemble_pinn.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,11 @@ def __init__(
9898
If ``None``, no weighting schema is used. Default is ``None``.
9999
:param int ensemble_dim: The dimension along which the ensemble
100100
outputs are stacked. Default is 0.
101+
:raises NotImplementedError: If an inverse problem is passed.
101102
"""
102103
if isinstance(problem, InverseProblem):
103104
raise NotImplementedError(
104-
"DeepEnsemblePINN does not work on inverse problems."
105+
"DeepEnsemblePINN can not be used to solve inverse problems."
105106
)
106107
super().__init__(
107108
problem=problem,
@@ -126,11 +127,12 @@ def loss_data(self, input, target):
126127
:return: The supervised loss, averaged over the number of observations.
127128
:rtype: torch.Tensor
128129
"""
130+
predictions = self.forward(input)
129131
loss = sum(
130-
self._loss_fn(self.forward(input, idx), target)
131-
for idx in range(self.num_ensembles)
132+
self._loss_fn(predictions[idx], target)
133+
for idx in range(self.num_ensemble)
132134
)
133-
return loss / self.num_ensembles
135+
return loss / self.num_ensemble
134136

135137
def loss_phys(self, samples, equation):
136138
"""
@@ -160,7 +162,9 @@ def _residual_loss(self, samples, equation):
160162
:rtype: torch.Tensor
161163
"""
162164
loss = 0
163-
for idx in range(self.num_ensembles):
164-
residuals = equation.residual(samples, self.forward(samples, idx))
165-
loss = loss + self._loss_fn(residuals, torch.zeros_like(residuals))
166-
return loss / self.num_ensembles
165+
predictions = self.forward(samples)
166+
for idx in range(self.num_ensemble):
167+
residuals = equation.residual(samples, predictions[idx])
168+
target = torch.zeros_like(residuals, requires_grad=True)
169+
loss = loss + self._loss_fn(residuals, target)
170+
return loss / self.num_ensemble

pina/solver/ensemble_solver/ensemble_solver_interface.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ def __init__(
6161
6262
:param AbstractProblem problem: The problem to be solved.
6363
:param torch.nn.Module models: The neural network models to be used.
64-
:param int ensemble_dim: The dimension along which the ensemble
65-
outputs are stacked. Default is 0.
6664
:param Optimizer optimizer: The optimizer to be used.
6765
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
6866
Default is ``None``.
@@ -73,6 +71,8 @@ def __init__(
7371
If ``None``, no weighting schema is used. Default is ``None``.
7472
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
7573
Default is ``True``.
74+
:param int ensemble_dim: The dimension along which the ensemble
75+
outputs are stacked. Default is 0.
7676
"""
7777
super().__init__(
7878
problem, models, optimizers, schedulers, weighting, use_lt
@@ -90,7 +90,8 @@ def forward(self, x, ensemble_idx=None):
9090
9191
:param LabelTensor x: The input tensor to the models.
9292
:param int ensemble_idx: Optional index to select a specific
93-
model from the ensemble.
93+
model from the ensemble. If ``None`` results for all models are
94+
stacked in ``ensemble_dim`` dimension. Default is ``None``.
9495
:return: The output of the selected model or the stacked
9596
outputs from all models.
9697
:rtype: LabelTensor
@@ -100,7 +101,7 @@ def forward(self, x, ensemble_idx=None):
100101
return self.models[ensemble_idx].forward(x)
101102
# otherwise return the stacked output
102103
return torch.stack(
103-
[self.forward(x, idx) for idx in range(self.num_ensembles)],
104+
[self.forward(x, idx) for idx in range(self.num_ensemble)],
104105
dim=self.ensemble_dim,
105106
)
106107

@@ -125,8 +126,9 @@ def training_step(self, batch):
125126
# perform backpropagation
126127
self.manual_backward(loss)
127128
# optimize
128-
for opt in self.optimizers:
129+
for opt, sched in zip(self.optimizers, self.schedulers):
129130
opt.instance.step()
131+
sched.instance.step()
130132
return loss
131133

132134
@property
@@ -140,7 +142,7 @@ def ensemble_dim(self):
140142
return self._ensemble_dim
141143

142144
@property
143-
def num_ensembles(self):
145+
def num_ensemble(self):
144146
"""
145147
The number of models in the ensemble.
146148

pina/solver/ensemble_solver/ensemble_supervised.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,9 @@ def loss_data(self, input, target):
114114
:return: The supervised loss, averaged over the number of observations.
115115
:rtype: torch.Tensor
116116
"""
117+
predictions = self.forward(input)
117118
loss = sum(
118-
self._loss_fn(self.forward(input, idx), target)
119-
for idx in range(self.num_ensembles)
119+
self._loss_fn(predictions[idx], target)
120+
for idx in range(self.num_ensemble)
120121
)
121-
return loss / self.num_ensembles
122+
return loss / self.num_ensemble

pina/solver/physics_informed_solver/pinn_interface.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,9 @@ def loss_data(self, input, target):
142142
between the network's output and the true solution. This method should
143143
be overridden by the derived class.
144144
145-
:param input: The input to the neural network.
146-
:type input: LabelTensor
147-
:param target: The target to compare with the network's output.
148-
:type target: LabelTensor
145+
:param LabelTensor input: The input to the neural network.
146+
:param LabelTensor target: The target to compare with the
147+
network's output.
149148
:return: The supervised loss, averaged over the number of observations.
150149
:rtype: LabelTensor
151150
"""

tests/test_solver/test_causal_pinn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ class DummySpatialProblem(SpatialProblem):
2727

2828
# define problems
2929
problem = DiffusionReactionProblem()
30-
problem.discretise_domain(50)
30+
problem.discretise_domain(10)
3131

3232
# add input-output condition to test supervised learning
33-
input_pts = torch.rand(50, len(problem.input_variables))
33+
input_pts = torch.rand(10, len(problem.input_variables))
3434
input_pts = LabelTensor(input_pts, problem.input_variables)
35-
output_pts = torch.rand(50, len(problem.output_variables))
35+
output_pts = torch.rand(10, len(problem.output_variables))
3636
output_pts = LabelTensor(output_pts, problem.output_variables)
3737
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)
3838

tests/test_solver/test_competitive_pinn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,19 @@
1919

2020
# define problems
2121
problem = Poisson()
22-
problem.discretise_domain(50)
22+
problem.discretise_domain(10)
2323
inverse_problem = InversePoisson()
24-
inverse_problem.discretise_domain(50)
24+
inverse_problem.discretise_domain(10)
2525

2626
# reduce the number of data points to speed up testing
2727
data_condition = inverse_problem.conditions["data"]
2828
data_condition.input = data_condition.input[:10]
2929
data_condition.target = data_condition.target[:10]
3030

3131
# add input-output condition to test supervised learning
32-
input_pts = torch.rand(50, len(problem.input_variables))
32+
input_pts = torch.rand(10, len(problem.input_variables))
3333
input_pts = LabelTensor(input_pts, problem.input_variables)
34-
output_pts = torch.rand(50, len(problem.output_variables))
34+
output_pts = torch.rand(10, len(problem.output_variables))
3535
output_pts = LabelTensor(output_pts, problem.output_variables)
3636
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)
3737

tests/test_solver/test_ensemble_pinn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
problem.discretise_domain(10)
2020

2121
# add input-output condition to test supervised learning
22-
input_pts = torch.rand(50, len(problem.input_variables))
22+
input_pts = torch.rand(10, len(problem.input_variables))
2323
input_pts = LabelTensor(input_pts, problem.input_variables)
24-
output_pts = torch.rand(50, len(problem.output_variables))
24+
output_pts = torch.rand(10, len(problem.output_variables))
2525
output_pts = LabelTensor(output_pts, problem.output_variables)
2626
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)
2727

@@ -42,7 +42,7 @@ def test_constructor():
4242
InputEquationCondition,
4343
DomainEquationCondition,
4444
)
45-
assert solver.num_ensembles == 5
45+
assert solver.num_ensemble == 5
4646

4747

4848
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])

tests/test_solver/test_ensemble_supervised_solver.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ class TensorProblem(AbstractProblem):
3030
}
3131

3232

33-
x = torch.rand((100, 20, 5))
34-
pos = torch.rand((100, 20, 2))
35-
output_ = torch.rand((100, 20, 1))
33+
x = torch.rand((15, 20, 5))
34+
pos = torch.rand((15, 20, 2))
35+
output_ = torch.rand((15, 20, 1))
3636
input_ = [
3737
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
3838
for x_, pos_ in zip(x, pos)
@@ -44,9 +44,9 @@ class GraphProblem(AbstractProblem):
4444
conditions = {"data": Condition(input=input_, target=output_)}
4545

4646

47-
x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"])
48-
pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"])
49-
output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"])
47+
x = LabelTensor(torch.rand((15, 20, 5)), ["a", "b", "c", "d", "e"])
48+
pos = LabelTensor(torch.rand((15, 20, 2)), ["x", "y"])
49+
output_ = LabelTensor(torch.rand((15, 20, 1)), ["u"])
5050
input_ = [
5151
KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True)
5252
for i in range(len(x))
@@ -96,7 +96,7 @@ def test_constructor():
9696
assert DeepEnsembleSupervisedSolver.accepted_conditions_types == (
9797
InputTargetCondition
9898
)
99-
assert solver.num_ensembles == 10
99+
assert solver.num_ensemble == 10
100100

101101

102102
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])

tests/test_solver/test_garom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TensorProblem(AbstractProblem):
1515
input_variables = ["u_0", "u_1"]
1616
output_variables = ["u"]
1717
conditions = {
18-
"data": Condition(target=torch.randn(50, 2), input=torch.randn(50, 1))
18+
"data": Condition(target=torch.randn(10, 2), input=torch.randn(10, 1))
1919
}
2020

2121

0 commit comments

Comments
 (0)