@@ -20,26 +20,26 @@ class AbstractProblem(metaclass=ABCMeta):
2020
2121 def __init__ (self ):
2222
23-
2423 self ._discretized_domains = {}
2524
2625 for name , domain in self .domains .items ():
2726 if isinstance (domain , (torch .Tensor , LabelTensor )):
2827 self ._discretized_domains [name ] = domain
2928
3029 for condition_name in self .conditions :
31- self .conditions [condition_name ]._problem = self
30+ self .conditions [condition_name ].set_problem (self )
31+
3232 # # variable storing all points
33- # self.input_pts = {}
33+ self .input_pts = {}
3434
3535 # # varible to check if sampling is done. If no location
3636 # # element is presented in Condition this variable is set to true
3737 # self._have_sampled_points = {}
38- # for condition_name in self.conditions:
39- # self._have_sampled_points [condition_name] = False
38+ for condition_name in self .conditions :
39+ self ._discretized_domains [condition_name ] = False
4040
4141 # # put in self.input_pts all the points that we don't need to sample
42- # self._span_condition_points()
42+ self ._span_condition_points ()
4343
4444 def __deepcopy__ (self , memo ):
4545 """
@@ -125,7 +125,7 @@ def _span_condition_points(self):
125125 if hasattr (condition , "input_points" ):
126126 samples = condition .input_points
127127 self .input_pts [condition_name ] = samples
128- self ._have_sampled_points [condition_name ] = True
128+ self ._discretized_domains [condition_name ] = True
129129 if hasattr (self , "unknown_parameter_domain" ):
130130 # initialize the unknown parameters of the inverse problem given
131131 # the domain the user gives
@@ -141,7 +141,7 @@ def _span_condition_points(self):
141141 )
142142
143143 def discretise_domain (
144- self , n , mode = "random" , variables = "all" , locations = "all"
144+ self , n , mode = "random" , variables = "all" , domains = "all"
145145 ):
146146 """
147147 Generate a set of points to span the `Location` of all the conditions of
@@ -193,24 +193,24 @@ def discretise_domain(
193193 )
194194
195195 # check consistency location
196- if locations == "all" :
197- locations = [condition for condition in self .conditions ]
196+ if domains == "all" :
197+ domains = [condition for condition in self .conditions ]
198198 else :
199- check_consistency (locations , str )
200-
201- if sorted (locations ) != sorted (self .conditions ):
199+ check_consistency (domains , str )
200+ print ( domains )
201+ if sorted (domains ) != sorted (self .conditions ):
202202 TypeError (
203203 f"Wrong locations for sampling. Location " ,
204204 f"should be in { self .conditions } ." ,
205205 )
206206
207207 # sampling
208- for location in locations :
209- condition = self .conditions [location ]
208+ for d in domains :
209+ condition = self .conditions [d ]
210210
211211 # we try to check if we have already sampled
212212 try :
213- already_sampled = [self .input_pts [location ]]
213+ already_sampled = [self .input_pts [d ]]
214214 # if we have not sampled, a key error is thrown
215215 except KeyError :
216216 already_sampled = []
@@ -219,22 +219,23 @@ def discretise_domain(
219219 # but we want to sample again we set already_sampled
220220 # to an empty list since we need to sample again, and
221221 # self._have_sampled_points to False.
222- if self ._have_sampled_points [ location ]:
222+ if self ._discretized_domains [ d ]:
223223 already_sampled = []
224- self ._have_sampled_points [location ] = False
225-
224+ self ._discretized_domains [d ] = False
225+ print (condition .domain )
226+ print (d )
226227 # build samples
227228 samples = [
228- condition . location .sample (n = n , mode = mode , variables = variables )
229+ self . domains [ d ] .sample (n = n , mode = mode , variables = variables )
229230 ] + already_sampled
230231 pts = merge_tensors (samples )
231- self .input_pts [location ] = pts
232+ self .input_pts [d ] = pts
232233
233234 # the condition is sampled if input_pts contains all labels
234- if sorted (self .input_pts [location ].labels ) == sorted (
235+ if sorted (self .input_pts [d ].labels ) == sorted (
235236 self .input_variables
236237 ):
237- self ._have_sampled_points [location ] = True
238+ self ._have_sampled_points [d ] = True
238239
239240 def add_points (self , new_points ):
240241 """
0 commit comments