Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T5 model #916

Merged
merged 17 commits into from
Nov 22, 2021
2 changes: 2 additions & 0 deletions paddlenlp/transformers/t5/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,7 @@ def forward(self,
use_cache=False,
output_attentions=False,
output_hidden_states=False):
assert input_ids is not None, "input_ids can not be None"
input_shape = input_ids.shape
input_ids = input_ids.reshape(shape=[-1, input_shape[-1]])

Expand Down Expand Up @@ -1622,6 +1623,7 @@ def __init__(self, t5):
super().__init__()
self.t5 = t5
del self.t5.decoder
paddle.device.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T5EncoderModel还是参考hf的实现来写;目前这种写法会存在一个显存占用的脉冲,容易造成oom。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在换成了这种写法,可以吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个类建议按照参考T5Model的方式来初始化(传入构造T5EncoderModel的具体参数)


self.init_weights()

Expand Down