Skip to content

Commit b265c3d

Browse files
committed
bisection method + associated fancy charts
1 parent b49ace6 commit b265c3d

File tree

4 files changed

+77
-14
lines changed

4 files changed

+77
-14
lines changed

calculate.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
from operator import itemgetter # https://stackoverflow.com/a/52083390/10372825
88
from ErrorProp import ErroredValue as EV
99

10+
from config import PRINT_PRECISION
11+
1012
CM_PER_INCH = 2.54
1113
MAX_COUNTRATE = 3500
14+
GOOD_ENOUGH = 1e-9
1215

1316
def process(dataitem):
1417
meta, data, background = itemgetter('meta', 'data', 'background')(dataitem)
@@ -46,31 +49,25 @@ def SSE(indicies, prediction, logits, logits_error, function):
4649
ri = function(prediction, indicies)
4750
return ((logits-ri)**2/logits_error**2).sum()
4851

49-
def sMinFit(datatable, function, param=1, lr=1e-4, epsilon=1e-8, epochs=10000):
50-
logits = datatable.normalized_count_rate.apply(lambda x:x.value)
51-
logits_err = datatable.normalized_count_rate.apply(lambda x:x.delta)
52-
inches = datatable.inches
52+
def unwrap(datatable):
53+
return datatable.inches, datatable.normalized_count_rate.apply(lambda x:x.value), datatable.normalized_count_rate.apply(lambda x:x.delta)
54+
55+
def sMinFit(datatable, function, param=1, lr=1e-4, epsilon=1e-8, epochs=100000):
56+
inches, logits, logits_err = unwrap(datatable)
5357
dydx = epsilon
5458

5559
dydx = epsilon
5660
bar = tqdm.tqdm(range(epochs))
5761

58-
minParam = param
59-
minS = int(1e10)
60-
6162
for _ in bar:
6263
param_next = param-(dydx*lr+epsilon)
6364
loss = SSE(inches, param+epsilon, logits, logits_err, function)
6465
dydx = (loss-SSE(inches, param_next, logits, logits_err, function))/(param-param_next)
6566
param = param_next
6667

67-
bar.set_description(f'Current fit: {param:2f}, Best fit: {minParam:2f}, Best loss: {minS:2f}')
68+
bar.set_description(f'Current fit: {param:2f}, Update: {dydx:2f}, Loss: {loss:2f}')
6869

69-
if loss < minS and param > 0:
70-
minParam = param
71-
minS = loss
72-
73-
return minParam, minS
70+
return param, SSE(inches, param+epsilon, logits, logits_err, function)
7471

7572
# breakpoint()
7673

@@ -79,3 +76,43 @@ def sMinFit(datatable, function, param=1, lr=1e-4, epsilon=1e-8, epochs=10000):
7976
# breakpoint()
8077
# print(dydx)
8178
# # weight update
79+
80+
def calculateSfitUncert(bestx, besty, targety, function, ax=None, low=0, high=100):
81+
def bisect(low, high, target, function):
82+
triedx, triedy = [], []
83+
if (low > high): low, high = high, low
84+
while (high-low > GOOD_ENOUGH):
85+
mid = (low+high)/2
86+
triedx.append(mid)
87+
triedy.append(function(mid))
88+
if (function(mid) < target) == (function(low) < target):
89+
low = mid
90+
else:
91+
high = mid
92+
return low, (triedx, triedy)
93+
94+
# binary search on min and max
95+
param_min, tries_min = bisect(low, bestx, targety, function)
96+
param_max, tries_max = bisect(bestx, high, targety, function)
97+
98+
# plot everything
99+
if ax is not None:
100+
# plot neighborhood
101+
abs_err = max(bestx - param_min, param_max-bestx)
102+
neighborhood = np.arange(bestx-abs_err*2, bestx+abs_err*2, abs_err*4/200) # 200 evenly spaced points
103+
ax.scatter(neighborhood, list(map(function, neighborhood)), color='black', label='S(T)')
104+
105+
# # plot binary search tries
106+
# ax.scatter(*tries_min, color='grey')
107+
# ax.scatter(*tries_max, color='grey')
108+
109+
# plot given information
110+
ax.axhline(y=besty, label=f"S min = {besty:.6f}", color='green')
111+
ax.axhline(y=targety, label=f"S min + chi^2 = {targety:.6f}", color='lightgreen')
112+
ax.axvline(x=bestx, label=f"T_best = {bestx:.6f}", color='green')
113+
114+
# plot found information
115+
ax.axvline(x=param_min, label=f"T_min = {param_min:.6f}", color='red')
116+
ax.axvline(x=param_max, label=f"T_min = {param_max:.6f}", color='blue')
117+
118+
return param_min, param_max

config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
PRINT_PRECISION = 6
2+

drive.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datacleaning import globit
22
from multiprocessing import Pool
3-
from calculate import process, sMinFit, RelativeIntersity
3+
from calculate import SSE, process, unwrap, sMinFit, RelativeIntersity, calculateSfitUncert
44

55
import matplotlib.pyplot as plt
66

@@ -62,3 +62,16 @@ def process(indx):
6262
# breakpoint()
6363

6464
# >>>>>>> 68cb3cc3a83d6f09391e99a4e1cc04d712bebe16
65+
66+
if __name__ == '__main__':
67+
# plot(7)
68+
fig, ax = plt.subplots()
69+
inches, logits, logits_err = unwrap(results[7])
70+
param = 0.946
71+
smin = 82
72+
calculateSfitUncert(param, smin, smin+1, lambda T: SSE(inches, T, logits, logits_err, RelativeIntersity), ax=ax, low=0.1, high=30)
73+
74+
ax.set_xlabel(f"T ({results[7].attrs['material']})")
75+
ax.set_ylabel("S(T)")
76+
ax.legend()
77+
plt.show()

fits.org

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
| job number | fit |
2+
+------------+-----+
3+
| 2 | 0.426406 |
4+
| 3 | 0.016489 |
5+
| 4 | 0.347708 |
6+
| 5 | 0.019284 |
7+
| 6 | nan |
8+
| 7 | 7.097304 |
9+
| 8 | nan |
10+
| 9 | 0.945550 |
11+
| 10 | 0.019839 |

0 commit comments

Comments
 (0)