Skip to content

Commit

Permalink
coarse grained train and evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
saleemhamo committed Jul 24, 2024
1 parent 892a8a7 commit 5dfc75d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 5 deletions.
8 changes: 4 additions & 4 deletions config.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
{
"coarse_grained": {
"clip_arch": "ViT-B/32",
"frame_extraction_interval": 30,
"learning_rate": 0.001,
"batch_size": 2,
"num_epochs": 1
"frame_extraction_interval": 5,
"learning_rate": 0.0001,
"batch_size": 16,
"num_epochs": 10
},
"fine_grained": {
"batch_size": 2,
Expand Down
72 changes: 72 additions & 0 deletions models/coarse_grained/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# evaluate.py
import torch
from torch.utils.data import DataLoader
from data.charades_sta import CharadesSTA
from models.coarse_grained.components.feature_extractor import FeatureExtractor
from models.coarse_grained.model import CoarseGrainedModel
from utils.config import Config
from utils.model_utils import load_model, get_device
from utils.constants import CHARADES_VIDEOS_DIR, CHARADES_ANNOTATIONS_TEST
from utils.logger import setup_logger
from models.coarse_grained.data_loaders.charades_sta_dataset import CharadesSTADataset # Updated import
from sklearn.metrics import accuracy_score, f1_score
import argparse

# Setup logger
logger = setup_logger('evaluate_logger')


def evaluate_model(model, test_loader, device):
model.eval()
all_labels = []
all_preds = []

with torch.no_grad():
for video_features, text_features, labels in test_loader:
video_features, text_features, labels = video_features.to(device), text_features.to(device), labels.to(
device)
outputs = model(video_features, text_features).squeeze(-1)
preds = (torch.sigmoid(outputs) > 0.5).float()
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())

accuracy = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
logger.info(f"Accuracy: {accuracy}, F1 Score: {f1}")
return accuracy, f1


def main():
parser = argparse.ArgumentParser(description="Evaluate Coarse-Grained Model")
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model file")
args = parser.parse_args()

logger.info("Loading configuration.")
config = Config()
charades_sta = CharadesSTA(
video_dir=CHARADES_VIDEOS_DIR,
test_file=CHARADES_ANNOTATIONS_TEST
)
annotations = charades_sta.get_test_data()

# Load feature extractor
logger.info("Loading feature extractor.")
feature_extractor = FeatureExtractor()

dataset = CharadesSTADataset(annotations, CHARADES_VIDEOS_DIR, feature_extractor)
test_loader = DataLoader(dataset, batch_size=config.coarse_grained['batch_size'], shuffle=False)
logger.info("Data loader created.")

device = get_device(logger)
model = CoarseGrainedModel(video_dim=512, text_dim=512, hidden_dim=512, output_dim=1).to(device)

# Load trained model weights from argument path
logger.info(f"Loading model from {args.model_path}")
load_model(model, args.model_path)

# Evaluate the model
evaluate_model(model, test_loader, device)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion models/coarse_grained/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main():
test_file=CHARADES_ANNOTATIONS_TEST
)
annotations = charades_sta.get_train_data()
annotations = annotations[:5] # For testing purposes
# annotations = annotations[:5] # For testing purposes

# Load feature extractor
logger.info("Loading feature extractor.")
Expand Down

0 comments on commit 5dfc75d

Please sign in to comment.