Skip to content

Commit

Permalink
Update Evaluate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunjp authored Jul 17, 2020
1 parent f0a7369 commit f47ec1a
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions Evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import copy
import time
from model.utils import DataLoader
from model.Reconstruction import *
from model.final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1 import *
from sklearn.metrics import roc_auc_score
from utils import *
import random
import glob

import argparse

Expand Down Expand Up @@ -79,7 +80,7 @@
# Loading the trained model
model = torch.load(args.model_dir)
model.cuda()
m_items = torch.load(args.m_itmes_dir)
m_items = torch.load(args.m_items_dir)
labels = np.load('./data/frame_labels_'+args.dataset_type+'.npy')

videos = OrderedDict()
Expand Down Expand Up @@ -122,7 +123,7 @@

imgs = Variable(imgs).cuda()

outputs, feas, updated_feas, m_items_test, softmax_score_query, softmax_score_memory, compactness_loss = model.forward(imgs[:,0:3*4], m_items_test, False)
outputs, feas, updated_feas, m_items_test, softmax_score_query, softmax_score_memory, _, _, _, compactness_loss = model.forward(imgs[:,0:3*4], m_items_test, False)
mse_imgs = torch.mean(loss_func_mse((outputs[0]+1)/2, (imgs[0,3*4:]+1)/2)).item()
mse_feas = compactness_loss.item()

Expand Down

0 comments on commit f47ec1a

Please sign in to comment.