Skip to content

Commit

Permalink
rev
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunjp committed Feb 4, 2021
1 parent a198158 commit 2616712
Show file tree
Hide file tree
Showing 4 changed files with 471 additions and 28 deletions.
45 changes: 29 additions & 16 deletions Evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import time
from model.utils import DataLoader
from model.final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1 import *
from model.Reconstruction import *
from sklearn.metrics import roc_auc_score
from utils import *
import random
Expand All @@ -36,6 +37,7 @@
parser.add_argument('--h', type=int, default=256, help='height of input images')
parser.add_argument('--w', type=int, default=256, help='width of input images')
parser.add_argument('--c', type=int, default=3, help='channel of input images')
parser.add_argument('--method', type=str, default='prediction', help='The target task for anoamly detection')
parser.add_argument('--t_length', type=int, default=5, help='length of the frame sequences')
parser.add_argument('--fdim', type=int, default=512, help='channel dimension of the features')
parser.add_argument('--mdim', type=int, default=512, help='channel dimension of the memory items')
Expand All @@ -51,8 +53,6 @@

args = parser.parse_args()

torch.manual_seed(2020)

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
if args.gpus is None:
gpus = "0"
Expand Down Expand Up @@ -83,11 +83,7 @@
model = torch.load(args.model_dir)
model.cuda()
m_items = torch.load(args.m_items_dir)


labels = np.load('./data/frame_labels_'+args.dataset_type+'.npy')
if args.dataset_type == 'shanghai':
labels = np.expand_dims(labels, 0)

videos = OrderedDict()
videos_list = sorted(glob.glob(os.path.join(test_folder, '*')))
Expand All @@ -109,7 +105,10 @@
# Setting for video anomaly detection
for video in sorted(videos_list):
video_name = video.split('/')[-1]
labels_list = np.append(labels_list, labels[0][4+label_length:videos[video_name]['length']+label_length])
if args.method == 'pred':
labels_list = np.append(labels_list, labels[0][4+label_length:videos[video_name]['length']+label_length])
else:
labels_list = np.append(labels_list, labels[0][label_length:videos[video_name]['length']+label_length])
label_length += videos[video_name]['length']
psnr_list[video_name] = []
feature_distance_list[video_name] = []
Expand All @@ -122,19 +121,33 @@
model.eval()

for k,(imgs) in enumerate(test_batch):

if k == label_length-4*(video_num+1):
video_num += 1
label_length += videos[videos_list[video_num].split('/')[-1]]['length']

if args.method == 'pred':
if k == label_length-4*(video_num+1):
video_num += 1
label_length += videos[videos_list[video_num].split('/')[-1]]['length']
else:
if k == label_length:
video_num += 1
label_length += videos[videos_list[video_num].split('/')[-1]]['length']

imgs = Variable(imgs).cuda()

if args.method == 'pred':
outputs, feas, updated_feas, m_items_test, softmax_score_query, softmax_score_memory, _, _, _, compactness_loss = model.forward(imgs[:,0:3*4], m_items_test, False)
mse_imgs = torch.mean(loss_func_mse((outputs[0]+1)/2, (imgs[0,3*4:]+1)/2)).item()
mse_feas = compactness_loss.item()

outputs, feas, updated_feas, m_items_test, softmax_score_query, softmax_score_memory, _, _, _, compactness_loss = model.forward(imgs[:,0:3*4], m_items_test, False)
mse_imgs = torch.mean(loss_func_mse((outputs[0]+1)/2, (imgs[0,3*4:]+1)/2)).item()
mse_feas = compactness_loss.item()
# Calculating the threshold for updating at the test time
point_sc = point_score(outputs, imgs[:,3*4:])

# Calculating the threshold for updating at the test time
point_sc = point_score(outputs, imgs[:,3*4:])
else:
outputs, feas, updated_feas, m_items_test, softmax_score_query, softmax_score_memory, compactness_loss = model.forward(imgs, m_items_test, False)
mse_imgs = torch.mean(loss_func_mse((outputs[0]+1)/2, (imgs[0]+1)/2)).item()
mse_feas = compactness_loss.item()

# Calculating the threshold for updating at the test time
point_sc = point_score(outputs, imgs)

if point_sc < args.th:
query = F.normalize(feas, dim=1)
Expand Down
38 changes: 26 additions & 12 deletions Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import copy
import time
from model.utils import DataLoader
from model.final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1 import *
from sklearn.metrics import roc_auc_score
from utils import *
import random
Expand All @@ -39,11 +38,11 @@
parser.add_argument('--w', type=int, default=256, help='width of input images')
parser.add_argument('--c', type=int, default=3, help='channel of input images')
parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate')
parser.add_argument('--method', type=str, default='prediction', help='The target task for anoamly detection')
parser.add_argument('--t_length', type=int, default=5, help='length of the frame sequences')
parser.add_argument('--fdim', type=int, default=512, help='channel dimension of the features')
parser.add_argument('--mdim', type=int, default=512, help='channel dimension of the memory items')
parser.add_argument('--msize', type=int, default=10, help='number of the memory items')
parser.add_argument('--alpha', type=float, default=0.6, help='weight for the anomality score')
parser.add_argument('--num_workers', type=int, default=2, help='number of workers for the train loader')
parser.add_argument('--num_workers_test', type=int, default=1, help='number of workers for the test loader')
parser.add_argument('--dataset_type', type=str, default='ped2', help='type of dataset: ped2, avenue, shanghai')
Expand All @@ -52,8 +51,6 @@

args = parser.parse_args()

torch.manual_seed(2020)

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
if args.gpus is None:
gpus = "0"
Expand Down Expand Up @@ -88,7 +85,13 @@


# Model setting
model = convAE(args.c, args.t_length, args.msize, args.fdim, args.mdim)
assert args.method == 'pred' or args.method == 'recon', 'Wrong task name'
if args.method == 'pred':
from model.final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1 import *
model = convAE(args.c, args.t_length, args.msize, args.fdim, args.mdim)
else:
from model.Reconstruction import *
model = convAE(args.c, memory_size = args.msize, feature_dim = args.fdim, key_dim = args.mdim)
params_encoder = list(model.encoder.parameters())
params_decoder = list(model.decoder.parameters())
params = params_encoder + params_decoder
Expand All @@ -98,7 +101,7 @@


# Report the training process
log_dir = os.path.join('./exp', args.dataset_type, args.exp_dir)
log_dir = os.path.join('./exp', args.dataset_type, args.method, args.exp_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
orig_stdout = sys.stdout
Expand All @@ -120,11 +123,19 @@

imgs = Variable(imgs).cuda()

outputs, _, _, m_items, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss = model.forward(imgs[:,0:12], m_items, True)
if args.method == 'pred':
outputs, _, _, m_items, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss = model.forward(imgs[:,0:12], m_items, True)

else:
outputs, _, _, m_items, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss = model.forward(imgs, m_items, True)


optimizer.zero_grad()
loss_pixel = torch.mean(loss_func_mse(outputs, imgs[:,12:]))
if args.method == 'pred':
loss_pixel = torch.mean(loss_func_mse(outputs, imgs[:,12:]))
else:
loss_pixel = torch.mean(loss_func_mse(outputs, imgs))

loss = loss_pixel + args.loss_compact * compactness_loss + args.loss_separate * separateness_loss
loss.backward(retain_graph=True)
optimizer.step()
Expand All @@ -133,15 +144,18 @@

print('----------------------------------------')
print('Epoch:', epoch+1)
print('Loss: Reconstruction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}'.format(loss_pixel.item(), compactness_loss.item(), separateness_loss.item()))
if args.method == 'pred':
print('Loss: Prediction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}'.format(loss_pixel.item(), compactness_loss.item(), separateness_loss.item()))
else:
print('Loss: Reconstruction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}'.format(loss_pixel.item(), compactness_loss.item(), separateness_loss.item()))
print('Memory_items:')
print(m_items)
print('----------------------------------------')

print('Training is finished')
# print('Training is finished')
# Save the model and the memory items
torch.save(model, os.path.join(log_dir, 'model.pth'))
torch.save(m_items, os.path.join(log_dir, 'keys.pt'))
torch.save(model, os.path.join(log_dir, 'model_%02d.pth'%(epoch)))
torch.save(m_items, os.path.join(log_dir, 'keys_%02d.pt'%(epoch)))

sys.stdout = orig_stdout
f.close()
Expand Down
Loading

0 comments on commit 2616712

Please sign in to comment.