Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of DCN-v2 model. #39

Merged
merged 2 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions examples/ranking/run_criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd
import torch
from torch_rechub.models.ranking import WideDeep, DeepFM, DCN, DeepFFM, FatDeepFFM
from torch_rechub.models.ranking import WideDeep, DeepFM, DCN, DCNv2, DeepFFM, FatDeepFFM
from torch_rechub.trainers import CTRTrainer
from torch_rechub.basic.features import DenseFeature, SparseFeature
from torch_rechub.utils.data import DataGenerator
Expand Down Expand Up @@ -65,6 +65,8 @@ def main(dataset_path, model_name, epoch, learning_rate, batch_size, weight_deca
model = DeepFM(deep_features=dense_feas, fm_features=sparse_feas, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
elif model_name == "dcn":
model = DCN(features=dense_feas + sparse_feas, n_cross_layers=3, mlp_params={"dims": [256, 128]})
elif model_name == "dcn_v2":
model = DCNv2(features=dense_feas + sparse_feas, n_cross_layers=3, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
elif model_name == "deepffm":
model = DeepFFM(linear_features=ffm_linear_feas, cross_features=ffm_cross_feas, embed_dim=10, mlp_params={"dims": [1600, 1600], "dropout": 0.5, "activation": "relu"})
elif model_name == "fat_deepffm":
Expand All @@ -80,7 +82,7 @@ def main(dataset_path, model_name, epoch, learning_rate, batch_size, weight_deca
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', default="./data/criteo/criteo_sample.csv")
parser.add_argument('--model_name', default='widedeep')
parser.add_argument('--model_name', default='dcn_v2')
parser.add_argument('--epoch', type=int, default=2) #100
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--batch_size', type=int, default=2048) #4096
Expand All @@ -95,6 +97,7 @@ def main(dataset_path, model_name, epoch, learning_rate, batch_size, weight_deca
python run_criteo.py --model_name widedeep
python run_criteo.py --model_name deepfm
python run_criteo.py --model_name dcn
python run_criteo.py --model_name dcn_v2
python run_criteo.py --model_name deepffm
python run_criteo.py --model_name fat_deepffm
"""
76 changes: 76 additions & 0 deletions torch_rechub/basic/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,82 @@ def forward(self, x):
x = x0 * xw + self.b[i] + x
return x

class CrossNetV2(nn.Module):
def __init__(self, input_dim, num_layers):
super().__init__()
self.num_layers = num_layers
self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])


def forward(self, x):
x0 = x
for i in range(self.num_layers):
x =x0*self.w[i](x) + self.b[i] + x
return x

class CrossNetMix(nn.Module):
""" CrossNetMix improves CrossNetwork by:
1. add MOE to learn feature interactions in different subspaces
2. add nonlinear transformations in low-dimensional space
:param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
"""

def __init__(self, input_dim, num_layers=2, low_rank=32, num_experts=4):
super(CrossNetMix, self).__init__()
self.num_layers = num_layers
self.num_experts = num_experts

# U: (input_dim, low_rank)
self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
# V: (input_dim, low_rank)
self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
# C: (low_rank, low_rank)
self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])

self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(
torch.empty(input_dim, 1))) for i in range(self.num_layers)])

def forward(self, x):
x_0 = x.unsqueeze(2) # (bs, in_features, 1)
x_l = x_0
for i in range(self.num_layers):
output_of_experts = []
gating_score_experts = []
for expert_id in range(self.num_experts):
# (1) G(x_l)
# compute the gating score by x_l
gating_score_experts.append(self.gating[expert_id](x_l.squeeze(2)))

# (2) E(x_l)
# project the input x_l to $\mathbb{R}^{r}$
v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)

# nonlinear activation in low rank space
v_x = torch.tanh(v_x)
v_x = torch.matmul(self.c_list[i][expert_id], v_x)
v_x = torch.tanh(v_x)

# project back to $\mathbb{R}^{d}$
uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)

dot_ = uv_x + self.bias[i]
dot_ = x_0 * dot_ # Hadamard-product

output_of_experts.append(dot_.squeeze(2))

# (3) mixture of low-rank experts
output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
x_l = moe_out + x_l # (bs, in_features, 1)

x_l = x_l.squeeze() # (bs, in_features)
return x_l

class MultiInterestSA(nn.Module):
"""MultiInterest Attention mentioned in the Comirec paper.
Expand Down
1 change: 1 addition & 0 deletions torch_rechub/models/ranking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .deepfm import DeepFM
from .din import DIN
from .dcn import DCN
from .dcn_v2 import DCNv2
from .deepffm import DeepFFM, FatDeepFFM
59 changes: 59 additions & 0 deletions torch_rechub/models/ranking/dcn_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Date: create on 09/01/2022
References:
paper: (WWW'21) Dcn v2: Improved deep & cross network and practical lessons for web-scale learning to rank systems
url: https://arxiv.org/abs/2008.13535
Authors: lailai, lailai_zxy@tju.edu.cn
"""
import torch
from ...basic.layers import LR, MLP,CrossNetV2, CrossNetMix, EmbeddingLayer

class DCNv2(torch.nn.Module):
def __init__(self,
features,
n_cross_layers,
mlp_params,
model_structure="parallel",
use_low_rank_mixture=True,
low_rank=32,
num_experts=4,
**kwargs):
super().__init__()
self.features = features
self.dims = sum([fea.embed_dim for fea in features])
self.embedding = EmbeddingLayer(features)
if use_low_rank_mixture:
self.crossnet = CrossNetMix(self.dims, n_cross_layers, low_rank=low_rank, num_experts=num_experts)
else:
self.crossnet = CrossNetV2(self.dims, n_cross_layers)
self.model_structure = model_structure
assert self.model_structure in ["crossnet_only", "stacked", "parallel"], \
"model_structure={} not supported!".format(self.model_structure)
if self.model_structure == "stacked":
self.stacked_dnn = MLP(self.dims,
output_layer=False,
** mlp_params)
final_dim = mlp_params["dims"][-1]
if self.model_structure == "parallel":
self.parallel_dnn = MLP(self.dims,
output_layer = False,
** mlp_params)
final_dim = mlp_params["dims"][-1] + self.dims
if self.model_structure == "crossnet_only": # only CrossNet
final_dim = self.dims
self.linear = LR(self.dims + mlp_params["dims"][-1])


def forward(self, x):
embed_x = self.embedding(x, self.features, squeeze_dim=True)
cross_out = self.crossnet(embed_x)
if self.model_structure == "crossnet_only":
final_out = cross_out
elif self.model_structure == "stacked":
final_out = self.stacked_dnn(cross_out)
elif self.model_structure == "parallel":
dnn_out = self.parallel_dnn(embed_x)
final_out = torch.cat([cross_out, dnn_out], dim=1)
y_pred = self.linear(final_out)
y_pred = torch.sigmoid(y_pred.squeeze(1))
return y_pred