Skip to content

Commit

Permalink
Fix LSTM LRP bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ichalkiad committed Aug 22, 2017
1 parent 14a21f9 commit 6d70f15
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions bokeh_vis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@


def button_callback():
text_src = re.sub('/home/icha/','/home/yannis/Desktop/tRustNN/',rawInput_selections.value)
text_banner.text = open(text_src,"r").read()

text_review,words,word_embeddings = get_rawText_data(rawInput_selections.value,keys_raw,data_raw,testX,embed_mat)
text_banner.text = text_review

def get_wc_colourGroups(rawInput_source):

Expand Down Expand Up @@ -61,13 +62,15 @@ def get_clustering_selections(algorithms_neurons):

return (algorithm_select_neuron,cluster_slider)


def get_rawInput_selections(keys_raw):

review = [str(r) for r in list(keys_raw)]
select_rawInput = Select(title="Input review", value=review[0], options=review)

return select_rawInput


def get_projection_selections(algorithms):

algorithm_select = Select(value="PCA",title="Select projection algorithm:",width=250, options=algorithms)
Expand All @@ -77,12 +80,13 @@ def get_projection_selections(algorithms):

def get_rawText_data(rawInput_selections,keys_raw,data_raw,feed,embed_mat):


text_review = str(data_raw[int(rawInput_selections)])
words = str(text_review).split()
text_review = np.array_str(data_raw[int(rawInput_selections)])
txt_rev = text_review.replace('<UNK>','UNK')

words = text_review.split()
word_embeddings = [embed_mat[i,:] for i in list(feed[int(rawInput_selections),:].astype(int))]

return text_review,words,word_embeddings
return txt_rev,words,word_embeddings


"""
Expand Down Expand Up @@ -148,7 +152,7 @@ def update_source(attrname, old, new):
if algorithm_cl_neurons=="KMeans - selected gate":
text_set.text = "KMeans: Clusters neurons based on their gate values after training."
elif algorithm_cl_neurons=="DBSCAN - selected review":
text_set.text = "DBSCAN - selected review: Clusters neurons based on how related their most activating words are. List of activating words generated from seleceted review."
text_set.text = "DBSCAN - selected review: Clusters neurons based on how related their most activating words are. List of activating words generated from selected review."
neuronData = similarityMatrix_PerReview
cluster_labels, colors, _ = clustering.apply_cluster(x,algorithm_cl_neurons,n_clusters,review=int(rawInput_selections.value),neuronData=neuronData,mode="nn")

Expand All @@ -166,7 +170,7 @@ def update_source(attrname, old, new):
elif gate_value=="output_gate":
wc_filename,wc_img,wc_words = get_wcloud(LRP,int(rawInput_selections.value),load_dir,color_dict=color_dict,gate="out")

words_to_be_highlighted = [i for i in wc_words and totalLRP[int(rawInput_selections.value)]['words']]
words_to_be_highlighted = list(set(wc_words).intersection(totalLRP[int(rawInput_selections.value)]['words']))
lrp_source.data['lrp'] = scaler.fit_transform(np.array(totalLRP[int(rawInput_selections.value)]['lrp'].tolist()).reshape(-1,1))
tap_source.data['wc_words'] = words_to_be_highlighted
wc_plot.add_glyph(img_source, ImageURL(url=dict(value=load_dir+wc_filename), x=0, y=0, anchor="bottom_left"))
Expand Down Expand Up @@ -196,10 +200,7 @@ def update_source(attrname, old, new):
with open(load_dir+"exploratoryDataFull.pickle", 'rb') as f:
(testX,embed_mat,excitingWords_fullSet,similarityMatrix_AllReviews,similarityMatrix_PerReview,neuron_types,totalLRP,LRP) = pickle.load(f)


#neuronExcitingWords_AllReviews = list((excitingWords_fullSet.values()))
_,lstm_hidden = data_format.get_data(load_dir+"test_model_internals_lstm_hidden.pickle")
#_,learned_embeddings = data_format.get_data(load_dir+"test_model_internals_ebd.pickle")

#Get preset buttons' selections

Expand Down Expand Up @@ -228,7 +229,7 @@ def update_source(attrname, old, new):
project_plot.add_tools(taptool)

#Input text
text_review,words,word_embeddings = get_rawText_data(rawInput_selections.value,keys_raw,data_raw,testX,embed_mat)
text_review,words,word_embeddings = get_rawText_data(rawInput_selections.value,keys_raw,data_raw,testX,embed_mat)
w2v_labels, w2v_colors, _ = clustering.apply_cluster(np.array(word_embeddings),algorithm="KMeans - selected gate",n_clusters=int(clustering_selections[1].value),mode="wc")
rawInput_source = ColumnDataSource(dict(z=w2v_colors,w=words))

Expand All @@ -248,6 +249,7 @@ def update_source(attrname, old, new):
tap_source = ColumnDataSource(dict(wc_words=words_to_be_highlighted))
scaler = MinMaxScaler(copy=True, feature_range=(-1, 1))
lrp_source = ColumnDataSource(dict(lrp=scaler.fit_transform(np.array(totalLRP[int(rawInput_selections.value)]['lrp'].tolist()).reshape(-1,1))))

#totalLRP : how relevant is each LSTM neuron


Expand Down Expand Up @@ -286,7 +288,7 @@ def update_source(attrname, old, new):
xdr = Range1d(start=0, end=600)
ydr = Range1d(start=0, end=600)
wc_plot = Plot(title=None, x_range=xdr, y_range=ydr, plot_width=500, plot_height=550, min_border=0)
image = ImageURL(url=dict(value=load_dir+wc_filename), x=0, y=0, anchor="bottom_left", retry_attempts=5, retry_timeout=1500)
image = ImageURL(url=dict(value=load_dir+wc_filename), x=0, y=0, anchor="bottom_left", retry_attempts=3, retry_timeout=1000)
wc_plot.add_glyph(img_source, image)


Expand All @@ -297,10 +299,10 @@ def update_source(attrname, old, new):
lrp_timedata = get_lrp_timedata(LRP)
time = [i for i in range(len(lrp_timedata))]
lrptime_source = ColumnDataSource(dict(lrptime = lrp_timedata,time=time))
lrp_plot = figure(title="Total normalized LRP per timestep",plot_width=300, plot_height=50)
lrp_plot = figure(title="Network focus per timestep",plot_width=300, plot_height=50)
lrp_plot.scatter('time','lrptime', marker='circle', size=5, alpha=0.5, source=lrptime_source)
lrp_plot.xaxis.axis_label = 'Time'
lrp_plot.yaxis.axis_label = 'Total normalized LRP'
lrp_plot.yaxis.axis_label = 'Normalized relevance score'


#Layout
Expand Down

0 comments on commit 6d70f15

Please sign in to comment.