forked from mouna99/dien
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
mouna.mn
committed
Sep 7, 2018
0 parents
commit 0328f8c
Showing
22 changed files
with
3,075 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
export PATH="~/anaconda4/bin:$PATH" | ||
wget http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Electronics.json.gz | ||
wget http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Electronics.json.gz | ||
gunzip reviews_Electronics.json.gz | ||
gunzip meta_Electronics.json.gz | ||
python script/process_data.py | ||
python script/local_aggretor.py | ||
python script/split_by_user.py | ||
python script/generate_voc.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
## prepare data | ||
### method 1 | ||
You can get the data from amazon website and process it using the script | ||
``` | ||
sh prepare_data.sh | ||
``` | ||
### method 2 (recommended) | ||
Because getting and processing the data is time consuming,so we had processed it and upload it for you. You can unzip it to use directly. | ||
``` | ||
tar -zxvf data.tar.gz | ||
mv data/* . | ||
``` | ||
When you see the files below, you can do the next work. | ||
> cat_voc.pkl | ||
> mid_voc.pkl | ||
> uid_voc.pkl | ||
> local_train_splitByUser | ||
> local_test_splitByUser | ||
## train model | ||
``` | ||
python train.py train [model name] | ||
``` | ||
The model blelow had been supported: | ||
> DNN | ||
> PNN | ||
> Wide (Wide&Deep NN) | ||
> DIN (https://arxiv.org/abs/1706.06978) | ||
> DIEN (Our model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
CUDA_VISIBLE_DEVICES=0 /usr/bin/python2.7 script/train.py train DIEN >train_dein2.log 2>&1 & |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import tensorflow as tf | ||
|
||
def dice(_x, axis=-1, epsilon=0.000000001, name=''): | ||
with tf.variable_scope(name, reuse=tf.AUTO_REUSE): | ||
alphas = tf.get_variable('alpha'+name, _x.get_shape()[-1], | ||
initializer=tf.constant_initializer(0.0), | ||
dtype=tf.float32) | ||
input_shape = list(_x.get_shape()) | ||
|
||
reduction_axes = list(range(len(input_shape))) | ||
del reduction_axes[axis] | ||
broadcast_shape = [1] * len(input_shape) | ||
broadcast_shape[axis] = input_shape[axis] | ||
|
||
# case: train mode (uses stats of the current batch) | ||
mean = tf.reduce_mean(_x, axis=reduction_axes) | ||
brodcast_mean = tf.reshape(mean, broadcast_shape) | ||
std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes) | ||
std = tf.sqrt(std) | ||
brodcast_std = tf.reshape(std, broadcast_shape) | ||
x_normed = (_x - brodcast_mean) / (brodcast_std + epsilon) | ||
# x_normed = tf.layers.batch_normalization(_x, center=False, scale=False) | ||
x_p = tf.sigmoid(x_normed) | ||
|
||
|
||
return alphas * (1.0 - x_p) * _x + x_p * _x | ||
|
||
def parametric_relu(_x): | ||
alphas = tf.get_variable('alpha', _x.get_shape()[-1], | ||
initializer=tf.constant_initializer(0.0), | ||
dtype=tf.float32) | ||
pos = tf.nn.relu(_x) | ||
neg = alphas * (_x - abs(_x)) * 0.5 | ||
|
||
return pos + neg |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import numpy | ||
import json | ||
import cPickle as pkl | ||
import random | ||
|
||
import gzip | ||
|
||
import shuffle | ||
|
||
def unicode_to_utf8(d): | ||
return dict((key.encode("UTF-8"), value) for (key,value) in d.items()) | ||
|
||
def load_dict(filename): | ||
try: | ||
with open(filename, 'rb') as f: | ||
return unicode_to_utf8(json.load(f)) | ||
except: | ||
with open(filename, 'rb') as f: | ||
return unicode_to_utf8(pkl.load(f)) | ||
|
||
|
||
def fopen(filename, mode='r'): | ||
if filename.endswith('.gz'): | ||
return gzip.open(filename, mode) | ||
return open(filename, mode) | ||
|
||
|
||
class DataIterator: | ||
|
||
def __init__(self, source, | ||
uid_voc, | ||
mid_voc, | ||
cat_voc, | ||
batch_size=128, | ||
maxlen=100, | ||
skip_empty=False, | ||
shuffle_each_epoch=False, | ||
sort_by_length=True, | ||
max_batch_size=20): | ||
if shuffle_each_epoch: | ||
self.source_orig = source | ||
self.source = shuffle.main(self.source_orig, temporary=True) | ||
else: | ||
self.source = fopen(source, 'r') | ||
self.source_dicts = [] | ||
for source_dict in [uid_voc, mid_voc, cat_voc]: | ||
self.source_dicts.append(load_dict(source_dict)) | ||
|
||
f_meta = open("item-info", "r") | ||
meta_map = {} | ||
for line in f_meta: | ||
arr = line.strip().split("\t") | ||
if arr[0] not in meta_map: | ||
meta_map[arr[0]] = arr[1] | ||
self.meta_id_map ={} | ||
for key in meta_map: | ||
val = meta_map[key] | ||
if key in self.source_dicts[1]: | ||
mid_idx = self.source_dicts[1][key] | ||
else: | ||
mid_idx = 0 | ||
if val in self.source_dicts[2]: | ||
cat_idx = self.source_dicts[2][val] | ||
else: | ||
cat_idx = 0 | ||
self.meta_id_map[mid_idx] = cat_idx | ||
|
||
f_review = open("reviews-info", "r") | ||
self.mid_list_for_random = [] | ||
for line in f_review: | ||
arr = line.strip().split("\t") | ||
tmp_idx = 0 | ||
if arr[1] in self.source_dicts[1]: | ||
tmp_idx = self.source_dicts[1][arr[1]] | ||
self.mid_list_for_random.append(tmp_idx) | ||
|
||
self.batch_size = batch_size | ||
self.maxlen = maxlen | ||
self.skip_empty = skip_empty | ||
|
||
self.n_uid = len(self.source_dicts[0]) | ||
self.n_mid = len(self.source_dicts[1]) | ||
self.n_cat = len(self.source_dicts[2]) | ||
|
||
self.shuffle = shuffle_each_epoch | ||
self.sort_by_length = sort_by_length | ||
|
||
self.source_buffer = [] | ||
self.k = batch_size * max_batch_size | ||
|
||
self.end_of_data = False | ||
|
||
def get_n(self): | ||
return self.n_uid, self.n_mid, self.n_cat | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def reset(self): | ||
if self.shuffle: | ||
self.source= shuffle.main(self.source_orig, temporary=True) | ||
else: | ||
self.source.seek(0) | ||
|
||
def next(self): | ||
if self.end_of_data: | ||
self.end_of_data = False | ||
self.reset() | ||
raise StopIteration | ||
|
||
source = [] | ||
target = [] | ||
|
||
if len(self.source_buffer) == 0: | ||
for k_ in xrange(self.k): | ||
ss = self.source.readline() | ||
if ss == "": | ||
break | ||
self.source_buffer.append(ss.strip("\n").split("\t")) | ||
|
||
# sort by history behavior length | ||
if self.sort_by_length: | ||
his_length = numpy.array([len(s[4].split("")) for s in self.source_buffer]) | ||
tidx = his_length.argsort() | ||
|
||
_sbuf = [self.source_buffer[i] for i in tidx] | ||
self.source_buffer = _sbuf | ||
else: | ||
self.source_buffer.reverse() | ||
|
||
if len(self.source_buffer) == 0: | ||
self.end_of_data = False | ||
self.reset() | ||
raise StopIteration | ||
|
||
try: | ||
|
||
# actual work here | ||
while True: | ||
|
||
# read from source file and map to word index | ||
try: | ||
ss = self.source_buffer.pop() | ||
except IndexError: | ||
break | ||
|
||
uid = self.source_dicts[0][ss[1]] if ss[1] in self.source_dicts[0] else 0 | ||
mid = self.source_dicts[1][ss[2]] if ss[2] in self.source_dicts[1] else 0 | ||
cat = self.source_dicts[2][ss[3]] if ss[3] in self.source_dicts[2] else 0 | ||
tmp = [] | ||
for fea in ss[4].split(""): | ||
m = self.source_dicts[1][fea] if fea in self.source_dicts[1] else 0 | ||
tmp.append(m) | ||
mid_list = tmp | ||
|
||
tmp1 = [] | ||
for fea in ss[5].split(""): | ||
c = self.source_dicts[2][fea] if fea in self.source_dicts[2] else 0 | ||
tmp1.append(c) | ||
cat_list = tmp1 | ||
|
||
# read from source file and map to word index | ||
|
||
#if len(mid_list) > self.maxlen: | ||
# continue | ||
if self.skip_empty and (not mid_list): | ||
continue | ||
|
||
noclk_mid_list = [] | ||
noclk_cat_list = [] | ||
for pos_mid in mid_list: | ||
noclk_tmp_mid = [] | ||
noclk_tmp_cat = [] | ||
noclk_index = 0 | ||
while True: | ||
noclk_mid_indx = random.randint(0, len(self.mid_list_for_random)-1) | ||
noclk_mid = self.mid_list_for_random[noclk_mid_indx] | ||
if noclk_mid == pos_mid: | ||
continue | ||
noclk_tmp_mid.append(noclk_mid) | ||
noclk_tmp_cat.append(self.meta_id_map[noclk_mid]) | ||
noclk_index += 1 | ||
if noclk_index >= 5: | ||
break | ||
noclk_mid_list.append(noclk_tmp_mid) | ||
noclk_cat_list.append(noclk_tmp_cat) | ||
source.append([uid, mid, cat, mid_list, cat_list, noclk_mid_list, noclk_cat_list]) | ||
target.append([float(ss[0]), 1-float(ss[0])]) | ||
|
||
if len(source) >= self.batch_size or len(target) >= self.batch_size: | ||
break | ||
except IOError: | ||
self.end_of_data = True | ||
|
||
# all sentence pairs in maxibatch filtered out because of length | ||
if len(source) == 0 or len(target) == 0: | ||
source, target = self.next() | ||
|
||
return source, target | ||
|
||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import cPickle | ||
|
||
f_train = open("local_train_splitByUser", "r") | ||
uid_dict = {} | ||
mid_dict = {} | ||
cat_dict = {} | ||
|
||
iddd = 0 | ||
for line in f_train: | ||
arr = line.strip("\n").split("\t") | ||
clk = arr[0] | ||
uid = arr[1] | ||
mid = arr[2] | ||
cat = arr[3] | ||
mid_list = arr[4] | ||
cat_list = arr[5] | ||
if uid not in uid_dict: | ||
uid_dict[uid] = 0 | ||
uid_dict[uid] += 1 | ||
if mid not in mid_dict: | ||
mid_dict[mid] = 0 | ||
mid_dict[mid] += 1 | ||
if cat not in cat_dict: | ||
cat_dict[cat] = 0 | ||
cat_dict[cat] += 1 | ||
if len(mid_list) == 0: | ||
continue | ||
for m in mid_list.split(""): | ||
if m not in mid_dict: | ||
mid_dict[m] = 0 | ||
mid_dict[m] += 1 | ||
#print iddd | ||
iddd+=1 | ||
for c in cat_list.split(""): | ||
if c not in cat_dict: | ||
cat_dict[c] = 0 | ||
cat_dict[c] += 1 | ||
|
||
sorted_uid_dict = sorted(uid_dict.iteritems(), key=lambda x:x[1], reverse=True) | ||
sorted_mid_dict = sorted(mid_dict.iteritems(), key=lambda x:x[1], reverse=True) | ||
sorted_cat_dict = sorted(cat_dict.iteritems(), key=lambda x:x[1], reverse=True) | ||
|
||
uid_voc = {} | ||
index = 0 | ||
for key, value in sorted_uid_dict: | ||
uid_voc[key] = index | ||
index += 1 | ||
|
||
mid_voc = {} | ||
mid_voc["default_mid"] = 0 | ||
index = 1 | ||
for key, value in sorted_mid_dict: | ||
mid_voc[key] = index | ||
index += 1 | ||
|
||
cat_voc = {} | ||
cat_voc["default_cat"] = 0 | ||
index = 1 | ||
for key, value in sorted_cat_dict: | ||
cat_voc[key] = index | ||
index += 1 | ||
|
||
cPickle.dump(uid_voc, open("uid_voc.pkl", "w")) | ||
cPickle.dump(mid_voc, open("mid_voc.pkl", "w")) | ||
cPickle.dump(cat_voc, open("cat_voc.pkl", "w")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import sys | ||
import hashlib | ||
import random | ||
|
||
fin = open("jointed-new-split-info", "r") | ||
ftrain = open("local_train", "w") | ||
ftest = open("local_test", "w") | ||
|
||
last_user = "0" | ||
common_fea = "" | ||
line_idx = 0 | ||
for line in fin: | ||
items = line.strip().split("\t") | ||
ds = items[0] | ||
clk = int(items[1]) | ||
user = items[2] | ||
movie_id = items[3] | ||
dt = items[5] | ||
cat1 = items[6] | ||
|
||
if ds=="20180118": | ||
fo = ftrain | ||
else: | ||
fo = ftest | ||
if user != last_user: | ||
movie_id_list = [] | ||
cate1_list = [] | ||
#print >> fo, items[1] + "\t" + user + "\t" + movie_id + "\t" + cat1 +"\t" + "" + "\t" + "" | ||
else: | ||
history_clk_num = len(movie_id_list) | ||
cat_str = "" | ||
mid_str = "" | ||
for c1 in cate1_list: | ||
cat_str += c1 + "" | ||
for mid in movie_id_list: | ||
mid_str += mid + "" | ||
if len(cat_str) > 0: cat_str = cat_str[:-1] | ||
if len(mid_str) > 0: mid_str = mid_str[:-1] | ||
if history_clk_num >= 1: # 8 is the average length of user behavior | ||
print >> fo, items[1] + "\t" + user + "\t" + movie_id + "\t" + cat1 +"\t" + mid_str + "\t" + cat_str | ||
last_user = user | ||
if clk: | ||
movie_id_list.append(movie_id) | ||
cate1_list.append(cat1) | ||
line_idx += 1 |
Oops, something went wrong.