Skip to content

Commit c12396a

Browse files
Add files via upload
1 parent ece4a99 commit c12396a

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

main.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def forward(self, x):
6060
'image_ends': torch.tensor([tokenizer('</image>')['input_ids']] * bs, dtype=torch.int),
6161
'audio_starts': torch.tensor([tokenizer('<audio>')['input_ids']] * bs, dtype=torch.int),
6262
'audio_ends': torch.tensor([tokenizer('</audio>')['input_ids']] * bs, dtype=torch.int),
63-
'input_ids' : torch.tensor([tokenizer('<text>')['input_ids']] * bs, dtype=torch.int)
63+
'input_ids' : torch.tensor([tokenizer('<text>')['input_ids']] * bs, dtype=torch.int),
64+
'input_ide': torch.tensor([tokenizer('</text>')['input_ids']] * bs, dtype=torch.int)
6465
}
6566

6667
inputs = {k: inputs[k].to(device) for k in inputs}
@@ -98,7 +99,7 @@ def forward(self, x):
9899

99100
embed_tokens = nn.Embedding(whisper_config.vocab_size, 256).to(device)
100101

101-
text_embeddings = embed_tokens(inputs['input_ids'])
102+
text_embeddings = embed_tokens(inputs['input_ids']) # (1,3,256)
102103

103104
token_embeddings = embed_tokens.weight.unsqueeze(0).repeat(
104105
text_embeddings.size(0), 1, 1).transpose(0, 1)
@@ -114,11 +115,9 @@ def forward(self, x):
114115

115116

116117
audio_inputs = torch.cat([torch.cat([audio_starts, audio_features], dim=1), audio_ends], dim=1)
118+
# (1,1504,256)
117119

118-
text_embeddings = torch.cat(
119-
[torch.cat([text_embeddings[:, 0, :].unsqueeze(1), audio_inputs], dim=1), text_embeddings[:, 1:, :]],
120-
dim=1)
121-
120+
text_embeddings = torch.cat([text_embeddings, audio_inputs], dim=1)
122121
# torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)
123122
# (1, 1507, 256)
124123

@@ -134,13 +133,11 @@ def forward(self, x):
134133
image_inputs = torch.cat([torch.cat([image_starts, image_features], dim=1), image_ends], dim=1)
135134

136135
text_embeddings = torch.cat(
137-
[torch.cat([text_embeddings[:, 0, :].unsqueeze(1), image_inputs], dim=1),
138-
text_embeddings[:, 1:, :]], dim=1)
136+
[torch.cat([text_embeddings, image_inputs], dim=1),
137+
embed_tokens(inputs['input_ide'])], dim=1)
139138

140139
# torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)
141-
# (1, 1561, 256)
142-
143-
pdb.set_trace()
140+
# (1, 1564, 256)
144141

145142
batch_size = 1
146143
sequence_length = text_embeddings.shape[1]
@@ -151,7 +148,3 @@ def forward(self, x):
151148
input_tensor = text_embeddings
152149
output = model(input_tensor)
153150
print(output.shape) # Should be [batch_size, num_classes]
154-
155-
156-
157-

0 commit comments

Comments
 (0)