7
7
from operator import itemgetter # https://stackoverflow.com/a/52083390/10372825
8
8
from ErrorProp import ErroredValue as EV
9
9
10
+ from config import PRINT_PRECISION
11
+
10
12
CM_PER_INCH = 2.54
11
13
MAX_COUNTRATE = 3500
14
+ GOOD_ENOUGH = 1e-9
12
15
13
16
def process (dataitem ):
14
17
meta , data , background = itemgetter ('meta' , 'data' , 'background' )(dataitem )
@@ -46,31 +49,25 @@ def SSE(indicies, prediction, logits, logits_error, function):
46
49
ri = function (prediction , indicies )
47
50
return ((logits - ri )** 2 / logits_error ** 2 ).sum ()
48
51
49
- def sMinFit (datatable , function , param = 1 , lr = 1e-4 , epsilon = 1e-8 , epochs = 10000 ):
50
- logits = datatable .normalized_count_rate .apply (lambda x :x .value )
51
- logits_err = datatable .normalized_count_rate .apply (lambda x :x .delta )
52
- inches = datatable .inches
52
+ def unwrap (datatable ):
53
+ return datatable .inches , datatable .normalized_count_rate .apply (lambda x :x .value ), datatable .normalized_count_rate .apply (lambda x :x .delta )
54
+
55
+ def sMinFit (datatable , function , param = 1 , lr = 1e-4 , epsilon = 1e-8 , epochs = 100000 ):
56
+ inches , logits , logits_err = unwrap (datatable )
53
57
dydx = epsilon
54
58
55
59
dydx = epsilon
56
60
bar = tqdm .tqdm (range (epochs ))
57
61
58
- minParam = param
59
- minS = int (1e10 )
60
-
61
62
for _ in bar :
62
63
param_next = param - (dydx * lr + epsilon )
63
64
loss = SSE (inches , param + epsilon , logits , logits_err , function )
64
65
dydx = (loss - SSE (inches , param_next , logits , logits_err , function ))/ (param - param_next )
65
66
param = param_next
66
67
67
- bar .set_description (f'Current fit: { param :2f} , Best fit : { minParam :2f} , Best loss : { minS :2f} ' )
68
+ bar .set_description (f'Current fit: { param :2f} , Update : { dydx :2f} , Loss : { loss :2f} ' )
68
69
69
- if loss < minS and param > 0 :
70
- minParam = param
71
- minS = loss
72
-
73
- return minParam , minS
70
+ return param , SSE (inches , param + epsilon , logits , logits_err , function )
74
71
75
72
# breakpoint()
76
73
@@ -79,3 +76,43 @@ def sMinFit(datatable, function, param=1, lr=1e-4, epsilon=1e-8, epochs=10000):
79
76
# breakpoint()
80
77
# print(dydx)
81
78
# # weight update
79
+
80
+ def calculateSfitUncert (bestx , besty , targety , function , ax = None , low = 0 , high = 100 ):
81
+ def bisect (low , high , target , function ):
82
+ triedx , triedy = [], []
83
+ if (low > high ): low , high = high , low
84
+ while (high - low > GOOD_ENOUGH ):
85
+ mid = (low + high )/ 2
86
+ triedx .append (mid )
87
+ triedy .append (function (mid ))
88
+ if (function (mid ) < target ) == (function (low ) < target ):
89
+ low = mid
90
+ else :
91
+ high = mid
92
+ return low , (triedx , triedy )
93
+
94
+ # binary search on min and max
95
+ param_min , tries_min = bisect (low , bestx , targety , function )
96
+ param_max , tries_max = bisect (bestx , high , targety , function )
97
+
98
+ # plot everything
99
+ if ax is not None :
100
+ # plot neighborhood
101
+ abs_err = max (bestx - param_min , param_max - bestx )
102
+ neighborhood = np .arange (bestx - abs_err * 2 , bestx + abs_err * 2 , abs_err * 4 / 200 ) # 200 evenly spaced points
103
+ ax .scatter (neighborhood , list (map (function , neighborhood )), color = 'black' , label = 'S(T)' )
104
+
105
+ # # plot binary search tries
106
+ # ax.scatter(*tries_min, color='grey')
107
+ # ax.scatter(*tries_max, color='grey')
108
+
109
+ # plot given information
110
+ ax .axhline (y = besty , label = f"S min = { besty :.6f} " , color = 'green' )
111
+ ax .axhline (y = targety , label = f"S min + chi^2 = { targety :.6f} " , color = 'lightgreen' )
112
+ ax .axvline (x = bestx , label = f"T_best = { bestx :.6f} " , color = 'green' )
113
+
114
+ # plot found information
115
+ ax .axvline (x = param_min , label = f"T_min = { param_min :.6f} " , color = 'red' )
116
+ ax .axvline (x = param_max , label = f"T_min = { param_max :.6f} " , color = 'blue' )
117
+
118
+ return param_min , param_max
0 commit comments