Skip to content

Commit

Permalink
update attention tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
tongplw committed Jul 10, 2020
1 parent 08b0367 commit 3b6d489
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
8 changes: 5 additions & 3 deletions dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
df = pd.read_csv('res/data.csv', header=None).iloc[:, 1:]
df.columns = ['delta', 'theta', 'low-alpha', 'high-alpha', 'low-beta', 'high-beta', 'low-gamma', 'mid-gamma']

atts = []

for i in tqdm(range(len(df))):
waves = df.iloc[i].to_dict()
attention = headset._attention(waves)
atts += [attention]

attention = int(np.round(float(attention), 2) * 100)
plot_attention(attention)
time.sleep(0.1)
plt.hist(atts, bins=200)
plt.show()
21 changes: 12 additions & 9 deletions src/muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@ def _reject_outliers(self, data, m=3):
IQR = (Q3 - Q1) * m
return data[(data > Q1 - IQR) & (data < Q3 + IQR)]

def _adjust_attention(self, att):
if len(self._attention_history) < 10:
return att
def _calibrate(self, att):
if len(self._attention_history) < 5:
self._attention_history += [att]
return min(1, max(1e-5, att))
self._attention_history += [att]
atts = self._reject_outliers(self._attention_history)
# return att
return (att - 0.5) / np.std(atts) * 0.25 + 0.5
print(np.mean(atts))
att = (att - np.mean(atts)) / np.std(atts) * 0.25 + 0.5
return min(1, max(1e-5, att))
return att

def _convert_to_mindwave(self, band, value):
d_map = {'delta': [7.32900391, 7.47392578, 5.576955, 5.687801],
Expand Down Expand Up @@ -114,10 +118,9 @@ def _attention(self, waves):
wave_transformed = self.scaler.transform(wave_array)
att = np.sum(wave_transformed * coef_) + intercept_

self._attention_history += [att]
att = self._adjust_attention(att)
att = min(1, max(1e-5, att))
self._attention_buff = [att] + self._attention_buff[:-1]
att = self._calibrate(att)
if 0 < att < 1:
self._attention_buff = [att] + self._attention_buff[:-1]
return att

def _eeg_handler(self, unused_addr, args, TP9, AF7, AF8, TP10, AUX):
Expand Down
5 changes: 4 additions & 1 deletion src/pylive.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ def live_plotter(x_vec, y1_data, line1, title='', pause_time=1e-2):
def plot_attention(att):
global x_vec, y_vec, line1
y_vec = np.append(y_vec[1:], att)
line1 = live_plotter(x_vec, y_vec, line1)
try:
line1 = live_plotter(x_vec, y_vec, line1)
except:
pass

0 comments on commit 3b6d489

Please sign in to comment.