Skip to content

Commit

Permalink
fix format (#6865)
Browse files Browse the repository at this point in the history
  • Loading branch information
MARD1NO authored Aug 30, 2023
1 parent 36ebbf5 commit 43d7b87
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 35 deletions.
2 changes: 1 addition & 1 deletion llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def __init__(

self.cache_kvs = [
paddle.zeros(shape, dtype=dtype)
for shape in self.model.get_cache_kvs_shape(self.model.config, config.max_batch_size)
for shape in self.model.get_cache_kvs_shape(self.model.config, config.max_batch_size, config.max_length)
]
self.pre_ids = paddle.full([config.max_batch_size, config.max_length], -1, dtype="int64")
if "chatglm" in self.architectures:
Expand Down
25 changes: 13 additions & 12 deletions paddlenlp/experimental/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,18 +294,19 @@ def forward(
new_cache = [None]
hidden_states = self.input_layernorm(hidden_states)

hidden_states, new_cache = self.transformer_block(
input_ids,
hidden_states,
cum_offsets=cum_offsets,
padding_offset=padding_offset,
attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype),
caches=cache_kvs,
rotary_embs=paddle.cast(rotary_embeds, "float32"),
rotary_emb_dims=2 if self.config.position_encoding_2d else 1,
seq_lens=seq_lens,
time_step=time_step,
)
with paddle.fluid.framework._stride_in_no_check_dy2st_diff():
hidden_states, new_cache = self.transformer_block(
input_ids,
hidden_states,
cum_offsets=cum_offsets,
padding_offset=padding_offset,
attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype),
caches=cache_kvs,
rotary_embs=paddle.cast(rotary_embeds, "float32"),
rotary_emb_dims=2 if self.config.position_encoding_2d else 1,
seq_lens=seq_lens,
time_step=time_step,
)
return (hidden_states, new_cache)

@paddle.no_grad()
Expand Down
10 changes: 5 additions & 5 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e
if cache is None:
# encoder's generation
model_kwargs["tgt_ids"] = paddle.where(just_decoder, model_kwargs["tgt_ids"], next_tokens)
if self.model.config["position_encoding_2d"] and self.model.config.position_encoding_2d is True:
if self.config["position_encoding_2d"] and self.config.position_encoding_2d is True:
tgt_pos = model_kwargs["tgt_pos"]
new_position_id = tgt_pos[:, 0, :].clone()
new_block_id = tgt_pos[:, 1, :].clone()
Expand All @@ -182,7 +182,7 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e
)
else:
model_kwargs["tgt_ids"] = next_tokens
if self.model.config["position_encoding_2d"] and self.model.config.position_encoding_2d is True:
if self.config["position_encoding_2d"] and self.config.position_encoding_2d is True:
tgt_pos = model_kwargs["tgt_pos"]
new_position_id = tgt_pos[:, 0, :].clone()
new_block_id = tgt_pos[:, 1, :].clone()
Expand Down Expand Up @@ -261,9 +261,9 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
# compute next_tokens, use paddle.top_p_sampling
logits = logits / temperature

_, next_tokens = top_p_sampling(probs, top_p)
_, next_tokens = top_p_sampling(probs, top_p, -1)

if self.model.config.tensor_parallel_degree > 1:
if self.config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(next_tokens, 0)

model_kwargs = self.update_model_kwargs_for_generation(
Expand All @@ -275,7 +275,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
batch_idx,
step_idx_ori,
"real_time_save.temp_ids",
self.model.config.tensor_parallel_rank,
self.config.tensor_parallel_rank,
)

return next_tokens, model_kwargs
Expand Down
35 changes: 18 additions & 17 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from paddlenlp.experimental.transformers.generation_utils import (
GenerationInferenceModel,
)
from paddlenlp.transformers import LlamaConfig, LlamaForCausalLM, LlamaPretrainedModel
from paddlenlp.transformers import LlamaConfig, LlamaPretrainedModel
from paddlenlp.transformers.llama.modeling import LlamaLMHead
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -200,18 +200,19 @@ def forward(

new_rope = fused_get_rotary_embedding(input_ids, position_ids, self.head_dim_shape_tensor, 0, True)

hidden_states, _ = self.transformer_block(
input_ids,
hidden_states,
cum_offsets=cum_offsets,
padding_offset=padding_offset,
attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype),
caches=cache_kvs,
seq_lens=seq_lens,
rotary_embs=new_rope,
rotary_emb_dims=1,
time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None,
)
with paddle.fluid.framework._stride_in_no_check_dy2st_diff():
hidden_states, _ = self.transformer_block(
input_ids,
hidden_states,
cum_offsets=cum_offsets,
padding_offset=padding_offset,
attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype),
caches=cache_kvs,
seq_lens=seq_lens,
rotary_embs=new_rope,
rotary_emb_dims=1,
time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None,
)
hidden_states = self.norm(hidden_states)

if output_hidden_states:
Expand Down Expand Up @@ -289,7 +290,7 @@ def set_state_dict(self, state_dict):
)


class LlamaForCausalLMInferenceModel(GenerationInferenceModel, LlamaForCausalLM):
class LlamaForCausalLMInferenceModel(GenerationInferenceModel, LlamaPretrainedModel):
"""
Dynamic Batching for LLaMA Model with pretraining tasks on top.
"""
Expand All @@ -298,7 +299,7 @@ class LlamaForCausalLMInferenceModel(GenerationInferenceModel, LlamaForCausalLM)

def __init__(self, config):
super().__init__(config)
self.model = LlamaInferenceModel(config)
self.llama = LlamaInferenceModel(config)
self.lm_head = LlamaLMHead(config)

@classmethod
Expand Down Expand Up @@ -384,7 +385,7 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.model(
outputs = self.llama(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -430,4 +431,4 @@ def forward(
def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.model.set_state_dict({k: state_dict[k] for k in state_dict.keys()})
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})

0 comments on commit 43d7b87

Please sign in to comment.