2121)
2222
2323from ._operation import gather_forward_split_backward , reduce_forward
24- from .parallel_module import ParallelModule
24+ from .parallel_module import PaddingParallelModule , ParallelModule
2525from .utils import create_randomizer_with_offset
2626
27- __all__ = ["Embedding1D" , "VocabParallelEmbedding1D" ]
27+ __all__ = ["Embedding1D" , "VocabParallelEmbedding1D" , "PaddingEmbedding" ]
2828
2929
3030class Embedding1D (ParallelModule ):
@@ -161,7 +161,80 @@ def forward(self, input_: Tensor) -> Tensor:
161161 return output_parallel
162162
163163
164- class VocabParallelEmbedding1D (ParallelModule ):
164+ class PaddingEmbedding (PaddingParallelModule ):
165+ def __init__ (
166+ self ,
167+ num_embeddings : int ,
168+ embedding_dim : int ,
169+ padding_idx : int = None ,
170+ dtype : torch .dtype = None ,
171+ device : torch .device = None ,
172+ weight : Optional [nn .Parameter ] = None ,
173+ make_vocab_size_divisible_by : int = 64 ,
174+ * args ,
175+ ** kwargs ,
176+ ):
177+ self .num_embeddings = num_embeddings
178+ self .embedding_dim = embedding_dim
179+ self .embed_args = args
180+ self .embed_kwargs = kwargs
181+ self .padding_idx = padding_idx
182+ if num_embeddings % make_vocab_size_divisible_by != 0 :
183+ self .num_embeddings = (
184+ num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by )
185+ )
186+ # create weight and bias
187+ if weight is None :
188+ factory_kwargs = {"device" : device , "dtype" : dtype }
189+ weight = nn .Parameter (torch .empty ((num_embeddings , self .embedding_dim ), ** factory_kwargs ))
190+ else :
191+ weight .data = weight .data .to (device = device , dtype = dtype )
192+
193+ super ().__init__ (self .num_embeddings , num_embeddings , weight )
194+
195+ if weight is None :
196+ self .reset_parameters ()
197+
198+ def reset_parameters (self ) -> None :
199+ init .normal_ (self .weight )
200+ self ._fill_padding_idx_with_zero ()
201+
202+ def _fill_padding_idx_with_zero (self ) -> None :
203+ if self .padding_idx is not None :
204+ with torch .no_grad ():
205+ self .weight [self .padding_idx ].fill_ (0 )
206+
207+ def forward (self , input : Tensor ) -> Tensor :
208+ return F .embedding (input , self .weight , self .padding_idx , * self .embed_args , ** self .embed_kwargs )
209+
210+ @staticmethod
211+ def from_native_module (
212+ module : nn .Embedding , process_group : Union [ProcessGroup , List [ProcessGroup ]], * args , ** kwargs
213+ ) -> PaddingParallelModule :
214+ r"""
215+ Convert a native pytorch embedding module to a parallel module.
216+ """
217+ LazyInitContext .materialize (module )
218+ # get the origin attributes
219+ num_embeddings = module .num_embeddings
220+ embedding_dim = module .embedding_dim
221+ padding_idx = module .padding_idx
222+ device = module .weight .device
223+ # create the parallel module
224+ padding_embedding = PaddingEmbedding (
225+ num_embeddings = num_embeddings ,
226+ embedding_dim = embedding_dim ,
227+ padding_idx = padding_idx ,
228+ device = device ,
229+ weight = module .weight ,
230+ * args ,
231+ ** kwargs ,
232+ )
233+
234+ return padding_embedding
235+
236+
237+ class VocabParallelEmbedding1D (PaddingParallelModule ):
165238 r"""Embedding parallelized in the vocabulary dimension.
166239
167240 Args:
@@ -201,10 +274,10 @@ def __init__(
201274 process_group : ProcessGroup = None ,
202275 weight : Optional [nn .Parameter ] = None ,
203276 weight_initializer : Callable = init .normal_ (),
277+ make_vocab_size_divisible_by : int = 64 ,
204278 * args ,
205279 ** kwargs ,
206280 ):
207- super ().__init__ ()
208281 self .num_embeddings = num_embeddings
209282 self .embedding_dim = embedding_dim
210283 self .embed_args = args
@@ -214,8 +287,23 @@ def __init__(
214287 tensor_parallel_size = dist .get_world_size (group = process_group )
215288 tensor_parallel_rank = dist .get_rank (group = process_group )
216289
217- self .num_embeddings_per_partition = divide (num_embeddings , tensor_parallel_size )
218- self .num_embeddings = self .num_embeddings_per_partition
290+ # generate weight and bias
291+ if weight is None :
292+ factory_kwargs = {"device" : device , "dtype" : dtype }
293+ weight = nn .Parameter (torch .empty ((num_embeddings , self .embedding_dim ), ** factory_kwargs ))
294+ else :
295+ weight .data = weight .data .to (device = device , dtype = dtype )
296+
297+ # calculate new padding size
298+ multiple = make_vocab_size_divisible_by * tensor_parallel_size
299+ if num_embeddings % multiple != 0 :
300+ self .num_embeddings = num_embeddings + multiple - (num_embeddings % multiple )
301+
302+ # resize vocabulary size
303+ super ().__init__ (self .num_embeddings , num_embeddings , weight )
304+
305+ # deal with tensor parallelism
306+ self .num_embeddings_per_partition = divide (self .num_embeddings , tensor_parallel_size )
219307 self .vocab_start_index = tensor_parallel_rank * self .num_embeddings_per_partition
220308 self .vocab_end_index = self .vocab_start_index + self .num_embeddings_per_partition
221309
@@ -226,13 +314,6 @@ def __init__(
226314 seed = torch .random .initial_seed ()
227315 self .randomizer = create_randomizer_with_offset (seed , process_group = self .process_group )
228316
229- # parameter
230- if weight is None :
231- factory_kwargs = {"device" : device , "dtype" : dtype }
232- self .weight = nn .Parameter (torch .empty ((num_embeddings , self .embedding_dim ), ** factory_kwargs ))
233- else :
234- weight .data = weight .data .to (device = device , dtype = dtype )
235- self .weight = weight
236317 if not is_distributed_tensor (self .weight ):
237318 sharded_weight = shard_rowwise (self .weight .data , process_group )
238319 sharded_tensor_to_existing_param (sharded_weight , self .weight )
@@ -243,7 +324,7 @@ def __init__(
243324 @staticmethod
244325 def from_native_module (
245326 module : nn .Embedding , process_group : Union [ProcessGroup , List [ProcessGroup ]], * args , ** kwargs
246- ) -> ParallelModule :
327+ ) -> PaddingParallelModule :
247328 r"""
248329 Convert a native pytorch embedding module to a parallel module.
249330 """
@@ -303,11 +384,9 @@ def forward(self, input_: Tensor) -> Tensor:
303384 # Mask the input.
304385 masked_input = input_ .clone () - self .vocab_start_index
305386 masked_input [input_mask ] = 0
306-
307387 output_parallel = F .embedding (
308388 masked_input , self .weight , self .padding_idx , * self .embed_args , ** self .embed_kwargs
309389 )
310-
311390 # Mask the output embedding.
312391 embedding_output = output_parallel .clone ()
313392 embedding_output [input_mask , :] = 0.0
0 commit comments