Skip to content

Commit

Permalink
🔨 add meditation and refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
tongplw committed Jul 17, 2020
1 parent 496372d commit 920f874
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 45 deletions.
9 changes: 7 additions & 2 deletions dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
df.columns = ['delta', 'theta', 'low-alpha', 'high-alpha', 'low-beta', 'high-beta', 'low-gamma', 'mid-gamma']

atts = []
meds = []

for i in tqdm(range(len(df))):
waves = df.iloc[i].to_dict()
attention = headset._attention(waves)
atts += [attention]
meditation = headset._meditation(waves)
meds += [meditation]

plt.hist(atts, bins=200)
# plt.plot(atts)
# plt.hist(atts, bins=200)
# plt.hist(meds, bins=200, alpha=0.5)
plt.plot(atts)
plt.plot(meds, alpha=0.5)
plt.show()
84 changes: 41 additions & 43 deletions src/muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from scipy.stats import normaltest
from pythonosc import dispatcher, osc_server
from multiprocessing import Process, Manager, Value
from src.utils import *
from src.RunningStats import RunningStats


Expand All @@ -19,11 +20,13 @@ def __init__(self, server=None, port=None, debug=False):
self.sample_rate = 256
self._buffer = []
self._attention_buff = [.5, .5, .5, .5, .5]
self._running_stats = RunningStats()
self._meditation_buff = [.5, .5, .5, .5, .5]
self._running_stats = {'att': RunningStats(), 'med': RunningStats()}
self.scaler = joblib.load('res/scaler')

self.raw = Value('d', 0)
self.attention = Value('d', 0)
self.meditation = Value('d', 0)
self.waves = Manager().dict()

if not debug:
Expand All @@ -43,13 +46,11 @@ def _get_fft(self, raw_list):
return [freqs, fft]

def _get_bands(self, raw_list):
bands = {'delta': (1, 3), 'theta': (4, 7), 'low-alpha': (8, 9), 'high-alpha': (10, 12),
'low-beta': (13, 17), 'high-beta': (18, 30), 'low-gamma': (30, 40), 'mid-gamma': (41, 50)}
band_list = {b: [] for b in bands}
band_list = {b: [] for b in BAND_RANGE}
freqs, fft = self._get_fft(raw_list)
for freq, amps in zip(freqs, fft):
for b in bands:
low, high = bands[b]
for b in BAND_RANGE:
low, high = BAND_RANGE[b]
if low <= freq < high:
band_list[b] += [amps]
for b in band_list:
Expand All @@ -63,53 +64,27 @@ def _reject_outliers(self, data, m=3):
IQR = (Q3 - Q1) * m
return data[(data > Q1 - IQR) & (data < Q3 + IQR)]

def _calibrate(self, att):
def _calibrate(self, key, val):
# recalibrate at p-value 0.01 (two-tailed)
if self._running_stats.get_count() > 30:
z = (att - self._running_stats.get_mean()) / self._running_stats.get_std()
if self._running_stats[key].get_count() > 30:
z = (val - self._running_stats[key].get_mean()) / self._running_stats[key].get_std()
if abs(z) >= 2.58:
self._running_stats.clear()
self._running_stats[key].clear()

self._running_stats.update(att)
if self._running_stats.get_count() > 5:
att = (att - self._running_stats.get_mean()) / self._running_stats.get_std() * 0.25 + 0.5
return min(1, max(1e-5, att))
self._running_stats[key].update(val)
if self._running_stats[key].get_count() > 5:
val = (val - self._running_stats[key].get_mean()) / self._running_stats[key].get_std() * 0.25 + 0.5
return min(1, max(1e-5, val))

def _convert_to_mindwave(self, band, value):
d_map = {'delta': [7.32900391, 7.47392578, 5.576955, 5.687801],
'theta': [6.39179688, 6.41220703, 5.832594, 6.030335],
'low-alpha': [5.95166016, 5.65654297, 5.188705, 5.819261],
'high-alpha': [5.67324219, 5.06787109, 5.440118, 6.053460],
'low-beta': [5.53769531, 4.69228516, 5.509398, 5.892162],
'high-beta': [5.55654297, 4.56396484, 5.261583, 5.753560],
'low-gamma': [5.14970703, 4.81093750, 5.088524, 5.302043],
'mid-gamma': [7.08144531, 4.92177734, 4.860328, 5.343249]}
mind_c, muse_c, mind_mean, muse_mean = d_map[band]
mind_c, muse_c, mind_mean, muse_mean = CONVERT_MAP[band]
value = value / 1.8 * 4096 * 2
return ((value ** (1 / muse_c)) - muse_mean + mind_mean) ** mind_c

def _attention(self, waves):
waves = waves.copy()
for band in waves:
waves[band] = self._convert_to_mindwave(band, waves[band])
index = ['delta', 'theta', 'low-alpha', 'high-alpha',
'low-beta', 'high-beta', 'low-gamma', 'mid-gamma',
'attention-1', 'attention-2', 'attention-3', 'attention-4',
'attention-5', 'log2-delta', 'log2-theta', 'log2-low-alpha',
'log2-high-alpha', 'log2-low-beta', 'log2-high-beta', 'log2-low-gamma',
'log2-mid-gamma', 'log2-attention-1', 'log2-attention-2', 'log2-attention-3',
'log2-attention-4', 'log2-attention-5', 'log2-theta-alpha']
coef_ = [1.85597993e-03, -3.89405744e-02, -2.17976458e-02,
-4.94719580e-03, 5.92689481e-02, -2.66903157e-02,
2.30846084e-02, 6.82606511e-02, 8.54525920e-01,
2.10894178e-02, -1.06262949e-01, -2.69200545e-01,
2.50926910e-01, 6.62901487e-03, 5.85212860e-03,
6.45113734e-03, 1.61475866e-04, -1.34012739e-03,
7.47788932e-01, -8.24321394e-04, -1.39632993e-03,
-9.23905598e-03, -9.27560591e-03, 2.64911164e-02,
2.01285851e-03, -2.64492016e-03, -7.13587632e-01]
intercept_ = 0.18241877

for i in range(5):
waves[f'attention-{i+1}'] = self._attention_buff[i]
for i in list(waves):
Expand All @@ -118,12 +93,31 @@ def _attention(self, waves):

wave_array = np.array([[val for val in waves.values()]])
wave_transformed = self.scaler.transform(wave_array)
att = np.sum(wave_transformed * coef_) + intercept_
att = np.sum(wave_transformed * ATT_COEF) + ATT_INTERCEPT

att = self._calibrate(att)
att = self._calibrate('att', att)
if 0 < att < 1:
self._attention_buff = [att] + self._attention_buff[:-1]
return att

def _meditation(self, waves):
waves = waves.copy()
for band in waves:
waves[band] = self._convert_to_mindwave(band, waves[band])
for i in range(5):
waves[f'meditation-{i+1}'] = self._meditation_buff[i]
for i in list(waves):
waves[f'log2-{i}'] = np.log2(waves[i])
waves['log2-theta-alpha'] = np.log2(waves['theta'] + waves['low-alpha'] + waves['high-alpha'])

wave_array = np.array([[val for val in waves.values()]])
wave_transformed = self.scaler.transform(wave_array)
med = np.sum(wave_transformed * MED_COEF) + MED_INTERCEPT

med = self._calibrate('med', med)
if 0 < med < 1:
self._meditation_buff = [med] + self._meditation_buff[:-1]
return med

def _eeg_handler(self, unused_addr, args, TP9, AF7, AF8, TP10, AUX):
self.raw.acquire()
Expand All @@ -139,6 +133,10 @@ def _eeg_handler(self, unused_addr, args, TP9, AF7, AF8, TP10, AUX):
self.attention.acquire()
self.attention.value = np.round(new_attention, 2) * 100
self.attention.release()
new_meditation = self._meditation(self.waves)
self.meditation.acquire()
self.meditation.value = np.round(new_meditation, 2) * 100
self.meditation.release()

def _run(self):
server = osc_server.BlockingOSCUDPServer((self.server, self.port), self._get_dispatcher())
Expand Down
39 changes: 39 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
BAND_RANGE = {'delta': (1, 3), 'theta': (4, 7), 'low-alpha': (8, 9), 'high-alpha': (10, 12),
'low-beta': (13, 17), 'high-beta': (18, 30), 'low-gamma': (30, 40), 'mid-gamma': (41, 50)}

CONVERT_MAP = {'delta': [7.32900391, 7.47392578, 5.576955, 5.687801],
'theta': [6.39179688, 6.41220703, 5.832594, 6.030335],
'low-alpha': [5.95166016, 5.65654297, 5.188705, 5.819261],
'high-alpha': [5.67324219, 5.06787109, 5.440118, 6.053460],
'low-beta': [5.53769531, 4.69228516, 5.509398, 5.892162],
'high-beta': [5.55654297, 4.56396484, 5.261583, 5.753560],
'low-gamma': [5.14970703, 4.81093750, 5.088524, 5.302043],
'mid-gamma': [7.08144531, 4.92177734, 4.860328, 5.343249]}

INDEX = ['delta', 'theta', 'low-alpha', 'high-alpha',
'low-beta', 'high-beta', 'low-gamma', 'mid-gamma',
'attention-1', 'attention-2', 'attention-3', 'attention-4',
'attention-5', 'log2-delta', 'log2-theta', 'log2-low-alpha',
'log2-high-alpha', 'log2-low-beta', 'log2-high-beta', 'log2-low-gamma',
'log2-mid-gamma', 'log2-attention-1', 'log2-attention-2', 'log2-attention-3',
'log2-attention-4', 'log2-attention-5', 'log2-theta-alpha']

ATT_COEF = [1.85597993e-03, -3.89405744e-02, -2.17976458e-02,
-4.94719580e-03, 5.92689481e-02, -2.66903157e-02,
2.30846084e-02, 6.82606511e-02, 8.54525920e-01,
2.10894178e-02, -1.06262949e-01, -2.69200545e-01,
2.50926910e-01, 6.62901487e-03, 5.85212860e-03,
6.45113734e-03, 1.61475866e-04, -1.34012739e-03,
7.47788932e-01, -8.24321394e-04, -1.39632993e-03,
-9.23905598e-03, -9.27560591e-03, 2.64911164e-02,
2.01285851e-03, -2.64492016e-03, -7.13587632e-01]

MED_COEF = [0.00911533, -0.13181668, 0.19096046, 0.12466165, -0.08485195,
-0.03896587, 0.04479122, 0.04009775, 0.85471768, 0.0059345 ,
-0.08760081, -0.22276642, 0.26068081, 0.0075658 , -0.65031589,
0.407374 , 0.34096692, -0.13563227, -0.0948423 , 0.01681845,
-0.00673638, -0.01348625, 0.03122508, -0.02148527, -0.09031348,
0.03204765, 0.20908319]

ATT_INTERCEPT = 0.18241877
MED_INTERCEPT = 0.09677243

0 comments on commit 920f874

Please sign in to comment.