Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def test_table_teds_metric(self):
self.assertTrue(teds_result.success)
self.assertIsInstance(teds_result.score, float)
# 验证固定内容的确定分数
self.assertAlmostEqual(teds_result.score, 0.300000, places=5,
msg=f"table_TEDS分数应该是0.300000,实际: {teds_result.score}")
self.assertAlmostEqual(teds_result.score, 0.5199999999999999, places=5,
msg=f"table_TEDS分数应该是0.5199999999999999,实际: {teds_result.score}")

# 验证详细信息
self.assertEqual(teds_result.details['content_type'], 'table')
Expand Down
27 changes: 26 additions & 1 deletion tests/test_teds.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,19 @@ def test_very_large_table(self):
self.assertTrue(result.success)
self.assertEqual(result.score, 1.0)

def test_teds_structure_same_content_different(self):
"""测试结构相同但内容不同的表格 - 验证修复后的TEDS不会返回0分"""
pred = "<table><tr><td>我不喜欢你</td></tr></table>"
gt = "<table><tr><td>我喜欢你</td></tr></table>"

result = self.teds_metric.calculate(
predicted=pred,
groundtruth=gt,
table_edit_result=self.valid_table_edit_result
)
assert result.score == 0.7999999999999999



class TestTEDSAdvanced(unittest.TestCase):
"""Advanced TEDS functionality tests - 高级功能测试"""
Expand Down Expand Up @@ -301,6 +314,18 @@ def test_teds_complex_table(self):
self.assertGreater(result.score, 0.0)
self.assertLess(result.score, 1.0)

def test_teds_content_similarity(self):
"""Test TEDS with similar content but different text - 测试内容相似度"""
table1 = "<table><tr><td>苹果很好吃</td><td>香蕉也不错</td></tr></table>"
table2 = "<table><tr><td>苹果很美味</td><td>香蕉也很好</td></tr></table>"

result = self.teds.calculate(
table1,
table2,
table_edit_result=self.valid_table_edit_result
)
assert result.score == 0.3999999999999999


class TestStructureTEDS(unittest.TestCase):
"""Structure-only TEDS tests - 结构化TEDS测试"""
Expand Down Expand Up @@ -457,7 +482,7 @@ def run_all_teds_tests():
# Add all test classes
test_classes = [
# 注意:确保TestTEDSBasic已定义或从其他文件导入
# TestTEDSBasic,
TestTEDSBasic,
TestTEDSAdvanced,
TestStructureTEDS,
TestTEDSEdgeCases
Expand Down
148 changes: 133 additions & 15 deletions webmainbench/metrics/teds_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _calculate_score(self, predicted: Any, groundtruth: Any, **kwargs) -> Metric

return MetricResult(
metric_name=self.name,
score=max(0.0, min(1.0, teds_score)),
score=max(0.0, min(1.0, teds_score)), # 删除多余的右括号
details=details
)

Expand Down Expand Up @@ -288,23 +288,141 @@ def _tree_edit_distance(self, tree1: Dict, tree2: Dict) -> float:

return self._list_edit_distance(children1, children2)
else:
# Nodes are different, calculate minimum cost
# Option 1: Replace tree1 with tree2
cost_replace = 1.0 + self._list_edit_distance(
tree1.get('children', []),
tree2.get('children', [])
)
# 检查结构是否相同(忽略文本内容)
if self._structure_equal(tree1, tree2):
# 结构相同,内容不同,使用内容编辑距离
content_distance = self._content_edit_distance(tree1, tree2)
children_cost = self._list_edit_distance(
tree1.get('children', []),
tree2.get('children', [])
)
return content_distance + children_cost
else:
# 结构不同,使用原有的删除插入策略
# Option 1: Replace tree1 with tree2
cost_replace = 1.0 + self._list_edit_distance(
tree1.get('children', []),
tree2.get('children', [])
)

# Option 2: Delete tree1 and insert tree2
cost_delete_insert = (
float(self._count_nodes(tree1)) +
float(self._count_nodes(tree2))
)

return min(cost_replace, cost_delete_insert)

def _structure_equal(self, tree1: Dict, tree2: Dict) -> bool:
"""Check if two trees have identical structure (same tag, attributes)"""
if tree1['tag'] != tree2['tag']:
return False

# Compare important attributes
attrs1 = tree1.get('attrs', {})
attrs2 = tree2.get('attrs', {})

# Check colspan and rowspan
important_attrs = ['colspan', 'rowspan']
for attr in important_attrs:
if attrs1.get(attr) != attrs2.get(attr):
return False

# 结构相同,忽略文本内容
return True

def _content_edit_distance(self, tree1: Dict, tree2: Dict) -> float:
"""Calculate content edit distance between two trees with same structure"""
if tree1['tag'] != tree2['tag']:
return 1.0 # 标签不同,惩罚1分

# 如果是叶子节点(如td),计算文本内容的编辑距离
if tree1['tag'] == 'td' or not tree1.get('children'):
text1 = tree1.get('text', '')
text2 = tree2.get('text', '')

# Option 2: Delete tree1 and insert tree2
cost_delete_insert = (
float(self._count_nodes(tree1)) +
float(self._count_nodes(tree2))
)
if text1 == text2:
return 0.0 # 内容相同

return min(cost_replace, cost_delete_insert)

# 计算文本编辑距离
return self._text_edit_distance(text1, text2)

# 非叶子节点,递归计算子节点的内容编辑距离
children1 = tree1.get('children', [])
children2 = tree2.get('children', [])

return self._list_content_edit_distance(children1, children2)

def _text_edit_distance(self, text1: str, text2: str) -> float:
"""Calculate normalized edit distance between two text strings"""
if not text1 and not text2:
return 0.0
if not text1 or not text2:
return 1.0

# 计算Levenshtein编辑距离
distance = self._levenshtein_distance(text1, text2)
max_len = max(len(text1), len(text2))

# 返回归一化的编辑距离(0-1之间)
return float(distance) / max_len if max_len > 0 else 0.0

def _levenshtein_distance(self, s1: str, s2: str) -> int:
"""Calculate Levenshtein distance between two strings"""
if len(s1) < len(s2):
return self._levenshtein_distance(s2, s1)

if len(s2) == 0:
return len(s1)

previous_row = list(range(len(s2) + 1))
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row

return previous_row[-1]

def _list_content_edit_distance(self, list1: List, list2: List) -> float:
"""Calculate content edit distance between two lists of trees"""
m, n = len(list1), len(list2)

# 初始化DP矩阵
dp = [[0.0] * (n + 1) for _ in range(m + 1)]

# 基础情况
for i in range(1, m + 1):
dp[i][0] = dp[i-1][0] + self._content_edit_distance(list1[i-1], list2[0]) if n > 0 else 1.0

for j in range(1, n + 1):
dp[0][j] = dp[0][j-1] + self._content_edit_distance(list1[0], list2[j-1]) if m > 0 else 1.0

# 填充DP矩阵
for i in range(1, m + 1):
for j in range(1, n + 1):
# 内容替换成本
subst_cost = self._content_edit_distance(list1[i-1], list2[j-1])

# 删除成本
del_cost = 1.0 # 删除一个节点的内容成本

# 插入成本
ins_cost = 1.0 # 插入一个节点的内容成本

dp[i][j] = min(
dp[i-1][j-1] + subst_cost, # 替换
dp[i-1][j] + del_cost, # 删除
dp[i][j-1] + ins_cost # 插入
)

return dp[m][n]

def _list_edit_distance(self, list1: List, list2: List) -> float:
"""Calculate edit distance between two lists of trees."""
"""Calculate edit distance between two lists of trees (for structure comparison)"""
m, n = len(list1), len(list2)

# Initialize DP matrix
Expand Down