1717import sys
1818
1919import numpy as np
20+ from numpy .compat import unicode
2021
2122from diffpy .srmise .modelparts import ModelPart , ModelParts
2223from diffpy .srmise .srmiseerrors import *
2324
2425logger = logging .getLogger ("diffpy.srmise" )
2526
27+
2628class BaseFunction (object ):
2729 """Base class for mathematical functions which model numeric sequences.
2830
@@ -61,7 +63,15 @@ class BaseFunction(object):
6163 transform_parameters()
6264 """
6365
64- def __init__ (self , parameterdict , parformats , default_formats , metadict , base = None , Cache = None ):
66+ def __init__ (
67+ self ,
68+ parameterdict ,
69+ parformats ,
70+ default_formats ,
71+ metadict ,
72+ base = None ,
73+ Cache = None ,
74+ ):
6575 """Set parameterdict defined by subclass
6676
6777 Parameters
@@ -96,23 +106,31 @@ def __init__(self, parameterdict, parformats, default_formats, metadict, base=No
96106 vals = self .parameterdict .values ()
97107 vals .sort ()
98108 if vals != range (self .npars ):
99- emsg = "Argument parameterdict's values must uniquely specify " + \
100- "the index of each parameter defined by its keys."
109+ emsg = (
110+ "Argument parameterdict's values must uniquely specify "
111+ + "the index of each parameter defined by its keys."
112+ )
101113 raise ValueError (emsg )
102114
103115 self .parformats = parformats
104116
105117 # Check validity of default_formats
106118 self .default_formats = default_formats
107- if not ("default_input" in self .default_formats and
108- "default_output" in self .default_formats ):
109- emsg = "Argument default_formats must specify 'default_input' " + \
110- "and 'default_output' as keys."
119+ if not (
120+ "default_input" in self .default_formats
121+ and "default_output" in self .default_formats
122+ ):
123+ emsg = (
124+ "Argument default_formats must specify 'default_input' "
125+ + "and 'default_output' as keys."
126+ )
111127 raise ValueError (emsg )
112128 for f in self .default_formats .values ():
113129 if not f in self .parformats :
114- emsg = "Keys of argument default_formats must map to a " + \
115- "value within argument parformats."
130+ emsg = (
131+ "Keys of argument default_formats must map to a "
132+ + "value within argument parformats."
133+ )
116134 raise ValueError ()
117135
118136 # Set metadictionary
@@ -126,12 +144,11 @@ def __init__(self, parameterdict, parformats, default_formats, metadict, base=No
126144 # of PeakFunction.
127145 # Object to cache: (basefunctioninstance, tuple of parameters)
128146 if Cache is not None :
129- #self.value = Cache(self.value, "value")
130- #self.jacobian = Cache(self.jacobian, "jacobian")
147+ # self.value = Cache(self.value, "value")
148+ # self.jacobian = Cache(self.jacobian, "jacobian")
131149 pass
132150 return
133151
134-
135152 #### "Virtual" class methods ####
136153
137154 def actualize (self , * args , ** kwds ):
@@ -164,7 +181,6 @@ def _valueraw(self, *args, **kwds):
164181 emsg = "_valueraw must() be implemented in a BaseFunction subclass."
165182 raise NotImplementedError (emsg )
166183
167-
168184 #### Class methods ####
169185
170186 def jacobian (self , p , r , rng = None ):
@@ -179,8 +195,10 @@ def jacobian(self, p, r, rng=None):
179195 previously calculated values instead.
180196 """
181197 if self is not p ._owner :
182- emsg = "Argument 'p' must be evaluated by the BaseFunction " + \
183- "subclass which owns it."
198+ emsg = (
199+ "Argument 'p' must be evaluated by the BaseFunction "
200+ + "subclass which owns it."
201+ )
184202 raise ValueError (emsg )
185203
186204 # normally r will be a sequence, but also allow single numeric values
@@ -192,7 +210,7 @@ def jacobian(self, p, r, rng=None):
192210 output = [None for j in jac ]
193211 for idx in range (len (output )):
194212 if jac [idx ] is not None :
195- output [idx ] = r * 0.
213+ output [idx ] = r * 0.0
196214 output [idx ][rng ] = jac [idx ]
197215 return output
198216 except TypeError :
@@ -201,10 +219,10 @@ def jacobian(self, p, r, rng=None):
201219 def transform_derivatives (self , pars , in_format = None , out_format = None ):
202220 """Return gradient matrix for pars converted from in_format to out_format.
203221
204- Parameters
205- pars - Sequence of parameters
206- in_format - A format defined for this class
207- out_format - A format defined for this class
222+ Parameters
223+ pars - Sequence of parameters
224+ in_format - A format defined for this class
225+ out_format - A format defined for this class
208226 """
209227 # Map unspecified formats to specific formats defined in default_formats
210228 if in_format is None :
@@ -223,25 +241,29 @@ def transform_derivatives(self, pars, in_format=None, out_format=None):
223241 out_format = self .default_formats ["default_input" ]
224242
225243 if not in_format in self .parformats :
226- raise ValueError ("Argument 'in_format' must be one of %s." \
227- % self .parformats )
244+ raise ValueError (
245+ "Argument 'in_format' must be one of %s." % self .parformats
246+ )
228247 if not out_format in self .parformats :
229- raise ValueError ("Argument 'out_format' must be one of %s." \
230- % self .parformats )
248+ raise ValueError (
249+ "Argument 'out_format' must be one of %s." % self .parformats
250+ )
231251 if in_format == out_format :
232252 return np .identity (self .npars )
233- return self ._transform_derivativesraw (pars , in_format = in_format , out_format = out_format )
253+ return self ._transform_derivativesraw (
254+ pars , in_format = in_format , out_format = out_format
255+ )
234256
235257 def transform_parameters (self , pars , in_format = None , out_format = None ):
236258 """Return new sequence with pars converted from in_format to out_format.
237259
238- Also restores parameters to a preferred range if it permits multiple
239- values that correspond to the same physical result.
260+ Also restores parameters to a preferred range if it permits multiple
261+ values that correspond to the same physical result.
240262
241- Parameters
242- pars - Sequence of parameters
243- in_format - A format defined for this class
244- out_format - A format defined for this class
263+ Parameters
264+ pars - Sequence of parameters
265+ in_format - A format defined for this class
266+ out_format - A format defined for this class
245267 """
246268 # Map unspecified formats to specific formats defined in default_formats
247269 if in_format is None :
@@ -260,15 +282,18 @@ def transform_parameters(self, pars, in_format=None, out_format=None):
260282 out_format = self .default_formats ["default_input" ]
261283
262284 if not in_format in self .parformats :
263- raise ValueError ("Argument 'in_format' must be one of %s." \
264- % self .parformats )
285+ raise ValueError (
286+ "Argument 'in_format' must be one of %s." % self .parformats
287+ )
265288 if not out_format in self .parformats :
266- raise ValueError ("Argument 'out_format' must be one of %s." \
267- % self .parformats )
268- #if in_format == out_format:
289+ raise ValueError (
290+ "Argument 'out_format' must be one of %s." % self .parformats
291+ )
292+ # if in_format == out_format:
269293 # return pars
270- return self ._transform_parametersraw (pars , in_format = in_format , out_format = out_format )
271-
294+ return self ._transform_parametersraw (
295+ pars , in_format = in_format , out_format = out_format
296+ )
272297
273298 def value (self , p , r , rng = None ):
274299 """Calculate value of ModelPart over r, possibly restricted by range.
@@ -282,16 +307,18 @@ def value(self, p, r, rng=None):
282307 previously calculated values instead.
283308 """
284309 if self is not p ._owner :
285- emsg = "Argument 'p' must be evaluated by the BaseFunction " + \
286- "subclass which owns it."
310+ emsg = (
311+ "Argument 'p' must be evaluated by the BaseFunction "
312+ + "subclass which owns it."
313+ )
287314 raise ValueError (emsg )
288315
289316 # normally r will be a sequence, but also allow single numeric values
290317 try :
291318 if rng is None :
292319 rng = slice (0 , len (r ))
293320 rpart = r [rng ]
294- output = r * 0.
321+ output = r * 0.0
295322 output [rng ] = self ._valueraw (p .pars , rpart )
296323 return output
297324 except TypeError :
@@ -338,17 +365,17 @@ def writestr(self, baselist):
338365 raise ValueError ("emsg" )
339366 lines = []
340367 # Write function type
341- lines .append ("function=%s" % repr (self .__class__ .__name__ ))
342- lines .append ("module=%s" % repr (self .getmodule ()))
368+ lines .append ("function=%s" % repr (self .__class__ .__name__ ))
369+ lines .append ("module=%s" % repr (self .getmodule ()))
343370 # Write base
344371 if self .base is not None :
345- lines .append ("base=%s" % repr (baselist .index (self .base )))
372+ lines .append ("base=%s" % repr (baselist .index (self .base )))
346373 else :
347- lines .append ("base=%s" % repr (None ))
374+ lines .append ("base=%s" % repr (None ))
348375 # Write all other metadata
349376 for k , (v , f ) in self .metadict .iteritems ():
350- lines .append ("%s=%s" % (k , f (v )))
351- datastring = "\n " .join (lines )+ "\n "
377+ lines .append ("%s=%s" % (k , f (v )))
378+ datastring = "\n " .join (lines ) + "\n "
352379 return datastring
353380
354381 @staticmethod
@@ -367,19 +394,19 @@ def factory(functionstr, baselist):
367394
368395 # populate dictionary with parameter definition
369396 # "key=value"->{"key":"value"}
370- data = re .split (r' (?:[\r\n]+|\A)(\S+)=' , data )
397+ data = re .split (r" (?:[\r\n]+|\A)(\S+)=" , data )
371398 ddict = {}
372- for i in range (len (data )/ 2 ):
373- ddict [data [2 * i + 1 ]] = data [2 * i + 2 ]
399+ for i in range (len (data ) / 2 ):
400+ ddict [data [2 * i + 1 ]] = data [2 * i + 2 ]
374401
375402 # dictionary of parameters
376403 pdict = {}
377- for ( k , v ) in ddict .items ():
404+ for k , v in ddict .items ():
378405 try :
379406 pdict [k ] = eval (v )
380- except Exception , e :
407+ except Exception as e :
381408 logger .exception (e )
382- emsg = ( "Invalid parameter: %s=%s" % (k ,v ) )
409+ emsg = "Invalid parameter: %s=%s" % (k , v )
383410 raise SrMiseDataFormatError (emsg )
384411
385412 function_name = pdict ["function" ]
@@ -438,9 +465,9 @@ def safefunction(f, fsafe):
438465 return
439466
440467
441- #end of class BaseFunction
468+ # end of class BaseFunction
442469
443- if __name__ == ' __main__' :
470+ if __name__ == " __main__" :
444471
445472 from diffpy .srmise .peaks import GaussianOverR , TerminationRipples
446473
@@ -451,7 +478,7 @@ def safefunction(f, fsafe):
451478
452479 pt = TerminationRipples (p , 20 )
453480 outstr2 = pt .writestr ([p ])
454- print outstr
481+ print ( outstr )
455482
456483 pt2 = BaseFunction .factory (outstr2 , [p ])
457- print type (pt2 )
484+ print ( type (pt2 ) )
0 commit comments