Skip to content

Commit de8b4e6

Browse files
authored
Fix bugs in 0.2 (#344)
* Fix some bugs
1 parent 36c447a commit de8b4e6

File tree

11 files changed

+95
-45
lines changed

11 files changed

+95
-45
lines changed

examples/problems/stokes.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,19 @@ def wall(input_, output_):
4949
value = 0.0
5050
return output_.extract(['ux', 'uy']) - value
5151

52+
domains = {
53+
'gamma_top': CartesianDomain({'x': [-2, 2], 'y': 1}),
54+
'gamma_bot': CartesianDomain({'x': [-2, 2], 'y': -1}),
55+
'gamma_out': CartesianDomain({'x': 2, 'y': [-1, 1]}),
56+
'gamma_in': CartesianDomain({'x': -2, 'y': [-1, 1]}),
57+
'D': CartesianDomain({'x': [-2, 2], 'y': [-1, 1]})
58+
}
59+
5260
# problem condition statement
5361
conditions = {
54-
'gamma_top': Condition(location=CartesianDomain({'x': [-2, 2], 'y': 1}), equation=Equation(wall)),
55-
'gamma_bot': Condition(location=CartesianDomain({'x': [-2, 2], 'y': -1}), equation=Equation(wall)),
56-
'gamma_out': Condition(location=CartesianDomain({'x': 2, 'y': [-1, 1]}), equation=Equation(outlet)),
57-
'gamma_in': Condition(location=CartesianDomain({'x': -2, 'y': [-1, 1]}), equation=Equation(inlet)),
58-
'D': Condition(location=CartesianDomain({'x': [-2, 2], 'y': [-1, 1]}), equation=SystemEquation([momentum, continuity]))
62+
'gamma_top': Condition(domain='gamma_top', equation=Equation(wall)),
63+
'gamma_bot': Condition(domain='gamma_bot', equation=Equation(wall)),
64+
'gamma_out': Condition(domain='gamma_out', equation=Equation(outlet)),
65+
'gamma_in': Condition(domain='gamma_in', equation=Equation(inlet)),
66+
'D': Condition(domain='D', equation=SystemEquation([momentum, continuity]))
5967
}

examples/run_stokes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
# create problem and discretise domain
1919
stokes_problem = Stokes()
20-
stokes_problem.discretise_domain(n=1000, locations=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
21-
stokes_problem.discretise_domain(n=2000, locations=['D'])
20+
stokes_problem.discretise_domain(n=1000, domains=['gamma_top', 'gamma_bot', 'gamma_in', 'gamma_out'])
21+
stokes_problem.discretise_domain(n=2000, domains=['D'])
2222

2323
# make the model
2424
model = FeedForward(

pina/condition/condition.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,15 @@ def __new__(cls, *args, **kwargs):
8484
return DomainEquationCondition(**kwargs)
8585
else:
8686
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
87-
87+
# TODO: remove, not used anymore
88+
'''
8889
if (
8990
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
9091
and sorted(kwargs.keys()) != sorted(["location", "equation"])
9192
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
9293
):
9394
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
94-
95+
# TODO: remove, not used anymore
9596
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
9697
raise TypeError("`input_points` must be a torch.Tensor.")
9798
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
@@ -103,3 +104,4 @@ def __new__(cls, *args, **kwargs):
103104
104105
for key, value in kwargs.items():
105106
setattr(self, key, value)
107+
'''

pina/condition/condition_interface.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,7 @@ def residual(self, model):
1515
:param model: The model to evaluate the condition.
1616
:return: The residual of the condition.
1717
"""
18-
pass
18+
pass
19+
20+
def set_problem(self, problem):
21+
self._problem = problem

pina/condition/domain_equation_condition.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,20 @@ def __init__(self, domain, equation):
1515
self.domain = domain
1616
self.equation = equation
1717

18+
def residual(self, model):
19+
"""
20+
Compute the residual of the condition.
21+
"""
22+
self.batch_residual(model, self.domain, self.equation)
23+
1824
@staticmethod
1925
def batch_residual(model, input_pts, equation):
2026
"""
2127
Compute the residual of the condition for a single batch. Input and
2228
output points are provided as arguments.
2329
2430
:param torch.nn.Module model: The model to evaluate the condition.
25-
:param torch.Tensor input_points: The input points.
26-
:param torch.Tensor output_points: The output points.
31+
:param torch.Tensor input_pts: The input points.
32+
:param torch.Tensor equation: The output points.
2733
"""
28-
return equation.residual(model(input_pts))
34+
return equation.residual(input_pts, model(input_pts))

pina/condition/domain_output_condition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,5 @@ def batch_residual(model, input_points, output_points):
4040
:param torch.Tensor input_points: The input points.
4141
:param torch.Tensor output_points: The output points.
4242
"""
43+
4344
return output_points - model(input_points)

pina/domain/cartesian.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch
23

34
from .domain_interface import DomainInterface
45
from ..label_tensor import LabelTensor

pina/label_tensor.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch import Tensor
66

77

8-
98
# class LabelTensor(torch.Tensor):
109
# """Torch tensor with a label for any column."""
1110

@@ -307,13 +306,13 @@
307306
# s = "no labels\n"
308307
# s += super().__str__()
309308
# return s
310-
311309
def issubset(a, b):
312310
"""
313311
Check if a is a subset of b.
314312
"""
315313
return set(a).issubset(set(b))
316314

315+
317316
class LabelTensor(torch.Tensor):
318317
"""Torch tensor with a label for any column."""
319318

@@ -403,6 +402,10 @@ def extract(self, label_to_extract):
403402
return LabelTensor(new_tensor, label_to_extract)
404403

405404
def __str__(self):
405+
"""
406+
returns a string with the representation of the class
407+
"""
408+
406409
s = ''
407410
for key, value in self.labels.items():
408411
s += f"{key}: {value}\n"
@@ -431,4 +434,32 @@ def requires_grad_(self, mode=True):
431434

432435
@property
433436
def dtype(self):
434-
return super().dtype
437+
return super().dtype
438+
439+
440+
def to(self, *args, **kwargs):
441+
"""
442+
Performs Tensor dtype and/or device conversion. For more details, see
443+
:meth:`torch.Tensor.to`.
444+
"""
445+
tmp = super().to(*args, **kwargs)
446+
new = self.__class__.clone(self)
447+
new.data = tmp.data
448+
return new
449+
450+
451+
def clone(self, *args, **kwargs):
452+
"""
453+
Clone the LabelTensor. For more details, see
454+
:meth:`torch.Tensor.clone`.
455+
456+
:return: A copy of the tensor.
457+
:rtype: LabelTensor
458+
"""
459+
# # used before merging
460+
# try:
461+
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
462+
# except:
463+
# out = super().clone(*args, **kwargs)
464+
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
465+
return out

pina/problem/abstract_problem.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,26 @@ class AbstractProblem(metaclass=ABCMeta):
2020

2121
def __init__(self):
2222

23-
2423
self._discretized_domains = {}
2524

2625
for name, domain in self.domains.items():
2726
if isinstance(domain, (torch.Tensor, LabelTensor)):
2827
self._discretized_domains[name] = domain
2928

3029
for condition_name in self.conditions:
31-
self.conditions[condition_name]._problem = self
30+
self.conditions[condition_name].set_problem(self)
31+
3232
# # variable storing all points
33-
# self.input_pts = {}
33+
self.input_pts = {}
3434

3535
# # varible to check if sampling is done. If no location
3636
# # element is presented in Condition this variable is set to true
3737
# self._have_sampled_points = {}
38-
# for condition_name in self.conditions:
39-
# self._have_sampled_points[condition_name] = False
38+
for condition_name in self.conditions:
39+
self._discretized_domains[condition_name] = False
4040

4141
# # put in self.input_pts all the points that we don't need to sample
42-
# self._span_condition_points()
42+
self._span_condition_points()
4343

4444
def __deepcopy__(self, memo):
4545
"""
@@ -125,7 +125,7 @@ def _span_condition_points(self):
125125
if hasattr(condition, "input_points"):
126126
samples = condition.input_points
127127
self.input_pts[condition_name] = samples
128-
self._have_sampled_points[condition_name] = True
128+
self._discretized_domains[condition_name] = True
129129
if hasattr(self, "unknown_parameter_domain"):
130130
# initialize the unknown parameters of the inverse problem given
131131
# the domain the user gives
@@ -141,7 +141,7 @@ def _span_condition_points(self):
141141
)
142142

143143
def discretise_domain(
144-
self, n, mode="random", variables="all", locations="all"
144+
self, n, mode="random", variables="all", domains="all"
145145
):
146146
"""
147147
Generate a set of points to span the `Location` of all the conditions of
@@ -193,24 +193,24 @@ def discretise_domain(
193193
)
194194

195195
# check consistency location
196-
if locations == "all":
197-
locations = [condition for condition in self.conditions]
196+
if domains == "all":
197+
domains = [condition for condition in self.conditions]
198198
else:
199-
check_consistency(locations, str)
200-
201-
if sorted(locations) != sorted(self.conditions):
199+
check_consistency(domains, str)
200+
print(domains)
201+
if sorted(domains) != sorted(self.conditions):
202202
TypeError(
203203
f"Wrong locations for sampling. Location ",
204204
f"should be in {self.conditions}.",
205205
)
206206

207207
# sampling
208-
for location in locations:
209-
condition = self.conditions[location]
208+
for d in domains:
209+
condition = self.conditions[d]
210210

211211
# we try to check if we have already sampled
212212
try:
213-
already_sampled = [self.input_pts[location]]
213+
already_sampled = [self.input_pts[d]]
214214
# if we have not sampled, a key error is thrown
215215
except KeyError:
216216
already_sampled = []
@@ -219,22 +219,23 @@ def discretise_domain(
219219
# but we want to sample again we set already_sampled
220220
# to an empty list since we need to sample again, and
221221
# self._have_sampled_points to False.
222-
if self._have_sampled_points[location]:
222+
if self._discretized_domains[d]:
223223
already_sampled = []
224-
self._have_sampled_points[location] = False
225-
224+
self._discretized_domains[d] = False
225+
print(condition.domain)
226+
print(d)
226227
# build samples
227228
samples = [
228-
condition.location.sample(n=n, mode=mode, variables=variables)
229+
self.domains[d].sample(n=n, mode=mode, variables=variables)
229230
] + already_sampled
230231
pts = merge_tensors(samples)
231-
self.input_pts[location] = pts
232+
self.input_pts[d] = pts
232233

233234
# the condition is sampled if input_pts contains all labels
234-
if sorted(self.input_pts[location].labels) == sorted(
235+
if sorted(self.input_pts[d].labels) == sorted(
235236
self.input_variables
236237
):
237-
self._have_sampled_points[location] = True
238+
self._have_sampled_points[d] = True
238239

239240
def add_points(self, new_points):
240241
"""

pina/solvers/supervised.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,6 @@ def training_step(self, batch, batch_idx):
134134
condition = self.problem.conditions[condition_name]
135135
pts = batch.input
136136
out = batch.output
137-
print(out)
138-
print(pts)
139137

140138
if condition_name not in self.problem.conditions:
141139
raise RuntimeError("Something wrong happened.")

0 commit comments

Comments
 (0)