forked from Moonlight-Syntax/LUNA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_base.py
63 lines (47 loc) · 1.6 KB
/
test_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import dataclasses
import typing as tp
import pytest
from luna.base import Metrics
from tests.conftest import get_mock_data
class ReferenceBasedMetric(Metrics):
def evaluate_example(self, hyp: str, ref: tp.Optional[str]) -> float:
return len(hyp) - len(ref)
class ReferenceFreeMetric(Metrics):
def evaluate_example(self, hyp: str, ref: tp.Optional[str]) -> float:
return len(hyp) - 10.0
@dataclasses.dataclass
class Case:
metric: Metrics
hyps: tp.List[str]
refs: tp.Optional[tp.List[str]]
expected_result: tp.List[float]
BASIC_TEST_CASES = [
Case(metric=ReferenceBasedMetric(),
hyps=['cc', 'bb'],
refs=['aa', 'bbb'],
expected_result=[0., -1.]),
Case(metric=ReferenceFreeMetric(),
hyps=['cc', 'bbb'],
refs=None,
expected_result=[-8., -7.])
]
@pytest.mark.parametrize('test_case', BASIC_TEST_CASES)
def test_base_class_compute(test_case: Case) -> None:
result = test_case.metric.evaluate_batch(test_case.hyps, test_case.refs)
assert result == test_case.expected_result
HYPS, REFS = get_mock_data()
BATCH_TEST_CASES = [
Case(metric=ReferenceBasedMetric(),
hyps=HYPS,
refs=REFS,
expected_result=[6]),
Case(metric=ReferenceFreeMetric(),
hyps=HYPS,
refs=REFS,
expected_result=[6])
]
@pytest.mark.parametrize('test_case', BATCH_TEST_CASES)
def test_mock_data_base_compute(test_case: Case) -> None:
result = test_case.metric.evaluate_batch(test_case.hyps, test_case.refs)
# Comparison of lengths
assert len(result) == test_case.expected_result[0]