Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#779 from duyiqi17/mind_fix
Browse files Browse the repository at this point in the history
fix random.seed problem && make capsual_layer's implementation same t…
  • Loading branch information
frankwhzhang authored Jun 7, 2022
2 parents c3f4873 + f1d786f commit 43314c8
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 45 deletions.
4 changes: 2 additions & 2 deletions models/recall/mind/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ python -u static_infer.py -m config.yaml -top_n 50 #对测试数据进行预测
在全量数据下模型的指标如下:
| 模型 | batch_size | epoch_num| Recall@50 | NDCG@50 | HitRate@50 |Time of each epoch |
| :------| :------ | :------ | :------| :------ | :------| :------ |
| mind | 128 | 20 | 8.43% | 13.28% | 17.22% | 398.64s(CPU) |

| mind(paddle实现) | 128 | 50 | 5.52% | 4.31% | 11.49% | 356.43s(CPU) |

1. 确认您当前所在目录为PaddleRec/models/recall/mind
2. 进入paddlerec/datasets/AmazonBook目录下执行run.sh脚本,会下载处理完成的AmazonBook数据集,并解压到指定目录
```bash
cd ../../../datasets/AmazonBook
sh run.sh
```

3. 安装依赖,我们使用[faiss](https://github.com/facebookresearch/faiss)来进行向量召回
```bash
# CPU-only version(pip)
Expand Down
8 changes: 4 additions & 4 deletions models/recall/mind/config_bigdata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ runner:
use_gpu: False
use_auc: False
train_batch_size: 128
epochs: 20
epochs: 50
print_interval: 500
model_save_path: "output_model_mind_all"
infer_batch_size: 128
infer_reader_path: "mind_infer_reader" # importlib format
test_data_dir: "../../../datasets/AmazonBook/valid"
infer_load_path: "output_model_mind_all"
infer_start_epoch: 19
infer_end_epoch: 20
infer_start_epoch: 49
infer_end_epoch: 50

# distribute_config
# sync_mode: "async"
Expand All @@ -39,7 +39,7 @@ hyper_parameters:
# optimizer config
optimizer:
class: Adam
learning_rate: 0.005
learning_rate: 0.001
# strategy: async
# user-defined <key, value> pairs
item_count: 367983
Expand Down
13 changes: 9 additions & 4 deletions models/recall/mind/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def main(args):
batch_data, config)

user_embs = user_embs.numpy()
# print(user_embs)
target_items = np.squeeze(batch_data[-1].numpy(), axis=1)

if len(user_embs.shape) == 2:
Expand All @@ -119,8 +120,9 @@ def main(args):
dcg = 0.0
item_list = set(I[i])
iid_list = list(filter(lambda x: x != 0, list(iid_list)))
for no, iid in enumerate(iid_list):
if iid in item_list:
true_item_set = set(iid_list)
for no, iid in enumerate(I[i]):
if iid in true_item_set:
recall += 1
dcg += 1.0 / math.log(no + 2, 2)
idcg = 0.0
Expand All @@ -138,6 +140,7 @@ def main(args):
recall = 0
dcg = 0.0
item_list_set = set()
item_cor_list = []
item_list = list(
zip(
np.reshape(I[i * ni:(i + 1) * ni], -1),
Expand All @@ -147,13 +150,15 @@ def main(args):
if item_list[j][0] not in item_list_set and item_list[
j][0] != 0:
item_list_set.add(item_list[j][0])
item_cor_list.append(item_list[j][0])
if len(item_list_set) >= args.top_n:
break
iid_list = list(filter(lambda x: x != 0, list(iid_list)))
for no, iid in enumerate(iid_list):
true_item_set = set(iid_list)
for no, iid in enumerate(item_cor_list):
if iid == 0:
break
if iid in item_list_set:
if iid in true_item_set:
recall += 1
dcg += 1.0 / math.log(no + 2, 2)
idcg = 0.0
Expand Down
6 changes: 4 additions & 2 deletions models/recall/mind/mind_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from paddle.io import IterableDataset
import random

random.seed(12345)


class RecDataset(IterableDataset):
def __init__(self, file_list, config):
Expand Down Expand Up @@ -51,7 +53,7 @@ def init(self):
self.items = list(self.items)

def __iter__(self):
random.seed(12345)
# random.seed(12345)
while True:
user_id_list = random.sample(self.users, self.batch_size)
if self.count >= self.batches_per_epoch * self.batch_size:
Expand All @@ -61,7 +63,7 @@ def __iter__(self):
item_list = self.graph[user_id]
if len(item_list) <= 4:
continue
random.seed(12345)
# random.seed(12345)
k = random.choice(range(4, len(item_list)))
item_id = item_list[k]

Expand Down
89 changes: 56 additions & 33 deletions models/recall/mind/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self,
self.new_prob = paddle.assign(self.prob.astype("float32"))
self.log_q = paddle.log(-(paddle.exp((-paddle.log1p(self.new_prob) * 2
* n_sample)) - 1.0))
self.loss = nn.CrossEntropyLoss(soft_label=True)

def sample(self, labels):
"""Random sample neg_samples
Expand All @@ -65,6 +66,7 @@ def forward(self, inputs, labels, weights, bias):
# weights.stop_gradient = False
embedding_dim = paddle.shape(weights)[-1]
true_log_probs, samp_log_probs, neg_samples = self.sample(labels)
# print(neg_samples)
n_sample = neg_samples.shape[0]

b1 = paddle.shape(labels)[0]
Expand All @@ -82,33 +84,33 @@ def forward(self, inputs, labels, weights, bias):
sample_b = all_b[-n_sample:]

# [B, D] * [B, 1,D]
true_logist = paddle.matmul(
true_w, inputs.unsqueeze(1), transpose_y=True).squeeze(1) + true_b
true_logist = paddle.sum(paddle.multiply(true_w, inputs.unsqueeze(1)),
axis=-1) + true_b
# print(true_logist)

sample_logist = paddle.matmul(
inputs.unsqueeze(1), sample_w, transpose_y=True) + sample_b

if self.subtract_log_q:
true_logist = true_logist - true_log_probs.unsqueeze(1)
sample_logist = sample_logist - samp_log_probs
inputs, sample_w, transpose_y=True) + sample_b

if self.remove_accidental_hits:
hit = (paddle.equal(labels[:, :], neg_samples)).unsqueeze(1)
hit = (paddle.equal(labels[:, :], neg_samples))
padding = paddle.ones_like(sample_logist) * -1e30
sample_logist = paddle.where(hit, padding, sample_logist)

sample_logist = sample_logist.squeeze(1)
if self.subtract_log_q:
true_logist = true_logist - true_log_probs.unsqueeze(1)
sample_logist = sample_logist - samp_log_probs

out_logist = paddle.concat([true_logist, sample_logist], axis=1)
out_label = paddle.concat(
[
paddle.ones_like(true_logist) / self.num_true,
paddle.zeros_like(sample_logist)
],
axis=1)
out_label.stop_gradient = True

sampled_loss = F.softmax_with_cross_entropy(
logits=out_logist, label=out_label, soft_label=True)
return sampled_loss, out_logist, out_label
loss = self.loss(out_logist, out_label)
return loss, out_logist, out_label


class Mind_Capsual_Layer(nn.Layer):
Expand Down Expand Up @@ -148,6 +150,7 @@ def __init__(self,
name="bilinear_mapping_matrix", trainable=True),
default_initializer=nn.initializer.Normal(
mean=0.0, std=self.init_std))
self.relu_layer = nn.Linear(self.output_units, self.output_units)

def squash(self, Z):
"""squash
Expand All @@ -164,8 +167,10 @@ def sequence_mask(self, lengths, maxlen=None, dtype="bool"):
batch_size = paddle.shape(lengths)[0]
if maxlen is None:
maxlen = lengths.max()
row_vector = paddle.arange(0, maxlen, 1).unsqueeze(0).expand(
shape=(batch_size, maxlen)).reshape((batch_size, -1, maxlen))
row_vector = paddle.arange(
0, maxlen,
1).unsqueeze(0).expand(shape=(batch_size, maxlen)).reshape(
(batch_size, -1, maxlen))
lengths = lengths.unsqueeze(-1)
mask = row_vector < lengths
return mask.astype(dtype)
Expand All @@ -182,39 +187,50 @@ def forward(self, item_his_emb, seq_len):

mask = self.sequence_mask(seq_len_tile, self.maxlen)
pad = paddle.ones_like(mask, dtype="float32") * (-2**32 + 1)

# S*e
low_capsule_new = paddle.matmul(item_his_emb,
self.bilinear_mapping_matrix)

low_capsule_new_nograd = paddle.assign(low_capsule_new)
low_capsule_new_tile = paddle.tile(low_capsule_new, [1, 1, self.k_max])
low_capsule_new_tile = paddle.reshape(
low_capsule_new_tile,
[-1, self.maxlen, self.k_max, self.output_units])
low_capsule_new_tile = paddle.transpose(low_capsule_new_tile,
[0, 2, 1, 3])
low_capsule_new_tile = paddle.reshape(
low_capsule_new_tile,
[-1, self.k_max, self.maxlen, self.output_units])
low_capsule_new_nograd = paddle.assign(low_capsule_new_tile)
low_capsule_new_nograd.stop_gradient = True

B = paddle.tile(self.routing_logits,
[paddle.shape(item_his_emb)[0], 1, 1])
B.stop_gradient = True

for i in range(self.iters - 1):
B_mask = paddle.where(mask, B, pad)
# print(B_mask)
W = F.softmax(B_mask, axis=1)
W = paddle.unsqueeze(W, axis=2)
high_capsule_tmp = paddle.matmul(W, low_capsule_new_nograd)
# print(low_capsule_new_nograd.shape)
high_capsule = self.squash(high_capsule_tmp)
B_delta = paddle.matmul(
high_capsule, low_capsule_new_nograd, transpose_y=True)
B += B_delta / paddle.maximum(
paddle.norm(
B_delta, p=2, axis=-1, keepdim=True),
paddle.ones_like(B_delta))
low_capsule_new_nograd,
paddle.transpose(high_capsule, [0, 1, 3, 2]))
B_delta = paddle.reshape(
B_delta, shape=[-1, self.k_max, self.maxlen])
B += B_delta

B_mask = paddle.where(mask, B, pad)
W = F.softmax(B_mask, axis=1)
# paddle.static.Print(W)
high_capsule_tmp = paddle.matmul(W, low_capsule_new)
# high_capsule_tmp.stop_gradient = False

high_capsule = self.squash(high_capsule_tmp)
# high_capsule.stop_gradient = False
W = paddle.unsqueeze(W, axis=2)
interest_capsule = paddle.matmul(W, low_capsule_new_tile)
interest_capsule = self.squash(interest_capsule)
high_capsule = paddle.reshape(interest_capsule,
[-1, self.k_max, self.output_units])

high_capsule = F.relu(self.relu_layer(high_capsule))
return high_capsule, W, seq_len


Expand Down Expand Up @@ -246,6 +262,7 @@ def __init__(self,
name="item_emb",
initializer=nn.initializer.XavierUniform(
fan_in=item_count, fan_out=embedding_dim)))
# print(self.item_emb.weight)
self.embedding_bias = self.create_parameter(
shape=(item_count, ),
is_bias=True,
Expand All @@ -267,11 +284,17 @@ def __init__(self,
def label_aware_attention(self, keys, query):
"""label_aware_attention
"""
weight = paddle.sum(keys * query, axis=-1, keepdim=True)
weight = paddle.pow(weight, self.pow_p) # [x,k_max,1]
weight = F.softmax(weight, axis=1)
output = paddle.sum(keys * weight, axis=1)
return output, weight
weight = paddle.matmul(keys,
paddle.reshape(query, [
-1, paddle.shape(query)[-1], 1
])) #[B, K, dim] * [B, dim, 1] == [B, k, 1]
weight = paddle.squeeze(weight, axis=-1)
weight = paddle.pow(weight, self.pow_p) # [x,k_max]
weight = F.softmax(weight) #[x, k_max]
weight = paddle.unsqueeze(weight, 1) #[B, 1, k_max]
output = paddle.matmul(
weight, keys) #[B, 1, k_max] * [B, k_max, dim] => [B, 1, dim]
return output.squeeze(1), weight

def forward(self, hist_item, seqlen, labels=None):
"""forward
Expand All @@ -281,7 +304,7 @@ def forward(self, hist_item, seqlen, labels=None):
seqlen : [B, 1]
target : [B, 1]
"""

# print(hist_item)
hit_item_emb = self.item_emb(hist_item) # [B, seqlen, embed_dim]
user_cap, cap_weights, cap_mask = self.capsual_layer(hit_item_emb,
seqlen)
Expand Down

0 comments on commit 43314c8

Please sign in to comment.