Skip to content

Commit

Permalink
eval on subset of converted data
Browse files Browse the repository at this point in the history
  • Loading branch information
rosafish committed Feb 7, 2020
1 parent d4a8e17 commit 15c7cd6
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions baseline/train_mf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def parse_args():
action="store_true")
parser.add_argument("--eval_data",
default='kprn_test',
choices=['kprn_test', '10users'],
choices=['kprn_test', 'kprn_test_subset_1000','10users'],
help='Evaluation data')
args = parser.parse_args()
return args
Expand Down Expand Up @@ -189,7 +189,7 @@ def evaluate(args, model, user_ix, song_ix, test_data):
rank_tuples = []
for i in instance:
tag = i[1]
if args.eval_data == 'kprn_test':
if args.eval_data in ['kprn_test_subset_1000', 'kprn_test']:
#convert kprn indices to mf indices (user and song)
user_ix_kprn = i[0][0]
song_ix_kprn = i[0][1]
Expand All @@ -214,12 +214,16 @@ def evaluate(args, model, user_ix, song_ix, test_data):


def load_test_data(args):
if args.subnetwork == 'dense' and args.eval_data == 'kprn_test':
test_data = None
if args.subnetwork == 'dense' and args.eval_data in ['kprn_test_subset_1000', 'kprn_test']:
with open("../data/song_test_data/bpr_matrix_test_dense_py2.pkl", 'rb') as handle:
test_data = cPickle.load(handle)
elif args.subnetwork == 'rs' and args.eval_data == 'kprn_test':
elif args.subnetwork == 'rs' and args.eval_data in ['kprn_test_subset_1000', 'kprn_test']:
with open("../data/song_test_data/bpr_matrix_test_rs_py2.pkl", 'rb') as handle:
test_data = cPickle.load(handle)

if args.eval_data == 'kprn_test_subset_1000':
return random.sample(test_data, 1000)
return test_data


Expand Down Expand Up @@ -262,7 +266,7 @@ def main():
print 'prepare test data...'
# note: the user and song indices have not been converted to the mf indices
# the conversion will be done in the evaluate function
if args.eval_data == 'kprn_test':
if args.eval_data in ['kprn_test_subset_1000', 'kprn_test']:
test_data = load_test_data(args)
elif args.eval_data == '10users':
test_data = prep_test_data(test_user_song, train_user_song, \
Expand Down

0 comments on commit 15c7cd6

Please sign in to comment.