Skip to content

Commit

Permalink
classification enhancement (langchain4j#2052)
Browse files Browse the repository at this point in the history
## Issue

when the EmbeddingModelTextClassifier has "higher" minScore, a lot of
useless LabelWithScore will be in the list labelsWithScores, and after
that will be filtered.

This is fix the check in the first for loop to not create useless
LabelWithScore and not to do more work on filter


## Change

## General checklist
- [x] There are no breaking changes
- [ ] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
omarmahamid authored Nov 11, 2024
1 parent 84fce05 commit 54c2f39
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,13 @@ public List<E> classify(String text) {
}
meanScore /= exampleEmbeddings.size();

labelsWithScores.add(new LabelWithScore(label, aggregatedScore(meanScore, maxScore)));
double aggregateScore = aggregatedScore(meanScore, maxScore);
if (aggregateScore >= minScore){
labelsWithScores.add(new LabelWithScore(label, aggregateScore));
}
});

return labelsWithScores.stream()
.filter(it -> it.score >= minScore)
// sorting in descending order to return highest score first
.sorted(comparingDouble(labelWithScore -> 1 - labelWithScore.score))
.limit(maxResults)
Expand Down

0 comments on commit 54c2f39

Please sign in to comment.