Skip to content

Commit caf5e36

Browse files
blbadgergante
andauthored
Contrastive Search peak memory reduction (#24120)
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
1 parent aa1b09c commit caf5e36

File tree

3 files changed

+147
-31
lines changed

3 files changed

+147
-31
lines changed

src/transformers/generation/configuration_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ class GenerationConfig(PushToHubMixin):
189189
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
190190
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
191191
prompt, usually at the expense of poorer quality.
192+
low_memory (`bool`, *optional*):
193+
Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search.
194+
192195
193196
> Parameters that define the output variables of `generate`
194197
@@ -270,6 +273,7 @@ def __init__(self, **kwargs):
270273
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
271274
self.sequence_bias = kwargs.pop("sequence_bias", None)
272275
self.guidance_scale = kwargs.pop("guidance_scale", None)
276+
self.low_memory = kwargs.pop("low_memory", None)
273277

274278
# Parameters that define the output variables of `generate`
275279
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)

src/transformers/generation/utils.py

Lines changed: 100 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,6 +1569,7 @@ def generate(
15691569
return_dict_in_generate=generation_config.return_dict_in_generate,
15701570
synced_gpus=synced_gpus,
15711571
streamer=streamer,
1572+
sequential=generation_config.low_memory,
15721573
**model_kwargs,
15731574
)
15741575

@@ -1832,6 +1833,7 @@ def contrastive_search(
18321833
return_dict_in_generate: Optional[bool] = None,
18331834
synced_gpus: bool = False,
18341835
streamer: Optional["BaseStreamer"] = None,
1836+
sequential: Optional[bool] = None,
18351837
**model_kwargs,
18361838
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
18371839
r"""
@@ -1882,6 +1884,8 @@ def contrastive_search(
18821884
streamer (`BaseStreamer`, *optional*):
18831885
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
18841886
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
1887+
sequential (`bool`, *optional*):
1888+
Switches topk hidden state computation from parallel to sequential to reduce memory if True.
18851889
model_kwargs:
18861890
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
18871891
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
@@ -1921,6 +1925,7 @@ def contrastive_search(
19211925
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
19221926
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
19231927
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
1928+
sequential = sequential if sequential is not None else self.generation_config.low_memory
19241929
if isinstance(eos_token_id, int):
19251930
eos_token_id = [eos_token_id]
19261931
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
@@ -1986,6 +1991,7 @@ def contrastive_search(
19861991
last_hidden_states = outputs.decoder_hidden_states[-1]
19871992
else:
19881993
last_hidden_states = outputs.hidden_states[-1]
1994+
19891995
# next logit for contrastive search to select top-k candidate tokens
19901996
logit_for_next_step = outputs.logits[:, -1, :]
19911997

@@ -1995,11 +2001,11 @@ def contrastive_search(
19952001
is_encoder_decoder=self.config.is_encoder_decoder,
19962002
standardize_cache_format=True,
19972003
)
1998-
1999-
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
2000-
_, model_kwargs = self._expand_inputs_for_generation(
2001-
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
2002-
)
2004+
if not sequential:
2005+
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
2006+
_, model_kwargs = self._expand_inputs_for_generation(
2007+
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
2008+
)
20032009

20042010
past_key_values = model_kwargs.get("past_key_values")
20052011
if past_key_values is None:
@@ -2019,7 +2025,6 @@ def contrastive_search(
20192025
# contrastive_search main logic start:
20202026
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
20212027
# degeneration penalty
2022-
20232028
logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
20242029
logit_for_next_step = logits_warper(input_ids, logit_for_next_step)
20252030
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
@@ -2049,25 +2054,74 @@ def contrastive_search(
20492054
items = []
20502055
# item is either the key or the value matrix
20512056
for item in layer:
2052-
items.append(item.repeat_interleave(top_k, dim=0))
2057+
if sequential:
2058+
items.append(item.repeat_interleave(1, dim=0))
2059+
else:
2060+
items.append(item.repeat_interleave(top_k, dim=0))
20532061
new_key_values.append(items)
20542062
model_kwargs["past_key_values"] = new_key_values
20552063

2056-
# compute the candidate tokens by the language model and collects their hidden_states
2057-
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
2058-
outputs = self(
2059-
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
2060-
)
2061-
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
2064+
if sequential:
2065+
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
2066+
all_last_hstates, all_hstates, all_logits = [], [], []
2067+
for i in range(top_k):
2068+
# compute the candidate tokens by the language model and collect their hidden_states
2069+
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
2070+
2071+
outputs = self(
2072+
**next_model_inputs,
2073+
return_dict=True,
2074+
output_hidden_states=True,
2075+
output_attentions=output_attentions,
2076+
)
2077+
for key in all_outputs:
2078+
all_outputs[key].append(outputs[key])
2079+
2080+
if self.config.is_encoder_decoder:
2081+
next_hidden = outputs.decoder_hidden_states[-1]
2082+
full_hidden_states = outputs.decoder_hidden_states
2083+
2084+
else:
2085+
next_hidden = outputs.hidden_states[-1]
2086+
full_hidden_states = outputs.hidden_states
2087+
2088+
all_last_hstates.append(torch.squeeze(next_hidden, 0))
2089+
all_hstates.append(full_hidden_states)
2090+
all_logits.append(outputs.logits[:, -1, :])
2091+
2092+
# stack hidden states
2093+
next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0)
2094+
final_full_hstates = [0 for i in range(len(full_hidden_states))]
2095+
for layer in range(len(full_hidden_states)):
2096+
final_full_hstates[layer] = torch.stack(
2097+
[torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0
2098+
)
2099+
full_hidden_states = tuple(final_full_hstates)
2100+
2101+
# stack logits
2102+
logits = torch.cat(all_logits, dim=0)
20622103

2063-
logits = outputs.logits[:, -1, :]
2064-
# name is different for encoder-decoder and decoder-only models
2065-
if self.config.is_encoder_decoder:
2066-
next_hidden = outputs.decoder_hidden_states[-1]
2067-
full_hidden_states = outputs.decoder_hidden_states
20682104
else:
2069-
next_hidden = outputs.hidden_states[-1]
2070-
full_hidden_states = outputs.hidden_states
2105+
# compute the candidate tokens by the language model and collect their hidden_states
2106+
# assembles top_k_ids into batch of size k
2107+
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
2108+
2109+
outputs = self(
2110+
**next_model_inputs,
2111+
return_dict=True,
2112+
output_hidden_states=True,
2113+
output_attentions=output_attentions,
2114+
)
2115+
# name is different for encoder-decoder and decoder-only models
2116+
if self.config.is_encoder_decoder:
2117+
next_hidden = outputs.decoder_hidden_states[-1]
2118+
full_hidden_states = outputs.decoder_hidden_states
2119+
else:
2120+
next_hidden = outputs.hidden_states[-1]
2121+
full_hidden_states = outputs.hidden_states
2122+
2123+
logits = outputs.logits[:, -1, :]
2124+
20712125
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
20722126

20732127
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
@@ -2089,17 +2143,32 @@ def contrastive_search(
20892143
layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
20902144
next_decoder_hidden_states += (layer,)
20912145

2092-
# select the past_key_value
2093-
new_key_values = ()
2094-
for layer in next_past_key_values:
2095-
items = ()
2096-
# item is either the key or the value matrix
2097-
for item in layer:
2098-
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
2099-
item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz]
2100-
items += (item,)
2101-
new_key_values += (items,)
2102-
next_past_key_values = new_key_values
2146+
# generate past_key_values cache of only the selected token
2147+
if sequential:
2148+
next_model_input = self.prepare_inputs_for_generation(
2149+
top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
2150+
)
2151+
2152+
selected_outputs = self(
2153+
**next_model_input,
2154+
return_dict=True,
2155+
output_hidden_states=False,
2156+
output_attentions=False,
2157+
)
2158+
next_past_key_values = selected_outputs["past_key_values"]
2159+
2160+
else:
2161+
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
2162+
new_key_values = ()
2163+
for layer in next_past_key_values:
2164+
items = ()
2165+
# item is either the key or the value matrix
2166+
for item in layer:
2167+
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
2168+
item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz]
2169+
items += (item,)
2170+
new_key_values += (items,)
2171+
next_past_key_values = new_key_values
21032172

21042173
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
21052174

tests/generation/test_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,6 +1457,49 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
14571457
for output in (output_contrastive, output_generate):
14581458
self._check_outputs(output, input_ids, model.config, use_cache=True)
14591459

1460+
def test_contrastive_generate_low_memory(self):
1461+
# Check that choosing 'low_memory' does not change the model output
1462+
for model_class in self.all_generative_model_classes:
1463+
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
1464+
if any(
1465+
model_name in model_class.__name__.lower()
1466+
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
1467+
):
1468+
return
1469+
1470+
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
1471+
1472+
# NOTE: contrastive search only works with cache on at the moment.
1473+
if not hasattr(config, "use_cache"):
1474+
return
1475+
1476+
config.use_cache = True
1477+
config.is_decoder = True
1478+
1479+
# test output equality of low versus high memory
1480+
model = model_class(config).to(torch_device).eval()
1481+
1482+
low_output = model.generate(
1483+
input_ids,
1484+
top_k=4,
1485+
penalty_alpha=0.6,
1486+
low_memory=True,
1487+
max_length=max_length,
1488+
attention_mask=attention_mask,
1489+
)
1490+
1491+
high_output = model.generate(
1492+
input_ids,
1493+
top_k=4,
1494+
penalty_alpha=0.6,
1495+
low_memory=False,
1496+
max_length=max_length,
1497+
attention_mask=attention_mask,
1498+
)
1499+
self.assertListEqual(low_output.tolist(), high_output.tolist())
1500+
1501+
return
1502+
14601503
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
14611504
def test_assisted_decoding_matches_greedy_search(self):
14621505
# This test ensures that the assisted generation does not introduce output changes over greedy search.

0 commit comments

Comments
 (0)