Skip to content

Commit 2ae6077

Browse files
add batching for rbapinns
1 parent 067fceb commit 2ae6077

File tree

2 files changed

+164
-84
lines changed

2 files changed

+164
-84
lines changed

pina/solver/physics_informed_solver/rba_pinn.py

Lines changed: 152 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Module for the Residual-Based Attention PINN solver."""
22

3-
from copy import deepcopy
43
import torch
54

65
from .pinn import PINN
@@ -73,7 +72,6 @@ def __init__(
7372
optimizer=None,
7473
scheduler=None,
7574
weighting=None,
76-
loss=None,
7775
eta=0.001,
7876
gamma=0.999,
7977
):
@@ -90,99 +88,193 @@ def __init__(
9088
scheduler is used. Default is ``None``.
9189
:param WeightingInterface weighting: The weighting schema to be used.
9290
If ``None``, no weighting schema is used. Default is ``None``.
93-
:param torch.nn.Module loss: The loss function to be minimized.
94-
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
95-
Default is `None`.
9691
:param float | int eta: The learning rate for the weights of the
9792
residuals. Default is ``0.001``.
9893
:param float gamma: The decay parameter in the update of the weights
9994
of the residuals. Must be between ``0`` and ``1``.
10095
Default is ``0.999``.
96+
:raises: ValueError if `gamma` is not in the range (0, 1).
10197
"""
10298
super().__init__(
10399
model=model,
104100
problem=problem,
105101
optimizer=optimizer,
106102
scheduler=scheduler,
107103
weighting=weighting,
108-
loss=loss,
104+
loss=torch.nn.MSELoss(reduction="none"),
109105
)
110106

111107
# check consistency
112108
check_consistency(eta, (float, int))
113109
check_consistency(gamma, float)
114-
assert (
115-
0 < gamma < 1
116-
), f"Invalid range: expected 0 < gamma < 1, got {gamma=}"
110+
111+
# Validate range for gamma
112+
if not 0 < gamma < 1:
113+
raise ValueError(
114+
f"Invalid range: expected 0 < gamma < 1, but got {gamma}"
115+
)
116+
117+
# Initialize parameters
117118
self.eta = eta
118119
self.gamma = gamma
119120

120-
# initialize weights
121-
self.weights = {}
122-
for condition_name in problem.conditions:
123-
self.weights[condition_name] = 0
121+
# Initialize the weight of each point to 0
122+
self.weights = {
123+
cond: torch.zeros((len(data), 1), device=self.device)
124+
for cond, data in self.problem.input_pts.items()
125+
}
124126

125-
# define vectorial loss
126-
self._vectorial_loss = deepcopy(self.loss)
127-
self._vectorial_loss.reduction = "none"
128-
129-
# for now RBAPINN is implemented only for batch_size = None
130127
def on_train_start(self):
131128
"""
132129
Hook method called at the beginning of training.
133-
134-
:raises NotImplementedError: If the batch size is not ``None``.
135130
"""
136-
if self.trainer.batch_size is not None:
137-
raise NotImplementedError(
138-
"RBAPINN only works with full batch "
139-
"size, set batch_size=None inside the "
140-
"Trainer to use the solver."
141-
)
131+
device = self.trainer.strategy.root_device
132+
for cond in self.weights:
133+
self.weights[cond] = self.weights[cond].to(device)
142134
return super().on_train_start()
143135

144-
def _vect_to_scalar(self, loss_value):
136+
def training_step(self, batch, batch_idx, **kwargs):
137+
"""
138+
Solver training step. It computes the optimization cycle and aggregates
139+
the losses using the ``weighting`` attribute.
140+
141+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
142+
tuple containing a condition name and a dictionary of points.
143+
:param int batch_idx: The index of the current batch.
144+
:param dict kwargs: Additional keyword arguments passed to
145+
``optimization_cycle``.
146+
:return: The loss of the training step.
147+
:rtype: torch.Tensor
148+
"""
149+
loss = self._optimization_cycle(
150+
batch=batch, batch_idx=batch_idx, **kwargs
151+
)
152+
self.store_log("train_loss", loss, self.get_batch_size(batch))
153+
return loss
154+
155+
@torch.set_grad_enabled(True)
156+
def validation_step(self, batch, **kwargs):
157+
"""
158+
The validation step for the PINN solver. It returns the average residual
159+
computed with the ``loss`` function not aggregated.
160+
161+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
162+
tuple containing a condition name and a dictionary of points.
163+
:param dict kwargs: Additional keyword arguments passed to
164+
``optimization_cycle``.
165+
:return: The loss of the validation step.
166+
:rtype: torch.Tensor
167+
"""
168+
losses = self.optimization_cycle(batch=batch, **kwargs)
169+
170+
# Aggregate losses for each condition
171+
for cond, loss in losses.items():
172+
losses[cond] = losses[cond].mean()
173+
174+
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
175+
self.store_log("val_loss", loss, self.get_batch_size(batch))
176+
return loss
177+
178+
@torch.set_grad_enabled(True)
179+
def test_step(self, batch, **kwargs):
145180
"""
146-
Computation of the scalar loss.
181+
The test step for the PINN solver. It returns the average residual
182+
computed with the ``loss`` function not aggregated.
147183
148-
:param LabelTensor loss_value: the tensor of pointwise losses.
149-
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
150-
:return: The computed scalar loss.
151-
:rtype: LabelTensor
184+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
185+
tuple containing a condition name and a dictionary of points.
186+
:param dict kwargs: Additional keyword arguments passed to
187+
``optimization_cycle``.
188+
:return: The loss of the test step.
189+
:rtype: torch.Tensor
152190
"""
153-
if self.loss.reduction == "mean":
154-
ret = torch.mean(loss_value)
155-
elif self.loss.reduction == "sum":
156-
ret = torch.sum(loss_value)
157-
else:
158-
raise RuntimeError(
159-
f"Invalid reduction, got {self.loss.reduction} "
160-
"but expected mean or sum."
191+
losses = self.optimization_cycle(batch=batch, **kwargs)
192+
193+
# Aggregate losses for each condition
194+
for cond, loss in losses.items():
195+
losses[cond] = losses[cond].mean()
196+
197+
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
198+
self.store_log("test_loss", loss, self.get_batch_size(batch))
199+
return loss
200+
201+
def _optimization_cycle(self, batch, batch_idx, **kwargs):
202+
"""
203+
Aggregate the loss for each condition in the batch.
204+
205+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
206+
tuple containing a condition name and a dictionary of points.
207+
:param int batch_idx: The index of the current batch.
208+
:param dict kwargs: Additional keyword arguments passed to
209+
``optimization_cycle``.
210+
:return: The losses computed for all conditions in the batch, casted
211+
to a subclass of :class:`torch.Tensor`. It should return a dict
212+
containing the condition name and the associated scalar loss.
213+
:rtype: dict
214+
"""
215+
# compute non-aggregated residuals
216+
residuals = self.optimization_cycle(batch)
217+
218+
# update weights based on residuals
219+
self._update_weights(batch, batch_idx, residuals)
220+
221+
# compute losses
222+
losses = {}
223+
for cond, res in residuals.items():
224+
225+
# Get the correct indices for the weights. Modulus is used according
226+
# to the number of points in the condition, as in the PinaDataset.
227+
len_res = len(res)
228+
idx = torch.arange(
229+
batch_idx * len_res,
230+
(batch_idx + 1) * len_res,
231+
device=res.device,
232+
) % len(self.problem.input_pts[cond])
233+
234+
losses[cond] = (res * self.weights[cond][idx]).mean()
235+
236+
# store log
237+
self.store_log(
238+
f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch)
161239
)
162-
return ret
163240

164-
def loss_phys(self, samples, equation):
241+
# clamp unknown parameters in InverseProblem (if needed)
242+
self._clamp_params()
243+
244+
# aggregate
245+
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
246+
247+
return loss
248+
249+
def _update_weights(self, batch, batch_idx, residuals):
165250
"""
166-
Computes the physics loss for the physics-informed solver based on the
167-
provided samples and equation.
251+
Update weights based on residuals.
168252
169-
:param LabelTensor samples: The samples to evaluate the physics loss.
170-
:param EquationInterface equation: The governing equation.
171-
:return: The computed physics loss.
172-
:rtype: LabelTensor
253+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
254+
tuple containing a condition name and a dictionary of points.
255+
:param int batch_idx: The index of the current batch.
256+
:param dict residuals: A dictionary containing the residuals for each
257+
condition. The keys are the condition names and the values are the
258+
residuals as tensors.
173259
"""
174-
residual = self.compute_residual(samples=samples, equation=equation)
175-
cond = self.current_condition_name
260+
# Iterate over each condition in the batch
261+
for cond, data in batch:
176262

177-
r_norm = (
178-
self.eta
179-
* torch.abs(residual)
180-
/ (torch.max(torch.abs(residual)) + 1e-12)
181-
)
182-
self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach()
263+
# Compute normalized residuals
264+
res = residuals[cond]
265+
res_abs = res.abs()
266+
r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12)
183267

184-
loss_value = self._vectorial_loss(
185-
torch.zeros_like(residual, requires_grad=True), residual
186-
)
268+
# Get the correct indices for the weights. Modulus is used according
269+
# to the number of points in the condition, as in the PinaDataset.
270+
len_pts = len(data["input"])
271+
idx = torch.arange(
272+
batch_idx * len_pts,
273+
(batch_idx + 1) * len_pts,
274+
device=res.device,
275+
) % len(self.problem.input_pts[cond])
187276

188-
return self._vect_to_scalar(self.weights[cond] ** 2 * loss_value)
277+
# Update weights
278+
weights = self.weights[cond]
279+
update = self.gamma * weights[idx] + r_norm
280+
weights[idx] = update.detach()

tests/test_solver/test_rba_pinn.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@
4242
@pytest.mark.parametrize("eta", [1, 0.001])
4343
@pytest.mark.parametrize("gamma", [0.5, 0.9])
4444
def test_constructor(problem, eta, gamma):
45-
with pytest.raises(AssertionError):
46-
solver = RBAPINN(model=model, problem=problem, gamma=1.5)
4745
solver = RBAPINN(model=model, problem=problem, eta=eta, gamma=gamma)
4846

47+
with pytest.raises(ValueError):
48+
solver = RBAPINN(model=model, problem=problem, gamma=1.5)
49+
4950
assert solver.accepted_conditions_types == (
5051
InputTargetCondition,
5152
InputEquationCondition,
@@ -54,30 +55,15 @@ def test_constructor(problem, eta, gamma):
5455

5556

5657
@pytest.mark.parametrize("problem", [problem, inverse_problem])
57-
def test_wrong_batch(problem):
58-
with pytest.raises(NotImplementedError):
59-
solver = RBAPINN(model=model, problem=problem)
60-
trainer = Trainer(
61-
solver=solver,
62-
max_epochs=2,
63-
accelerator="cpu",
64-
batch_size=10,
65-
train_size=1.0,
66-
val_size=0.0,
67-
test_size=0.0,
68-
)
69-
trainer.train()
70-
71-
72-
@pytest.mark.parametrize("problem", [problem, inverse_problem])
58+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
7359
@pytest.mark.parametrize("compile", [True, False])
74-
def test_solver_train(problem, compile):
60+
def test_solver_train(problem, batch_size, compile):
7561
solver = RBAPINN(model=model, problem=problem)
7662
trainer = Trainer(
7763
solver=solver,
7864
max_epochs=2,
7965
accelerator="cpu",
80-
batch_size=None,
66+
batch_size=batch_size,
8167
train_size=1.0,
8268
val_size=0.0,
8369
test_size=0.0,
@@ -89,14 +75,15 @@ def test_solver_train(problem, compile):
8975

9076

9177
@pytest.mark.parametrize("problem", [problem, inverse_problem])
78+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
9279
@pytest.mark.parametrize("compile", [True, False])
93-
def test_solver_validation(problem, compile):
80+
def test_solver_validation(problem, batch_size, compile):
9481
solver = RBAPINN(model=model, problem=problem)
9582
trainer = Trainer(
9683
solver=solver,
9784
max_epochs=2,
9885
accelerator="cpu",
99-
batch_size=None,
86+
batch_size=batch_size,
10087
train_size=0.9,
10188
val_size=0.1,
10289
test_size=0.0,
@@ -108,14 +95,15 @@ def test_solver_validation(problem, compile):
10895

10996

11097
@pytest.mark.parametrize("problem", [problem, inverse_problem])
98+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
11199
@pytest.mark.parametrize("compile", [True, False])
112-
def test_solver_test(problem, compile):
100+
def test_solver_test(problem, batch_size, compile):
113101
solver = RBAPINN(model=model, problem=problem)
114102
trainer = Trainer(
115103
solver=solver,
116104
max_epochs=2,
117105
accelerator="cpu",
118-
batch_size=None,
106+
batch_size=batch_size,
119107
train_size=0.7,
120108
val_size=0.2,
121109
test_size=0.1,

0 commit comments

Comments
 (0)