Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multilingual NLI Tasks #329

Merged
merged 43 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
5c69eb0
add multilignaul dynamic generative metrics
hynky1999 Sep 5, 2024
39c4220
Merge branch 'main' into geneartive_dynamic_metrics
hynky1999 Sep 5, 2024
2a5cdca
Merge branch 'geneartive_dynamic_metrics' into config_templates
hynky1999 Sep 5, 2024
2df9a08
draft
hynky1999 Sep 6, 2024
95729ee
finish multichoice config
hynky1999 Sep 9, 2024
3aa0579
Merge branch 'main' into geneartive_dynamic_metrics
hynky1999 Sep 9, 2024
b8f90a9
update tokenizers + install nltk reqs
hynky1999 Sep 9, 2024
f5a8717
use punkt tab
hynky1999 Sep 9, 2024
227f572
Update src/lighteval/utils/imports.py
hynky1999 Sep 13, 2024
d80b3ba
Update src/lighteval/metrics/normalizations.py
hynky1999 Sep 13, 2024
532bdad
fix imports
Sep 13, 2024
75f7ac5
remove unused import
Sep 13, 2024
f99e330
Merge branch 'main' into geneartive_dynamic_metrics
NathanHB Sep 13, 2024
92daf90
Merge branch 'main' into geneartive_dynamic_metrics
clefourrier Sep 14, 2024
f2a801d
Merge branch 'main' into geneartive_dynamic_metrics
NathanHB Sep 17, 2024
91d9d4f
finish implementation of templates + move stuff around
Sep 23, 2024
9356cc6
resolve nits
Sep 23, 2024
0fbc731
when in rome do as romans do (handle error messages the same way)
Sep 23, 2024
fa1fa83
fix utils
hynky1999 Sep 23, 2024
db36e16
Merge branch 'geneartive_dynamic_metrics' into config_templates
hynky1999 Sep 23, 2024
44aeecf
nicers tests + fix them
hynky1999 Sep 23, 2024
2bff963
nicer todo
hynky1999 Sep 23, 2024
3c9eb21
add nice doscrings 📃
hynky1999 Sep 23, 2024
4216ae2
add even more docstring
hynky1999 Sep 23, 2024
d8f56b8
nit
hynky1999 Sep 23, 2024
f26e88c
fix test
hynky1999 Sep 23, 2024
111d615
add multilingual to dev group
hynky1999 Sep 24, 2024
7ca4239
merge nli, add languagees to literals
hynky1999 Sep 25, 2024
22eeddb
translation literals
hynky1999 Sep 25, 2024
7faaa8a
add nli
hynky1999 Sep 25, 2024
ba44fe9
add rcb + chinese nli
hynky1999 Sep 26, 2024
2d09256
Merge branch 'geneartive_dynamic_metrics' into config_templates
hynky1999 Sep 26, 2024
7324e89
Merge branch 'config_templates' into multilnag_nli_tasks
hynky1999 Sep 26, 2024
ca865bd
Update src/lighteval/tasks/multilingual/tasks.py
hynky1999 Sep 30, 2024
1cc1187
Update src/lighteval/tasks/multilingual/tasks.py
hynky1999 Sep 30, 2024
d64251f
Update src/lighteval/tasks/multilingual/tasks.py
hynky1999 Sep 30, 2024
9806fab
Update src/lighteval/tasks/multilingual/tasks.py
hynky1999 Sep 30, 2024
35d7e6d
Update src/lighteval/tasks/multilingual/tasks.py
hynky1999 Sep 30, 2024
e560738
Update src/lighteval/tasks/multilingual/tasks.py
hynky1999 Sep 30, 2024
99524c5
Update src/lighteval/tasks/multilingual/tasks.py
hynky1999 Sep 30, 2024
150c76f
add two new tasks + docs
hynky1999 Sep 30, 2024
4e6100d
Merge branch 'multilnag_nli_tasks' of github.com:huggingface/lighteva…
hynky1999 Sep 30, 2024
7b561fe
Merge remote-tracking branch 'origin/main' into multilnag_nli_tasks
hynky1999 Sep 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion community_tasks/_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from lighteval.metrics import Metrics
from lighteval.metrics.metrics import SampleLevelMetric
from lighteval.metrics.utils import MetricCategory, MetricUseCase
from lighteval.metrics.utils.utils import MetricCategory, MetricUseCase
from lighteval.tasks.default_prompts import LETTER_INDICES
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,18 @@ tensorboardX = ["tensorboardX"]
vllm = ["vllm", "ray", "more_itertools"]
quality = ["ruff==v0.2.2","pre-commit"]
tests = ["pytest==7.4.0"]
dev = ["lighteval[accelerate,quality,tests]"]
dev = ["lighteval[accelerate,quality,tests,multilingual]"]
extended_tasks = [
"langdetect", # ifeval
"openai", # llm as a judge using openai models
]
s3 = ["s3fs"]
multilingual = [
"stanza",
"spacy[ja,ko]",
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
84 changes: 80 additions & 4 deletions src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,25 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Callable
from typing import Callable, Literal

import numpy as np

from lighteval.metrics.metrics_sample import LoglikelihoodAcc, NormalizedMultiChoiceProbability, Probability
from lighteval.metrics.normalizations import LogProbNormalization, LogProbPMINorm, LogProbTokenNorm
from lighteval.metrics.utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.metrics.metrics_sample import (
ExactMatches,
F1_score,
LoglikelihoodAcc,
NormalizedMultiChoiceProbability,
Probability,
)
from lighteval.metrics.normalizations import (
LogProbNormalization,
LogProbPMINorm,
LogProbTokenNorm,
get_multilingual_normalizer,
)
from lighteval.metrics.utils.utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.utils.language import Language


def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None) -> SampleLevelMetric:
Expand Down Expand Up @@ -92,3 +104,67 @@ def probability_metric(
corpus_level_fn=np.mean,
higher_is_better=True,
)


def multilingual_quasi_f1_score_metric(
language: Language, aggregation_function: Callable[[list[float]], float] = max
) -> SampleLevelMetric:
"""
Creates a language-aware F1 score metric, which returns the F1 score.

Args:
language: The language of the samples.
aggregation_function: Aggregation samples to use when multiple golds are present.

Returns:
F1 score metric.
"""
metric_name = f"f1_{language}"

multilang_normalizer = get_multilingual_normalizer(language)
return SampleLevelMetric(
metric_name=metric_name,
sample_level_fn=F1_score(
normalize_gold=multilang_normalizer,
normalize_pred=multilang_normalizer,
aggregation_function=aggregation_function,
).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
higher_is_better=True,
)


def multilingual_quasi_exact_match_metric(
language: Language,
match_type: Literal["prefix", "suffix", "full"] = "full",
aggregation_function: Callable[[list[float]], float] = max,
) -> SampleLevelMetric:
"""
Creates a language-aware exact match metric, which returns the exact match score
Args:
language: The language of the samples.
match_type: The type of match to use
- "prefix": Prefixes must match
- "suffix": Suffixes must match
- "full": Full strings must match
aggregation_function: Aggregation samples to use when multiple golds are present.
Returns:
Exact match metric.
"""
metric_name = f"exact_match_{language}_{match_type}"
multilang_normalizer = get_multilingual_normalizer(language)
return SampleLevelMetric(
metric_name=metric_name,
sample_level_fn=ExactMatches(
normalize_gold=multilang_normalizer,
normalize_pred=multilang_normalizer,
aggregation_function=aggregation_function,
type_exact_match=match_type,
).compute,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
higher_is_better=True,
)
2 changes: 1 addition & 1 deletion src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
remove_braces_and_strip,
)
from lighteval.metrics.sample_preparator import GenerativePreparator, LoglikelihoodPreparator, PerplexityPreparator
from lighteval.metrics.utils import (
from lighteval.metrics.utils.utils import (
CorpusLevelMetric,
CorpusLevelMetricGrouping,
Metric,
Expand Down
18 changes: 7 additions & 11 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@
class ExactMatches:
def __init__(
self,
aggregation_function: callable = None,
normalize_gold: callable = None,
normalize_pred: callable = None,
aggregation_function: Callable[[list[float]], float] = max,
normalize_gold: Callable[[str], str] | None = None,
normalize_pred: Callable[[str], str] | None = None,
strip_strings: bool = False,
type_exact_match: str = "full",
):
Expand All @@ -78,8 +78,6 @@ def __init__(
`suffix` if the prediction ends with the gold,
`full` if the prediction and gold are equal
"""
if aggregation_function is None:
aggregation_function = max
self.aggregation_function = aggregation_function
self.normalize_gold = normalize_gold
self.normalize_pred = normalize_pred
Expand Down Expand Up @@ -145,9 +143,9 @@ def compute_one_item(
class F1_score:
def __init__(
self,
aggregation_function: callable = None,
normalize_gold: callable = None,
normalize_pred: callable = None,
aggregation_function: Callable[[list[float]], float] = max,
normalize_gold: Callable[[str], str] | None = None,
normalize_pred: Callable[[str], str] | None = None,
strip_strings: bool = False,
):
"""An F1 score class. F1 is computed over the bag of words of the golds and predictions.
Expand All @@ -161,10 +159,8 @@ def __init__(
Defaults to None if no normalization is applied.
strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False.
"""
if aggregation_function is None:
aggregation_function = max
self.aggregation_function = aggregation_function

self.aggregation_function = aggregation_function
self.normalize_gold = normalize_gold
self.normalize_pred = normalize_pred
self.strip_strings = strip_strings
Expand Down
60 changes: 60 additions & 0 deletions src/lighteval/metrics/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@

import re
import string
import sys
import unicodedata
from dataclasses import dataclass
from typing import Callable

from lighteval.metrics.utils.linguistic_tokenizers import get_word_tokenizer
from lighteval.utils.language import Language


# From HELM
Expand Down Expand Up @@ -352,6 +358,60 @@ def gsm8k_normalizer(text: str) -> str:
return INVALID_ANS


PUNCT = {chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")}.union(
string.punctuation
)

_ARTICLE_PATTERNS = {
Language.ENGLISH: r"\b(a|an|the)\b",
Language.SPANISH: r"\b(el|la|los|las|un|una|unos|unas)\b",
Language.PORTUGUESE: r"\b(o|a|os|as|um|uma|uns|umas)\b",
Language.ITALIAN: r"\b(il|lo|la|i|gli|le|un|uno|una)\b",
Language.FRENCH: r"\b(le|la|les|l'|un|une|des)\b",
Language.GERMAN: r"\b(der|die|das|den|dem|des|ein|eine|einer|eines|einem|einen)\b",
Language.FINNISH: r"\b(se|yksi|yks)\b",
Language.GREEK: r"\b(ὁ|οἱ|τοῦ|τῶν|τόν|τούς|ὦ|ἡ|αἱ|τῆς|τῶν|τήν|τάς|τό|τά|τοῦ|τῶν|τό|τά)\b",
Language.NORWEGIAN: r"\b(en|ei|et|den|det|de)\b",
Language.SWEDISH: r"\b(en|ett|den|det|de)\b",
Language.TURKISH: r"\b(bir)\b",
Language.DUTCH: r"\b(de|het|een)\b",
Language.HUNGARIAN: r"\b(a|az|egy)\b",
Language.CATALAN: r"\b(el|la|els|les|un|una|uns|unes)\b",
Language.HEBREW: r"\b(ה)\b",
Language.GALICIAN: r"\b(o|a|os|as|un|unha|uns|unhas)\b",
}


def remove_articles(text: str, lang: Language) -> str:
"""
Removes definite and indefinite articles from the text.
Generated using LLM then manually checked by non-expert.
Only languages that don't blend the articles, if you are native speaker,
we would appreciate adding also languages that blend the articles.
"""
pattern = _ARTICLE_PATTERNS.get(lang)
return re.sub(pattern, " ", text) if pattern else text


def remove_punc(text: str) -> str:
return "".join(ch for ch in text if ch not in PUNCT)


def get_multilingual_normalizer(lang: Language, lower: bool = True) -> Callable[[str], str]:
tokenizer = get_word_tokenizer(lang)

def _inner_normalizer(text: str) -> str:
text = remove_articles(text, lang)
text = remove_punc(text)
if lower:
text = text.lower()

tokens = tokenizer.word_tokenize(text)
return " ".join(tokens)

return _inner_normalizer


# Loglikelihood normalization
@dataclass
class LogProbPMINorm:
Expand Down
Loading