Skip to content

plot function for token attributions #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

TomPham97
Copy link
Contributor

I'm tackling the goal of token attributions visualization, and I've created a plot function that still needs to be intergrated with explainer.py's output.normalized_token_attributions.

I attempted to use dynamic assertation and category suggestion based on this Stackoverflow comment. I was very satisfied with the previous pull request process and would gladly appreciate additional guidance and feedback on this one ☺️

@JoaoLages JoaoLages self-requested a review September 13, 2022 09:00
Copy link
Owner

@JoaoLages JoaoLages left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition! 👏
Left some comments to try to make things simpler, but overall looks good!
Also, move the file from src/diffusers_interpret/plots/plots.py to src/diffusers_interpret/token_attributions.py. The plan is to make output.token_attributions a class instead of being str - I can help with that 😄

Comment on lines +31 to +54
# Convert list of tuples to a dataframe
df = pd.DataFrame(self, columns=['Tokens', 'percent']).set_index('Tokens')

# Bar chart
if type == 'bar':
df.plot.bar(ylabel = 'percent',
title = title,
legend = False,
rot = rot);

# Horizontal bar chart
elif type == 'barh':
df.plot.barh(ylabel = 'percent',
title = title,
legend = False,
rot = 0);

# Pie chart
elif type == 'pie':
df.plot.pie(y = 'percent',
title = title,
legend = False,
autopct = '%1.1f%%',
figsize = (8, 8));
Copy link
Owner

@JoaoLages JoaoLages Sep 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pandas is calling matplotlib in the background.
Let's call matplotlib directly and not add an extra dependency to the package

Note: haven't tested this code yet :)

Suggested change
# Convert list of tuples to a dataframe
df = pd.DataFrame(self, columns=['Tokens', 'percent']).set_index('Tokens')
# Bar chart
if type == 'bar':
df.plot.bar(ylabel = 'percent',
title = title,
legend = False,
rot = rot);
# Horizontal bar chart
elif type == 'barh':
df.plot.barh(ylabel = 'percent',
title = title,
legend = False,
rot = 0);
# Pie chart
elif type == 'pie':
df.plot.pie(y = 'percent',
title = title,
legend = False,
autopct = '%1.1f%%',
figsize = (8, 8));
tokens, attributions = list(zip(*self)) # TODO: this can be changed, depending how we construct the class
plot_kwargs = {'title': 'Token Attributions', **plot_kwargs}
if plot_type == 'bar':
# Bar chart
plt.bar(tokens, attributions, **{'legend': False, 'rot': 60, **plot_kwargs})
elif plot_type == 'barh':
# Horizontal bar chart
plt.barh(tokens, attributions, **{'legend': False, 'rot': 0, **plot_kwargs})
elif type == 'pie':
# Pie chart
plt.pie(attributions, **{'labels': tokens, **plot_kwargs})
else:
raise NotImplementedError(f"`plot_type={plot_type}` is not implemented. Choose one of: ['bar', 'barh', 'pie']")

@TomPham97
Copy link
Contributor Author

Continued at PR #13

@TomPham97 TomPham97 closed this Sep 13, 2022
JoaoLages added a commit that referenced this pull request Sep 14, 2022
* apply suggestions and customize pyplot

* Apply suggestions from code review

Co-authored-by: João Lages <joaop.glages@gmail.com>

Co-authored-by: João Lages <joaop.glages@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants