From 2616712802af6703b2ed0979543a07a19486d58b Mon Sep 17 00:00:00 2001 From: hyunjp Date: Thu, 4 Feb 2021 13:17:21 +0900 Subject: [PATCH] rev --- Evaluate.py | 45 ++++--- Train.py | 38 ++++-- model/Memory.py | 260 ++++++++++++++++++++++++++++++++++++++++ model/Reconstruction.py | 156 ++++++++++++++++++++++++ 4 files changed, 471 insertions(+), 28 deletions(-) create mode 100644 model/Memory.py create mode 100644 model/Reconstruction.py diff --git a/Evaluate.py b/Evaluate.py index 85642cf7..9b40f8e4 100644 --- a/Evaluate.py +++ b/Evaluate.py @@ -21,6 +21,7 @@ import time from model.utils import DataLoader from model.final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1 import * +from model.Reconstruction import * from sklearn.metrics import roc_auc_score from utils import * import random @@ -36,6 +37,7 @@ parser.add_argument('--h', type=int, default=256, help='height of input images') parser.add_argument('--w', type=int, default=256, help='width of input images') parser.add_argument('--c', type=int, default=3, help='channel of input images') +parser.add_argument('--method', type=str, default='prediction', help='The target task for anoamly detection') parser.add_argument('--t_length', type=int, default=5, help='length of the frame sequences') parser.add_argument('--fdim', type=int, default=512, help='channel dimension of the features') parser.add_argument('--mdim', type=int, default=512, help='channel dimension of the memory items') @@ -51,8 +53,6 @@ args = parser.parse_args() -torch.manual_seed(2020) - os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" if args.gpus is None: gpus = "0" @@ -83,11 +83,7 @@ model = torch.load(args.model_dir) model.cuda() m_items = torch.load(args.m_items_dir) - - labels = np.load('./data/frame_labels_'+args.dataset_type+'.npy') -if args.dataset_type == 'shanghai': - labels = np.expand_dims(labels, 0) videos = OrderedDict() videos_list = sorted(glob.glob(os.path.join(test_folder, '*'))) @@ -109,7 +105,10 @@ # Setting for video anomaly detection for video in sorted(videos_list): video_name = video.split('/')[-1] - labels_list = np.append(labels_list, labels[0][4+label_length:videos[video_name]['length']+label_length]) + if args.method == 'pred': + labels_list = np.append(labels_list, labels[0][4+label_length:videos[video_name]['length']+label_length]) + else: + labels_list = np.append(labels_list, labels[0][label_length:videos[video_name]['length']+label_length]) label_length += videos[video_name]['length'] psnr_list[video_name] = [] feature_distance_list[video_name] = [] @@ -122,19 +121,33 @@ model.eval() for k,(imgs) in enumerate(test_batch): - - if k == label_length-4*(video_num+1): - video_num += 1 - label_length += videos[videos_list[video_num].split('/')[-1]]['length'] + + if args.method == 'pred': + if k == label_length-4*(video_num+1): + video_num += 1 + label_length += videos[videos_list[video_num].split('/')[-1]]['length'] + else: + if k == label_length: + video_num += 1 + label_length += videos[videos_list[video_num].split('/')[-1]]['length'] imgs = Variable(imgs).cuda() + + if args.method == 'pred': + outputs, feas, updated_feas, m_items_test, softmax_score_query, softmax_score_memory, _, _, _, compactness_loss = model.forward(imgs[:,0:3*4], m_items_test, False) + mse_imgs = torch.mean(loss_func_mse((outputs[0]+1)/2, (imgs[0,3*4:]+1)/2)).item() + mse_feas = compactness_loss.item() - outputs, feas, updated_feas, m_items_test, softmax_score_query, softmax_score_memory, _, _, _, compactness_loss = model.forward(imgs[:,0:3*4], m_items_test, False) - mse_imgs = torch.mean(loss_func_mse((outputs[0]+1)/2, (imgs[0,3*4:]+1)/2)).item() - mse_feas = compactness_loss.item() + # Calculating the threshold for updating at the test time + point_sc = point_score(outputs, imgs[:,3*4:]) - # Calculating the threshold for updating at the test time - point_sc = point_score(outputs, imgs[:,3*4:]) + else: + outputs, feas, updated_feas, m_items_test, softmax_score_query, softmax_score_memory, compactness_loss = model.forward(imgs, m_items_test, False) + mse_imgs = torch.mean(loss_func_mse((outputs[0]+1)/2, (imgs[0]+1)/2)).item() + mse_feas = compactness_loss.item() + + # Calculating the threshold for updating at the test time + point_sc = point_score(outputs, imgs) if point_sc < args.th: query = F.normalize(feas, dim=1) diff --git a/Train.py b/Train.py index f6827c43..25e701ad 100644 --- a/Train.py +++ b/Train.py @@ -20,7 +20,6 @@ import copy import time from model.utils import DataLoader -from model.final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1 import * from sklearn.metrics import roc_auc_score from utils import * import random @@ -39,11 +38,11 @@ parser.add_argument('--w', type=int, default=256, help='width of input images') parser.add_argument('--c', type=int, default=3, help='channel of input images') parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate') +parser.add_argument('--method', type=str, default='prediction', help='The target task for anoamly detection') parser.add_argument('--t_length', type=int, default=5, help='length of the frame sequences') parser.add_argument('--fdim', type=int, default=512, help='channel dimension of the features') parser.add_argument('--mdim', type=int, default=512, help='channel dimension of the memory items') parser.add_argument('--msize', type=int, default=10, help='number of the memory items') -parser.add_argument('--alpha', type=float, default=0.6, help='weight for the anomality score') parser.add_argument('--num_workers', type=int, default=2, help='number of workers for the train loader') parser.add_argument('--num_workers_test', type=int, default=1, help='number of workers for the test loader') parser.add_argument('--dataset_type', type=str, default='ped2', help='type of dataset: ped2, avenue, shanghai') @@ -52,8 +51,6 @@ args = parser.parse_args() -torch.manual_seed(2020) - os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" if args.gpus is None: gpus = "0" @@ -88,7 +85,13 @@ # Model setting -model = convAE(args.c, args.t_length, args.msize, args.fdim, args.mdim) +assert args.method == 'pred' or args.method == 'recon', 'Wrong task name' +if args.method == 'pred': + from model.final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1 import * + model = convAE(args.c, args.t_length, args.msize, args.fdim, args.mdim) +else: + from model.Reconstruction import * + model = convAE(args.c, memory_size = args.msize, feature_dim = args.fdim, key_dim = args.mdim) params_encoder = list(model.encoder.parameters()) params_decoder = list(model.decoder.parameters()) params = params_encoder + params_decoder @@ -98,7 +101,7 @@ # Report the training process -log_dir = os.path.join('./exp', args.dataset_type, args.exp_dir) +log_dir = os.path.join('./exp', args.dataset_type, args.method, args.exp_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) orig_stdout = sys.stdout @@ -120,11 +123,19 @@ imgs = Variable(imgs).cuda() - outputs, _, _, m_items, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss = model.forward(imgs[:,0:12], m_items, True) + if args.method == 'pred': + outputs, _, _, m_items, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss = model.forward(imgs[:,0:12], m_items, True) + + else: + outputs, _, _, m_items, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss = model.forward(imgs, m_items, True) optimizer.zero_grad() - loss_pixel = torch.mean(loss_func_mse(outputs, imgs[:,12:])) + if args.method == 'pred': + loss_pixel = torch.mean(loss_func_mse(outputs, imgs[:,12:])) + else: + loss_pixel = torch.mean(loss_func_mse(outputs, imgs)) + loss = loss_pixel + args.loss_compact * compactness_loss + args.loss_separate * separateness_loss loss.backward(retain_graph=True) optimizer.step() @@ -133,15 +144,18 @@ print('----------------------------------------') print('Epoch:', epoch+1) - print('Loss: Reconstruction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}'.format(loss_pixel.item(), compactness_loss.item(), separateness_loss.item())) + if args.method == 'pred': + print('Loss: Prediction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}'.format(loss_pixel.item(), compactness_loss.item(), separateness_loss.item())) + else: + print('Loss: Reconstruction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}'.format(loss_pixel.item(), compactness_loss.item(), separateness_loss.item())) print('Memory_items:') print(m_items) print('----------------------------------------') -print('Training is finished') +# print('Training is finished') # Save the model and the memory items -torch.save(model, os.path.join(log_dir, 'model.pth')) -torch.save(m_items, os.path.join(log_dir, 'keys.pt')) + torch.save(model, os.path.join(log_dir, 'model_%02d.pth'%(epoch))) + torch.save(m_items, os.path.join(log_dir, 'keys_%02d.pt'%(epoch))) sys.stdout = orig_stdout f.close() diff --git a/model/Memory.py b/model/Memory.py new file mode 100644 index 00000000..da5e88a5 --- /dev/null +++ b/model/Memory.py @@ -0,0 +1,260 @@ +import torch +import torch.autograd as ag +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +import functools +import random +from torch.nn import functional as F + +def random_uniform(shape, low, high, cuda): + x = torch.rand(*shape) + result_cpu = (high - low) * x + low + if cuda: + return result_cpu.cuda() + else: + return result_cpu + +def distance(a, b): + return torch.sqrt(((a - b) ** 2).sum()).unsqueeze(0) + +def distance_batch(a, b): + bs, _ = a.shape + result = distance(a[0], b) + for i in range(bs-1): + result = torch.cat((result, distance(a[i], b)), 0) + + return result + +def multiply(x): #to flatten matrix into a vector + return functools.reduce(lambda x,y: x*y, x, 1) + +def flatten(x): + """ Flatten matrix into a vector """ + count = multiply(x.size()) + return x.resize_(count) + +def index(batch_size, x): + idx = torch.arange(0, batch_size).long() + idx = torch.unsqueeze(idx, -1) + return torch.cat((idx, x), dim=1) + +def MemoryLoss(memory): + + m, d = memory.size() + memory_t = torch.t(memory) + similarity = (torch.matmul(memory, memory_t))/2 + 1/2 # 30X30 + identity_mask = torch.eye(m).cuda() + sim = torch.abs(similarity - identity_mask) + + return torch.sum(sim)/(m*(m-1)) + + +class Memory(nn.Module): + def __init__(self, memory_size, feature_dim, key_dim, temp_update, temp_gather): + super(Memory, self).__init__() + # Constants + self.memory_size = memory_size + self.feature_dim = feature_dim + self.key_dim = key_dim + self.temp_update = temp_update + self.temp_gather = temp_gather + + def hard_neg_mem(self, mem, i): + similarity = torch.matmul(mem,torch.t(self.keys_var)) + similarity[:,i] = -1 + _, max_idx = torch.topk(similarity, 1, dim=1) + + + return self.keys_var[max_idx] + + def random_pick_memory(self, mem, max_indices): + + m, d = mem.size() + output = [] + for i in range(m): + flattened_indices = (max_indices==i).nonzero() + a, _ = flattened_indices.size() + if a != 0: + number = np.random.choice(a, 1) + output.append(flattened_indices[number, 0]) + else: + output.append(-1) + + return torch.tensor(output) + + def get_update_query(self, mem, max_indices, update_indices, score, query, train): + + m, d = mem.size() + if train: + query_update = torch.zeros((m,d)).cuda() + random_update = torch.zeros((m,d)).cuda() + for i in range(m): + idx = torch.nonzero(max_indices.squeeze(1)==i) + a, _ = idx.size() + #ex = update_indices[0][i] + if a != 0: + #random_idx = torch.randperm(a)[0] + #idx = idx[idx != ex] +# query_update[i] = torch.sum(query[idx].squeeze(1), dim=0) + query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0) + #random_update[i] = query[random_idx] * (score[random_idx,i] / torch.max(score[:,i])) + else: + query_update[i] = 0 + #random_update[i] = 0 + + + return query_update + + else: + query_update = torch.zeros((m,d)).cuda() + for i in range(m): + idx = torch.nonzero(max_indices.squeeze(1)==i) + a, _ = idx.size() + #ex = update_indices[0][i] + if a != 0: + #idx = idx[idx != ex] + query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0) +# query_update[i] = torch.sum(query[idx].squeeze(1), dim=0) + else: + query_update[i] = 0 + + return query_update + + def get_score(self, mem, query): + bs, h,w,d = query.size() + m, d = mem.size() + + score = torch.matmul(query, torch.t(mem))# b X h X w X m + score = score.view(bs*h*w, m)# (b X h X w) X m + + score_query = F.softmax(score, dim=0) + score_memory = F.softmax(score,dim=1) + + return score_query, score_memory + + def forward(self, query, keys, train=True): + + batch_size, dims,h,w = query.size() # b X d X h X w + query = F.normalize(query, dim=1) + query = query.permute(0,2,3,1) # b X h X w X d + + #train + if train: + #gathering loss + gathering_loss = self.gather_loss(query,keys, train) + #spreading_loss + spreading_loss = self.spread_loss(query, keys, train) + # read + updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys) + #update + updated_memory = self.update(query, keys, train) + + return updated_query, updated_memory, softmax_score_query, softmax_score_memory, gathering_loss, spreading_loss + + #test + else: + #gathering loss + gathering_loss = self.gather_loss(query,keys, train) + + # read + updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys) + + #update + updated_memory = keys + + + return updated_query, updated_memory, softmax_score_query, softmax_score_memory, gathering_loss + + + + def update(self, query, keys,train): + + batch_size, h,w,dims = query.size() # b X h X w X d + + softmax_score_query, softmax_score_memory = self.get_score(keys, query) + + query_reshape = query.contiguous().view(batch_size*h*w, dims) + + _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1) + _, updating_indices = torch.topk(softmax_score_query, 1, dim=0) + + if train: + # top-1 queries (of each memory) update (weighted sum) & random pick + query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape,train) + updated_memory = F.normalize(query_update + keys, dim=1) + + else: + # only weighted sum update when test + query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape, train) + updated_memory = F.normalize(query_update + keys, dim=1) + + # top-1 update + #query_update = query_reshape[updating_indices][0] + #updated_memory = F.normalize(query_update + keys, dim=1) + + return updated_memory.detach() + + + def pointwise_gather_loss(self, query_reshape, keys, gathering_indices, train): + n,dims = query_reshape.size() # (b X h X w) X d + loss_mse = torch.nn.MSELoss(reduction='none') + + pointwise_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach()) + + return pointwise_loss + + def spread_loss(self,query, keys, train): + batch_size, h,w,dims = query.size() # b X h X w X d + + loss = torch.nn.TripletMarginLoss(margin=1.0) + + softmax_score_query, softmax_score_memory = self.get_score(keys, query) + + query_reshape = query.contiguous().view(batch_size*h*w, dims) + + _, gathering_indices = torch.topk(softmax_score_memory, 2, dim=1) + + #1st, 2nd closest memories + pos = keys[gathering_indices[:,0]] + neg = keys[gathering_indices[:,1]] + + spreading_loss = loss(query_reshape,pos.detach(), neg.detach()) + + return spreading_loss + + def gather_loss(self, query, keys, train): + + batch_size, h,w,dims = query.size() # b X h X w X d + + loss_mse = torch.nn.MSELoss() + + softmax_score_query, softmax_score_memory = self.get_score(keys, query) + + query_reshape = query.contiguous().view(batch_size*h*w, dims) + + _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1) + + gathering_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach()) + + return gathering_loss + + + + + def read(self, query, updated_memory): + batch_size, h,w,dims = query.size() # b X h X w X d + + softmax_score_query, softmax_score_memory = self.get_score(updated_memory, query) + + query_reshape = query.contiguous().view(batch_size*h*w, dims) + + concat_memory = torch.matmul(softmax_score_memory.detach(), updated_memory) # (b X h X w) X d + updated_query = torch.cat((query_reshape, concat_memory), dim = 1) # (b X h X w) X 2d + updated_query = updated_query.view(batch_size, h, w, 2*dims) + updated_query = updated_query.permute(0,3,1,2) + + return updated_query, softmax_score_query, softmax_score_memory + + \ No newline at end of file diff --git a/model/Reconstruction.py b/model/Reconstruction.py new file mode 100644 index 00000000..08836339 --- /dev/null +++ b/model/Reconstruction.py @@ -0,0 +1,156 @@ +import numpy as np +import os +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +from .Memory import * + +class Encoder(torch.nn.Module): + def __init__(self, t_length = 2, n_channel =3): + super(Encoder, self).__init__() + + def Basic(intInput, intOutput): + return torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.BatchNorm2d(intOutput), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.BatchNorm2d(intOutput), + torch.nn.ReLU(inplace=False) + ) + + def Basic_(intInput, intOutput): + return torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.BatchNorm2d(intOutput), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + ) + + self.moduleConv1 = Basic(n_channel*(t_length-1), 64) + self.modulePool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2) + + self.moduleConv2 = Basic(64, 128) + self.modulePool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2) + + self.moduleConv3 = Basic(128, 256) + self.modulePool3 = torch.nn.MaxPool2d(kernel_size=2, stride=2) + + self.moduleConv4 = Basic_(256, 512) + self.moduleBatchNorm = torch.nn.BatchNorm2d(512) + self.moduleReLU = torch.nn.ReLU(inplace=False) + + def forward(self, x): + + tensorConv1 = self.moduleConv1(x) + tensorPool1 = self.modulePool1(tensorConv1) + + tensorConv2 = self.moduleConv2(tensorPool1) + tensorPool2 = self.modulePool2(tensorConv2) + + tensorConv3 = self.moduleConv3(tensorPool2) + tensorPool3 = self.modulePool3(tensorConv3) + + tensorConv4 = self.moduleConv4(tensorPool3) + + return tensorConv4 + + + +class Decoder(torch.nn.Module): + def __init__(self, t_length = 2, n_channel =3): + super(Decoder, self).__init__() + + def Basic(intInput, intOutput): + return torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.BatchNorm2d(intOutput), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.BatchNorm2d(intOutput), + torch.nn.ReLU(inplace=False) + ) + + + def Gen(intInput, intOutput, nc): + return torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intInput, out_channels=nc, kernel_size=3, stride=1, padding=1), + torch.nn.BatchNorm2d(nc), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=1), + torch.nn.BatchNorm2d(nc), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=nc, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.Tanh() + ) + + def Upsample(nc, intOutput): + return torch.nn.Sequential( + torch.nn.ConvTranspose2d(in_channels = nc, out_channels=intOutput, kernel_size = 3, stride = 2, padding = 1, output_padding = 1), + torch.nn.BatchNorm2d(intOutput), + torch.nn.ReLU(inplace=False) + ) + + self.moduleConv = Basic(1024, 512) + self.moduleUpsample4 = Upsample(512, 512) + + self.moduleDeconv3 = Basic(512, 256) + self.moduleUpsample3 = Upsample(256, 256) + + self.moduleDeconv2 = Basic(256, 128) + self.moduleUpsample2 = Upsample(128, 128) + + self.moduleDeconv1 = Gen(128,n_channel,64) + + + + def forward(self, x): + + tensorConv = self.moduleConv(x) + + tensorUpsample4 = self.moduleUpsample4(tensorConv) + + tensorDeconv3 = self.moduleDeconv3(tensorUpsample4) + tensorUpsample3 = self.moduleUpsample3(tensorDeconv3) + + tensorDeconv2 = self.moduleDeconv2(tensorUpsample3) + tensorUpsample2 = self.moduleUpsample2(tensorDeconv2) + + output = self.moduleDeconv1(tensorUpsample2) + + + return output + + + +class convAE(torch.nn.Module): + def __init__(self, n_channel =3, t_length = 2, memory_size = 10, feature_dim = 512, key_dim = 512, temp_update = 0.1, temp_gather=0.1): + super(convAE, self).__init__() + + self.encoder = Encoder(t_length, n_channel) + self.decoder = Decoder(t_length, n_channel) + self.memory = Memory(memory_size,feature_dim, key_dim, temp_update, temp_gather) + + + def forward(self, x, keys,train=True): + + fea = self.encoder(x) + if train: + updated_fea, keys, softmax_score_query, softmax_score_memory, gathering_loss, spreading_loss = self.memory(fea, keys, train) + output = self.decoder(updated_fea) + + return output, fea, updated_fea, keys, softmax_score_query, softmax_score_memory, gathering_loss, spreading_loss + + #test + else: + updated_fea, keys, softmax_score_query, softmax_score_memory, gathering_loss = self.memory(fea, keys, train) + output = self.decoder(updated_fea) + + return output, fea, updated_fea, keys, softmax_score_query, softmax_score_memory, gathering_loss + + + + + + \ No newline at end of file