Skip to content

Commit

Permalink
Update test_config.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Zplusdragon authored May 31, 2023
1 parent da92d93 commit 735200f
Showing 1 changed file with 20 additions and 29 deletions.
49 changes: 20 additions & 29 deletions test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,25 @@

def Test_parse_args():
parser = argparse.ArgumentParser()

parser.add_argument('--model_path', type=str, default='/home/dancer/Papercode/PLIPv2/log_AA/MA__VAP')
parser.add_argument('--image_path', type=str, default='/home/dancer/Papercode/FISNet_deit_bert/data/cuhk')
parser.add_argument('--model_path', type=str, default='checkpoints/PLIP_MRN50.pth.tar')
parser.add_argument('--image_path', type=str, default='data/cuhk')
parser.add_argument('--test_path', type=str,
default='/home/dancer/Papercode/FISNet_deit_bert/data/cuhk_annotations/CUHK-PEDES-test.json',
default='data/CUHK-PEDES-test.json',
help='path for test annotation json file')
# ***********************************************************************************************************************
# 设置模型backbone的类型和参数
parser.add_argument('--plip_model', type=str, default='MResNet_BERT')
parser.add_argument('--img_backbone', type=str, default='Swin_Base_BERT',
parser.add_argument('--img_backbone', type=str, default='ModifiedResNet',
help="ResNet:xxx, ModifiedResNet, ViT:xxx")
parser.add_argument('--txt_backbone', type=str, default="bert-base-uncased")
parser.add_argument('--img_dim', type=int, default=768, help='dimension of image embedding vectors')
parser.add_argument('--text_dim', type=int, default=768, help='dimension of text embedding vectors')
parser.add_argument('--patch_size', type=int, default=16, help='Just for ViT model')
parser.add_argument('--layers', type=list, default=[3, 4, 6, 3], help='Just for ModifiedResNet model')
parser.add_argument('--heads', type=int, default=8, help='Just for ModifiedResNet model')

parser.add_argument('--content_mask_ratio', type=float, default=0.5, help='dimension of lstm hidden states')
parser.add_argument('--others_mask_ratio', type=float, default=0.1, help='number of layers in lstm')
parser.add_argument('--height', type=int, default=256)
parser.add_argument('--width', type=int, default=128)

# 设置超参数
parser.add_argument('--num_epoches', type=int, default=70)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--device',type=str,default="cuda:0")
Expand Down Expand Up @@ -91,31 +85,28 @@ def Test_main(args):
ac_t2i_top10_best = 0.0
mAP_best = 0.0
best = 0
#dst_best = args.best_dir + "/model_best" + ".pth"
model = Create_PLIP_Model(args).to(device)

for i in range(29, 41):
model_file = os.path.join(args.model_path, str(i+1))+".pth.tar"
print(model_file)
if os.path.isdir(model_file):
continue
checkpoint = torch.load(model_file,map_location='cpu')
model.image_encoder.load_state_dict(checkpoint["ImgEncoder_state_dict"],strict=False)
model.text_encoder.load_state_dict(checkpoint["TxtEncoder_state_dict"])
ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, mAP = test(image_test_loader,text_test_loader, model)
if ac_top1_t2i > ac_t2i_top1_best:
ac_t2i_top1_best = ac_top1_t2i
ac_t2i_top5_best = ac_top5_t2i
ac_t2i_top10_best = ac_top10_t2i
mAP_best = mAP
best = i
#shutil.copyfile(model_file, dst_best)

model_file = args.model_path
print(model_file)
if os.path.isdir(model_file):
continue
checkpoint = torch.load(model_file,map_location='cpu')
model.image_encoder.load_state_dict(checkpoint["ImgEncoder_state_dict"])
model.text_encoder.load_state_dict(checkpoint["TxtEncoder_state_dict"])
ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, mAP = test(image_test_loader,text_test_loader, model)
if ac_top1_t2i > ac_t2i_top1_best:
ac_t2i_top1_best = ac_top1_t2i
ac_t2i_top5_best = ac_top5_t2i
ac_t2i_top10_best = ac_top10_t2i
mAP_best = mAP
best = i

print('Epo{}: {:.5f} {:.5f} {:.5f} {:.5f}'.format(
best+1, ac_t2i_top1_best, ac_t2i_top5_best, ac_t2i_top10_best, mAP_best))

import warnings
warnings.filterwarnings("ignore")
if __name__ == '__main__':
args = Test_parse_args()
Test_main(args)
Test_main(args)

0 comments on commit 735200f

Please sign in to comment.