Skip to content

SpeechT5 RelativePositionalEncoding can create empty tensors #42087

@ppadjinTT

Description

@ppadjinTT

Overview

This is the definition of the SpeechT5RelativePositionalEncoding class in transformers/src/transformers/models/speecht5/modeling_speecht5.py file:

class SpeechT5RelativePositionalEncoding(torch.nn.Module):
    def __init__(self, dim, max_length=1000):
        super().__init__()
        self.dim = dim
        self.max_length = max_length
        self.pe_k = torch.nn.Embedding(2 * max_length, dim)

    def forward(self, hidden_states):
        seq_len = hidden_states.shape[1]
        pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long)
        pos_seq = pos_seq[:, None] - pos_seq[None, :]

        pos_seq[pos_seq < -self.max_length] = -self.max_length
        pos_seq[pos_seq >= self.max_length] = self.max_length - 1
        pos_seq = pos_seq + self.max_length

        return self.pe_k(pos_seq)

In the forward pass, lines that can create empty tensors are the ones using advanced conditional indexing:

        pos_seq[pos_seq < -self.max_length] = -self.max_length
        pos_seq[pos_seq >= self.max_length] = self.max_length - 1

While this works with torch with 'cuda', this can be problematic for other AI accelerators like Tenstorrent AI accelerator chips.

Proposed solution

This issue can be easily avoided by using the equivalent torch.where syntax that yields the same results, but doesn't result in empty tensors. The above problematic lines can be switched with:

        pos_seq = torch.where(pos_seq < -self.max_length, -self.max_length, pos_seq)
        pos_seq = torch.where(pos_seq >= self.max_length, self.max_length - 1, pos_seq)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions