Skip to content

Commit

Permalink
finish parenthesizing print statements
Browse files Browse the repository at this point in the history
  • Loading branch information
sbillinge committed Jul 29, 2024
1 parent c28b97b commit 155a85d
Show file tree
Hide file tree
Showing 9 changed files with 606 additions and 476 deletions.
44 changes: 23 additions & 21 deletions devutils/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,21 @@ def __init__(self):

def test(self, call, *args, **kwds):
m = sys.modules[call.__module__]
testname = m.__name__+'.'+call.__name__
testname = m.__name__ + "." + call.__name__
path = os.path.dirname(m.__file__)
os.chdir(path)
try:
call(*args, **kwds)
self.messages.append("%s: success" %testname)
except Exception, e:
self.messages.append("%s: error, details below.\n%s" %(testname, e))
self.messages.append("%s: success" % testname)
except Exception as e:
self.messages.append("%s: error, details below.\n%s" % (testname, e))
finally:
os.chdir(__basedir__)

def report(self):
print '==== Results of Tests ===='
print '\n'.join(self.messages)
print("==== Results of Tests ====")
print("\n".join(self.messages))


def scrubeol(directory, filerestr):
"""Use unix-style endlines for files in directory matched by regex string.
Expand All @@ -50,11 +51,11 @@ def scrubeol(directory, filerestr):
text = unicode(original.read())
original.close()

updated = io.open(f, 'w', newline='\n')
updated = io.open(f, "w", newline="\n")
updated.write(text)
updated.close()

print "Updated %s to unix-style endlines." %f
print("Updated %s to unix-style endlines." % f)


def rm(directory, filerestr):
Expand All @@ -72,14 +73,13 @@ def rm(directory, filerestr):
for f in files:
os.remove(f)

print "Deleted %s." %f

print("Deleted %s." % f)


if __name__ == "__main__":

# Temporarily add examples to path
lib_path = os.path.abspath(os.path.join('..','doc','examples'))
lib_path = os.path.abspath(os.path.join("..", "doc", "examples"))
sys.path.append(lib_path)

# Delete existing files that don't necessarily have a fixed name.
Expand All @@ -88,14 +88,16 @@ def rm(directory, filerestr):

### Testing examples
examples = Test()
test_names = ["extract_single_peak",
"parameter_summary",
"fit_initial",
"query_results",
"multimodel_known_dG1",
"multimodel_known_dG2",
"multimodel_unknown_dG1",
"multimodel_unknown_dG2"]
test_names = [
"extract_single_peak",
"parameter_summary",
"fit_initial",
"query_results",
"multimodel_known_dG1",
"multimodel_known_dG2",
"multimodel_unknown_dG1",
"multimodel_unknown_dG2",
]

test_modules = []
for test in test_names:
Expand All @@ -107,7 +109,7 @@ def rm(directory, filerestr):
examples.report()

### Convert output of example files to Unix-style endlines for sdist.
if os.linesep != '\n':
print"==== Scrubbing Endlines ===="
if os.linesep != "\n":
print("==== Scrubbing Endlines ====")
# All *.srmise and *.pwa files in examples directory.
scrubeol("../doc/examples/output", r".*(\.srmise|\.pwa)")
75 changes: 47 additions & 28 deletions diffpy/srmise/peaks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

logger = logging.getLogger("diffpy.srmise")


class PeakFunction(BaseFunction):
"""Base class for functions which represent peaks.
Expand Down Expand Up @@ -60,7 +61,15 @@ class PeakFunction(BaseFunction):
transform_parameters()
"""

def __init__(self, parameterdict, parformats, default_formats, metadict, base=None, Cache=None):
def __init__(
self,
parameterdict,
parformats,
default_formats,
metadict,
base=None,
Cache=None,
):
"""Set parameterdict defined by subclass
parameterdict: A dictionary mapping string keys to their index in a
Expand All @@ -82,24 +91,31 @@ def __init__(self, parameterdict, parformats, default_formats, metadict, base=No
raise ValueError(emsg)
BaseFunction.__init__(self, parameterdict, parformats, default_formats, metadict, base, Cache)


#### "Virtual" class methods ####

def scale_at(self, peak, x, scale):
emsg = "scale_at must be implemented in a PeakFunction subclass."
raise NotImplementedError(emsg)


#### Methods required by BaseFunction ####

def actualize(self, pars, in_format="default_input", free=None, removable=True, static_owner=False):
def actualize(
self,
pars,
in_format="default_input",
free=None,
removable=True,
static_owner=False,
):
converted = self.transform_parameters(pars, in_format, out_format="internal")
return Peak(self, converted, free, removable, static_owner)

def getmodule(self):
return __name__

#end of class PeakFunction

# end of class PeakFunction


class Peaks(ModelParts):
"""A collection for Peak objects."""
Expand All @@ -110,12 +126,12 @@ def __init__(self, *args, **kwds):

def argsort(self, key="position"):
"""Return sequence of indices which sort peaks in order specified by key."""
keypars=np.array([p[key] for p in self])
keypars = np.array([p[key] for p in self])
# In normal use the peaks will already be sorted, so check for it.
sorted=True
for i in range(len(keypars)-1):
if keypars[i] > keypars[i+1]:
sorted=False
sorted = True
for i in range(len(keypars) - 1):
if keypars[i] > keypars[i + 1]:
sorted = False
break
if not sorted:
return keypars.argsort().tolist()
Expand All @@ -142,14 +158,14 @@ def match_at(self, x, y):
orig = self.copy()

try:
scale = y/height
scale = y / height

# First attempt at scaling peaks. Record which peaks, if any,
# were not scaled in case a second attempt is required.
scaled = []
all_scaled = True
any_scaled = False
fixed_height = 0.
fixed_height = 0.0
for peak in self:
scaled.append(peak.scale_at(x, scale))
all_scaled = all_scaled and scaled[-1]
Expand All @@ -161,27 +177,29 @@ def match_at(self, x, y):
if not all_scaled and fixed_height < y and fixed_height < height:
self[:] = orig[:]
any_scaled = False
scale = (y - fixed_height)/(height - fixed_height)
scale = (y - fixed_height) / (height - fixed_height)
for peak, s in (self, scaled):
if s:
# "or" is short-circuited, so scale_at() must be first
# to guarantee it is called.
any_scaled = peak.scale_at(x, scale) or any_scaled
except Exception, e:
except Exception as e:
logger.debug("An exception prevented matching -- %s", e)
self[:] = orig[:]
return False
return any_scaled

def sort(self, key="position"):
"""Sort peaks in order specified by key."""
keypars=np.array([p[key] for p in self])
keypars = np.array([p[key] for p in self])
order = keypars.argsort()
self[:] = [self[idx] for idx in order]
return


# End of class Peaks


class Peak(ModelPart):
"""Represents a single peak associated with a PeakFunction subclass."""

Expand Down Expand Up @@ -225,7 +243,7 @@ def scale_at(self, x, scale):

try:
adj_pars = self._owner.scale_at(self.pars, x, scale)
except SrMiseScalingError, err:
except SrMiseScalingError as err:
logger.debug("Cannot scale peak:", err)
return False

Expand Down Expand Up @@ -256,10 +274,10 @@ def factory(peakstr, ownerlist):
try:
pdict[l[0]] = eval(l[1])
except Exception:
emsg = ("Invalid parameter: %s" %d)
emsg = "Invalid parameter: %s" % d
raise SrMiseDataFormatError(emsg)
else:
emsg = ("Invalid parameter: %s" %d)
emsg = "Invalid parameter: %s" % d
raise SrMiseDataFormatError(emsg)

# Correctly initialize the base function, if one exists.
Expand All @@ -271,10 +289,11 @@ def factory(peakstr, ownerlist):

return Peak(**pdict)


# End of class Peak

# simple test code
if __name__ == '__main__':
if __name__ == "__main__":

import matplotlib.pyplot as plt
from numpy.random import randn
Expand All @@ -283,26 +302,26 @@ def factory(peakstr, ownerlist):
from diffpy.srmise.modelevaluators import AICc
from diffpy.srmise.peaks import GaussianOverR

res = .01
r = np.arange(2,4,res)
err = np.ones(len(r)) #default unknown errors
pf = GaussianOverR(.7)
res = 0.01
r = np.arange(2, 4, res)
err = np.ones(len(r)) # default unknown errors
pf = GaussianOverR(0.7)
evaluator = AICc()

pars = [[3, .2, 10], [3.5, .2, 10]]
pars = [[3, 0.2, 10], [3.5, 0.2, 10]]
ideal_peaks = Peaks([pf.actualize(p, "pwa") for p in pars])
y = ideal_peaks.value(r) + .1*randn(len(r))
y = ideal_peaks.value(r) + 0.1 * randn(len(r))

guesspars = [[2.7, .15, 5], [3.7, .3, 5]]
guesspars = [[2.7, 0.15, 5], [3.7, 0.3, 5]]
guess_peaks = Peaks([pf.actualize(p, "pwa") for p in guesspars])
cluster = ModelCluster(guess_peaks, r, y, err, None, AICc, [pf])

qual1 = cluster.quality()
print qual1.stat
print(qual1.stat)
cluster.fit()
yfit = cluster.calc()
qual2 = cluster.quality()
print qual2.stat
print(qual2.stat)

plt.figure(1)
plt.plot(r, y, r, yfit)
Expand Down
Loading

0 comments on commit 155a85d

Please sign in to comment.