Skip to content

Commit 2e4d93e

Browse files
committed
make tensors buffers
1 parent 4b978ba commit 2e4d93e

File tree

4 files changed

+42
-33
lines changed

4 files changed

+42
-33
lines changed

src/transformers/models/codegen/modeling_codegen.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from torch import nn
2121

22+
from ... import initialization as init
2223
from ...activations import ACT2FN
2324
from ...cache_utils import Cache, DynamicCache
2425
from ...generation import GenerationMixin
@@ -69,7 +70,7 @@ class CodeGenAttention(nn.Module):
6970
def __init__(self, config, layer_idx=None):
7071
super().__init__()
7172

72-
max_positions = config.max_position_embeddings
73+
self.max_positions = config.max_position_embeddings
7374
self.attn_dropout = nn.Dropout(config.attn_pdrop)
7475
self.resid_dropout = nn.Dropout(config.resid_pdrop)
7576
self.layer_idx = layer_idx
@@ -93,8 +94,10 @@ def __init__(self, config, layer_idx=None):
9394

9495
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
9596
self.rotary_dim = config.rotary_dim
96-
pos_embd_dim = self.rotary_dim or self.embed_dim
97-
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
97+
self.pos_embd_dim = self.rotary_dim or self.embed_dim
98+
self.register_buffer(
99+
"embed_positions", create_sinusoidal_positions(self.max_positions, self.pos_embd_dim), persistent=False
100+
)
98101

99102
def _split_heads(self, x, n_head, dim_head, mp_num):
100103
reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
@@ -279,6 +282,11 @@ class CodeGenPreTrainedModel(PreTrainedModel):
279282
_skip_keys_device_placement = "past_key_values"
280283
_can_compile_fullgraph = True
281284

285+
def _init_weights(self, module):
286+
super()._init_weights(module)
287+
if isinstance(module, CodeGenAttention):
288+
init.copy_(module.embed_positions, create_sinusoidal_positions(module.max_positions, module.pos_embd_dim))
289+
282290

283291
@auto_docstring
284292
class CodeGenModel(CodeGenPreTrainedModel):

src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,17 @@ def __init__(self, config):
7474
super().__init__()
7575
self.max_len = config.max_source_positions
7676
self.d_model = config.hidden_size
77-
self.pe = None
78-
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
77+
self.register_buffer("pe", self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)), persistent=False)
7978

80-
def extend_pe(self, x):
79+
def extend_pe(self, x, pe=None):
8180
# Reset the positional encodings
82-
if self.pe is not None:
81+
if pe is not None:
8382
# self.pe contains both positive and negative parts
8483
# the length of self.pe is 2 * input_len - 1
85-
if self.pe.size(1) >= x.size(1) * 2 - 1:
86-
if self.pe.dtype != x.dtype or self.pe.device != x.device:
87-
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
88-
return
84+
if pe.size(1) >= x.size(1) * 2 - 1:
85+
if pe.dtype != x.dtype or pe.device != x.device:
86+
pe = pe.to(dtype=x.dtype, device=x.device)
87+
return pe
8988
# Suppose `i` is the position of query vector and `j` is the
9089
# position of key vector. We use positive relative positions when keys
9190
# are to the left (i>j) and negative relative positions otherwise (i<j).
@@ -106,10 +105,10 @@ def extend_pe(self, x):
106105
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
107106
pe_negative = pe_negative[1:].unsqueeze(0)
108107
pe = torch.cat([pe_positive, pe_negative], dim=1)
109-
self.pe = pe.to(device=x.device, dtype=x.dtype)
108+
return pe.to(device=x.device, dtype=x.dtype)
110109

111110
def forward(self, hidden_states: torch.Tensor):
112-
self.extend_pe(hidden_states)
111+
self.pe = self.extend_pe(hidden_states, self.pe)
113112
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
114113
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
115114
relative_position_embeddings = self.pe[:, start_idx:end_idx]

src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,18 +164,17 @@ def __init__(self, config):
164164
super().__init__()
165165
self.max_len = config.max_source_positions
166166
self.d_model = config.hidden_size
167-
self.pe = None
168-
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
167+
self.register_buffer("pe", self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)), persistent=False)
169168

170-
def extend_pe(self, x):
169+
def extend_pe(self, x, pe=None):
171170
# Reset the positional encodings
172-
if self.pe is not None:
171+
if pe is not None:
173172
# self.pe contains both positive and negative parts
174173
# the length of self.pe is 2 * input_len - 1
175-
if self.pe.size(1) >= x.size(1) * 2 - 1:
176-
if self.pe.dtype != x.dtype or self.pe.device != x.device:
177-
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
178-
return
174+
if pe.size(1) >= x.size(1) * 2 - 1:
175+
if pe.dtype != x.dtype or pe.device != x.device:
176+
pe = pe.to(dtype=x.dtype, device=x.device)
177+
return pe
179178
# Suppose `i` is the position of query vector and `j` is the
180179
# position of key vector. We use positive relative positions when keys
181180
# are to the left (i>j) and negative relative positions otherwise (i<j).
@@ -196,10 +195,10 @@ def extend_pe(self, x):
196195
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
197196
pe_negative = pe_negative[1:].unsqueeze(0)
198197
pe = torch.cat([pe_positive, pe_negative], dim=1)
199-
self.pe = pe.to(device=x.device, dtype=x.dtype)
198+
return pe.to(device=x.device, dtype=x.dtype)
200199

201200
def forward(self, hidden_states: torch.Tensor):
202-
self.extend_pe(hidden_states)
201+
self.pe = self.extend_pe(hidden_states, self.pe)
203202
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
204203
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
205204
relative_position_embeddings = self.pe[:, start_idx:end_idx]
@@ -903,6 +902,8 @@ def _init_weights(self, module):
903902
base = self.config.rotary_embedding_base
904903
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
905904
init.copy_(module.inv_freq, inv_freq)
905+
elif isinstance(module, Wav2Vec2ConformerRelPositionalEmbedding):
906+
init.copy_(module.pe, module.extend_pe(torch.tensor(0.0).expand(1, self.max_len)))
906907

907908
def _get_feat_extract_output_lengths(
908909
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None

src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,17 @@ def __init__(self, config):
116116
super().__init__()
117117
self.max_len = config.max_source_positions
118118
self.d_model = config.hidden_size
119-
self.pe = None
120-
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
119+
self.register_buffer("pe", self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)), persistent=False)
121120

122-
def extend_pe(self, x):
121+
def extend_pe(self, x, pe=None):
123122
# Reset the positional encodings
124-
if self.pe is not None:
123+
if pe is not None:
125124
# self.pe contains both positive and negative parts
126125
# the length of self.pe is 2 * input_len - 1
127-
if self.pe.size(1) >= x.size(1) * 2 - 1:
128-
if self.pe.dtype != x.dtype or self.pe.device != x.device:
129-
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
130-
return
126+
if pe.size(1) >= x.size(1) * 2 - 1:
127+
if pe.dtype != x.dtype or pe.device != x.device:
128+
pe = pe.to(dtype=x.dtype, device=x.device)
129+
return pe
131130
# Suppose `i` is the position of query vector and `j` is the
132131
# position of key vector. We use positive relative positions when keys
133132
# are to the left (i>j) and negative relative positions otherwise (i<j).
@@ -148,10 +147,10 @@ def extend_pe(self, x):
148147
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
149148
pe_negative = pe_negative[1:].unsqueeze(0)
150149
pe = torch.cat([pe_positive, pe_negative], dim=1)
151-
self.pe = pe.to(device=x.device, dtype=x.dtype)
150+
return pe.to(device=x.device, dtype=x.dtype)
152151

153152
def forward(self, hidden_states: torch.Tensor):
154-
self.extend_pe(hidden_states)
153+
self.pe = self.extend_pe(hidden_states, self.pe)
155154
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
156155
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
157156
relative_position_embeddings = self.pe[:, start_idx:end_idx]
@@ -602,6 +601,8 @@ def _init_weights(self, module):
602601
base = self.config.rotary_embedding_base
603602
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
604603
init.copy_(module.inv_freq, inv_freq)
604+
elif isinstance(module, Wav2Vec2ConformerRelPositionalEmbedding):
605+
init.copy_(module.pe, module.extend_pe(torch.tensor(0.0).expand(1, self.max_len)))
605606

606607
def _get_feat_extract_output_lengths(
607608
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None

0 commit comments

Comments
 (0)