@@ -1569,6 +1569,7 @@ def generate(
15691569 return_dict_in_generate = generation_config .return_dict_in_generate ,
15701570 synced_gpus = synced_gpus ,
15711571 streamer = streamer ,
1572+ sequential = generation_config .low_memory ,
15721573 ** model_kwargs ,
15731574 )
15741575
@@ -1832,6 +1833,7 @@ def contrastive_search(
18321833 return_dict_in_generate : Optional [bool ] = None ,
18331834 synced_gpus : bool = False ,
18341835 streamer : Optional ["BaseStreamer" ] = None ,
1836+ sequential : Optional [bool ] = None ,
18351837 ** model_kwargs ,
18361838 ) -> Union [ContrastiveSearchOutput , torch .LongTensor ]:
18371839 r"""
@@ -1882,6 +1884,8 @@ def contrastive_search(
18821884 streamer (`BaseStreamer`, *optional*):
18831885 Streamer object that will be used to stream the generated sequences. Generated tokens are passed
18841886 through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
1887+ sequential (`bool`, *optional*):
1888+ Switches topk hidden state computation from parallel to sequential to reduce memory if True.
18851889 model_kwargs:
18861890 Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
18871891 If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
@@ -1921,6 +1925,7 @@ def contrastive_search(
19211925 stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList ()
19221926 pad_token_id = pad_token_id if pad_token_id is not None else self .generation_config .pad_token_id
19231927 eos_token_id = eos_token_id if eos_token_id is not None else self .generation_config .eos_token_id
1928+ sequential = sequential if sequential is not None else self .generation_config .low_memory
19241929 if isinstance (eos_token_id , int ):
19251930 eos_token_id = [eos_token_id ]
19261931 eos_token_id_tensor = torch .tensor (eos_token_id ).to (input_ids .device ) if eos_token_id is not None else None
@@ -1986,6 +1991,7 @@ def contrastive_search(
19861991 last_hidden_states = outputs .decoder_hidden_states [- 1 ]
19871992 else :
19881993 last_hidden_states = outputs .hidden_states [- 1 ]
1994+
19891995 # next logit for contrastive search to select top-k candidate tokens
19901996 logit_for_next_step = outputs .logits [:, - 1 , :]
19911997
@@ -1995,11 +2001,11 @@ def contrastive_search(
19952001 is_encoder_decoder = self .config .is_encoder_decoder ,
19962002 standardize_cache_format = True ,
19972003 )
1998-
1999- # Expands model inputs top_k times, for batched forward passes (akin to beam search).
2000- _ , model_kwargs = self ._expand_inputs_for_generation (
2001- expand_size = top_k , is_encoder_decoder = self .config .is_encoder_decoder , ** model_kwargs
2002- )
2004+ if not sequential :
2005+ # Expands model inputs top_k times, for batched forward passes (akin to beam search).
2006+ _ , model_kwargs = self ._expand_inputs_for_generation (
2007+ expand_size = top_k , is_encoder_decoder = self .config .is_encoder_decoder , ** model_kwargs
2008+ )
20032009
20042010 past_key_values = model_kwargs .get ("past_key_values" )
20052011 if past_key_values is None :
@@ -2019,7 +2025,6 @@ def contrastive_search(
20192025 # contrastive_search main logic start:
20202026 # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
20212027 # degeneration penalty
2022-
20232028 logit_for_next_step = logits_processor (input_ids , logit_for_next_step )
20242029 logit_for_next_step = logits_warper (input_ids , logit_for_next_step )
20252030 next_probs = nn .functional .softmax (logit_for_next_step , dim = - 1 )
@@ -2049,25 +2054,74 @@ def contrastive_search(
20492054 items = []
20502055 # item is either the key or the value matrix
20512056 for item in layer :
2052- items .append (item .repeat_interleave (top_k , dim = 0 ))
2057+ if sequential :
2058+ items .append (item .repeat_interleave (1 , dim = 0 ))
2059+ else :
2060+ items .append (item .repeat_interleave (top_k , dim = 0 ))
20532061 new_key_values .append (items )
20542062 model_kwargs ["past_key_values" ] = new_key_values
20552063
2056- # compute the candidate tokens by the language model and collects their hidden_states
2057- next_model_inputs = self .prepare_inputs_for_generation (top_k_ids .view (- 1 , 1 ), ** model_kwargs )
2058- outputs = self (
2059- ** next_model_inputs , return_dict = True , output_hidden_states = True , output_attentions = output_attentions
2060- )
2061- next_past_key_values = self ._extract_past_from_model_output (outputs , standardize_cache_format = True )
2064+ if sequential :
2065+ all_outputs = {key : [] for key in outputs } # defined in first loop iteration
2066+ all_last_hstates , all_hstates , all_logits = [], [], []
2067+ for i in range (top_k ):
2068+ # compute the candidate tokens by the language model and collect their hidden_states
2069+ next_model_inputs = self .prepare_inputs_for_generation (top_k_ids [:, i ].view (- 1 , 1 ), ** model_kwargs )
2070+
2071+ outputs = self (
2072+ ** next_model_inputs ,
2073+ return_dict = True ,
2074+ output_hidden_states = True ,
2075+ output_attentions = output_attentions ,
2076+ )
2077+ for key in all_outputs :
2078+ all_outputs [key ].append (outputs [key ])
2079+
2080+ if self .config .is_encoder_decoder :
2081+ next_hidden = outputs .decoder_hidden_states [- 1 ]
2082+ full_hidden_states = outputs .decoder_hidden_states
2083+
2084+ else :
2085+ next_hidden = outputs .hidden_states [- 1 ]
2086+ full_hidden_states = outputs .hidden_states
2087+
2088+ all_last_hstates .append (torch .squeeze (next_hidden , 0 ))
2089+ all_hstates .append (full_hidden_states )
2090+ all_logits .append (outputs .logits [:, - 1 , :])
2091+
2092+ # stack hidden states
2093+ next_hidden = torch .stack ([all_last_hstates [i ] for i in range (top_k )], dim = 0 )
2094+ final_full_hstates = [0 for i in range (len (full_hidden_states ))]
2095+ for layer in range (len (full_hidden_states )):
2096+ final_full_hstates [layer ] = torch .stack (
2097+ [torch .squeeze (all_hstates [i ][layer ], 0 ) for i in range (top_k )], dim = 0
2098+ )
2099+ full_hidden_states = tuple (final_full_hstates )
2100+
2101+ # stack logits
2102+ logits = torch .cat (all_logits , dim = 0 )
20622103
2063- logits = outputs .logits [:, - 1 , :]
2064- # name is different for encoder-decoder and decoder-only models
2065- if self .config .is_encoder_decoder :
2066- next_hidden = outputs .decoder_hidden_states [- 1 ]
2067- full_hidden_states = outputs .decoder_hidden_states
20682104 else :
2069- next_hidden = outputs .hidden_states [- 1 ]
2070- full_hidden_states = outputs .hidden_states
2105+ # compute the candidate tokens by the language model and collect their hidden_states
2106+ # assembles top_k_ids into batch of size k
2107+ next_model_inputs = self .prepare_inputs_for_generation (top_k_ids .view (- 1 , 1 ), ** model_kwargs )
2108+
2109+ outputs = self (
2110+ ** next_model_inputs ,
2111+ return_dict = True ,
2112+ output_hidden_states = True ,
2113+ output_attentions = output_attentions ,
2114+ )
2115+ # name is different for encoder-decoder and decoder-only models
2116+ if self .config .is_encoder_decoder :
2117+ next_hidden = outputs .decoder_hidden_states [- 1 ]
2118+ full_hidden_states = outputs .decoder_hidden_states
2119+ else :
2120+ next_hidden = outputs .hidden_states [- 1 ]
2121+ full_hidden_states = outputs .hidden_states
2122+
2123+ logits = outputs .logits [:, - 1 , :]
2124+
20712125 context_hidden = last_hidden_states .repeat_interleave (top_k , dim = 0 )
20722126
20732127 # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
@@ -2089,17 +2143,32 @@ def contrastive_search(
20892143 layer = torch .stack (torch .split (layer , top_k ))[range (batch_size ), selected_idx , :]
20902144 next_decoder_hidden_states += (layer ,)
20912145
2092- # select the past_key_value
2093- new_key_values = ()
2094- for layer in next_past_key_values :
2095- items = ()
2096- # item is either the key or the value matrix
2097- for item in layer :
2098- item = torch .stack (torch .split (item , top_k , dim = 0 )) # [B, K, num_head, seq_len, esz]
2099- item = item [range (batch_size ), selected_idx , ...] # [B, num_head, seq_len, esz]
2100- items += (item ,)
2101- new_key_values += (items ,)
2102- next_past_key_values = new_key_values
2146+ # generate past_key_values cache of only the selected token
2147+ if sequential :
2148+ next_model_input = self .prepare_inputs_for_generation (
2149+ top_k_ids [:, selected_idx ].view (- 1 , 1 ), ** model_kwargs
2150+ )
2151+
2152+ selected_outputs = self (
2153+ ** next_model_input ,
2154+ return_dict = True ,
2155+ output_hidden_states = False ,
2156+ output_attentions = False ,
2157+ )
2158+ next_past_key_values = selected_outputs ["past_key_values" ]
2159+
2160+ else :
2161+ next_past_key_values = self ._extract_past_from_model_output (outputs , standardize_cache_format = True )
2162+ new_key_values = ()
2163+ for layer in next_past_key_values :
2164+ items = ()
2165+ # item is either the key or the value matrix
2166+ for item in layer :
2167+ item = torch .stack (torch .split (item , top_k , dim = 0 )) # [B, K, num_head, seq_len, esz]
2168+ item = item [range (batch_size ), selected_idx , ...] # [B, num_head, seq_len, esz]
2169+ items += (item ,)
2170+ new_key_values += (items ,)
2171+ next_past_key_values = new_key_values
21032172
21042173 logit_for_next_step = torch .stack (torch .split (logits , top_k ))[range (batch_size ), selected_idx , :]
21052174
0 commit comments