Skip to content

Commit

Permalink
Support for multilingual generative metrics (#293)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com>
Co-authored-by: Hynek Kydlicek <kydliceh.hynek@gmail.com>
Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 27, 2024
1 parent 7e80aaa commit 1bb0dff
Show file tree
Hide file tree
Showing 17 changed files with 624 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
cache: 'pip'
- name: Install lighteval in editable mode
run: |
pip install -e .[dev,extended_tasks]
pip install -e .[dev,extended_tasks,multilingual]
- name: Get cached files
uses: actions/cache@v2
id: get-cache
Expand Down
2 changes: 1 addition & 1 deletion community_tasks/_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from lighteval.metrics import Metrics
from lighteval.metrics.metrics import SampleLevelMetric
from lighteval.metrics.utils import MetricCategory, MetricUseCase
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase
from lighteval.tasks.default_prompts import LETTER_INDICES
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ extended_tasks = [
"openai", # llm as a judge using openai models
]
s3 = ["s3fs"]
multilingual = [
"stanza",
"spacy[ja,ko,th]",
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
84 changes: 80 additions & 4 deletions src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,25 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Callable
from typing import Callable, Literal

import numpy as np

from lighteval.metrics.metrics_sample import LoglikelihoodAcc, NormalizedMultiChoiceProbability, Probability
from lighteval.metrics.normalizations import LogProbNormalization, LogProbPMINorm, LogProbTokenNorm
from lighteval.metrics.utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.metrics.metrics_sample import (
ExactMatches,
F1_score,
LoglikelihoodAcc,
NormalizedMultiChoiceProbability,
Probability,
)
from lighteval.metrics.normalizations import (
LogProbNormalization,
LogProbPMINorm,
LogProbTokenNorm,
get_multilingual_normalizer,
)
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.utils.language import Language


def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None) -> SampleLevelMetric:
Expand Down Expand Up @@ -92,3 +104,67 @@ def probability_metric(
corpus_level_fn=np.mean,
higher_is_better=True,
)


def multilingual_quasi_f1_score_metric(
language: Language, aggregation_function: Callable[[list[float]], float] = max
) -> SampleLevelMetric:
"""
Creates a language-aware F1 score metric, which returns the F1 score.
Args:
language: The language of the samples.
aggregation_function: Aggregation samples to use when multiple golds are present.
Returns:
F1 score metric.
"""
metric_name = f"f1_{language}"

multilang_normalizer = get_multilingual_normalizer(language)
return SampleLevelMetric(
metric_name=metric_name,
sample_level_fn=F1_score(
normalize_gold=multilang_normalizer,
normalize_pred=multilang_normalizer,
aggregation_function=aggregation_function,
).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
higher_is_better=True,
)


def multilingual_quasi_exact_match_metric(
language: Language,
match_type: Literal["prefix", "suffix", "full"] = "full",
aggregation_function: Callable[[list[float]], float] = max,
) -> SampleLevelMetric:
"""
Creates a language-aware exact match metric, which returns the exact match score
Args:
language: The language of the samples.
match_type: The type of match to use
- "prefix": Prefixes must match
- "suffix": Suffixes must match
- "full": Full strings must match
aggregation_function: Aggregation samples to use when multiple golds are present.
Returns:
Exact match metric.
"""
metric_name = f"exact_match_{language}_{match_type}"
multilang_normalizer = get_multilingual_normalizer(language)
return SampleLevelMetric(
metric_name=metric_name,
sample_level_fn=ExactMatches(
normalize_gold=multilang_normalizer,
normalize_pred=multilang_normalizer,
aggregation_function=aggregation_function,
type_exact_match=match_type,
).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
higher_is_better=True,
)
2 changes: 1 addition & 1 deletion src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
remove_braces_and_strip,
)
from lighteval.metrics.sample_preparator import GenerativePreparator, LoglikelihoodPreparator, PerplexityPreparator
from lighteval.metrics.utils import (
from lighteval.metrics.utils.metric_utils import (
CorpusLevelMetric,
CorpusLevelMetricGrouping,
Metric,
Expand Down
16 changes: 7 additions & 9 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@
class ExactMatches:
def __init__(
self,
aggregation_function: callable = None,
normalize_gold: callable = None,
normalize_pred: callable = None,
aggregation_function: Callable[[list[float]], float] = max,
normalize_gold: Callable[[str], str] | None = None,
normalize_pred: Callable[[str], str] | None = None,
strip_strings: bool = False,
type_exact_match: str = "full",
):
Expand All @@ -78,8 +78,6 @@ def __init__(
`suffix` if the prediction ends with the gold,
`full` if the prediction and gold are equal
"""
if aggregation_function is None:
aggregation_function = max
self.aggregation_function = aggregation_function
self.normalize_gold = normalize_gold
self.normalize_pred = normalize_pred
Expand Down Expand Up @@ -145,9 +143,9 @@ def compute_one_item(
class F1_score:
def __init__(
self,
aggregation_function: callable = None,
normalize_gold: callable = None,
normalize_pred: callable = None,
aggregation_function: Callable[[list[float]], float] = max,
normalize_gold: Callable[[str], str] | None = None,
normalize_pred: Callable[[str], str] | None = None,
strip_strings: bool = False,
):
"""An F1 score class. F1 is computed over the bag of words of the golds and predictions.
Expand All @@ -163,8 +161,8 @@ def __init__(
"""
if aggregation_function is None:
aggregation_function = max
self.aggregation_function = aggregation_function

self.aggregation_function = aggregation_function
self.normalize_gold = normalize_gold
self.normalize_pred = normalize_pred
self.strip_strings = strip_strings
Expand Down
61 changes: 61 additions & 0 deletions src/lighteval/metrics/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@

import re
import string
import sys
import unicodedata
from dataclasses import dataclass
from typing import Callable

from lighteval.metrics.utils.linguistic_tokenizers import get_word_tokenizer
from lighteval.utils.language import Language


# From HELM
Expand Down Expand Up @@ -355,6 +361,61 @@ def gsm8k_normalizer(text: str) -> str:
return INVALID_ANS


PUNCT = {chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")}.union(
string.punctuation
)

_ARTICLE_PATTERNS = {
Language.ENGLISH: r"\b(a|an|the)\b",
Language.SPANISH: r"\b(el|la|los|las|un|una|unos|unas)\b",
Language.PORTUGUESE: r"\b(o|a|os|as|um|uma|uns|umas)\b",
Language.ITALIAN: r"\b(il|lo|la|i|gli|le|un|uno|una)\b",
Language.FRENCH: r"\b(le|la|les|l'|un|une|des)\b",
Language.GERMAN: r"\b(der|die|das|den|dem|des|ein|eine|einer|eines|einem|einen)\b",
Language.FINNISH: r"\b(se|yksi|yks)\b",
Language.GREEK: r"\b(ὁ|οἱ|τοῦ|τῶν|τόν|τούς|ὦ|ἡ|αἱ|τῆς|τῶν|τήν|τάς|τό|τά|τοῦ|τῶν|τό|τά)\b",
Language.NORWEGIAN: r"\b(en|ei|et|den|det|de)\b",
Language.SWEDISH: r"\b(en|ett|den|det|de)\b",
Language.TURKISH: r"\b(bir)\b",
Language.DUTCH: r"\b(de|het|een)\b",
Language.HUNGARIAN: r"\b(a|az|egy)\b",
Language.CATALAN: r"\b(el|la|els|les|un|una|uns|unes)\b",
Language.HEBREW: r"\b(ה)\b",
Language.GALICIAN: r"\b(o|a|os|as|un|unha|uns|unhas)\b",
}


def remove_articles(text: str, lang: Language) -> str:
"""
Removes definite and indefinite articles from the text.
Generated using LLM then manually checked by non-expert.
We currently only support languages that don't blend articles.
If you are a native speaker of a language where articles are blended,
we would appreciate your contribution!
"""
pattern = _ARTICLE_PATTERNS.get(lang)
return re.sub(pattern, " ", text) if pattern else text


def remove_punc(text: str) -> str:
return "".join(ch for ch in text if ch not in PUNCT)


def get_multilingual_normalizer(lang: Language, lower: bool = True) -> Callable[[str], str]:
tokenizer = get_word_tokenizer(lang)

def _inner_normalizer(text: str) -> str:
text = remove_articles(text, lang)
text = remove_punc(text)
if lower:
text = text.lower()

tokens = tokenizer.word_tokenize(text)
return " ".join(tokens)

return _inner_normalizer


# Loglikelihood normalization
@dataclass
class LogProbPMINorm:
Expand Down
Loading

0 comments on commit 1bb0dff

Please sign in to comment.