From c8060c85a664fe51890fa70f3d5b5b7a955f9a09 Mon Sep 17 00:00:00 2001 From: deipss Date: Thu, 27 May 2021 17:00:37 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E9=AA=8C=E7=BB=93=E6=9E=9C=E5=88=86?= =?UTF-8?q?=E6=9E=90=EF=BC=8C=E6=89=80=E6=9C=89=E7=94=B5=E5=BD=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- item_util.py | 30 ++++++++++++++++++++++++++++-- templates.py | 4 ++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/item_util.py b/item_util.py index b5d786f..a1ec957 100644 --- a/item_util.py +++ b/item_util.py @@ -59,7 +59,7 @@ def evaluate(emb, top_k, movie_lists): def search_neighbor_item(): # 读取pandas df = generate_meta_map() - m_list = [17, 71, 98, 44, 501] + m_list = [ 17, 71, 44, 98,501] # # 加载模型 get_E() m_list = [mid2idx[i] for i in m_list] @@ -75,6 +75,27 @@ def search_neighbor_item(): item_mid = midx2id[k] print(get_info_by_sid(df, item_mid)) +def search_neighbor_item_all(): + # 读取pandas + df = generate_meta_map() + # # 加载模型 + get_E() + m_list = list(mid2idx.values()) + rst, cos_v = evaluate(E, 11, m_list) + cnt = 0 + print(rst) + for i, v in zip(rst, cos_v): + print('emb_id=%d\tmovie_id=%d' % (i[0], midx2id[i[0]])) + ori_set = get_meta_by_sid(df, midx2id[i[0]]) + for k, c in zip(i[1:], v[1:]): + if k not in midx2id.keys(): + continue + item_mid = midx2id[k] + sim_set = get_meta_by_sid(df, item_mid) + if len(ori_set & sim_set) > 0: + cnt = cnt + 1 + print('all = %d' % cnt) + def generate_meta_map(): file_path = '/home/deipss/BERT4Rec-VAE-Pytorch-master/Data/ml-1m/movies.dat' @@ -88,6 +109,11 @@ def get_info_by_sid(df, sid): return '' + str(info[0]) + '\t' + str(info[1]) + '\t' + str(info[2]) +def get_meta_by_sid(df, sid): + info = df.loc[df['sid'] == sid].values.tolist()[0] + return set((str(info[2])).split('|')) + + if __name__ == '__main__': # df = generate_meta_map() # get_info_by_sid(df, 12) @@ -97,4 +123,4 @@ def get_info_by_sid(df, sid): args.dae_latent_dim = i args.vae_latent_dim = i args.dim = i - search_neighbor_item() + search_neighbor_item_all() diff --git a/templates.py b/templates.py index fa655ee..fe1c81c 100644 --- a/templates.py +++ b/templates.py @@ -8,7 +8,7 @@ def set_template(args): args.kernel_size = 6 # args.dataset_code = 'ml-' + input('Input 1 for ml-1m, 20 for ml-20m: ') + 'm' - args.dataset_code = 'ml-20m' + args.dataset_code = 'ml-1m' args.min_rating = 0 if args.dataset_code == 'ml-1m' else 4 args.min_uc = 5 args.min_sc = 0 @@ -54,7 +54,7 @@ def set_template(args): elif args.template.startswith('train_bert'): args.mode = 'train' # args.dataset_code = 'ml-' + input('Input 1 for ml-1m, 20 for ml-20m: ') + 'm' - args.dataset_code = 'ml-20m' + args.dataset_code = 'ml-1m' args.min_rating = 0 if args.dataset_code == 'ml-1m' else 4 args.min_uc = 5 args.min_sc = 0