Skip to content

Commit

Permalink
chatglm2 beam search fix (#7012)
Browse files Browse the repository at this point in the history
* chatglm2 beam search fix

* changes

---------

Co-authored-by: εˆ˜ζ±€ <wtmlon@foxmail.com>
  • Loading branch information
sijunhe and wtmlon authored Sep 12, 2023
1 parent 45d4ee8 commit e49842c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 23 deletions.
11 changes: 5 additions & 6 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,9 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f

return input_ids[:, origin_len:], scores

def reorder_cache(self, cache, beam_idx):
cache = map_structure(lambda x: paddle.index_select(x, beam_idx), cache)

def beam_search(
self,
input_ids,
Expand Down Expand Up @@ -1623,9 +1626,7 @@ def beam_search(
cache_name = "cache" if "cache" in model_kwargs else "past_key_values"
if model_kwargs[cache_name] is not None:
# reorder the cache
model_kwargs[cache_name] = map_structure(
lambda x: paddle.index_select(x, beam_idx), model_kwargs[cache_name]
)
self.reorder_cache(model_kwargs[cache_name], beam_idx)

pred_ids, scores = beam_scorer.finalize(
input_ids,
Expand Down Expand Up @@ -1773,9 +1774,7 @@ def group_beam_search(
cache_name = "cache" if "cache" in model_kwargs else "past_key_values"
if model_kwargs[cache_name] is not None:
# reorder the cache
model_kwargs[cache_name] = map_structure(
lambda x: paddle.index_select(x, reordering_indices), model_kwargs[cache_name]
)
self.reorder_cache(model_kwargs[cache_name], beam_idx)

pred_ids, scores = beam_scorer.finalize(
input_ids,
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/transformers/chatglm_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.distributed.fleet.utils import recompute
from paddle.utils import map_structure

from .. import PretrainedModel, register_base_model
from ..model_outputs import (
Expand Down Expand Up @@ -765,6 +766,9 @@ def __init__(self, config: ChatGLMv2Config):
self.max_sequence_length = config.max_sequence_length
self.chatglm_v2 = ChatGLMv2Model(config)

def reorder_cache(self, cache: paddle.Tensor, beam_idx):
cache = map_structure(lambda x: paddle.index_select(x, beam_idx, axis=1), cache)

def update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
Expand Down
25 changes: 8 additions & 17 deletions tests/transformers/chatglm_v2/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import paddle
from parameterized import parameterized_class

from paddlenlp.transformers import ChatGLMv2Config, ChatGLMv2ForCausalLM, ChatGLMv2Model
from tests.transformers.test_generation_utils import GenerationTesterMixin
Expand All @@ -24,8 +25,6 @@
random_attention_mask,
)

# from parameterized import parameterized_class


class ChatGLMv2Tester:
def __init__(
Expand Down Expand Up @@ -172,13 +171,13 @@ def create_and_check_model_attention_mask(self, config: ChatGLMv2Config, input_i
self.parent.assertTrue((result_2d[attn_mask_2d] == result_no_attention_mask[attn_mask_2d]).all())


# @parameterized_class(
# ("return_dict", "use_labels"),
# [
# [False, True],
# [True, False],
# ],
# )
@parameterized_class(
("return_dict", "use_labels"),
[
[False, True],
[True, False],
],
)
class ChatGLMv2Test(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
base_model_class = ChatGLMv2Model
return_dict: bool = True
Expand Down Expand Up @@ -220,14 +219,6 @@ def test_model_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_attention_mask(*config_and_inputs)

# chatglm_v2 cannot use beam search temporarily
def test_beam_search_generate(self):
pass

# chatglm_v2 cannot use group beam search temporarily
def test_group_beam_search_generate(self):
pass


# class ChatGLMV2GenerationD2STest(GenerationD2STestMixin, unittest.TestCase):
# internal_testing_model = "__internal_testing__/tiny-random-chatglm2"
Expand Down

0 comments on commit e49842c

Please sign in to comment.