-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
93 lines (78 loc) · 2.57 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import inspect
import numpy as np
from time import time
from tqdm import tqdm
from build_graph import build_graph
from randomwalk.rw import RandomWalk as RW
from randomwalk.rw_cache import RandomWalk as RWcache
from randomwalk.rw_parallel import RandomWalk as RWparallel
from randomwalk.rw_cy import RandomWalk as RWcy
from randomwalk.rw_cython import RandomWalk as RWcython
def run_evaluate(model, latest_prefs, top_k):
scores = []
elapsed_times = []
for u, prefs in tqdm(latest_prefs.items()):
if len(prefs) == 0:
continue
known_prefs = model.get_neighbors(u)
if len(known_prefs) == 0:
# new user
continue
t1 = time()
visit_count = model.run_random_walk(query_items=known_prefs)
t2 = time()
recommended_items = [
key
for (key, value) in sorted(
iter(visit_count.items()), key=lambda k_v: (k_v[1], k_v[0])
)[::-1]
if key not in known_prefs
][:top_k]
# calc ndcg
ind = np.array([1 if i in prefs else 0 for i in recommended_items])
if len(ind) < top_k:
ind = np.concatenate((ind, np.zeros(top_k - len(ind))))
denom = np.log(np.arange(2, top_k + 2))
numer = 2 ** ind - 1
s = np.sum(numer / denom)
scores.append(s)
elapsed_times.append(t2 - t1)
print("model: {}".format(inspect.getfile(model.__class__)))
print(
"avg.elapsed time: {}[s]".format(
sum(elapsed_times) / len(elapsed_times)
)
)
print(
"avg.nDCG@{}(#samples): {:.3f}({})"
"".format(top_k, np.mean(scores), len(scores))
)
def main(path, mode):
adjacency, offsets, test_df, user2uid, item2iid = build_graph(
path, test_size=32, split="time", out_dir="model"
)
n_users, n_items = len(user2uid), len(item2iid)
latest_prefs = test_df.groupby("user_id").item_id.apply(list).to_dict()
rw_class = {
"vanilla": RW,
"cache": RWcache,
"parallel": RWparallel,
"cy": RWcy,
"cython": RWcython,
}
rw = rw_class[mode](alpha=0.1, n_total_steps=100000).load_graph(
adjacency, offsets, n_users, n_items
)
run_evaluate(rw, latest_prefs, 32)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("inp")
parser.add_argument(
"--mode",
"-m",
choices=["vanilla", "cache", "parallel", "cy", "cython"],
default="vanilla",
)
args = parser.parse_args()
main(args.inp, args.mode)