1
1
"""Module for the Residual-Based Attention PINN solver."""
2
2
3
- from copy import deepcopy
4
3
import torch
5
4
6
5
from .pinn import PINN
@@ -73,7 +72,6 @@ def __init__(
73
72
optimizer = None ,
74
73
scheduler = None ,
75
74
weighting = None ,
76
- loss = None ,
77
75
eta = 0.001 ,
78
76
gamma = 0.999 ,
79
77
):
@@ -90,99 +88,193 @@ def __init__(
90
88
scheduler is used. Default is ``None``.
91
89
:param WeightingInterface weighting: The weighting schema to be used.
92
90
If ``None``, no weighting schema is used. Default is ``None``.
93
- :param torch.nn.Module loss: The loss function to be minimized.
94
- If ``None``, the :class:`torch.nn.MSELoss` loss is used.
95
- Default is `None`.
96
91
:param float | int eta: The learning rate for the weights of the
97
92
residuals. Default is ``0.001``.
98
93
:param float gamma: The decay parameter in the update of the weights
99
94
of the residuals. Must be between ``0`` and ``1``.
100
95
Default is ``0.999``.
96
+ :raises: ValueError if `gamma` is not in the range (0, 1).
101
97
"""
102
98
super ().__init__ (
103
99
model = model ,
104
100
problem = problem ,
105
101
optimizer = optimizer ,
106
102
scheduler = scheduler ,
107
103
weighting = weighting ,
108
- loss = loss ,
104
+ loss = torch . nn . MSELoss ( reduction = "none" ) ,
109
105
)
110
106
111
107
# check consistency
112
108
check_consistency (eta , (float , int ))
113
109
check_consistency (gamma , float )
114
- assert (
115
- 0 < gamma < 1
116
- ), f"Invalid range: expected 0 < gamma < 1, got { gamma = } "
110
+
111
+ # Validate range for gamma
112
+ if not 0 < gamma < 1 :
113
+ raise ValueError (
114
+ f"Invalid range: expected 0 < gamma < 1, but got { gamma } "
115
+ )
116
+
117
+ # Initialize parameters
117
118
self .eta = eta
118
119
self .gamma = gamma
119
120
120
- # initialize weights
121
- self .weights = {}
122
- for condition_name in problem .conditions :
123
- self .weights [condition_name ] = 0
121
+ # Initialize the weight of each point to 0
122
+ self .weights = {
123
+ cond : torch .zeros ((len (data ), 1 ), device = self .device )
124
+ for cond , data in self .problem .input_pts .items ()
125
+ }
124
126
125
- # define vectorial loss
126
- self ._vectorial_loss = deepcopy (self .loss )
127
- self ._vectorial_loss .reduction = "none"
128
-
129
- # for now RBAPINN is implemented only for batch_size = None
130
127
def on_train_start (self ):
131
128
"""
132
129
Hook method called at the beginning of training.
133
-
134
- :raises NotImplementedError: If the batch size is not ``None``.
135
130
"""
136
- if self .trainer .batch_size is not None :
137
- raise NotImplementedError (
138
- "RBAPINN only works with full batch "
139
- "size, set batch_size=None inside the "
140
- "Trainer to use the solver."
141
- )
131
+ device = self .trainer .strategy .root_device
132
+ for cond in self .weights :
133
+ self .weights [cond ] = self .weights [cond ].to (device )
142
134
return super ().on_train_start ()
143
135
144
- def _vect_to_scalar (self , loss_value ):
136
+ def training_step (self , batch , batch_idx , ** kwargs ):
137
+ """
138
+ Solver training step. It computes the optimization cycle and aggregates
139
+ the losses using the ``weighting`` attribute.
140
+
141
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
142
+ tuple containing a condition name and a dictionary of points.
143
+ :param int batch_idx: The index of the current batch.
144
+ :param dict kwargs: Additional keyword arguments passed to
145
+ ``optimization_cycle``.
146
+ :return: The loss of the training step.
147
+ :rtype: torch.Tensor
148
+ """
149
+ loss = self ._optimization_cycle (
150
+ batch = batch , batch_idx = batch_idx , ** kwargs
151
+ )
152
+ self .store_log ("train_loss" , loss , self .get_batch_size (batch ))
153
+ return loss
154
+
155
+ @torch .set_grad_enabled (True )
156
+ def validation_step (self , batch , ** kwargs ):
157
+ """
158
+ The validation step for the PINN solver. It returns the average residual
159
+ computed with the ``loss`` function not aggregated.
160
+
161
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
162
+ tuple containing a condition name and a dictionary of points.
163
+ :param dict kwargs: Additional keyword arguments passed to
164
+ ``optimization_cycle``.
165
+ :return: The loss of the validation step.
166
+ :rtype: torch.Tensor
167
+ """
168
+ losses = self .optimization_cycle (batch = batch , ** kwargs )
169
+
170
+ # Aggregate losses for each condition
171
+ for cond , loss in losses .items ():
172
+ losses [cond ] = losses [cond ].mean ()
173
+
174
+ loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
175
+ self .store_log ("val_loss" , loss , self .get_batch_size (batch ))
176
+ return loss
177
+
178
+ @torch .set_grad_enabled (True )
179
+ def test_step (self , batch , ** kwargs ):
145
180
"""
146
- Computation of the scalar loss.
181
+ The test step for the PINN solver. It returns the average residual
182
+ computed with the ``loss`` function not aggregated.
147
183
148
- :param LabelTensor loss_value: the tensor of pointwise losses.
149
- :raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
150
- :return: The computed scalar loss.
151
- :rtype: LabelTensor
184
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
185
+ tuple containing a condition name and a dictionary of points.
186
+ :param dict kwargs: Additional keyword arguments passed to
187
+ ``optimization_cycle``.
188
+ :return: The loss of the test step.
189
+ :rtype: torch.Tensor
152
190
"""
153
- if self .loss .reduction == "mean" :
154
- ret = torch .mean (loss_value )
155
- elif self .loss .reduction == "sum" :
156
- ret = torch .sum (loss_value )
157
- else :
158
- raise RuntimeError (
159
- f"Invalid reduction, got { self .loss .reduction } "
160
- "but expected mean or sum."
191
+ losses = self .optimization_cycle (batch = batch , ** kwargs )
192
+
193
+ # Aggregate losses for each condition
194
+ for cond , loss in losses .items ():
195
+ losses [cond ] = losses [cond ].mean ()
196
+
197
+ loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
198
+ self .store_log ("test_loss" , loss , self .get_batch_size (batch ))
199
+ return loss
200
+
201
+ def _optimization_cycle (self , batch , batch_idx , ** kwargs ):
202
+ """
203
+ Aggregate the loss for each condition in the batch.
204
+
205
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
206
+ tuple containing a condition name and a dictionary of points.
207
+ :param int batch_idx: The index of the current batch.
208
+ :param dict kwargs: Additional keyword arguments passed to
209
+ ``optimization_cycle``.
210
+ :return: The losses computed for all conditions in the batch, casted
211
+ to a subclass of :class:`torch.Tensor`. It should return a dict
212
+ containing the condition name and the associated scalar loss.
213
+ :rtype: dict
214
+ """
215
+ # compute non-aggregated residuals
216
+ residuals = self .optimization_cycle (batch )
217
+
218
+ # update weights based on residuals
219
+ self ._update_weights (batch , batch_idx , residuals )
220
+
221
+ # compute losses
222
+ losses = {}
223
+ for cond , res in residuals .items ():
224
+
225
+ # Get the correct indices for the weights. Modulus is used according
226
+ # to the number of points in the condition, as in the PinaDataset.
227
+ len_res = len (res )
228
+ idx = torch .arange (
229
+ batch_idx * len_res ,
230
+ (batch_idx + 1 ) * len_res ,
231
+ device = res .device ,
232
+ ) % len (self .problem .input_pts [cond ])
233
+
234
+ losses [cond ] = (res * self .weights [cond ][idx ]).mean ()
235
+
236
+ # store log
237
+ self .store_log (
238
+ f"{ cond } _loss" , losses [cond ].item (), self .get_batch_size (batch )
161
239
)
162
- return ret
163
240
164
- def loss_phys (self , samples , equation ):
241
+ # clamp unknown parameters in InverseProblem (if needed)
242
+ self ._clamp_params ()
243
+
244
+ # aggregate
245
+ loss = self .weighting .aggregate (losses ).as_subclass (torch .Tensor )
246
+
247
+ return loss
248
+
249
+ def _update_weights (self , batch , batch_idx , residuals ):
165
250
"""
166
- Computes the physics loss for the physics-informed solver based on the
167
- provided samples and equation.
251
+ Update weights based on residuals.
168
252
169
- :param LabelTensor samples: The samples to evaluate the physics loss.
170
- :param EquationInterface equation: The governing equation.
171
- :return: The computed physics loss.
172
- :rtype: LabelTensor
253
+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
254
+ tuple containing a condition name and a dictionary of points.
255
+ :param int batch_idx: The index of the current batch.
256
+ :param dict residuals: A dictionary containing the residuals for each
257
+ condition. The keys are the condition names and the values are the
258
+ residuals as tensors.
173
259
"""
174
- residual = self . compute_residual ( samples = samples , equation = equation )
175
- cond = self . current_condition_name
260
+ # Iterate over each condition in the batch
261
+ for cond , data in batch :
176
262
177
- r_norm = (
178
- self .eta
179
- * torch .abs (residual )
180
- / (torch .max (torch .abs (residual )) + 1e-12 )
181
- )
182
- self .weights [cond ] = (self .gamma * self .weights [cond ] + r_norm ).detach ()
263
+ # Compute normalized residuals
264
+ res = residuals [cond ]
265
+ res_abs = res .abs ()
266
+ r_norm = (self .eta * res_abs ) / (res_abs .max () + 1e-12 )
183
267
184
- loss_value = self ._vectorial_loss (
185
- torch .zeros_like (residual , requires_grad = True ), residual
186
- )
268
+ # Get the correct indices for the weights. Modulus is used according
269
+ # to the number of points in the condition, as in the PinaDataset.
270
+ len_pts = len (data ["input" ])
271
+ idx = torch .arange (
272
+ batch_idx * len_pts ,
273
+ (batch_idx + 1 ) * len_pts ,
274
+ device = res .device ,
275
+ ) % len (self .problem .input_pts [cond ])
187
276
188
- return self ._vect_to_scalar (self .weights [cond ] ** 2 * loss_value )
277
+ # Update weights
278
+ weights = self .weights [cond ]
279
+ update = self .gamma * weights [idx ] + r_norm
280
+ weights [idx ] = update .detach ()
0 commit comments