Skip to content

Commit 74c2a4e

Browse files
authored
1 parent 1dcb967 commit 74c2a4e

File tree

4 files changed

+52
-9
lines changed

4 files changed

+52
-9
lines changed

tests/metrics/test_bleu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
class TestBLEU(unittest.TestCase):
2121
def test_metrics(self):
2222
bleu = BLEU()
23+
bleu.reset()
2324
cand = ["The", "cat", "The", "cat", "on", "the", "mat"]
2425
ref_list = [["The", "cat", "is", "on", "the", "mat"], ["There", "is", "a", "cat", "on", "the", "mat"]]
2526
bleu.add_inst(cand, ref_list)

tests/metrics/test_chunk.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,24 @@
1414

1515
import unittest
1616

17+
import paddle
18+
1719
from paddlenlp.metrics import ChunkEvaluator
1820

1921

2022
class TestChunk(unittest.TestCase):
2123
def test_metrics(self):
22-
num_infer_chunks = 10
23-
num_label_chunks = 9
24-
num_correct_chunks = 8
25-
26-
label_list = [1, 1, 0, 0, 1, 0, 1]
24+
label_list = ["O", "B-Person", "I-Person"]
2725
evaluator = ChunkEvaluator(label_list)
28-
evaluator.update(num_infer_chunks, num_label_chunks, num_correct_chunks)
26+
evaluator.reset()
27+
lengths = paddle.to_tensor([5])
28+
predictions = paddle.to_tensor([[0, 1, 2, 1, 2]])
29+
labels = paddle.to_tensor([[0, 1, 2, 1, 1]])
30+
num_infer_chunks, num_label_chunks, num_correct_chunks = evaluator.compute(
31+
lengths=lengths, predictions=predictions, labels=labels
32+
)
33+
evaluator.update(num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
2934
precision, recall, f1 = evaluator.accumulate()
30-
self.assertEqual(precision, 0.8)
31-
self.assertEqual(recall, 0.8888888888888888)
32-
self.assertEqual(f1, 0.8421052631578948)
35+
self.assertEqual(precision, 0.5)
36+
self.assertEqual(recall, 0.3333333333333333)
37+
self.assertEqual(f1, 0.4)

tests/metrics/test_rouge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
class TestRouge(unittest.TestCase):
2121
def test_rogue1(self):
2222
rouge1 = Rouge1()
23+
rouge1.reset()
2324
cand = ["The", "cat", "The", "cat", "on", "the", "mat"]
2425
ref_list = [["The", "cat", "is", "on", "the", "mat"], ["There", "is", "a", "cat", "on", "the", "mat"]]
2526
self.assertEqual(rouge1.score(cand, ref_list), 0.07692307692307693)
2627

2728
def test_roguel(self):
2829
rougel = RougeL()
30+
rougel.reset()
2931
cand = ["The", "cat", "The", "cat", "on", "the", "mat"]
3032
ref_list = [["The", "cat", "is", "on", "the", "mat"], ["There", "is", "a", "cat", "on", "the", "mat"]]
3133
rougel.add_inst(cand, ref_list)

tests/metrics/test_span.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle
18+
19+
from paddlenlp.metrics import SpanEvaluator
20+
21+
22+
class TestSpanEvaluator(unittest.TestCase):
23+
def test_metrics(self):
24+
metric = SpanEvaluator()
25+
metric.reset()
26+
start_prob = paddle.to_tensor([[0.1, 0.1, 0.6, 0.2], [0.0, 0.9, 0.1, 0.0]])
27+
end_prob = paddle.to_tensor([[0.1, 0.1, 0.2, 0.6], [0.0, 0.9, 0.1, 0.0]])
28+
start_ids = paddle.to_tensor([[0, 0, 1, 0], [0, 0, 1, 0]])
29+
end_ids = paddle.to_tensor([[0, 0, 0, 1], [0, 0, 1, 0]])
30+
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
31+
metric.update(num_correct, num_infer, num_label)
32+
precision, recall, f1 = metric.accumulate()
33+
self.assertEqual(precision, 0.5)
34+
self.assertEqual(recall, 0.5)
35+
self.assertEqual(f1, 0.5)

0 commit comments

Comments
 (0)