From 0b75aa02aedfb1c9dfed97d34490da845e623575 Mon Sep 17 00:00:00 2001 From: duyiqi17 Date: Sun, 22 May 2022 16:54:44 +0800 Subject: [PATCH 1/4] fix random.seed problem && make capsual_layer's implementation same to the one in Paper ComiRec --- models/recall/mind/infer.py | 13 +++-- models/recall/mind/mind_reader.py | 6 +-- models/recall/mind/net.py | 82 ++++++++++++++++++------------- 3 files changed, 59 insertions(+), 42 deletions(-) diff --git a/models/recall/mind/infer.py b/models/recall/mind/infer.py index 1ccc07762..c1dece2a0 100644 --- a/models/recall/mind/infer.py +++ b/models/recall/mind/infer.py @@ -110,6 +110,7 @@ def main(args): batch_data, config) user_embs = user_embs.numpy() + # print(user_embs) target_items = np.squeeze(batch_data[-1].numpy(), axis=1) if len(user_embs.shape) == 2: @@ -119,8 +120,9 @@ def main(args): dcg = 0.0 item_list = set(I[i]) iid_list = list(filter(lambda x: x != 0, list(iid_list))) - for no, iid in enumerate(iid_list): - if iid in item_list: + true_item_set = set(iid_list) + for no, iid in enumerate(I[i]): + if iid in true_item_set: recall += 1 dcg += 1.0 / math.log(no + 2, 2) idcg = 0.0 @@ -138,6 +140,7 @@ def main(args): recall = 0 dcg = 0.0 item_list_set = set() + item_cor_list = [] item_list = list( zip( np.reshape(I[i * ni:(i + 1) * ni], -1), @@ -147,13 +150,15 @@ def main(args): if item_list[j][0] not in item_list_set and item_list[ j][0] != 0: item_list_set.add(item_list[j][0]) + item_cor_list.append(item_list[j][0]) if len(item_list_set) >= args.top_n: break iid_list = list(filter(lambda x: x != 0, list(iid_list))) - for no, iid in enumerate(iid_list): + true_item_set = set(iid_list) + for no, iid in enumerate(item_cor_list): if iid == 0: break - if iid in item_list_set: + if iid in true_item_set: recall += 1 dcg += 1.0 / math.log(no + 2, 2) idcg = 0.0 diff --git a/models/recall/mind/mind_reader.py b/models/recall/mind/mind_reader.py index 0d1b78df4..ca0dc363c 100644 --- a/models/recall/mind/mind_reader.py +++ b/models/recall/mind/mind_reader.py @@ -16,7 +16,7 @@ import numpy as np from paddle.io import IterableDataset import random - +random.seed(12345) class RecDataset(IterableDataset): def __init__(self, file_list, config): @@ -51,7 +51,7 @@ def init(self): self.items = list(self.items) def __iter__(self): - random.seed(12345) + # random.seed(12345) while True: user_id_list = random.sample(self.users, self.batch_size) if self.count >= self.batches_per_epoch * self.batch_size: @@ -61,7 +61,7 @@ def __iter__(self): item_list = self.graph[user_id] if len(item_list) <= 4: continue - random.seed(12345) + # random.seed(12345) k = random.choice(range(4, len(item_list))) item_id = item_list[k] diff --git a/models/recall/mind/net.py b/models/recall/mind/net.py index d131cb74d..f2fbbdaf6 100644 --- a/models/recall/mind/net.py +++ b/models/recall/mind/net.py @@ -21,7 +21,6 @@ class Mind_SampledSoftmaxLoss_Layer(nn.Layer): """SampledSoftmaxLoss with LogUniformSampler """ - def __init__(self, num_classes, n_sample, @@ -45,6 +44,7 @@ def __init__(self, self.new_prob = paddle.assign(self.prob.astype("float32")) self.log_q = paddle.log(-(paddle.exp((-paddle.log1p(self.new_prob) * 2 * n_sample)) - 1.0)) + self.loss = nn.CrossEntropyLoss(soft_label=True) def sample(self, labels): """Random sample neg_samples @@ -65,6 +65,7 @@ def forward(self, inputs, labels, weights, bias): # weights.stop_gradient = False embedding_dim = paddle.shape(weights)[-1] true_log_probs, samp_log_probs, neg_samples = self.sample(labels) + # print(neg_samples) n_sample = neg_samples.shape[0] b1 = paddle.shape(labels)[0] @@ -82,22 +83,22 @@ def forward(self, inputs, labels, weights, bias): sample_b = all_b[-n_sample:] # [B, D] * [B, 1,D] - true_logist = paddle.matmul( - true_w, inputs.unsqueeze(1), transpose_y=True).squeeze(1) + true_b - + true_logist = paddle.sum(paddle.multiply( + true_w, inputs.unsqueeze(1)), axis=-1) + true_b + # print(true_logist) + sample_logist = paddle.matmul( - inputs.unsqueeze(1), sample_w, transpose_y=True) + sample_b + inputs, sample_w, transpose_y=True) + sample_b + + if self.remove_accidental_hits: + hit = (paddle.equal(labels[:, :], neg_samples)) + padding = paddle.ones_like(sample_logist) * -1e30 + sample_logist = paddle.where(hit, padding, sample_logist) if self.subtract_log_q: true_logist = true_logist - true_log_probs.unsqueeze(1) sample_logist = sample_logist - samp_log_probs - if self.remove_accidental_hits: - hit = (paddle.equal(labels[:, :], neg_samples)).unsqueeze(1) - padding = paddle.ones_like(sample_logist) * -1e30 - sample_logist = paddle.where(hit, padding, sample_logist) - - sample_logist = sample_logist.squeeze(1) out_logist = paddle.concat([true_logist, sample_logist], axis=1) out_label = paddle.concat( [ @@ -105,16 +106,15 @@ def forward(self, inputs, labels, weights, bias): paddle.zeros_like(sample_logist) ], axis=1) + out_label.stop_gradient = True - sampled_loss = F.softmax_with_cross_entropy( - logits=out_logist, label=out_label, soft_label=True) - return sampled_loss, out_logist, out_label + loss = self.loss(out_logist, out_label) + return loss, out_logist, out_label class Mind_Capsual_Layer(nn.Layer): """Mind_Capsual_Layer """ - def __init__(self, input_units, output_units, @@ -148,6 +148,7 @@ def __init__(self, name="bilinear_mapping_matrix", trainable=True), default_initializer=nn.initializer.Normal( mean=0.0, std=self.init_std)) + self.relu_layer = nn.Linear(self.output_units, self.output_units) def squash(self, Z): """squash @@ -182,39 +183,47 @@ def forward(self, item_his_emb, seq_len): mask = self.sequence_mask(seq_len_tile, self.maxlen) pad = paddle.ones_like(mask, dtype="float32") * (-2**32 + 1) - # S*e low_capsule_new = paddle.matmul(item_his_emb, self.bilinear_mapping_matrix) - low_capsule_new_nograd = paddle.assign(low_capsule_new) + low_capsule_new_tile = paddle.tile(low_capsule_new, [1, 1, self.k_max]) + low_capsule_new_tile = paddle.reshape( + low_capsule_new_tile, [-1, self.maxlen, self.k_max, self.output_units]) + low_capsule_new_tile = paddle.transpose( + low_capsule_new_tile, [0, 2, 1, 3]) + low_capsule_new_tile = paddle.reshape( + low_capsule_new_tile, [-1, self.k_max, self.maxlen, self.output_units]) + low_capsule_new_nograd = paddle.assign(low_capsule_new_tile) low_capsule_new_nograd.stop_gradient = True B = paddle.tile(self.routing_logits, [paddle.shape(item_his_emb)[0], 1, 1]) + B.stop_gradient = True for i in range(self.iters - 1): B_mask = paddle.where(mask, B, pad) # print(B_mask) W = F.softmax(B_mask, axis=1) + W = paddle.unsqueeze(W, axis=2) high_capsule_tmp = paddle.matmul(W, low_capsule_new_nograd) + # print(low_capsule_new_nograd.shape) high_capsule = self.squash(high_capsule_tmp) - B_delta = paddle.matmul( - high_capsule, low_capsule_new_nograd, transpose_y=True) - B += B_delta / paddle.maximum( - paddle.norm( - B_delta, p=2, axis=-1, keepdim=True), - paddle.ones_like(B_delta)) + B_delta = paddle.matmul(low_capsule_new_nograd, + paddle.transpose(high_capsule, [0, 1, 3, 2])) + B_delta = paddle.reshape( + B_delta, shape=[-1, self.k_max, self.maxlen]) + B += B_delta B_mask = paddle.where(mask, B, pad) W = F.softmax(B_mask, axis=1) - # paddle.static.Print(W) - high_capsule_tmp = paddle.matmul(W, low_capsule_new) - # high_capsule_tmp.stop_gradient = False - - high_capsule = self.squash(high_capsule_tmp) - # high_capsule.stop_gradient = False + W = paddle.unsqueeze(W, axis=2) + interest_capsule = paddle.matmul(W, low_capsule_new_tile) + interest_capsule = self.squash(interest_capsule) + high_capsule = paddle.reshape( + interest_capsule, [-1, self.k_max, self.output_units]) + high_capsule = F.relu(self.relu_layer(high_capsule)) return high_capsule, W, seq_len @@ -246,6 +255,7 @@ def __init__(self, name="item_emb", initializer=nn.initializer.XavierUniform( fan_in=item_count, fan_out=embedding_dim))) + # print(self.item_emb.weight) self.embedding_bias = self.create_parameter( shape=(item_count, ), is_bias=True, @@ -267,11 +277,13 @@ def __init__(self, def label_aware_attention(self, keys, query): """label_aware_attention """ - weight = paddle.sum(keys * query, axis=-1, keepdim=True) - weight = paddle.pow(weight, self.pow_p) # [x,k_max,1] - weight = F.softmax(weight, axis=1) - output = paddle.sum(keys * weight, axis=1) - return output, weight + weight = paddle.matmul(keys, paddle.reshape(query, [-1, paddle.shape(query)[-1], 1])) #[B, K, dim] * [B, dim, 1] == [B, k, 1] + weight = paddle.squeeze(weight, axis=-1) + weight = paddle.pow(weight, self.pow_p) # [x,k_max] + weight = F.softmax(weight) #[x, k_max] + weight = paddle.unsqueeze(weight, 1) #[B, 1, k_max] + output = paddle.matmul(weight, keys) #[B, 1, k_max] * [B, k_max, dim] => [B, 1, dim] + return output.squeeze(1), weight def forward(self, hist_item, seqlen, labels=None): """forward @@ -281,7 +293,7 @@ def forward(self, hist_item, seqlen, labels=None): seqlen : [B, 1] target : [B, 1] """ - + # print(hist_item) hit_item_emb = self.item_emb(hist_item) # [B, seqlen, embed_dim] user_cap, cap_weights, cap_mask = self.capsual_layer(hit_item_emb, seqlen) From c587dba9c4e11361d97afd70245550d12b42e11d Mon Sep 17 00:00:00 2001 From: duyiqi Date: Mon, 23 May 2022 12:09:34 +0800 Subject: [PATCH 2/4] update readme && fix codestyle problem --- models/recall/mind/README.md | 4 +-- models/recall/mind/config_bigdata.yaml | 6 ++-- models/recall/mind/net.py | 43 ++++++++++++++++---------- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/models/recall/mind/README.md b/models/recall/mind/README.md index b93aecd3e..74ffe7094 100644 --- a/models/recall/mind/README.md +++ b/models/recall/mind/README.md @@ -107,8 +107,7 @@ python -u static_infer.py -m config.yaml -top_n 50 #对测试数据进行预测 在全量数据下模型的指标如下: | 模型 | batch_size | epoch_num| Recall@50 | NDCG@50 | HitRate@50 |Time of each epoch | | :------| :------ | :------ | :------| :------ | :------| :------ | -| mind | 128 | 20 | 8.43% | 13.28% | 17.22% | 398.64s(CPU) | - +| mind(paddle实现) | 128 | 50 | 5.52% | 4.31% | 11.49% | 356.43s(CPU) | 1. 确认您当前所在目录为PaddleRec/models/recall/mind 2. 进入paddlerec/datasets/AmazonBook目录下执行run.sh脚本,会下载处理完成的AmazonBook数据集,并解压到指定目录 @@ -116,6 +115,7 @@ python -u static_infer.py -m config.yaml -top_n 50 #对测试数据进行预测 cd ../../../datasets/AmazonBook sh run.sh ``` + 3. 安装依赖,我们使用[faiss](https://github.com/facebookresearch/faiss)来进行向量召回 ```bash # CPU-only version(pip) diff --git a/models/recall/mind/config_bigdata.yaml b/models/recall/mind/config_bigdata.yaml index 6e7a63c22..1d16bbcda 100644 --- a/models/recall/mind/config_bigdata.yaml +++ b/models/recall/mind/config_bigdata.yaml @@ -18,15 +18,15 @@ runner: use_gpu: False use_auc: False train_batch_size: 128 - epochs: 20 + epochs: 50 print_interval: 500 model_save_path: "output_model_mind_all" infer_batch_size: 128 infer_reader_path: "mind_infer_reader" # importlib format test_data_dir: "../../../datasets/AmazonBook/valid" infer_load_path: "output_model_mind_all" - infer_start_epoch: 19 - infer_end_epoch: 20 + infer_start_epoch: 49 + infer_end_epoch: 50 # distribute_config # sync_mode: "async" diff --git a/models/recall/mind/net.py b/models/recall/mind/net.py index f2fbbdaf6..1c1c0a57f 100644 --- a/models/recall/mind/net.py +++ b/models/recall/mind/net.py @@ -21,6 +21,7 @@ class Mind_SampledSoftmaxLoss_Layer(nn.Layer): """SampledSoftmaxLoss with LogUniformSampler """ + def __init__(self, num_classes, n_sample, @@ -83,13 +84,13 @@ def forward(self, inputs, labels, weights, bias): sample_b = all_b[-n_sample:] # [B, D] * [B, 1,D] - true_logist = paddle.sum(paddle.multiply( - true_w, inputs.unsqueeze(1)), axis=-1) + true_b + true_logist = paddle.sum(paddle.multiply(true_w, inputs.unsqueeze(1)), + axis=-1) + true_b # print(true_logist) - + sample_logist = paddle.matmul( - inputs, sample_w, transpose_y=True) + sample_b - + inputs, sample_w, transpose_y=True) + sample_b + if self.remove_accidental_hits: hit = (paddle.equal(labels[:, :], neg_samples)) padding = paddle.ones_like(sample_logist) * -1e30 @@ -115,6 +116,7 @@ def forward(self, inputs, labels, weights, bias): class Mind_Capsual_Layer(nn.Layer): """Mind_Capsual_Layer """ + def __init__(self, input_units, output_units, @@ -189,11 +191,13 @@ def forward(self, item_his_emb, seq_len): low_capsule_new_tile = paddle.tile(low_capsule_new, [1, 1, self.k_max]) low_capsule_new_tile = paddle.reshape( - low_capsule_new_tile, [-1, self.maxlen, self.k_max, self.output_units]) - low_capsule_new_tile = paddle.transpose( - low_capsule_new_tile, [0, 2, 1, 3]) + low_capsule_new_tile, + [-1, self.maxlen, self.k_max, self.output_units]) + low_capsule_new_tile = paddle.transpose(low_capsule_new_tile, + [0, 2, 1, 3]) low_capsule_new_tile = paddle.reshape( - low_capsule_new_tile, [-1, self.k_max, self.maxlen, self.output_units]) + low_capsule_new_tile, + [-1, self.k_max, self.maxlen, self.output_units]) low_capsule_new_nograd = paddle.assign(low_capsule_new_tile) low_capsule_new_nograd.stop_gradient = True @@ -209,8 +213,9 @@ def forward(self, item_his_emb, seq_len): high_capsule_tmp = paddle.matmul(W, low_capsule_new_nograd) # print(low_capsule_new_nograd.shape) high_capsule = self.squash(high_capsule_tmp) - B_delta = paddle.matmul(low_capsule_new_nograd, - paddle.transpose(high_capsule, [0, 1, 3, 2])) + B_delta = paddle.matmul( + low_capsule_new_nograd, + paddle.transpose(high_capsule, [0, 1, 3, 2])) B_delta = paddle.reshape( B_delta, shape=[-1, self.k_max, self.maxlen]) B += B_delta @@ -220,8 +225,8 @@ def forward(self, item_his_emb, seq_len): W = paddle.unsqueeze(W, axis=2) interest_capsule = paddle.matmul(W, low_capsule_new_tile) interest_capsule = self.squash(interest_capsule) - high_capsule = paddle.reshape( - interest_capsule, [-1, self.k_max, self.output_units]) + high_capsule = paddle.reshape(interest_capsule, + [-1, self.k_max, self.output_units]) high_capsule = F.relu(self.relu_layer(high_capsule)) return high_capsule, W, seq_len @@ -277,12 +282,16 @@ def __init__(self, def label_aware_attention(self, keys, query): """label_aware_attention """ - weight = paddle.matmul(keys, paddle.reshape(query, [-1, paddle.shape(query)[-1], 1])) #[B, K, dim] * [B, dim, 1] == [B, k, 1] + weight = paddle.matmul(keys, + paddle.reshape(query, [ + -1, paddle.shape(query)[-1], 1 + ])) #[B, K, dim] * [B, dim, 1] == [B, k, 1] weight = paddle.squeeze(weight, axis=-1) weight = paddle.pow(weight, self.pow_p) # [x,k_max] - weight = F.softmax(weight) #[x, k_max] - weight = paddle.unsqueeze(weight, 1) #[B, 1, k_max] - output = paddle.matmul(weight, keys) #[B, 1, k_max] * [B, k_max, dim] => [B, 1, dim] + weight = F.softmax(weight) #[x, k_max] + weight = paddle.unsqueeze(weight, 1) #[B, 1, k_max] + output = paddle.matmul( + weight, keys) #[B, 1, k_max] * [B, k_max, dim] => [B, 1, dim] return output.squeeze(1), weight def forward(self, hist_item, seqlen, labels=None): From 24a9d18eb05b53486acdeccfd64e9c2bf18cc130 Mon Sep 17 00:00:00 2001 From: duyiqi Date: Tue, 24 May 2022 16:09:58 +0800 Subject: [PATCH 3/4] fix codestyle --- models/recall/mind/mind_reader.py | 2 ++ models/recall/mind/net.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/models/recall/mind/mind_reader.py b/models/recall/mind/mind_reader.py index ca0dc363c..18be7d2b6 100644 --- a/models/recall/mind/mind_reader.py +++ b/models/recall/mind/mind_reader.py @@ -16,8 +16,10 @@ import numpy as np from paddle.io import IterableDataset import random + random.seed(12345) + class RecDataset(IterableDataset): def __init__(self, file_list, config): super(RecDataset, self).__init__() diff --git a/models/recall/mind/net.py b/models/recall/mind/net.py index 1c1c0a57f..56faa15e0 100644 --- a/models/recall/mind/net.py +++ b/models/recall/mind/net.py @@ -167,8 +167,10 @@ def sequence_mask(self, lengths, maxlen=None, dtype="bool"): batch_size = paddle.shape(lengths)[0] if maxlen is None: maxlen = lengths.max() - row_vector = paddle.arange(0, maxlen, 1).unsqueeze(0).expand( - shape=(batch_size, maxlen)).reshape((batch_size, -1, maxlen)) + row_vector = paddle.arange( + 0, maxlen, + 1).unsqueeze(0).expand(shape=(batch_size, maxlen)).reshape( + (batch_size, -1, maxlen)) lengths = lengths.unsqueeze(-1) mask = row_vector < lengths return mask.astype(dtype) From 89570ad9b6d0aaeb7d0a8c8d888dfea0ff6c7946 Mon Sep 17 00:00:00 2001 From: duyiqi Date: Sat, 28 May 2022 12:01:36 +0800 Subject: [PATCH 4/4] change lr --- models/recall/mind/config_bigdata.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/recall/mind/config_bigdata.yaml b/models/recall/mind/config_bigdata.yaml index 1d16bbcda..0ce329ebf 100644 --- a/models/recall/mind/config_bigdata.yaml +++ b/models/recall/mind/config_bigdata.yaml @@ -39,7 +39,7 @@ hyper_parameters: # optimizer config optimizer: class: Adam - learning_rate: 0.005 + learning_rate: 0.001 # strategy: async # user-defined pairs item_count: 367983