Skip to content

Commit

Permalink
实验结果分析,所有电影
Browse files Browse the repository at this point in the history
  • Loading branch information
deipss committed May 27, 2021
1 parent d97d13f commit c8060c8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
30 changes: 28 additions & 2 deletions item_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def evaluate(emb, top_k, movie_lists):
def search_neighbor_item():
# 读取pandas
df = generate_meta_map()
m_list = [17, 71, 98, 44, 501]
m_list = [ 17, 71, 44, 98,501]
# # 加载模型
get_E()
m_list = [mid2idx[i] for i in m_list]
Expand All @@ -75,6 +75,27 @@ def search_neighbor_item():
item_mid = midx2id[k]
print(get_info_by_sid(df, item_mid))

def search_neighbor_item_all():
# 读取pandas
df = generate_meta_map()
# # 加载模型
get_E()
m_list = list(mid2idx.values())
rst, cos_v = evaluate(E, 11, m_list)
cnt = 0
print(rst)
for i, v in zip(rst, cos_v):
print('emb_id=%d\tmovie_id=%d' % (i[0], midx2id[i[0]]))
ori_set = get_meta_by_sid(df, midx2id[i[0]])
for k, c in zip(i[1:], v[1:]):
if k not in midx2id.keys():
continue
item_mid = midx2id[k]
sim_set = get_meta_by_sid(df, item_mid)
if len(ori_set & sim_set) > 0:
cnt = cnt + 1
print('all = %d' % cnt)


def generate_meta_map():
file_path = '/home/deipss/BERT4Rec-VAE-Pytorch-master/Data/ml-1m/movies.dat'
Expand All @@ -88,6 +109,11 @@ def get_info_by_sid(df, sid):
return '' + str(info[0]) + '\t' + str(info[1]) + '\t' + str(info[2])


def get_meta_by_sid(df, sid):
info = df.loc[df['sid'] == sid].values.tolist()[0]
return set((str(info[2])).split('|'))


if __name__ == '__main__':
# df = generate_meta_map()
# get_info_by_sid(df, 12)
Expand All @@ -97,4 +123,4 @@ def get_info_by_sid(df, sid):
args.dae_latent_dim = i
args.vae_latent_dim = i
args.dim = i
search_neighbor_item()
search_neighbor_item_all()
4 changes: 2 additions & 2 deletions templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def set_template(args):
args.kernel_size = 6

# args.dataset_code = 'ml-' + input('Input 1 for ml-1m, 20 for ml-20m: ') + 'm'
args.dataset_code = 'ml-20m'
args.dataset_code = 'ml-1m'
args.min_rating = 0 if args.dataset_code == 'ml-1m' else 4
args.min_uc = 5
args.min_sc = 0
Expand Down Expand Up @@ -54,7 +54,7 @@ def set_template(args):
elif args.template.startswith('train_bert'):
args.mode = 'train'
# args.dataset_code = 'ml-' + input('Input 1 for ml-1m, 20 for ml-20m: ') + 'm'
args.dataset_code = 'ml-20m'
args.dataset_code = 'ml-1m'
args.min_rating = 0 if args.dataset_code == 'ml-1m' else 4
args.min_uc = 5
args.min_sc = 0
Expand Down

0 comments on commit c8060c8

Please sign in to comment.