Skip to content

Allow for assessing performance in multi-class classification setting #176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 4, 2021

Conversation

rfriedman22
Copy link
Contributor

Reference Issues/PRs

No issue to reference.

What does this implement/fix? Explain your changes.

In the case of multi-class classification, each prediction task is not entirely separate. A user might want to compute the micro average of the ROC or PR curves, or compute the F1 score (micro or macro averaged). Computing these metrics requires knowledge of all targets and all predictions, at the same time. Currently the implementation does not allow this because it loops over each column of predictions separately, computes performance metrics, and then averages them together at the end.

I added a condition to the compute_score function that allows for this. In the config file, the user specifies the targets as single values (i.e. a column vector). The NN will output K predictions per example, corresponding to the probability that the example belongs to each of K classes.

Performance of the original implementation can be rescued by either (1) one-hot encoding the targets or (2) using implementations of performance metrics that use macro averaging, and specifying them in the config file.

What testing did you do to verify the changes in this PR?

Ran Selene using data where targets is a column vector of integer values 0,...,K-1 and NN architecture outputs K values for each example, corresponding to probabilities of the example belonging to each of the K classes. Wrote custom wrappers of performance metrics and confirmed that the metrics are only called once per epoch, rather than K times.

@rfriedman22 rfriedman22 marked this pull request as ready for review September 16, 2021 22:22
@kathyxchen
Copy link
Collaborator

Hi Ryan, thanks again for all your PRs and bugfixes as you continue to use Selene, it's really great to have your contributions. :)

I had a few questions about this PR -

  1. Just to confirm, this won't disrupt our existing functionality at all right? (It looks like the if-else accounts for that pretty seamlessly, I just want to double check that that's your understanding too)
  2. We should probably document that this is an option! Could you provide a code snippet showing how you use the new functionality (e.g. how do you specify it in the config file format)?
    I'm thinking we should put it somewhere on the documentation website or at least give an example in the config_examples folder

@rfriedman22
Copy link
Contributor Author

Hi Kathy,

It's been fun contributing to the code as it develops and adding small features/bugfixes as I use Selene in my own work!

To answer your questions:

  1. That is correct, it does not disrupt existing functionality, it just adds a new feature, but I can double check that existing functionality is maintained just to be certain. It came about from working with my own data. Basically, I'm doing multi-class classification where my NN has K classification tasks, but the targets can only take on one of K values (i.e. to make a prediction you would take the argmax over the K classification tasks). I wanted to track the F1 score during the fitting procedure because the ROC and PR curves can sometimes be misleading, but the existing functionality doesn't allow for me to do that.
  2. Yes, this is great idea! Attached are a few files with the relevant code:
    1. The config file with the NN architecture and the custom implementation of performance metrics. I'm not sure if the features field is necessarily correct with how it is intended, but it works for me without any issues. Critically, the targets are encoded as an (N, 1) column vector, i.e. for any sequence n the target is a list of length one, corresponding to a value 0, 1, ..., K - 1.
    2. The file with a model architecture.
    3. A custom implementation of the loss function to handle the fact that the targets needed to be a column vector for Selene, but NLLLoss expects something different.
    4. The custom implementation of the metrics I'm using.

Let me know if anything is unclear or if you need anything else from me! Happy to help document this feature as needed.

ForKathy.tar.gz

@kathyxchen
Copy link
Collaborator

Thanks for this example! It's really useful/interesting to know how you've been using the library - and I'm really looking forward to seeing what you do with this model! (I'm guessing you're aiming to publish a paper+code for this at some point?)

I'll plan to add some info / examples for this to our docs website soon (though "soon" for me is usually quite slow these days... haha)

At some point, if you have thoughts on if there are other changes we could incorporate to make Selene more flexible to your use case, I'd love to chat! We have been talking about substantially updating the API, among other things, for quite a while now & would welcome more input :)

@kathyxchen kathyxchen merged commit c67b8b0 into FunctionLab:master Oct 4, 2021
@rfriedman22
Copy link
Contributor Author

Yes, I'm aiming to use work from my model to publish eventually! Glad you're interested in the work :)

Let me know if you need anything as you add details to the docs. I know it's a thankless job.

I'm happy to chat and discuss my experience using Selene in more detail and talk a bit more about how I'm using my model. So far things have been reasonably straightforward, although I have a few thoughts on how it can be more flexible etc. Shoot me an email at ryan.friedman@wustl.edu and we can find a time to chat!

@rfriedman22 rfriedman22 deleted the multiclass-metrics branch October 5, 2021 15:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants