Skip to content

Commit c2c043e

Browse files
committed
Access WorkingData by attribute
1 parent 06a480b commit c2c043e

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

Stoner/analysis/fitting/functions.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ def __lmfit_one(
179179
raise RuntimeError("lmfit module not available.")
180180

181181
kwargs = {}
182-
kwargs[model.independent_vars[0]] = data[0]
183-
fit = model.fit(data[1], params, scale_covar=scale_covar, weights=1.0 / data[2], nan_policy=nan_policy, **kwargs)
182+
kwargs[model.independent_vars[0]] = data.x
183+
fit = model.fit(data.y, params, scale_covar=scale_covar, weights=1.0 / data.e, nan_policy=nan_policy, **kwargs)
184184
if fit.success:
185185
row = _record_curve_fit_result(
186186
datafile,
@@ -708,13 +708,13 @@ def p0_func(ydata,x=xdata):
708708
data, kwargs, cols = _assemnle_data_to_fit(datafile, xcol=xcol, ycol=ycol, yerr=sigma, **kwargs)
709709
_func, p0 = _get_curve_fit_func(func, kwargs)
710710

711-
if getattr(data[1], "ndim", 0) == 1:
711+
if getattr(data.y, "ndim", 0) == 1:
712712
for i in [1, 2, 3]:
713713
if isinstance(data[i], np.ndarray):
714714
data[i] = np.atleast_2d(data[i])
715715
if callable(p0): # Allow the user to supply p0 as a callanble function
716716
try: # Skip the guess if it fails
717-
p0 = p0(data[1].ravel(), np.tile(data[0], data[1].size // data[0].size))
717+
p0 = p0(data.y.ravel(), np.tile(data.x, data.y.size // data.x.size))
718718
except (
719719
ArithmeticError,
720720
RuntimeError,
@@ -726,12 +726,12 @@ def p0_func(ydata,x=xdata):
726726

727727
retvals = []
728728
i = None
729-
xdat = data[0]
729+
xdat = data.x
730730
if p0:
731731
kwargs["p0"] = p0
732-
for i, ydat in enumerate(data[1]):
733-
if data[2] is not None:
734-
kwargs["sigma"] = data[2][i]
732+
for i, ydat in enumerate(data.y):
733+
if data.e is not None:
734+
kwargs["sigma"] = data.e[i]
735735
else:
736736
sigma = None
737737
for var in ["xcol", "ycol", "zcol", "xerr", "yerr", "zerr", "scale_covar"]:
@@ -743,7 +743,7 @@ def p0_func(ydata,x=xdata):
743743
else:
744744
report.p0 = p0
745745
report.data = datafile
746-
report.residual_vals = data[1] - report.fvec
746+
report.residual_vals = data.y - report.fvec
747747
report.chisq = (report.residual_vals**2).sum()
748748
report.nfree = len(datafile) - len(report.popt)
749749
report.chisq /= report.nfree
@@ -848,9 +848,8 @@ def differential_evolution(datafile, model, xcol=None, ycol=None, p0=None, sigma
848848
output = kwargs.pop("output", "row" if asrow else "fit")
849849

850850
data, kwargs, _ = _assemnle_data_to_fit(datafile, xcol=xcol, ycol=ycol, sigma=sigma, bounds=bounds, **kwargs)
851-
data = data[0:3]
852851
model, prefix = _prep_lmfit_model(model, kwargs)
853-
p0, single_fit = _prep_lmfit_p0(model, data[1], data[0], p0, kwargs)
852+
p0, single_fit = _prep_lmfit_p0(model, data.y, data.x, p0, kwargs)
854853

855854
for k in model.param_names:
856855
kwargs.pop(k, None)
@@ -864,18 +863,27 @@ def differential_evolution(datafile, model, xcol=None, ycol=None, p0=None, sigma
864863

865864
if not single_fit:
866865
raise NotImplementedError("Sorry chi^2 mapping not implemented for differential evolution yet.")
867-
fit = _differential_evolution(diff_model.minimize_func, diff_model.bounds, data, **kwargs)
866+
fit = _differential_evolution(
867+
diff_model.minimize_func,
868+
diff_model.bounds,
869+
args=(
870+
data.x,
871+
data.y,
872+
data.e,
873+
),
874+
**kwargs,
875+
)
868876
if not fit.success:
869877
raise RuntimeError(fit.message)
870878
kwargs.pop("polish", None)
871879
kwargs["full_output"] = True
872880
kwargs["absolute_sigma"] = abs_sigma
873-
polish = _curve_fit_result(*_curve_fit(model.func, data[0], data[1][0], sigma=data[2][0], p0=fit.x, **kwargs))
881+
polish = _curve_fit_result(*_curve_fit(model.func, data.x, data.y[0], sigma=data.e[0], p0=fit.x, **kwargs))
874882

875883
polish.func = model.func
876884
polish.p0 = p0
877885
polish.data = datafile
878-
polish.residual_vals = data[1] - polish.fvec
886+
polish.residual_vals = data.y - polish.fvec
879887
polish.chisq = (polish.residual_vals**2).sum()
880888
polish.nfree = len(datafile) - len(polish.popt)
881889
polish.chisq /= polish.nfree
@@ -988,7 +996,7 @@ def lmfit(datafile, model, xcol=None, ycol=None, p0=None, sigma=None, **kwargs):
988996

989997
data, kwargs, _ = _assemnle_data_to_fit(datafile, xcol=xcol, ycol=ycol, yerr=sigma, **kwargs)
990998
model, prefix = _prep_lmfit_model(model, kwargs)
991-
p0, single_fit = _prep_lmfit_p0(model, data[1], data[0], p0, kwargs)
999+
p0, single_fit = _prep_lmfit_p0(model, data.y, data.x, p0, kwargs)
9921000
nan_policy = kwargs.pop("nan_policy", getattr(model, "nan_policy", "omit"))
9931001

9941002
if single_fit:
@@ -1012,7 +1020,7 @@ def lmfit(datafile, model, xcol=None, ycol=None, p0=None, sigma=None, **kwargs):
10121020
ret_val = np.zeros((pn.shape[0], pn.shape[1] * 2 + 1))
10131021
for i, pn_i in enumerate(pn): # iterate over every row in the supplied p0 values
10141022
p0, single_fit = _prep_lmfit_p0(
1015-
model, data[1], data[0], pn_i, kwargs
1023+
model, data.y, data.x, pn_i, kwargs
10161024
) # model, data, params, prefix, columns, scale_covar,**kwargs)
10171025
ret_val[i, :] = __lmfit_one(
10181026
datafile,
@@ -1211,13 +1219,13 @@ def odr(datafile, model, xcol=None, ycol=None, **kwargs):
12111219
model, prefix = _prep_lmfit_model(model, kwargs)
12121220
else:
12131221
prefix = kwargs.pop("prefix", getattr(model, "name", model.fcn.__name__))
1214-
p0, single_fit = _prep_lmfit_p0(model, data[1], data[0], p0, kwargs)
1222+
p0, single_fit = _prep_lmfit_p0(model, data.y, data.x, p0, kwargs)
12151223
kwargs["p0"] = p0
12161224
model = ODR_Model(model, p0=p0)
12171225
if kwargs.get("scale_covar", True):
1218-
data = sp.odr.Data(data[0], data[1], wd=1 / data[3] ** 2, we=1 / data[2] ** 2)
1226+
data = sp.odr.Data(data.x, data.y, wd=1 / data.d**2, we=1 / data.e**2)
12191227
else:
1220-
data = sp.odr.RealData(data[0], data[1], sx=data[3], sy=data[2])
1228+
data = sp.odr.RealData(data.x, data.y, sx=data.d, sy=data.e)
12211229

12221230
if single_fit:
12231231
ret_val = _odr_one(datafile, data, model, prefix, _, **kwargs)

0 commit comments

Comments
 (0)