-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Open
Description
Overview
This is the definition of the SpeechT5RelativePositionalEncoding class in transformers/src/transformers/models/speecht5/modeling_speecht5.py file:
class SpeechT5RelativePositionalEncoding(torch.nn.Module):
def __init__(self, dim, max_length=1000):
super().__init__()
self.dim = dim
self.max_length = max_length
self.pe_k = torch.nn.Embedding(2 * max_length, dim)
def forward(self, hidden_states):
seq_len = hidden_states.shape[1]
pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_seq[pos_seq < -self.max_length] = -self.max_length
pos_seq[pos_seq >= self.max_length] = self.max_length - 1
pos_seq = pos_seq + self.max_length
return self.pe_k(pos_seq)
In the forward pass, lines that can create empty tensors are the ones using advanced conditional indexing:
pos_seq[pos_seq < -self.max_length] = -self.max_length
pos_seq[pos_seq >= self.max_length] = self.max_length - 1
While this works with torch with 'cuda', this can be problematic for other AI accelerators like Tenstorrent AI accelerator chips.
Proposed solution
This issue can be easily avoided by using the equivalent torch.where syntax that yields the same results, but doesn't result in empty tensors. The above problematic lines can be switched with:
pos_seq = torch.where(pos_seq < -self.max_length, -self.max_length, pos_seq)
pos_seq = torch.where(pos_seq >= self.max_length, self.max_length - 1, pos_seq)
Metadata
Metadata
Assignees
Labels
No labels