Skip to content

Commit

Permalink
fixed confusion plots
Browse files Browse the repository at this point in the history
  • Loading branch information
Felix Burkhardt committed Aug 19, 2021
1 parent b274e40 commit dba8654
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 35 deletions.
4 changes: 2 additions & 2 deletions exp_xgb.ini → exp_emodb.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[EXP]
root = /home/fburkhardt/ResearchProjects/nkululeko
root = /home/fburkhardt/ResearchProjects/nkululeko/
store = ./store/
name = xgb-exp
fig_dir = ./images/
Expand All @@ -8,7 +8,7 @@ databases = ['emodb']
emodb = /home/fburkhardt/audb/emodb/6/
emodb.mapping = {'anger':'angry', 'happiness':'happy', 'sadness':'sad', 'neutral':'neutral'}
emodb.split_strategy = speaker_split
emodb.testsplit = 10
emodb.testsplit = 40
target = emotion
labels = ['neutral', 'happy', 'sad', 'angry']
[FEATS]
Expand Down
2 changes: 1 addition & 1 deletion main.py → exp_emodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ def main(config_file):


if __name__ == "__main__":
main('/home/fburkhardt/ResearchProjects/nkululeko/exp_xgb.ini') #sys.argv[1])
main('/home/fburkhardt/ResearchProjects/nkululeko/exp_emodb.ini') #sys.argv[1])
31 changes: 0 additions & 31 deletions exp_svm.ini

This file was deleted.

23 changes: 22 additions & 1 deletion src/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import ast
import numpy as np
import glob_conf
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import confusion_matrix

class Reporter:

Expand All @@ -20,12 +22,25 @@ def continuous_to_categorical(self):
self.truths = np.digitize(self.truths, bins)-1
self.preds = np.digitize(self.preds, bins)-1

def plot_confmatrix_mew(self, plot_name):
fig_dir = self.util.get_path('fig_dir')
labels = ast.literal_eval(glob_conf.config['DATA']['labels'])
plt.figure() # figsize=[5, 5]
cm = confusion_matrix(self.truths, self.preds, normalize = 'true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels).plot(cmap='Blues')
print(f'plotting conf matrix to {fig_dir+plot_name}')
plt.title('Confusion Matrix')
plt.savefig(fig_dir+plot_name)
plt.close()


def plot_confmatrix(self, plot_name):
fig_dir = self.util.get_path('fig_dir')
sns.set() # get prettier plots
labels = ast.literal_eval(glob_conf.config['DATA']['labels'])
plt.figure(figsize=[5, 5])
plt.figure() # figsize=[5, 5]
plt.title('Confusion Matrix')
plt.ylabel('UAR')
audplot.confusion_matrix(self.truths, self.preds)
# replace labels
locs, _ = plt.xticks()
Expand All @@ -35,6 +50,12 @@ def plot_confmatrix(self, plot_name):
print(f'plotting conf matrix to {fig_dir+plot_name}')
plt.savefig(fig_dir+plot_name)
plt.close()
print('truths')
print(self.truths.values)
print('preds')
print(self.preds)
print('labels')
print(labels)

def uar(self):
return recall_score(self.truths, self.preds, average='macro')
Expand Down

0 comments on commit dba8654

Please sign in to comment.