Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelForLM #7

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix bugs in ModelForLM and add Roberta model
  • Loading branch information
乔子卿 committed Apr 26, 2022
commit bfdda9675c82648af0df2a6095ea0ef0ea36d035
50 changes: 20 additions & 30 deletions model_center/model/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,6 @@ def __init__(self, config: BertConfig):
post_layer_norm=config.post_layer_norm,
)

self.tied = config.tied
self.cls_head = config.cls_head
if self.cls_head:
self.cls_projection = Linear(
dim_out=self.cls_head,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)
if not self.tied:
self.lm_head = BertLMHead(
dim_model=config.dim_model,
vocab_size=config.vocab_size,
norm_eps=config.norm_eps,
)

self.pooler = BertPooler(config.dim_model)

Expand Down Expand Up @@ -221,16 +202,6 @@ def forward(self,

hidden_states = self.encoder(hidden_states, attention_mask)

"""if self.cls_head:
logits = self.cls_projection(hidden_states)
elif self.tied:
logits = self.input_embedding.projection(hidden_states)
elif not self.tied:
logits = self.lm_head(hidden_states)

if return_logits:
return logits"""

pooled_output = self.pooler(hidden_states)

if not return_dict:
Expand All @@ -253,6 +224,25 @@ def __init__(self, config: BertConfig):
super().__init__()
self.bert = Bert(config)
self.seq_cls = torch.nn.Linear(config.hidden_size, 2)
self.tied = config.tied
self.cls_head = config.cls_head
if self.cls_head:
self.cls_projection = Linear(
dim_out=self.cls_head,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)
if not self.tied:
self.lm_head = BertLMHead(
dim_model=config.dim_model,
vocab_size=config.vocab_size,
norm_eps=config.norm_eps,
)

def forward(self,
input_ids=None,
Expand Down Expand Up @@ -285,7 +275,7 @@ def forward(self,
if self.cls_head:
logits = self.cls_projection(hidden_states)
elif self.tied:
logits = self.input_embedding.projection(hidden_states)
logits = self.bert.input_embedding.projection(hidden_states)
elif not self.tied:
logits = self.lm_head(hidden_states)

Expand Down
107 changes: 107 additions & 0 deletions model_center/model/config/roberta_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# coding=utf-8
# Copyright 2022 The OpenBMB team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import Config
import torch

class RobertaConfig(Config):
"""
This is a configuration class that stores the configuration of the BERT model, which inherits from the Config class.
It is used to instantiate the Bert model according to the specified parameters and define the model architecture.
You can set specific parameters to control the output of the model.

For example:
[`dim_model`] is used to determine the Dimension of the encoder layers and the pooler layer.
You can choose to use the default value of 768 or customize their dimensions.

"""

def __init__(self, vocab_size=50265,
type_size=1,
position_size=514,
dim_model=1024,
num_heads=16,
dim_head=64,
dim_ff=4096,
num_layers=24,
dropout_p=0.1,
emb_init_mean = 0.0,
emb_init_std = 1,
pos_bias_type = "none",
position_bias_max_distance = 1024,
norm_init_var = 1.0,
norm_bias = True,
norm_eps = 1e-05,
att_init_mean = 0.0,
att_init_std = 0.02,
att_bias = True,
att_mask_value = float("-1e4"),
ffn_init_mean = 0.0,
ffn_init_std = 0.02,
ffn_bias = True,
ffn_activate_fn = "gelu",
proj_init_mean = 0.0,
proj_init_std = 1,
proj_bias = True,
length_scale = False,
attn_scale = True,
half = True,
int8 = False,
tied = False,
cls_head = None,
post_layer_norm = True,
pad_token_id = 1,
):

super().__init__()

self.vocab_size = vocab_size
self.type_size = type_size
self.position_size = position_size
self.position_size = position_size
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_head = dim_head
self.dim_ff = dim_ff
self.num_layers = num_layers
self.dropout_p = dropout_p
self.emb_init_mean = emb_init_mean
self.emb_init_std = emb_init_std
self.pos_bias_type = pos_bias_type
self.position_bias_max_distance = position_bias_max_distance
self.norm_init_var = norm_init_var
self.norm_bias = norm_bias
self.norm_eps = norm_eps
self.att_init_mean = att_init_mean
self.att_init_std = att_init_std
self.att_bias = att_bias
self.att_mask_value = att_mask_value
self.ffn_init_mean = ffn_init_mean
self.ffn_init_std = ffn_init_std
self.ffn_bias = ffn_bias
self.ffn_activate_fn = ffn_activate_fn
self.proj_init_mean = proj_init_mean
self.proj_init_std = proj_init_std
self.proj_bias = proj_bias
self.length_scale = length_scale
self.attn_scale = attn_scale
self.int8 = int8
self.tied = tied
if half:
self.dtype = torch.half
else:
self.dtype = torch.float
self.cls_head = cls_head
self.post_layer_norm = post_layer_norm
self.pad_token_id = pad_token_id
50 changes: 25 additions & 25 deletions model_center/model/cpm1.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,6 @@ def __init__(self, config: CPM1Config):
init_std=config.pos_init_std,
)

self.tied = config.tied
self.cls_head = config.cls_head
if self.cls_head:
self.output_projection = Linear(
dim_out=config.cls_head,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)
elif not self.tied:
self.output_projection = Linear(
dim_out=config.vocab_size,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)

def forward(self, input: torch.Tensor, # (batch, seqlen)
length: torch.Tensor, # (batch)
Expand Down Expand Up @@ -149,6 +125,30 @@ class CPM1ForLM(BaseModel):
def __init__(self, config: CPM1Config):
super().__init__()
self.cpm1 = CPM1(config)
self.tied = config.tied
self.cls_head = config.cls_head
if self.cls_head:
self.output_projection = Linear(
dim_out=config.cls_head,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)
elif not self.tied:
self.output_projection = Linear(
dim_out=config.vocab_size,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)

def forward(self,
input: torch.Tensor, # (batch, seqlen)
Expand All @@ -169,7 +169,7 @@ def forward(self,
elif not self.tied:
logits = self.output_projection(hidden_states)
else:
logits = self.input_embedding.projection(hidden_states)
logits = self.cpm1.input_embedding.projection(hidden_states)
if labels:
_logits = logits[..., :-1, :].contiguous()
_labels = labels[..., 1:].contiguous()
Expand Down
23 changes: 12 additions & 11 deletions model_center/model/cpm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,7 @@ def __init__(self, config: CPM2Config):
init_std=config.pos_init_std,
)

self.cls_head = config.cls_head
self.output_projection = Linear(
dim_out=self.cls_head if self.cls_head else config.vocab_size,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)


def forward(self,
enc_input: torch.Tensor, # (batch, seq_enc)
Expand Down Expand Up @@ -195,6 +185,17 @@ class CPM2ForLM(BaseModel):
def __init__(self, config: CPM2Config):
super().__init__()
self.cpm2 = CPM2(config)
self.cls_head = config.cls_head
self.output_projection = Linear(
dim_out=self.cls_head if self.cls_head else config.vocab_size,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)

def forward(self,
enc_input: torch.Tensor, # (batch, seq_enc)
Expand Down
50 changes: 25 additions & 25 deletions model_center/model/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,6 @@ def __init__(self, config: GPT2Config):
init_std=config.emb_init_std,
)

self.tied = config.tied
self.cls_head = config.cls_head
if self.cls_head:
self.cls_projection = Linear(
dim_out=self.cls_head,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)
if not self.tied:
self.output_projection = Linear(
dim_out=config.vocab_size,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)

def forward(self,
input_ids=None, # (batch, seqlen)
Expand Down Expand Up @@ -192,6 +168,30 @@ class GPT2ForLM(BaseModel):
def __init__(self, config: GPT2Config):
super().__init__()
self.gpt2 = GPT2(config)
self.tied = config.tied
self.cls_head = config.cls_head
if self.cls_head:
self.cls_projection = Linear(
dim_out=self.cls_head,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)
if not self.tied:
self.output_projection = Linear(
dim_out=config.vocab_size,
dim_in=config.dim_model,
length_scale=config.length_scale,
dtype=config.dtype,
int8=config.int8,
init_mean=config.proj_init_mean,
init_std=config.proj_init_std,
bias=config.proj_bias,
)

def forward(self,
input_ids=None, # (batch, seqlen)
Expand Down Expand Up @@ -226,7 +226,7 @@ def forward(self,
if self.cls_head:
logits = self.cls_projection(hidden_states)
elif self.tied:
logits = self.input_embedding.projection(hidden_states)
logits = self.gpt2.input_embedding.projection(hidden_states)
logits[:, :, -1] = -float(
"inf") # TODO not an elegant implementation, gpt2 vocab is odd number, expand to even and ignore last
elif not self.tied:
Expand Down
Loading