Skip to content

Commit

Permalink
updated train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saleemhamo committed Jul 23, 2024
1 parent 5707106 commit 9f95e4a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions models/fine_grained/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ def collate_fn(batch):

# Tokenize text sentences
if config.fine_grained_text_extractor == 'bert':
text_sentences = [tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=512)['input_ids'].squeeze(0) for text in text_sentences]
text_sentences = [tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=512)[
'input_ids'].squeeze(0) for text in text_sentences]
elif config.fine_grained_text_extractor == 'clip':
text_sentences = [tokenizer(texts=[text], return_tensors='pt')['input_ids'].squeeze(0) for text in text_sentences]
text_sentences = [tokenizer(text=[text], return_tensors='pt')['input_ids'].squeeze(0) for text in
text_sentences]

# Convert video frames and text sentences to numpy arrays first
video_frames_padded = pad_sequence([torch.tensor(np.array(v)) for v in video_frames], batch_first=True)
Expand Down

0 comments on commit 9f95e4a

Please sign in to comment.