Skip to content

Commit

Permalink
updated charades_sta_dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saleemhamo committed Jul 23, 2024
1 parent ab64a61 commit dd153cb
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions models/fine_grained/data_loaders/charades_sta_dataset.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
# models/fine_grained/data_loaders/charades_sta_dataset.py
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset
from utils.logger import setup_logger

logger = setup_logger('charades_sta_dataset_logger')


class CharadesSTADatasetFineGrained(Dataset):
def __init__(self, annotations, video_dir):
def __init__(self, annotations, video_dir, target_size=(224, 224)):
self.annotations = annotations
self.video_dir = video_dir
self.target_size = target_size
logger.info(f"Initialized CharadesSTADataset with {len(annotations)} annotations.")

def __len__(self):
return len(self.annotations)

def __getitem__(self, idx):
annotation = self.annotations[idx]
video_path = self.get_video_path(annotation['video_name'])
video_frames = self.load_video_frames(video_path)
text_sentence = annotation['sentence']
video_features = self.load_and_preprocess_video(self.get_video_path(annotation['video_name']))
text_features = annotation['sentence']
label = torch.tensor(1) # Assuming all pairs are positive examples for this task
return video_frames, text_sentence, label
return video_features, text_features, label

def get_video_path(self, video_name):
return os.path.join(self.video_dir, f"{video_name}.mp4")

def load_video_frames(self, video_path):
# Placeholder function to load video frames
# Replace with actual video frame extraction logic
frames = []
# Example logic for loading video frames using OpenCV
import cv2
def load_and_preprocess_video(self, video_path):
cap = cv2.VideoCapture(video_path)
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame = cv2.resize(frame, self.target_size)
frames.append(frame)
cap.release()
return frames
return np.array(frames)

0 comments on commit dd153cb

Please sign in to comment.