Skip to content

Commit

Permalink
updated text_feature_extractor.py and train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saleemhamo committed Jul 23, 2024
1 parent c8a8300 commit 0cfe658
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 54 deletions.
6 changes: 4 additions & 2 deletions models/fine_grained/components/text_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def __init__(self):
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def extract_features(self, text, device):
# Check if input is already tokenized
if isinstance(text, torch.Tensor):
text = text.tolist()
inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True).to(device)
inputs = text.to(device)
else:
inputs = self.processor(text=[text], return_tensors="pt", padding=True, truncation=True).to(device)
outputs = self.model.get_text_features(**inputs)
return outputs
76 changes: 24 additions & 52 deletions models/fine_grained/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# models/fine_grained/train.py
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from data.charades_sta import CharadesSTA
from utils.config import Config
from utils.model_utils import save_model, get_device
Expand All @@ -15,21 +17,32 @@
from models.fine_grained.components.qd_detr import QDDETRModel
from models.fine_grained.data_loaders.charades_sta_dataset import CharadesSTADatasetFineGrained
from transformers import BertTokenizer, CLIPProcessor
import numpy as np
import os

# Setup logger
logger = setup_logger('train_logger')

# Initialize config
config = Config()
# Initialize tokenizers
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


def collate_fn(batch):
video_frames, text_sentences, labels = zip(*batch)

# Pad video frames
video_frames_padded = pad_sequence([torch.tensor(np.array(v)) for v in video_frames], batch_first=True)

# Initialize tokenizer based on config
if config.fine_grained_text_extractor == 'bert':
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
elif config.fine_grained_text_extractor == 'clip':
tokenizer = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
else:
raise ValueError("Invalid text_extractor value in config")
# Convert text sentences to strings if they are tensors
if isinstance(text_sentences[0], torch.Tensor):
text_sentences = [text_sentence.tolist() for text_sentence in text_sentences]
text_sentences = [" ".join(map(str, text_sentence)) for text_sentence in text_sentences]

# Pad text sentences
text_sentences_padded = pad_sequence([torch.tensor(np.array(t)) for t in text_sentences], batch_first=True)

labels = torch.tensor(labels)
return video_frames_padded, text_sentences_padded, labels


def fine_grained_retrieval(train_loader, config):
Expand Down Expand Up @@ -64,8 +77,6 @@ def fine_grained_retrieval(train_loader, config):
for video_frames, text_sentence, labels in train_loader:
video_frames, text_sentence, labels = video_frames.to(device), text_sentence.to(device), labels.to(device)

logger.info(f"Processing text_sentence: {text_sentence}")

# Extract features
enhanced_text_features = text_extractor.extract_features(text_sentence, device)
enhanced_video_features = video_extractor.extract_features(video_frames, device)
Expand All @@ -88,48 +99,9 @@ def fine_grained_retrieval(train_loader, config):
return detector


def pad_sequence(sequences, batch_first=False, padding_value=0.0):
"""Pad a list of sequences to the same length."""
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max([s.size(0) for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims

out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# Use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor

return out_tensor


def collate_fn(batch):
video_frames, text_sentences, labels = zip(*batch)

# Tokenize text sentences
if config.fine_grained_text_extractor == 'bert':
text_sentences = [tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=512)[
'input_ids'].squeeze(0) for text in text_sentences]
elif config.fine_grained_text_extractor == 'clip':
text_sentences = [tokenizer(text=[text], return_tensors='pt')['input_ids'].squeeze(0) for text in
text_sentences]

# Convert video frames and text sentences to numpy arrays first
video_frames_padded = pad_sequence([torch.tensor(np.array(v)) for v in video_frames], batch_first=True)
text_sentences_padded = pad_sequence(text_sentences, batch_first=True)
labels = torch.tensor(labels)
return video_frames_padded, text_sentences_padded, labels


def main():
logger.info("Loading configuration.")
config = Config()
charades_sta = CharadesSTA(
video_dir=CHARADES_VIDEOS_DIR,
train_file=CHARADES_ANNOTATIONS_TRAIN,
Expand Down

0 comments on commit 0cfe658

Please sign in to comment.