Skip to content

Commit e304f97

Browse files
apoorvumangganteArthurZucker
authored
Adding Prompt lookup decoding (#27775)
* MVP * fix ci * more ci * remove redundant kwarg * added and wired up PromptLookupCandidateGenerator * rebased with main, working * removed print * style fixes * fix test * fixed tests * added test for prompt lookup decoding * fixed circleci * fixed test issue * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/candidate_generator.py * Update src/transformers/generation/candidate_generator.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Joao Gante <joao@huggingface.co> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent 29a2b14 commit e304f97

File tree

4 files changed

+170
-9
lines changed

4 files changed

+170
-9
lines changed

src/transformers/generation/candidate_generator.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,98 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
226226
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
227227

228228

229+
class PromptLookupCandidateGenerator(CandidateGenerator):
230+
"""
231+
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
232+
likely continuations in the provided prompt (input_ids) itself.
233+
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
234+
235+
Args:
236+
max_matching_ngram_size (`int`):
237+
The maximum ngram size to be considered for matching in the prompt
238+
num_output_tokens (`int`):
239+
The number of tokens to be output as candidate tokens.
240+
"""
241+
242+
def __init__(
243+
self,
244+
num_output_tokens: int = 10,
245+
max_matching_ngram_size: int = 2,
246+
):
247+
self.num_output_tokens = num_output_tokens
248+
self.max_matching_ngram_size = max_matching_ngram_size
249+
250+
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
251+
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
252+
253+
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
254+
"""
255+
Fetches the candidates to be tried for the current input.
256+
257+
Args:
258+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
259+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
260+
261+
Return:
262+
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
263+
"""
264+
input_length = input_ids.size(1)
265+
266+
chosen_ids = None
267+
match_found = False
268+
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
269+
# Create sliding windows of size ngram_size
270+
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
271+
272+
# Convert ngram to a tensor for comparison
273+
ngram_tensor = input_ids[0, -ngram_size:]
274+
275+
# Find where the windows match the ngram
276+
matches = (windows == ngram_tensor).all(dim=2)
277+
278+
# Get the indices of matches
279+
match_indices = matches.nonzero(as_tuple=True)[1]
280+
281+
# Iterate through match indices to find a valid continuation
282+
for idx in match_indices:
283+
start_idx = idx + ngram_size
284+
end_idx = start_idx + self.num_output_tokens
285+
end_idx = min(end_idx, input_length)
286+
287+
if start_idx < end_idx:
288+
chosen_ids = input_ids[0, start_idx:end_idx]
289+
match_found = True
290+
break
291+
if match_found:
292+
break
293+
294+
if chosen_ids is None or len(chosen_ids) == 0:
295+
# Need to make a dummy tensor to avoid errors
296+
chosen_ids = torch.zeros((1), dtype=torch.long, device=input_ids.device)
297+
298+
# Now need extend input_ids with chosen_ids
299+
chosen_ids = chosen_ids.unsqueeze(0)
300+
candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
301+
# assisted_generation expects logits as well, but we don't have those here, so returning None
302+
return candidate_input_ids, None
303+
304+
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
305+
"""
306+
Updates the candidate generation strategy based on the outcomes.
307+
308+
Args:
309+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
310+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
311+
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
312+
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
313+
beam search or log softmax for each vocabulary token when using beam search
314+
num_matches (`int`):
315+
The number of matches between the candidate sequences and the model predictions.
316+
"""
317+
# Currently does nothing
318+
return
319+
320+
229321
def _crop_past_key_values(model, past_key_values, maximum_length):
230322
"""Crops the past key values up to a certain maximum length."""
231323
new_past = []

src/transformers/generation/configuration_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,9 @@ def __init__(self, **kwargs):
320320
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
321321
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
322322

323+
# Prompt lookup decoding
324+
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
325+
323326
# Wild card
324327
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
325328

src/transformers/generation/utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .candidate_generator import (
4141
AssistedCandidateGenerator,
4242
CandidateGenerator,
43+
PromptLookupCandidateGenerator,
4344
_crop_past_key_values,
4445
_prepare_attention_mask,
4546
_prepare_token_type_ids,
@@ -908,14 +909,19 @@ def _get_candidate_generator(
908909
"""
909910
Returns the candidate generator to be used in `assisted_generation`
910911
"""
911-
candidate_generator = AssistedCandidateGenerator(
912-
input_ids=input_ids,
913-
assistant_model=assistant_model,
914-
generation_config=generation_config,
915-
logits_processor=logits_processor,
916-
model_kwargs=model_kwargs,
917-
inputs_tensor=inputs_tensor,
918-
)
912+
if generation_config.prompt_lookup_num_tokens is not None:
913+
candidate_generator = PromptLookupCandidateGenerator(
914+
num_output_tokens=generation_config.prompt_lookup_num_tokens,
915+
)
916+
else:
917+
candidate_generator = AssistedCandidateGenerator(
918+
input_ids=input_ids,
919+
assistant_model=assistant_model,
920+
generation_config=generation_config,
921+
logits_processor=logits_processor,
922+
model_kwargs=model_kwargs,
923+
inputs_tensor=inputs_tensor,
924+
)
919925
return candidate_generator
920926

921927
def _get_logits_warper(
@@ -995,7 +1001,7 @@ def _get_generation_mode(
9951001
generation_mode = GenerationMode.BEAM_SEARCH
9961002

9971003
# Assisted generation may extend some generation modes
998-
if assistant_model is not None:
1004+
if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None:
9991005
if generation_mode in ("greedy_search", "sample"):
10001006
generation_mode = GenerationMode.ASSISTED_GENERATION
10011007
else:

tests/generation/test_utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,6 +1569,66 @@ def test_assisted_decoding_matches_greedy_search(self):
15691569
for output in (output_greedy, output_assisted):
15701570
self._check_outputs(output, input_ids, model.config, use_cache=True)
15711571

1572+
@is_flaky()
1573+
def test_prompt_lookup_decoding_matches_greedy_search(self):
1574+
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
1575+
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search
1576+
1577+
for model_class in self.all_generative_model_classes:
1578+
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
1579+
self.skipTest("Won't fix: old model with different cache format")
1580+
if any(
1581+
model_name in model_class.__name__.lower()
1582+
for model_name in [
1583+
"bigbirdpegasus",
1584+
"led",
1585+
"mega",
1586+
"speech2text",
1587+
"git",
1588+
"prophetnet",
1589+
"seamlessm4t",
1590+
"clvp",
1591+
]
1592+
):
1593+
self.skipTest("May fix in the future: need model-specific fixes")
1594+
1595+
# enable cache
1596+
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
1597+
1598+
# NOTE: assisted generation only works with cache on at the moment.
1599+
if not hasattr(config, "use_cache"):
1600+
self.skipTest("This model doesn't support caching")
1601+
1602+
config.use_cache = True
1603+
config.is_decoder = True
1604+
model = model_class(config).to(torch_device).eval()
1605+
# Sets assisted generation arguments such that:
1606+
# a) no EOS is generated, to ensure generation doesn't break early
1607+
# b) the prompt lookup tries to give the model 2 tokens, to ensure the input preparation of
1608+
# prompt lookup is correct
1609+
# c) there are at least two forward passes in the main model, to ensure the input preparation of
1610+
# the main model is correct
1611+
generation_kwargs = {
1612+
"eos_token_id": -1, # see a)
1613+
"max_new_tokens": 4, # see c)
1614+
"num_beams": 1,
1615+
"do_sample": False,
1616+
"output_scores": True,
1617+
"output_hidden_states": True,
1618+
"output_attentions": True,
1619+
"return_dict_in_generate": True,
1620+
}
1621+
1622+
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
1623+
1624+
generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b)
1625+
output_prompt_lookup = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
1626+
1627+
# The two outputs must match and their shape must be as expected
1628+
self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
1629+
for output in (output_greedy, output_prompt_lookup):
1630+
self._check_outputs(output, input_ids, model.config, use_cache=True)
1631+
15721632
def test_assisted_decoding_sample(self):
15731633
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
15741634
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with

0 commit comments

Comments
 (0)