Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test #37

Closed
wants to merge 16 commits into from
Closed

test #37

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
lint check and fix print and exception python2 issue (#20)
  • Loading branch information
stevenhua0320 authored Jul 29, 2024
commit 866e71487963ac95b905adb9adca0a24b5f4e787
2 changes: 1 addition & 1 deletion devutils/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,6 @@ def rm(directory, filerestr):

### Convert output of example files to Unix-style endlines for sdist.
if os.linesep != '\n':
print "==== Scrubbing Endlines ===="
print"==== Scrubbing Endlines ===="
# All *.srmise and *.pwa files in examples directory.
scrubeol("../doc/examples/output", r".*(\.srmise|\.pwa)")
141 changes: 84 additions & 57 deletions diffpy/srmise/basefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import sys

import numpy as np
from numpy.compat import unicode

from diffpy.srmise.modelparts import ModelPart, ModelParts
from diffpy.srmise.srmiseerrors import *

logger = logging.getLogger("diffpy.srmise")


class BaseFunction(object):
"""Base class for mathematical functions which model numeric sequences.

Expand Down Expand Up @@ -61,7 +63,15 @@ class BaseFunction(object):
transform_parameters()
"""

def __init__(self, parameterdict, parformats, default_formats, metadict, base=None, Cache=None):
def __init__(
self,
parameterdict,
parformats,
default_formats,
metadict,
base=None,
Cache=None,
):
"""Set parameterdict defined by subclass

Parameters
Expand Down Expand Up @@ -96,23 +106,31 @@ def __init__(self, parameterdict, parformats, default_formats, metadict, base=No
vals = self.parameterdict.values()
vals.sort()
if vals != range(self.npars):
emsg = "Argument parameterdict's values must uniquely specify "+\
"the index of each parameter defined by its keys."
emsg = (
"Argument parameterdict's values must uniquely specify "
+ "the index of each parameter defined by its keys."
)
raise ValueError(emsg)

self.parformats = parformats

# Check validity of default_formats
self.default_formats = default_formats
if not ("default_input" in self.default_formats and
"default_output" in self.default_formats):
emsg = "Argument default_formats must specify 'default_input' "+\
"and 'default_output' as keys."
if not (
"default_input" in self.default_formats
and "default_output" in self.default_formats
):
emsg = (
"Argument default_formats must specify 'default_input' "
+ "and 'default_output' as keys."
)
raise ValueError(emsg)
for f in self.default_formats.values():
if not f in self.parformats:
emsg = "Keys of argument default_formats must map to a "+\
"value within argument parformats."
emsg = (
"Keys of argument default_formats must map to a "
+ "value within argument parformats."
)
raise ValueError()

# Set metadictionary
Expand All @@ -126,12 +144,11 @@ def __init__(self, parameterdict, parformats, default_formats, metadict, base=No
# of PeakFunction.
# Object to cache: (basefunctioninstance, tuple of parameters)
if Cache is not None:
#self.value = Cache(self.value, "value")
#self.jacobian = Cache(self.jacobian, "jacobian")
# self.value = Cache(self.value, "value")
# self.jacobian = Cache(self.jacobian, "jacobian")
pass
return


#### "Virtual" class methods ####

def actualize(self, *args, **kwds):
Expand Down Expand Up @@ -164,7 +181,6 @@ def _valueraw(self, *args, **kwds):
emsg = "_valueraw must() be implemented in a BaseFunction subclass."
raise NotImplementedError(emsg)


#### Class methods ####

def jacobian(self, p, r, rng=None):
Expand All @@ -179,8 +195,10 @@ def jacobian(self, p, r, rng=None):
previously calculated values instead.
"""
if self is not p._owner:
emsg = "Argument 'p' must be evaluated by the BaseFunction "+\
"subclass which owns it."
emsg = (
"Argument 'p' must be evaluated by the BaseFunction "
+ "subclass which owns it."
)
raise ValueError(emsg)

# normally r will be a sequence, but also allow single numeric values
Expand All @@ -192,7 +210,7 @@ def jacobian(self, p, r, rng=None):
output = [None for j in jac]
for idx in range(len(output)):
if jac[idx] is not None:
output[idx] = r * 0.
output[idx] = r * 0.0
output[idx][rng] = jac[idx]
return output
except TypeError:
Expand All @@ -201,10 +219,10 @@ def jacobian(self, p, r, rng=None):
def transform_derivatives(self, pars, in_format=None, out_format=None):
"""Return gradient matrix for pars converted from in_format to out_format.

Parameters
pars - Sequence of parameters
in_format - A format defined for this class
out_format - A format defined for this class
Parameters
pars - Sequence of parameters
in_format - A format defined for this class
out_format - A format defined for this class
"""
# Map unspecified formats to specific formats defined in default_formats
if in_format is None:
Expand All @@ -223,25 +241,29 @@ def transform_derivatives(self, pars, in_format=None, out_format=None):
out_format = self.default_formats["default_input"]

if not in_format in self.parformats:
raise ValueError("Argument 'in_format' must be one of %s." \
% self.parformats)
raise ValueError(
"Argument 'in_format' must be one of %s." % self.parformats
)
if not out_format in self.parformats:
raise ValueError("Argument 'out_format' must be one of %s." \
% self.parformats)
raise ValueError(
"Argument 'out_format' must be one of %s." % self.parformats
)
if in_format == out_format:
return np.identity(self.npars)
return self._transform_derivativesraw(pars, in_format=in_format, out_format=out_format)
return self._transform_derivativesraw(
pars, in_format=in_format, out_format=out_format
)

def transform_parameters(self, pars, in_format=None, out_format=None):
"""Return new sequence with pars converted from in_format to out_format.

Also restores parameters to a preferred range if it permits multiple
values that correspond to the same physical result.
Also restores parameters to a preferred range if it permits multiple
values that correspond to the same physical result.

Parameters
pars - Sequence of parameters
in_format - A format defined for this class
out_format - A format defined for this class
Parameters
pars - Sequence of parameters
in_format - A format defined for this class
out_format - A format defined for this class
"""
# Map unspecified formats to specific formats defined in default_formats
if in_format is None:
Expand All @@ -260,15 +282,18 @@ def transform_parameters(self, pars, in_format=None, out_format=None):
out_format = self.default_formats["default_input"]

if not in_format in self.parformats:
raise ValueError("Argument 'in_format' must be one of %s." \
% self.parformats)
raise ValueError(
"Argument 'in_format' must be one of %s." % self.parformats
)
if not out_format in self.parformats:
raise ValueError("Argument 'out_format' must be one of %s." \
% self.parformats)
#if in_format == out_format:
raise ValueError(
"Argument 'out_format' must be one of %s." % self.parformats
)
# if in_format == out_format:
# return pars
return self._transform_parametersraw(pars, in_format=in_format, out_format=out_format)

return self._transform_parametersraw(
pars, in_format=in_format, out_format=out_format
)

def value(self, p, r, rng=None):
"""Calculate value of ModelPart over r, possibly restricted by range.
Expand All @@ -282,16 +307,18 @@ def value(self, p, r, rng=None):
previously calculated values instead.
"""
if self is not p._owner:
emsg = "Argument 'p' must be evaluated by the BaseFunction "+\
"subclass which owns it."
emsg = (
"Argument 'p' must be evaluated by the BaseFunction "
+ "subclass which owns it."
)
raise ValueError(emsg)

# normally r will be a sequence, but also allow single numeric values
try:
if rng is None:
rng = slice(0, len(r))
rpart = r[rng]
output = r * 0.
output = r * 0.0
output[rng] = self._valueraw(p.pars, rpart)
return output
except TypeError:
Expand Down Expand Up @@ -338,17 +365,17 @@ def writestr(self, baselist):
raise ValueError("emsg")
lines = []
# Write function type
lines.append("function=%s" %repr(self.__class__.__name__))
lines.append("module=%s" %repr(self.getmodule()))
lines.append("function=%s" % repr(self.__class__.__name__))
lines.append("module=%s" % repr(self.getmodule()))
# Write base
if self.base is not None:
lines.append("base=%s" %repr(baselist.index(self.base)))
lines.append("base=%s" % repr(baselist.index(self.base)))
else:
lines.append("base=%s" %repr(None))
lines.append("base=%s" % repr(None))
# Write all other metadata
for k, (v, f) in self.metadict.iteritems():
lines.append("%s=%s" %(k, f(v)))
datastring = "\n".join(lines)+"\n"
lines.append("%s=%s" % (k, f(v)))
datastring = "\n".join(lines) + "\n"
return datastring

@staticmethod
Expand All @@ -367,19 +394,19 @@ def factory(functionstr, baselist):

# populate dictionary with parameter definition
# "key=value"->{"key":"value"}
data = re.split(r'(?:[\r\n]+|\A)(\S+)=', data)
data = re.split(r"(?:[\r\n]+|\A)(\S+)=", data)
ddict = {}
for i in range(len(data)/2):
ddict[data[2*i+1]] = data[2*i+2]
for i in range(len(data) / 2):
ddict[data[2 * i + 1]] = data[2 * i + 2]

# dictionary of parameters
pdict = {}
for (k, v) in ddict.items():
for k, v in ddict.items():
try:
pdict[k] = eval(v)
except Exception, e:
except Exception as e:
logger.exception(e)
emsg = ("Invalid parameter: %s=%s" %(k,v))
emsg = "Invalid parameter: %s=%s" % (k, v)
raise SrMiseDataFormatError(emsg)

function_name = pdict["function"]
Expand Down Expand Up @@ -438,9 +465,9 @@ def safefunction(f, fsafe):
return


#end of class BaseFunction
# end of class BaseFunction

if __name__ == '__main__':
if __name__ == "__main__":

from diffpy.srmise.peaks import GaussianOverR, TerminationRipples

Expand All @@ -451,7 +478,7 @@ def safefunction(f, fsafe):

pt = TerminationRipples(p, 20)
outstr2 = pt.writestr([p])
print outstr
print(outstr)

pt2 = BaseFunction.factory(outstr2, [p])
print type(pt2)
print(type(pt2))
Loading
Loading