-
Notifications
You must be signed in to change notification settings - Fork 79
Support batching for rbapinns #589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
"""Module for the Residual-Based Attention PINN solver.""" | ||
|
||
from copy import deepcopy | ||
import torch | ||
|
||
from .pinn import PINN | ||
|
@@ -73,7 +72,6 @@ def __init__( | |
optimizer=None, | ||
scheduler=None, | ||
weighting=None, | ||
loss=None, | ||
eta=0.001, | ||
gamma=0.999, | ||
): | ||
|
@@ -90,99 +88,193 @@ def __init__( | |
scheduler is used. Default is ``None``. | ||
:param WeightingInterface weighting: The weighting schema to be used. | ||
If ``None``, no weighting schema is used. Default is ``None``. | ||
:param torch.nn.Module loss: The loss function to be minimized. | ||
If ``None``, the :class:`torch.nn.MSELoss` loss is used. | ||
Default is `None`. | ||
:param float | int eta: The learning rate for the weights of the | ||
residuals. Default is ``0.001``. | ||
:param float gamma: The decay parameter in the update of the weights | ||
of the residuals. Must be between ``0`` and ``1``. | ||
Default is ``0.999``. | ||
:raises: ValueError if `gamma` is not in the range (0, 1). | ||
""" | ||
super().__init__( | ||
model=model, | ||
problem=problem, | ||
optimizer=optimizer, | ||
scheduler=scheduler, | ||
weighting=weighting, | ||
loss=loss, | ||
loss=torch.nn.MSELoss(reduction="none"), | ||
) | ||
|
||
# check consistency | ||
check_consistency(eta, (float, int)) | ||
check_consistency(gamma, float) | ||
assert ( | ||
0 < gamma < 1 | ||
), f"Invalid range: expected 0 < gamma < 1, got {gamma=}" | ||
|
||
# Validate range for gamma | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also a check on eta? |
||
if not 0 < gamma < 1: | ||
raise ValueError( | ||
f"Invalid range: expected 0 < gamma < 1, but got {gamma}" | ||
) | ||
|
||
# Initialize parameters | ||
self.eta = eta | ||
self.gamma = gamma | ||
|
||
# initialize weights | ||
self.weights = {} | ||
for condition_name in problem.conditions: | ||
self.weights[condition_name] = 0 | ||
# Initialize the weight of each point to 0 | ||
self.weights = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I would use register buffer see here. In this case we can restore the training without problems |
||
cond: torch.zeros((len(data), 1), device=self.device) | ||
for cond, data in self.problem.input_pts.items() | ||
} | ||
|
||
# define vectorial loss | ||
self._vectorial_loss = deepcopy(self.loss) | ||
self._vectorial_loss.reduction = "none" | ||
|
||
# for now RBAPINN is implemented only for batch_size = None | ||
def on_train_start(self): | ||
""" | ||
Hook method called at the beginning of training. | ||
|
||
:raises NotImplementedError: If the batch size is not ``None``. | ||
""" | ||
if self.trainer.batch_size is not None: | ||
raise NotImplementedError( | ||
"RBAPINN only works with full batch " | ||
"size, set batch_size=None inside the " | ||
"Trainer to use the solver." | ||
) | ||
device = self.trainer.strategy.root_device | ||
for cond in self.weights: | ||
self.weights[cond] = self.weights[cond].to(device) | ||
return super().on_train_start() | ||
|
||
def _vect_to_scalar(self, loss_value): | ||
def training_step(self, batch, batch_idx, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like the fact that we are touching training/val/test step nor the optimization cycle. We splitted this in #542 to make the code more maintainable. Can we avoid it? From what I can see, the train/val/test can be kept the same (we don't need to take out global averaging of the losses, the user can decide imho whether to use it or not in addition to RBA weights) |
||
""" | ||
Solver training step. It computes the optimization cycle and aggregates | ||
the losses using the ``weighting`` attribute. | ||
|
||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a | ||
tuple containing a condition name and a dictionary of points. | ||
:param int batch_idx: The index of the current batch. | ||
:param dict kwargs: Additional keyword arguments passed to | ||
``optimization_cycle``. | ||
:return: The loss of the training step. | ||
:rtype: torch.Tensor | ||
""" | ||
loss = self._optimization_cycle( | ||
batch=batch, batch_idx=batch_idx, **kwargs | ||
) | ||
self.store_log("train_loss", loss, self.get_batch_size(batch)) | ||
return loss | ||
|
||
@torch.set_grad_enabled(True) | ||
def validation_step(self, batch, **kwargs): | ||
""" | ||
The validation step for the PINN solver. It returns the average residual | ||
computed with the ``loss`` function not aggregated. | ||
|
||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a | ||
tuple containing a condition name and a dictionary of points. | ||
:param dict kwargs: Additional keyword arguments passed to | ||
``optimization_cycle``. | ||
:return: The loss of the validation step. | ||
:rtype: torch.Tensor | ||
""" | ||
losses = self.optimization_cycle(batch=batch, **kwargs) | ||
|
||
# Aggregate losses for each condition | ||
for cond, loss in losses.items(): | ||
losses[cond] = losses[cond].mean() | ||
|
||
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) | ||
self.store_log("val_loss", loss, self.get_batch_size(batch)) | ||
return loss | ||
|
||
@torch.set_grad_enabled(True) | ||
def test_step(self, batch, **kwargs): | ||
""" | ||
Computation of the scalar loss. | ||
The test step for the PINN solver. It returns the average residual | ||
computed with the ``loss`` function not aggregated. | ||
|
||
:param LabelTensor loss_value: the tensor of pointwise losses. | ||
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``. | ||
:return: The computed scalar loss. | ||
:rtype: LabelTensor | ||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a | ||
tuple containing a condition name and a dictionary of points. | ||
:param dict kwargs: Additional keyword arguments passed to | ||
``optimization_cycle``. | ||
:return: The loss of the test step. | ||
:rtype: torch.Tensor | ||
""" | ||
if self.loss.reduction == "mean": | ||
ret = torch.mean(loss_value) | ||
elif self.loss.reduction == "sum": | ||
ret = torch.sum(loss_value) | ||
else: | ||
raise RuntimeError( | ||
f"Invalid reduction, got {self.loss.reduction} " | ||
"but expected mean or sum." | ||
losses = self.optimization_cycle(batch=batch, **kwargs) | ||
|
||
# Aggregate losses for each condition | ||
for cond, loss in losses.items(): | ||
losses[cond] = losses[cond].mean() | ||
|
||
loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor) | ||
self.store_log("test_loss", loss, self.get_batch_size(batch)) | ||
return loss | ||
|
||
def _optimization_cycle(self, batch, batch_idx, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would put this logic in loss_phys. It shouldn't be a big deal. In such a way we only need to override loss_phys. |
||
""" | ||
Aggregate the loss for each condition in the batch. | ||
|
||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a | ||
tuple containing a condition name and a dictionary of points. | ||
:param int batch_idx: The index of the current batch. | ||
:param dict kwargs: Additional keyword arguments passed to | ||
``optimization_cycle``. | ||
:return: The losses computed for all conditions in the batch, casted | ||
to a subclass of :class:`torch.Tensor`. It should return a dict | ||
containing the condition name and the associated scalar loss. | ||
:rtype: dict | ||
""" | ||
# compute non-aggregated residuals | ||
residuals = self.optimization_cycle(batch) | ||
|
||
# update weights based on residuals | ||
self._update_weights(batch, batch_idx, residuals) | ||
|
||
# compute losses | ||
losses = {} | ||
for cond, res in residuals.items(): | ||
|
||
# Get the correct indices for the weights. Modulus is used according | ||
# to the number of points in the condition, as in the PinaDataset. | ||
len_res = len(res) | ||
idx = torch.arange( | ||
batch_idx * len_res, | ||
(batch_idx + 1) * len_res, | ||
device=res.device, | ||
) % len(self.problem.input_pts[cond]) | ||
|
||
losses[cond] = (res * self.weights[cond][idx]).mean() | ||
|
||
# store log | ||
self.store_log( | ||
f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch) | ||
) | ||
return ret | ||
|
||
def loss_phys(self, samples, equation): | ||
# clamp unknown parameters in InverseProblem (if needed) | ||
self._clamp_params() | ||
|
||
# aggregate | ||
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) | ||
|
||
return loss | ||
|
||
def _update_weights(self, batch, batch_idx, residuals): | ||
""" | ||
Computes the physics loss for the physics-informed solver based on the | ||
provided samples and equation. | ||
Update weights based on residuals. | ||
|
||
:param LabelTensor samples: The samples to evaluate the physics loss. | ||
:param EquationInterface equation: The governing equation. | ||
:return: The computed physics loss. | ||
:rtype: LabelTensor | ||
:param list[tuple[str, dict]] batch: A batch of data. Each element is a | ||
tuple containing a condition name and a dictionary of points. | ||
:param int batch_idx: The index of the current batch. | ||
:param dict residuals: A dictionary containing the residuals for each | ||
condition. The keys are the condition names and the values are the | ||
residuals as tensors. | ||
""" | ||
residual = self.compute_residual(samples=samples, equation=equation) | ||
cond = self.current_condition_name | ||
# Iterate over each condition in the batch | ||
for cond, data in batch: | ||
|
||
r_norm = ( | ||
self.eta | ||
* torch.abs(residual) | ||
/ (torch.max(torch.abs(residual)) + 1e-12) | ||
) | ||
self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach() | ||
# Compute normalized residuals | ||
res = residuals[cond] | ||
res_abs = res.abs() | ||
r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12) | ||
|
||
loss_value = self._vectorial_loss( | ||
torch.zeros_like(residual, requires_grad=True), residual | ||
) | ||
# Get the correct indices for the weights. Modulus is used according | ||
# to the number of points in the condition, as in the PinaDataset. | ||
len_pts = len(data["input"]) | ||
idx = torch.arange( | ||
batch_idx * len_pts, | ||
(batch_idx + 1) * len_pts, | ||
device=res.device, | ||
) % len(self.problem.input_pts[cond]) | ||
|
||
return self._vect_to_scalar(self.weights[cond] ** 2 * loss_value) | ||
# Update weights | ||
weights = self.weights[cond] | ||
update = self.gamma * weights[idx] + r_norm | ||
weights[idx] = update.detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do this internally? The user will for sure forget to pass reduction=none when passing a custom loss