Skip to content

Commit

Permalink
Merge branch 'main' into clean-config
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Aug 28, 2024
2 parents 97738c8 + 10f70fa commit 098ca6a
Show file tree
Hide file tree
Showing 3 changed files with 468 additions and 242 deletions.
163 changes: 160 additions & 3 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,170 @@
)

from trl import DPOConfig, DPOTrainer, FDivergenceType
from trl.trainer.dpo_trainer import _build_tokenized_answer, _truncate_tokens

from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft


class TestBuildTokenizedAnswer(unittest.TestCase):
def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.tokenizer.pad_token = self.tokenizer.eos_token

def test_basic_functionality(self):
prompt = "Hello, how are you?"
answer = "I'm doing well, thank you!"

result = _build_tokenized_answer(prompt, answer, tokenizer=self.tokenizer)

self.assertIn("prompt_input_ids", result)
self.assertIn("prompt_attention_mask", result)
self.assertIn("input_ids", result)
self.assertIn("attention_mask", result)

self.assertEqual(len(result["prompt_input_ids"]), len(result["prompt_attention_mask"]))
self.assertEqual(len(result["input_ids"]), len(result["attention_mask"]))

decoded_prompt = self.tokenizer.decode(result["prompt_input_ids"])
self.assertTrue(prompt in decoded_prompt)

decoded_answer = self.tokenizer.decode(result["input_ids"])
self.assertTrue(answer in decoded_answer)

def test_with_processor(self):
def mock_processor(text, images=None, add_special_tokens=True):
return {"input_ids": torch.tensor([[1, 2, 3]]), "attention_mask": torch.tensor([[1, 1, 1]])}

prompt = "Describe this image:"
answer = "A beautiful sunset over the ocean."

result = _build_tokenized_answer(prompt, answer, processor=mock_processor)

self.assertIn("prompt_input_ids", result)
self.assertIn("prompt_attention_mask", result)
self.assertIn("input_ids", result)
self.assertIn("attention_mask", result)

self.assertEqual(result["prompt_input_ids"], [1, 2, 3])
self.assertEqual(result["prompt_attention_mask"], [1, 1, 1])

def test_token_merging(self):
prompt = "The quick brown"
answer = " fox jumps over the lazy dog."

result = _build_tokenized_answer(prompt, answer, tokenizer=self.tokenizer)

full_text = prompt + answer
full_tokenized = self.tokenizer(full_text, add_special_tokens=False)

self.assertEqual(result["prompt_input_ids"] + result["input_ids"], full_tokenized["input_ids"])

def test_vision_model(self):
def mock_vision_processor(text, images=None, add_special_tokens=True):
return {
"input_ids": torch.tensor([[1, 2, 3]]),
"attention_mask": torch.tensor([[1, 1, 1]]),
"pixel_values": torch.rand(1, 3, 224, 224),
"pixel_attention_mask": torch.ones(1, 224, 224),
}

prompt = "Describe this image:"
answer = "A cat sitting on a windowsill."

result = _build_tokenized_answer(prompt, answer, processor=mock_vision_processor)

self.assertIn("prompt_pixel_values", result)
self.assertIn("prompt_pixel_attention_mask", result)
self.assertTrue(torch.is_tensor(result["prompt_pixel_values"]))
self.assertTrue(torch.is_tensor(result["prompt_pixel_attention_mask"]))


class TestTruncateTokens(unittest.TestCase):
def setUp(self):
with tempfile.TemporaryDirectory() as tmp_dir:
self.args = DPOConfig(
max_length=20, max_prompt_length=10, truncation_mode="keep_start", output_dir=tmp_dir
)

def test_truncate_tokens(self):
chosen_tokens = [
{
"prompt_input_ids": list(range(15)),
"prompt_attention_mask": [1] * 15,
"input_ids": list(range(10)),
"attention_mask": [1] * 10,
}
]
rejected_tokens = [
{
"prompt_input_ids": list(range(15)),
"prompt_attention_mask": [1] * 15,
"input_ids": list(range(12)),
"attention_mask": [1] * 12,
}
]
prompt_tokens = [{"prompt_input_ids": list(range(15)), "prompt_attention_mask": [1] * 15}]

_truncate_tokens(chosen_tokens, rejected_tokens, prompt_tokens, self.args)

# Check if prompt is truncated correctly
self.assertEqual(len(chosen_tokens[0]["prompt_input_ids"]), 10)
self.assertEqual(len(chosen_tokens[0]["prompt_attention_mask"]), 10)
self.assertEqual(len(rejected_tokens[0]["prompt_input_ids"]), 10)
self.assertEqual(len(rejected_tokens[0]["prompt_attention_mask"]), 10)
self.assertEqual(len(prompt_tokens[0]["prompt_input_ids"]), 10)
self.assertEqual(len(prompt_tokens[0]["prompt_attention_mask"]), 10)

# Check if responses are truncated correctly
self.assertEqual(len(chosen_tokens[0]["input_ids"]), 10)
self.assertEqual(len(chosen_tokens[0]["attention_mask"]), 10)
self.assertEqual(len(rejected_tokens[0]["input_ids"]), 10)
self.assertEqual(len(rejected_tokens[0]["attention_mask"]), 10)

def test_truncation_mode_keep_end(self):
self.args.truncation_mode = "keep_end"
chosen_tokens = [
{
"prompt_input_ids": list(range(15)),
"prompt_attention_mask": [1] * 15,
"input_ids": list(range(15, 25)),
"attention_mask": [1] * 10,
}
]
rejected_tokens = [
{
"prompt_input_ids": list(range(15)),
"prompt_attention_mask": [1] * 15,
"input_ids": list(range(15, 28)),
"attention_mask": [1] * 13,
}
]
prompt_tokens = [{"prompt_input_ids": list(range(15)), "prompt_attention_mask": [1] * 15}]

_truncate_tokens(chosen_tokens, rejected_tokens, prompt_tokens, self.args)

# Check if prompt is truncated correctly from the end
self.assertEqual(prompt_tokens[0]["prompt_input_ids"], list(range(5, 15)))
self.assertEqual(prompt_tokens[0]["prompt_attention_mask"], [1] * 10)

# Check if chosen tokens are truncated correctly
self.assertEqual(chosen_tokens[0]["prompt_input_ids"], list(range(5, 15)))
self.assertEqual(chosen_tokens[0]["prompt_attention_mask"], [1] * 10)
self.assertEqual(chosen_tokens[0]["input_ids"], list(range(15, 25)))
self.assertEqual(chosen_tokens[0]["attention_mask"], [1] * 10)

# Check if rejected tokens are truncated correctly
self.assertEqual(rejected_tokens[0]["prompt_input_ids"], list(range(5, 15)))
self.assertEqual(rejected_tokens[0]["prompt_attention_mask"], [1] * 10)
self.assertEqual(rejected_tokens[0]["input_ids"], list(range(15, 25)))
self.assertEqual(rejected_tokens[0]["attention_mask"], [1] * 10)

def test_invalid_truncation_mode(self):
self.args.truncation_mode = "invalid_mode"
with self.assertRaises(ValueError):
_truncate_tokens([], [], [], self.args)


class DPOTrainerTester(unittest.TestCase):
def setUp(self):
self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
Expand Down Expand Up @@ -138,9 +298,6 @@ def test_dpo_trainer(self, name, loss_type, pre_compute):
ref_model = self.t5_ref_model
tokenizer = self.t5_tokenizer

if name == "t5":
self.skipTest("For some reason t5 does not compute gradients properly on tiny models")

trainer = DPOTrainer(
model=model,
ref_model=ref_model,
Expand Down
Loading

0 comments on commit 098ca6a

Please sign in to comment.