Allow for assessing performance in multi-class classification setting #176
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Reference Issues/PRs
No issue to reference.
What does this implement/fix? Explain your changes.
In the case of multi-class classification, each prediction task is not entirely separate. A user might want to compute the micro average of the ROC or PR curves, or compute the F1 score (micro or macro averaged). Computing these metrics requires knowledge of all targets and all predictions, at the same time. Currently the implementation does not allow this because it loops over each column of predictions separately, computes performance metrics, and then averages them together at the end.
I added a condition to the
compute_score
function that allows for this. In the config file, the user specifies the targets as single values (i.e. a column vector). The NN will outputK
predictions per example, corresponding to the probability that the example belongs to each ofK
classes.Performance of the original implementation can be rescued by either (1) one-hot encoding the targets or (2) using implementations of performance metrics that use macro averaging, and specifying them in the config file.
What testing did you do to verify the changes in this PR?
Ran Selene using data where
targets
is a column vector of integer values0,...,K-1
and NN architecture outputsK
values for each example, corresponding to probabilities of the example belonging to each of theK
classes. Wrote custom wrappers of performance metrics and confirmed that the metrics are only called once per epoch, rather thanK
times.