Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,13 @@ dist
build

.DS_store
CLIP

ljs
dataflow/example/video_caption/*.mp4
dataflow/example/video_cot_qa/*.mp4
dataflow/example/video_cot_qa/*.png
dataflow/example/video_split/*.mp4

requirements_zyh.txt
selected_samples_output
74 changes: 0 additions & 74 deletions dataflow/example/video/sample_data.json

This file was deleted.

11 changes: 11 additions & 0 deletions dataflow/example/video_caption/sample_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[
{
"video": ["./DataFlow-MM/dataflow/example/video_caption/oceans.mp4"],
"conversation": [
{
"from": "human",
"value": "Please describe the video in detail."
}
]
}
]
47 changes: 47 additions & 0 deletions dataflow/example/video_cot_qa/sample_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
[
{
"video": ["./DataFlow-MM/dataflow/example/video_cot_qa/ytb_7nRmsEw7nsE.mp4"],
"problem_type": "multiple choice",
"problem": "What appears on the screen in Russian during the missile's ascent?",
"options": [
"A. A YouTube subscription notification",
"B. A military command",
"C. A warning message",
"D. A weather update"
],
"data_type": "video",
"solution": "<answer>A</answer>"
},
{
"video": ["./DataFlow-MM/dataflow/example/video_cot_qa/split_8.mp4"],
"problem_type": "free-form",
"problem": "What cooking action does the person perform with the black frying pan on the right burner?",
"options": [],
"data_type": "video",
"solution": "<answer>The person cracks an egg into the black frying pan on the right burner.</answer>"
},
{
"video": ["./DataFlow-MM/dataflow/example/video_cot_qa/video_05282.mp4"],
"problem_type": "numerical",
"problem": "How many moving rubber objects are there when the video ends?",
"options": [],
"data_type": "video",
"solution": "<answer>1</answer>"
},
{
"image": ["./DataFlow-MM/dataflow/example/video_cot_qa/identity_71274.png"],
"problem_type": "OCR",
"problem": "Please write out the expression of the formula in the image using LaTeX format.",
"options": [],
"data_type": "image",
"solution": "<answer>1 4 = 4 + 2 + 8</answer>"
},
{
"image": ["./DataFlow-MM/dataflow/example/video_cot_qa/image_21088.png"],
"problem_type": "regression",
"problem": "Estimate the vertical distance between the yellow forklift with black handle and the metal shelving with white shelves. Please measure it in inches.",
"options": [],
"data_type": "image",
"solution": "<answer>9.62</answer>"
}
]
11 changes: 11 additions & 0 deletions dataflow/example/video_split/sample_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[
{
"video": ["./DataFlow-MM/dataflow/example/video_split/oceans.mp4"],
"conversation": [
{
"from": "human",
"value": "<video>\nPlease describe the video in detail."
}
]
}
]
7 changes: 5 additions & 2 deletions dataflow/operators/core_vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from .generate.prompted_image_generator import PromptedImageGenerator
from .generate.prompted_vqa_generator import PromptedVQAGenerator
from .generate.video_clip_generator import VideoClipGenerator
from .generate.video_caption_to_qa_generator import VideoCaptionToQAGenerator
from .generate.video_video_to_caption_generator import VideoToCaptionGenerator
from .generate.video_qa_generator import VideoCaptionToQAGenerator
from .generate.video_caption_generator import VideoToCaptionGenerator
from .generate.video_cotqa_generator import VideoCOTQAGenerator
from .generate.sk_vqa_generator import ImageSKVQAGenerate
from .generate.image_caprl_mcq_generator import CapRLMCQGenerate, CapRLMCQConfig

Expand All @@ -32,6 +33,7 @@
from .filter.consistency_filter import ConsistencyFilter
from .filter.text_image_diversity_filter import TextImageDiversityFilter
from .filter.sensitive_filter import SensitiveFilter
from .filter.score_filter import ScoreFilter
from .refine.vision_seg_cutout_refine import VisionSegCutoutRefine
from .filter.rule_base_filter import RuleBaseFilter
from .filter.deduplication_filter import DeduplicateFilter
Expand All @@ -46,6 +48,7 @@
from .eval.video_luminance_evaluator import VideoLuminanceEvaluator
from .eval.video_ocr_evaluator import VideoOCREvaluator
from .eval.emscore_evaluator import EMScoreEval
from .eval.general_text_answer_evaluator import GeneralTextAnswerEvaluator
from .eval.image.image_evaluator import EvalImageGenerationGenerator
from .eval.image_text.clip_image_text_evaluator import CLIPEvaluator
from .eval.image_text.long_clip_image_text_evaluator import LongCLIPEvaluator
Expand Down
207 changes: 207 additions & 0 deletions dataflow/operators/core_vision/eval/general_text_answer_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from __future__ import annotations

import re
from typing import Optional

import pandas as pd
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer

from dataflow import get_logger
from dataflow.core import OperatorABC
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow.utils.storage import DataFlowStorage


@OPERATOR_REGISTRY.register()
class GeneralTextAnswerEvaluator(OperatorABC):
"""Evaluate text-based answers across multiple question types.

This evaluator supports multiple question types with appropriate scoring metrics:
- multiple choice: Exact match (0 or 1)
- numerical: Numerical comparison with decimal handling
- OCR: Word Error Rate (WER) based scoring
- free-form: ROUGE score based evaluation
- regression: Relative difference based scoring
"""

def __init__(
self,
use_stemmer: bool = True,
) -> None:
"""Initialize the evaluator.

Args:
use_stemmer: Whether to use stemmer in ROUGE score calculation
"""
self.logger = get_logger()
self.use_stemmer = use_stemmer

@staticmethod
def get_desc(lang: str = "zh") -> str:
return "通用文本答案评估器" if lang == "zh" else "General text answer evaluator."

# ----------------------------- Helper Methods ----------------------------
@staticmethod
def extract_answer(text: str) -> str:
"""Extract content between <answer> and </answer> tags."""
pattern = r'<answer>\s*(.*?)\s*</answer>'
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1).strip()
return ""

@staticmethod
def normalize_number(num_str: str) -> Optional[float]:
"""Convert string to float, handling commas."""
try:
num_str = num_str.replace(',', '')
return float(num_str)
except Exception:
return None

@staticmethod
def wer(reference: str, hypothesis: str) -> float:
"""Calculate Word Error Rate."""
ref_words = reference.split()
hyp_words = hypothesis.split()
m = len(ref_words)
n = len(hyp_words)
d = [[0]*(n+1) for _ in range(m+1)]
for i in range(m+1):
d[i][0] = i
for j in range(n+1):
d[0][j] = j
for i in range(1, m+1):
for j in range(1, n+1):
if ref_words[i-1] == hyp_words[j-1]:
d[i][j] = d[i-1][j-1]
else:
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
return d[m][n] / max(1, m)

@staticmethod
def compute_bleu_score(reference: str, hypothesis: str) -> float:
"""Calculate BLEU score."""
try:
smoothing = SmoothingFunction().method1
ref_tokens = reference.split()
hyp_tokens = hypothesis.split()
score = sentence_bleu([ref_tokens], hyp_tokens, smoothing_function=smoothing)
return score
except Exception:
return 0.0

def compute_rouge_score(self, reference: str, hypothesis: str) -> float:
"""Calculate average ROUGE F-measure."""
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=self.use_stemmer)
scores = scorer.score(reference, hypothesis)
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
return average_fmeasure

def calculate_reward(self, gt_solution: str, model_output: str, question_type: str) -> float:
"""Calculate reward based on question type and model output.

Args:
gt_solution: Ground truth solution (with <answer> tags)
model_output: Model generated output (with <answer> tags)
question_type: Type of question (multiple choice, numerical, OCR, free-form, regression)

Returns:
Reward score between 0.0 and 1.0
"""
try:
output_ans = self.extract_answer(model_output)
gt_ans = self.extract_answer(gt_solution)

if question_type == "multiple choice":
return 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
elif question_type == "numerical":
gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
out_has_decimal = ("." in output_ans) or ("," in output_ans)
if gt_has_decimal != out_has_decimal:
return 0.0
gt_number = self.normalize_number(gt_ans)
out_number = self.normalize_number(output_ans)
if gt_number is None or out_number is None:
return 0.0
return 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
elif question_type == "OCR":
error_rate = self.wer(gt_ans, output_ans)
reward = 1 - error_rate
return max(0.0, min(1.0, reward))
elif question_type == "free-form":
score = self.compute_rouge_score(gt_ans, output_ans)
return max(0.0, min(1.0, score))
elif question_type == "regression":
gt_number = self.normalize_number(gt_ans)
out_number = self.normalize_number(output_ans)
if gt_number is None or out_number is None:
return 0.0
rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
rel_diff = min(1.0, max(0.0, rel_diff))
return 1 - rel_diff
else:
return 0.0
except Exception as e:
self.logger.error(f"Error in calculate_reward for question_type '{question_type}': {e}")
return 0.0

# ----------------------------- Execution ---------------------------------
def run(
self,
storage: DataFlowStorage,
input_model_output_key: str = "model_output",
input_gt_solution_key: str = "solution",
input_question_type_key: str = "problem_type",
output_reward_key: str = "reward",
) -> str:
"""Evaluate text answers and compute rewards.

Args:
storage: DataFlowStorage object
input_model_output_key: Column name for model outputs
input_gt_solution_key: Column name for ground truth solutions
input_question_type_key: Column name for question types
output_reward_key: Column name for output rewards

Returns:
The output_reward_key

Expected input columns:
- input_model_output_key: Model generated text with <answer> tags
- input_gt_solution_key: Ground truth with <answer> tags
- input_question_type_key: Question type (multiple choice, numerical, OCR, free-form, regression)
"""
self.logger.info("Running GeneralTextAnswerEvaluator...")

df: pd.DataFrame = storage.read("dataframe")
self.logger.info("Loaded dataframe with %d rows", len(df))

# Validate required columns
required_cols = [input_model_output_key, input_gt_solution_key, input_question_type_key]
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise KeyError(f"Missing required columns: {missing_cols}")

# Calculate rewards
rewards = []

for _, row in df.iterrows():
model_output = row.get(input_model_output_key, '')
gt_solution = row.get(input_gt_solution_key, '')
question_type = row.get(input_question_type_key, '')

reward = self.calculate_reward(gt_solution, model_output, question_type)
rewards.append(reward)

# Attach outputs
df[output_reward_key] = rewards

storage.write(df)

self.logger.info("Evaluation finished for %d rows", len(df))
self.logger.info(f"Average reward: {sum(rewards)/len(rewards):.4f}")

return output_reward_key

Loading