@@ -60,7 +60,8 @@ def forward(self, x):
6060 'image_ends' : torch .tensor ([tokenizer ('</image>' )['input_ids' ]] * bs , dtype = torch .int ),
6161 'audio_starts' : torch .tensor ([tokenizer ('<audio>' )['input_ids' ]] * bs , dtype = torch .int ),
6262 'audio_ends' : torch .tensor ([tokenizer ('</audio>' )['input_ids' ]] * bs , dtype = torch .int ),
63- 'input_ids' : torch .tensor ([tokenizer ('<text>' )['input_ids' ]] * bs , dtype = torch .int )
63+ 'input_ids' : torch .tensor ([tokenizer ('<text>' )['input_ids' ]] * bs , dtype = torch .int ),
64+ 'input_ide' : torch .tensor ([tokenizer ('</text>' )['input_ids' ]] * bs , dtype = torch .int )
6465 }
6566
6667 inputs = {k : inputs [k ].to (device ) for k in inputs }
@@ -98,7 +99,7 @@ def forward(self, x):
9899
99100 embed_tokens = nn .Embedding (whisper_config .vocab_size , 256 ).to (device )
100101
101- text_embeddings = embed_tokens (inputs ['input_ids' ])
102+ text_embeddings = embed_tokens (inputs ['input_ids' ]) # (1,3,256)
102103
103104 token_embeddings = embed_tokens .weight .unsqueeze (0 ).repeat (
104105 text_embeddings .size (0 ), 1 , 1 ).transpose (0 , 1 )
@@ -114,11 +115,9 @@ def forward(self, x):
114115
115116
116117 audio_inputs = torch .cat ([torch .cat ([audio_starts , audio_features ], dim = 1 ), audio_ends ], dim = 1 )
118+ # (1,1504,256)
117119
118- text_embeddings = torch .cat (
119- [torch .cat ([text_embeddings [:, 0 , :].unsqueeze (1 ), audio_inputs ], dim = 1 ), text_embeddings [:, 1 :, :]],
120- dim = 1 )
121-
120+ text_embeddings = torch .cat ([text_embeddings , audio_inputs ], dim = 1 )
122121 # torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)
123122 # (1, 1507, 256)
124123
@@ -134,13 +133,11 @@ def forward(self, x):
134133 image_inputs = torch .cat ([torch .cat ([image_starts , image_features ], dim = 1 ), image_ends ], dim = 1 )
135134
136135 text_embeddings = torch .cat (
137- [torch .cat ([text_embeddings [:, 0 , :]. unsqueeze ( 1 ) , image_inputs ], dim = 1 ),
138- text_embeddings [:, 1 :, :] ], dim = 1 )
136+ [torch .cat ([text_embeddings , image_inputs ], dim = 1 ),
137+ embed_tokens ( inputs [ 'input_ide' ]) ], dim = 1 )
139138
140139 # torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)
141- # (1, 1561, 256)
142-
143- pdb .set_trace ()
140+ # (1, 1564, 256)
144141
145142 batch_size = 1
146143 sequence_length = text_embeddings .shape [1 ]
@@ -151,7 +148,3 @@ def forward(self, x):
151148 input_tensor = text_embeddings
152149 output = model (input_tensor )
153150 print (output .shape ) # Should be [batch_size, num_classes]
154-
155-
156-
157-
0 commit comments