From 15c7cd6924a8a6db0e98712aae5313efcc24bda5 Mon Sep 17 00:00:00 2001 From: rosafish Date: Thu, 6 Feb 2020 19:29:07 -0600 Subject: [PATCH] eval on subset of converted data --- baseline/train_mf.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/baseline/train_mf.py b/baseline/train_mf.py index f08f9cc..83a7c67 100644 --- a/baseline/train_mf.py +++ b/baseline/train_mf.py @@ -60,7 +60,7 @@ def parse_args(): action="store_true") parser.add_argument("--eval_data", default='kprn_test', - choices=['kprn_test', '10users'], + choices=['kprn_test', 'kprn_test_subset_1000','10users'], help='Evaluation data') args = parser.parse_args() return args @@ -189,7 +189,7 @@ def evaluate(args, model, user_ix, song_ix, test_data): rank_tuples = [] for i in instance: tag = i[1] - if args.eval_data == 'kprn_test': + if args.eval_data in ['kprn_test_subset_1000', 'kprn_test']: #convert kprn indices to mf indices (user and song) user_ix_kprn = i[0][0] song_ix_kprn = i[0][1] @@ -214,12 +214,16 @@ def evaluate(args, model, user_ix, song_ix, test_data): def load_test_data(args): - if args.subnetwork == 'dense' and args.eval_data == 'kprn_test': + test_data = None + if args.subnetwork == 'dense' and args.eval_data in ['kprn_test_subset_1000', 'kprn_test']: with open("../data/song_test_data/bpr_matrix_test_dense_py2.pkl", 'rb') as handle: test_data = cPickle.load(handle) - elif args.subnetwork == 'rs' and args.eval_data == 'kprn_test': + elif args.subnetwork == 'rs' and args.eval_data in ['kprn_test_subset_1000', 'kprn_test']: with open("../data/song_test_data/bpr_matrix_test_rs_py2.pkl", 'rb') as handle: test_data = cPickle.load(handle) + + if args.eval_data == 'kprn_test_subset_1000': + return random.sample(test_data, 1000) return test_data @@ -262,7 +266,7 @@ def main(): print 'prepare test data...' # note: the user and song indices have not been converted to the mf indices # the conversion will be done in the evaluate function - if args.eval_data == 'kprn_test': + if args.eval_data in ['kprn_test_subset_1000', 'kprn_test']: test_data = load_test_data(args) elif args.eval_data == '10users': test_data = prep_test_data(test_user_song, train_user_song, \