Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pina/callbacks/adaptive_refinment_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _compute_residual(self, trainer):
"""

# extract the solver and device from trainer
solver = trainer._model
solver = trainer.solver
device = trainer._accelerator_connector._accelerator_flag
precision = trainer.precision
if precision == "64-true":
Expand All @@ -67,7 +67,7 @@ def _compute_residual(self, trainer):
# compute residual
res_loss = {}
tot_loss = []
for location in self._sampling_locations:
for location in self._sampling_locations: #TODO fix for new collector
condition = solver.problem.conditions[location]
pts = solver.problem.input_pts[location]
# send points to correct device
Expand All @@ -79,6 +79,8 @@ def _compute_residual(self, trainer):
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target))

print(tot_loss)

return torch.vstack(tot_loss), res_loss

def _r3_routine(self, trainer):
Expand Down Expand Up @@ -139,7 +141,7 @@ def on_train_start(self, trainer, _):
:rtype: None
"""
# extract locations for sampling
problem = trainer._model.problem
problem = trainer.solver.problem
locations = []
for condition_name in problem.conditions:
condition = problem.conditions[condition_name]
Expand Down
60 changes: 20 additions & 40 deletions pina/callbacks/optimizer_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,61 +3,45 @@
from lightning.pytorch.callbacks import Callback
import torch
from ..utils import check_consistency
from pina.optim import TorchOptimizer


class SwitchOptimizer(Callback):

def __init__(self, new_optimizers, new_optimizers_kwargs, epoch_switch):
def __init__(self, new_optimizers, epoch_switch):
"""
PINA Implementation of a Lightning Callback to switch optimizer during training.
PINA Implementation of a Lightning Callback to switch optimizer during
training.

This callback allows for switching between different optimizers during training, enabling
the exploration of multiple optimization strategies without the need to stop training.
This callback allows for switching between different optimizers during
training, enabling the exploration of multiple optimization strategies
without the need to stop training.

:param new_optimizers: The model optimizers to switch to. Can be a single
:class:`torch.optim.Optimizer` or a list of them for multiple model solvers.
:type new_optimizers: torch.optim.Optimizer | list
:param new_optimizers_kwargs: The keyword arguments for the new optimizers. Can be a single dictionary
or a list of dictionaries corresponding to each optimizer.
:type new_optimizers_kwargs: dict | list
:param new_optimizers: The model optimizers to switch to. Can be a
single :class:`torch.optim.Optimizer` or a list of them for multiple
model solvers.
:type new_optimizers: pina.optim.TorchOptimizer | list
:param epoch_switch: The epoch at which to switch to the new optimizer.
:type epoch_switch: int

:raises ValueError: If `epoch_switch` is less than 1 or if there is a mismatch in the number of
optimizers and their corresponding keyword argument dictionaries.

Example:
>>> switch_callback = SwitchOptimizer(new_optimizers=[optimizer1, optimizer2],
>>> new_optimizers_kwargs=[{'lr': 0.001}, {'lr': 0.01}],
>>> switch_callback = SwitchOptimizer(new_optimizers=optimizer,
>>> epoch_switch=10)
"""
super().__init__()

# check type consistency
check_consistency(new_optimizers, torch.optim.Optimizer, subclass=True)
check_consistency(new_optimizers_kwargs, dict)
check_consistency(epoch_switch, int)

if epoch_switch < 1:
raise ValueError("epoch_switch must be greater than one.")

if not isinstance(new_optimizers, list):
new_optimizers = [new_optimizers]
new_optimizers_kwargs = [new_optimizers_kwargs]
len_optimizer = len(new_optimizers)
len_optimizer_kwargs = len(new_optimizers_kwargs)

if len_optimizer_kwargs != len_optimizer:
raise ValueError(
"You must define one dictionary of keyword"
" arguments for each optimizers."
f" Got {len_optimizer} optimizers, and"
f" {len_optimizer_kwargs} dicitionaries"
)

# check type consistency
for optimizer in new_optimizers:
check_consistency(optimizer, TorchOptimizer)
check_consistency(epoch_switch, int)
# save new optimizers
self._new_optimizers = new_optimizers
self._new_optimizers_kwargs = new_optimizers_kwargs
self._epoch_switch = epoch_switch

def on_train_epoch_start(self, trainer, __):
Expand All @@ -73,13 +57,9 @@ def on_train_epoch_start(self, trainer, __):
"""
if trainer.current_epoch == self._epoch_switch:
optims = []
for idx, (optim, optim_kwargs) in enumerate(
zip(self._new_optimizers, self._new_optimizers_kwargs)
):
optims.append(
optim(
trainer._model.models[idx].parameters(), **optim_kwargs
)
)

for idx, optim in enumerate(self._new_optimizers):
optim.hook(trainer.solver.models[idx].parameters())
optims.append(optim.instance)

trainer.optimizers = optims
7 changes: 5 additions & 2 deletions pina/callbacks/processing_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class MetricTracker(Callback):

def __init__(self):
def __init__(self, metrics_to_track=None):
"""
PINA Implementation of a Lightning Callback for Metric Tracking.

Expand All @@ -37,6 +37,9 @@ def __init__(self):
"""
super().__init__()
self._collection = []
if metrics_to_track is not None:
metrics_to_track = ['train_loss_epoch', 'train_loss_step', 'val_loss']
self.metrics_to_track = metrics_to_track

def on_train_epoch_end(self, trainer, pl_module):
"""
Expand Down Expand Up @@ -72,7 +75,7 @@ class PINAProgressBar(TQDMProgressBar):

BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"

def __init__(self, metrics="mean", **kwargs):
def __init__(self, metrics="val_loss", **kwargs):
"""
PINA Implementation of a Lightning Callback for enriching the progress
bar.
Expand Down
4 changes: 3 additions & 1 deletion pina/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__all__ = [
'LpLoss',

'PowerLoss',
'weightningInterface',
'LossInterface'
]

from .loss_interface import LossInterface
Expand Down
3 changes: 3 additions & 0 deletions pina/optim/torch_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ def __init__(self, optimizer_class, **kwargs):
def hook(self, parameters):
self.optimizer_instance = self.optimizer_class(parameters,
**self.kwargs)
@property
def instance(self):
return self.optimizer_instance
2 changes: 1 addition & 1 deletion pina/solvers/pinns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"RBAPINN",
]

from .basepinn import PINNInterface
from .pinn_interface import PINNInterface
from .pinn import PINN
from .gpinn import GPINN
from .causalpinn import CausalPINN
Expand Down
2 changes: 1 addition & 1 deletion pina/solvers/pinns/competitive_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from torch.optim.lr_scheduler import ConstantLR

from .basepinn import PINNInterface
from .pinn_interface import PINNInterface
from pina.utils import check_consistency
from pina.problem import InverseProblem

Expand Down
6 changes: 2 additions & 4 deletions pina/solvers/pinns/pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
) # torch < 2.0


from .basepinn import PINNInterface
from .pinn_interface import PINNInterface
from ...problem import InverseProblem


Expand Down Expand Up @@ -60,7 +60,6 @@ def __init__(
self,
problem,
model,
extra_features=None,
loss=None,
optimizer=None,
scheduler=None,
Expand All @@ -82,10 +81,9 @@ def __init__(
super().__init__(
models=model,
problem=problem,
loss=loss,
optimizers=optimizer,
schedulers=scheduler,
extra_features=extra_features,
loss=loss,
)

# assign variables
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABCMeta, abstractmethod
import torch
from torch.nn.modules.loss import _Loss
from ...solvers.solver import SolverInterface
from ..solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
Expand Down Expand Up @@ -33,10 +33,9 @@ def __init__(
self,
models,
problem,
optimizers,
schedulers,
extra_features,
loss,
loss=None,
optimizers=None,
schedulers=None,
):
"""
:param models: Multiple torch neural network models instances.
Expand Down Expand Up @@ -70,7 +69,6 @@ def __init__(
problem=problem,
optimizers=optimizers,
schedulers=schedulers,
extra_features=extra_features,
)

# check consistency
Expand Down Expand Up @@ -198,6 +196,11 @@ def loss_phys(self, samples, equation):
"""
pass

def configure_optimizers(self):
self._optimizer.hook(self._model)
self.schedulers.hook(self._optimizer)
return [self.optimizers.instance]#, self.schedulers.scheduler_instance

def compute_residual(self, samples, equation):
"""
Compute the residual for Physics Informed learning. This function
Expand Down
2 changes: 1 addition & 1 deletion pina/solvers/pinns/sapinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
_LRScheduler as LRScheduler,
) # torch < 2.0

from .basepinn import PINNInterface
from .pinn_interface import PINNInterface
from pina.utils import check_consistency
from pina.problem import InverseProblem

Expand Down
4 changes: 1 addition & 3 deletions pina/solvers/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self,
problem,
optimizers,
schedulers,
extra_features,
use_lt=True):
"""
:param model: A torch neural network model instance.
Expand Down Expand Up @@ -56,7 +55,6 @@ def __init__(self,
model=model,
input_variables=problem.input_variables,
output_variables=problem.output_variables,
extra_features=extra_features,
)

# Check scheduler consistency + encapsulation
Expand Down Expand Up @@ -98,7 +96,7 @@ def training_step(self, batch):

@abstractmethod
def configure_optimizers(self):
pass
raise NotImplementedError

@property
def models(self):
Expand Down
49 changes: 25 additions & 24 deletions tests/test_callbacks/test_adaptive_refinment_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,35 @@
boundaries = ['nil_g1', 'nil_g2', 'nil_g3', 'nil_g4']
n = 10
poisson_problem.discretise_domain(n, 'grid', locations=boundaries)
poisson_problem.discretise_domain(n, 'grid', locations='laplace_D')
model = FeedForward(len(poisson_problem.input_variables),
len(poisson_problem.output_variables))

# make the solver
solver = PINN(problem=poisson_problem, model=model)


def test_r3constructor():
R3Refinement(sample_every=10)


def test_r3refinment_routine():
# make the trainer
trainer = Trainer(solver=solver,
callbacks=[R3Refinement(sample_every=1)],
accelerator='cpu',
max_epochs=5)
trainer.train()

def test_r3refinment_routine():
model = FeedForward(len(poisson_problem.input_variables),
len(poisson_problem.output_variables))
solver = PINN(problem=poisson_problem, model=model)
trainer = Trainer(solver=solver,
callbacks=[R3Refinement(sample_every=1)],
accelerator='cpu',
max_epochs=5)
before_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
trainer.train()
after_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
assert before_n_points == after_n_points
# def test_r3constructor():
# R3Refinement(sample_every=10)


# def test_r3refinment_routine():
# # make the trainer
# trainer = Trainer(solver=solver,
# callbacks=[R3Refinement(sample_every=1)],
# accelerator='cpu',
# max_epochs=5)
# trainer.train()

# def test_r3refinment_routine():
# model = FeedForward(len(poisson_problem.input_variables),
# len(poisson_problem.output_variables))
# solver = PINN(problem=poisson_problem, model=model)
# trainer = Trainer(solver=solver,
# callbacks=[R3Refinement(sample_every=1)],
# accelerator='cpu',
# max_epochs=5)
# before_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
# trainer.train()
# after_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
# assert before_n_points == after_n_points
Loading
Loading