Skip to content

Support batching for rbapinns #589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 152 additions & 60 deletions pina/solver/physics_informed_solver/rba_pinn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Module for the Residual-Based Attention PINN solver."""

from copy import deepcopy
import torch

from .pinn import PINN
Expand Down Expand Up @@ -73,7 +72,6 @@ def __init__(
optimizer=None,
scheduler=None,
weighting=None,
loss=None,
eta=0.001,
gamma=0.999,
):
Expand All @@ -90,99 +88,193 @@ def __init__(
scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If ``None``, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param float | int eta: The learning rate for the weights of the
residuals. Default is ``0.001``.
:param float gamma: The decay parameter in the update of the weights
of the residuals. Must be between ``0`` and ``1``.
Default is ``0.999``.
:raises: ValueError if `gamma` is not in the range (0, 1).
"""
super().__init__(
model=model,
problem=problem,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
loss=loss,
loss=torch.nn.MSELoss(reduction="none"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this internally? The user will for sure forget to pass reduction=none when passing a custom loss

)

# check consistency
check_consistency(eta, (float, int))
check_consistency(gamma, float)
assert (
0 < gamma < 1
), f"Invalid range: expected 0 < gamma < 1, got {gamma=}"

# Validate range for gamma
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a check on eta?

if not 0 < gamma < 1:
raise ValueError(
f"Invalid range: expected 0 < gamma < 1, but got {gamma}"
)

# Initialize parameters
self.eta = eta
self.gamma = gamma

# initialize weights
self.weights = {}
for condition_name in problem.conditions:
self.weights[condition_name] = 0
# Initialize the weight of each point to 0
self.weights = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I would use register buffer see here. In this case we can restore the training without problems

cond: torch.zeros((len(data), 1), device=self.device)
for cond, data in self.problem.input_pts.items()
}

# define vectorial loss
self._vectorial_loss = deepcopy(self.loss)
self._vectorial_loss.reduction = "none"

# for now RBAPINN is implemented only for batch_size = None
def on_train_start(self):
"""
Hook method called at the beginning of training.

:raises NotImplementedError: If the batch size is not ``None``.
"""
if self.trainer.batch_size is not None:
raise NotImplementedError(
"RBAPINN only works with full batch "
"size, set batch_size=None inside the "
"Trainer to use the solver."
)
device = self.trainer.strategy.root_device
for cond in self.weights:
self.weights[cond] = self.weights[cond].to(device)
return super().on_train_start()

def _vect_to_scalar(self, loss_value):
def training_step(self, batch, batch_idx, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the fact that we are touching training/val/test step nor the optimization cycle. We splitted this in #542 to make the code more maintainable. Can we avoid it?

From what I can see, the train/val/test can be kept the same (we don't need to take out global averaging of the losses, the user can decide imho whether to use it or not in addition to RBA weights)

"""
Solver training step. It computes the optimization cycle and aggregates
the losses using the ``weighting`` attribute.

:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the training step.
:rtype: torch.Tensor
"""
loss = self._optimization_cycle(
batch=batch, batch_idx=batch_idx, **kwargs
)
self.store_log("train_loss", loss, self.get_batch_size(batch))
return loss

@torch.set_grad_enabled(True)
def validation_step(self, batch, **kwargs):
"""
The validation step for the PINN solver. It returns the average residual
computed with the ``loss`` function not aggregated.

:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
losses = self.optimization_cycle(batch=batch, **kwargs)

# Aggregate losses for each condition
for cond, loss in losses.items():
losses[cond] = losses[cond].mean()

loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
self.store_log("val_loss", loss, self.get_batch_size(batch))
return loss

@torch.set_grad_enabled(True)
def test_step(self, batch, **kwargs):
"""
Computation of the scalar loss.
The test step for the PINN solver. It returns the average residual
computed with the ``loss`` function not aggregated.

:param LabelTensor loss_value: the tensor of pointwise losses.
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
:return: The computed scalar loss.
:rtype: LabelTensor
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The loss of the test step.
:rtype: torch.Tensor
"""
if self.loss.reduction == "mean":
ret = torch.mean(loss_value)
elif self.loss.reduction == "sum":
ret = torch.sum(loss_value)
else:
raise RuntimeError(
f"Invalid reduction, got {self.loss.reduction} "
"but expected mean or sum."
losses = self.optimization_cycle(batch=batch, **kwargs)

# Aggregate losses for each condition
for cond, loss in losses.items():
losses[cond] = losses[cond].mean()

loss = (sum(losses.values()) / len(losses)).as_subclass(torch.Tensor)
self.store_log("test_loss", loss, self.get_batch_size(batch))
return loss

def _optimization_cycle(self, batch, batch_idx, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put this logic in loss_phys. It shouldn't be a big deal. In such a way we only need to override loss_phys.

"""
Aggregate the loss for each condition in the batch.

:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
:param dict kwargs: Additional keyword arguments passed to
``optimization_cycle``.
:return: The losses computed for all conditions in the batch, casted
to a subclass of :class:`torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict
"""
# compute non-aggregated residuals
residuals = self.optimization_cycle(batch)

# update weights based on residuals
self._update_weights(batch, batch_idx, residuals)

# compute losses
losses = {}
for cond, res in residuals.items():

# Get the correct indices for the weights. Modulus is used according
# to the number of points in the condition, as in the PinaDataset.
len_res = len(res)
idx = torch.arange(
batch_idx * len_res,
(batch_idx + 1) * len_res,
device=res.device,
) % len(self.problem.input_pts[cond])

losses[cond] = (res * self.weights[cond][idx]).mean()

# store log
self.store_log(
f"{cond}_loss", losses[cond].item(), self.get_batch_size(batch)
)
return ret

def loss_phys(self, samples, equation):
# clamp unknown parameters in InverseProblem (if needed)
self._clamp_params()

# aggregate
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)

return loss

def _update_weights(self, batch, batch_idx, residuals):
"""
Computes the physics loss for the physics-informed solver based on the
provided samples and equation.
Update weights based on residuals.

:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation.
:return: The computed physics loss.
:rtype: LabelTensor
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:param int batch_idx: The index of the current batch.
:param dict residuals: A dictionary containing the residuals for each
condition. The keys are the condition names and the values are the
residuals as tensors.
"""
residual = self.compute_residual(samples=samples, equation=equation)
cond = self.current_condition_name
# Iterate over each condition in the batch
for cond, data in batch:

r_norm = (
self.eta
* torch.abs(residual)
/ (torch.max(torch.abs(residual)) + 1e-12)
)
self.weights[cond] = (self.gamma * self.weights[cond] + r_norm).detach()
# Compute normalized residuals
res = residuals[cond]
res_abs = res.abs()
r_norm = (self.eta * res_abs) / (res_abs.max() + 1e-12)

loss_value = self._vectorial_loss(
torch.zeros_like(residual, requires_grad=True), residual
)
# Get the correct indices for the weights. Modulus is used according
# to the number of points in the condition, as in the PinaDataset.
len_pts = len(data["input"])
idx = torch.arange(
batch_idx * len_pts,
(batch_idx + 1) * len_pts,
device=res.device,
) % len(self.problem.input_pts[cond])

return self._vect_to_scalar(self.weights[cond] ** 2 * loss_value)
# Update weights
weights = self.weights[cond]
update = self.gamma * weights[idx] + r_norm
weights[idx] = update.detach()
36 changes: 12 additions & 24 deletions tests/test_solver/test_rba_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@
@pytest.mark.parametrize("eta", [1, 0.001])
@pytest.mark.parametrize("gamma", [0.5, 0.9])
def test_constructor(problem, eta, gamma):
with pytest.raises(AssertionError):
solver = RBAPINN(model=model, problem=problem, gamma=1.5)
solver = RBAPINN(model=model, problem=problem, eta=eta, gamma=gamma)

with pytest.raises(ValueError):
solver = RBAPINN(model=model, problem=problem, gamma=1.5)

assert solver.accepted_conditions_types == (
InputTargetCondition,
InputEquationCondition,
Expand All @@ -54,30 +55,15 @@ def test_constructor(problem, eta, gamma):


@pytest.mark.parametrize("problem", [problem, inverse_problem])
def test_wrong_batch(problem):
with pytest.raises(NotImplementedError):
solver = RBAPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=10,
train_size=1.0,
val_size=0.0,
test_size=0.0,
)
trainer.train()


@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_train(problem, compile):
def test_solver_train(problem, batch_size, compile):
solver = RBAPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=None,
batch_size=batch_size,
train_size=1.0,
val_size=0.0,
test_size=0.0,
Expand All @@ -89,14 +75,15 @@ def test_solver_train(problem, compile):


@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_validation(problem, compile):
def test_solver_validation(problem, batch_size, compile):
solver = RBAPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=None,
batch_size=batch_size,
train_size=0.9,
val_size=0.1,
test_size=0.0,
Expand All @@ -108,14 +95,15 @@ def test_solver_validation(problem, compile):


@pytest.mark.parametrize("problem", [problem, inverse_problem])
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_test(problem, compile):
def test_solver_test(problem, batch_size, compile):
solver = RBAPINN(model=model, problem=problem)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=None,
batch_size=batch_size,
train_size=0.7,
val_size=0.2,
test_size=0.1,
Expand Down