-
Notifications
You must be signed in to change notification settings - Fork 14
plot function for token attributions #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition! 👏
Left some comments to try to make things simpler, but overall looks good!
Also, move the file from src/diffusers_interpret/plots/plots.py
to src/diffusers_interpret/token_attributions.py
. The plan is to make output.token_attributions
a class instead of being str
- I can help with that 😄
# Convert list of tuples to a dataframe | ||
df = pd.DataFrame(self, columns=['Tokens', 'percent']).set_index('Tokens') | ||
|
||
# Bar chart | ||
if type == 'bar': | ||
df.plot.bar(ylabel = 'percent', | ||
title = title, | ||
legend = False, | ||
rot = rot); | ||
|
||
# Horizontal bar chart | ||
elif type == 'barh': | ||
df.plot.barh(ylabel = 'percent', | ||
title = title, | ||
legend = False, | ||
rot = 0); | ||
|
||
# Pie chart | ||
elif type == 'pie': | ||
df.plot.pie(y = 'percent', | ||
title = title, | ||
legend = False, | ||
autopct = '%1.1f%%', | ||
figsize = (8, 8)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pandas is calling matplotlib in the background.
Let's call matplotlib directly and not add an extra dependency to the package
Note: haven't tested this code yet :)
# Convert list of tuples to a dataframe | |
df = pd.DataFrame(self, columns=['Tokens', 'percent']).set_index('Tokens') | |
# Bar chart | |
if type == 'bar': | |
df.plot.bar(ylabel = 'percent', | |
title = title, | |
legend = False, | |
rot = rot); | |
# Horizontal bar chart | |
elif type == 'barh': | |
df.plot.barh(ylabel = 'percent', | |
title = title, | |
legend = False, | |
rot = 0); | |
# Pie chart | |
elif type == 'pie': | |
df.plot.pie(y = 'percent', | |
title = title, | |
legend = False, | |
autopct = '%1.1f%%', | |
figsize = (8, 8)); | |
tokens, attributions = list(zip(*self)) # TODO: this can be changed, depending how we construct the class | |
plot_kwargs = {'title': 'Token Attributions', **plot_kwargs} | |
if plot_type == 'bar': | |
# Bar chart | |
plt.bar(tokens, attributions, **{'legend': False, 'rot': 60, **plot_kwargs}) | |
elif plot_type == 'barh': | |
# Horizontal bar chart | |
plt.barh(tokens, attributions, **{'legend': False, 'rot': 0, **plot_kwargs}) | |
elif type == 'pie': | |
# Pie chart | |
plt.pie(attributions, **{'labels': tokens, **plot_kwargs}) | |
else: | |
raise NotImplementedError(f"`plot_type={plot_type}` is not implemented. Choose one of: ['bar', 'barh', 'pie']") | |
Continued at PR #13 |
I'm tackling the goal of token attributions visualization, and I've created a plot function that still needs to be intergrated with explainer.py's
output.normalized_token_attributions
.I attempted to use dynamic assertation and category suggestion based on this Stackoverflow comment. I was very satisfied with the previous pull request process and would gladly appreciate additional guidance and feedback on this one☺️