-
Notifications
You must be signed in to change notification settings - Fork 627
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1148 from chenyuwuxin/1.0.x
FEA: Add ADMMSLIM in General models
- Loading branch information
Showing
6 changed files
with
235 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
ADMMSLIM | ||
============ | ||
|
||
Introduction | ||
------------------ | ||
|
||
`[paper] <https://doi.org/10.1145/3336191.3371774>`_ | ||
|
||
**Title:** ADMM SLIM: Sparse Recommendations for Many Users | ||
|
||
**Authors:** Harald Steck,Maria Dimakopoulou,Nickolai Riabov,Tony Jebara | ||
|
||
|
||
**Abstract:** The Sparse Linear Method (Slim) is a well-established approach | ||
for top-N recommendations. This article proposes several improvements | ||
that are enabled by the Alternating Directions Method of | ||
Multipliers (ADMM), a well-known optimization method | ||
with many application areas. First, we show that optimizing the | ||
original Slim-objective by ADMM results in an approach where the | ||
training time is independent of the number of users in the training | ||
data, and hence trivially scales to large numbers of users. Second, | ||
the flexibility of ADMM allows us to switch on and off the various | ||
constraints and regularization terms in the original Slim-objective, | ||
in order to empirically assess their contributions to ranking accuracy | ||
on given data. Third, we also propose two extensions to the | ||
original Slim training-objective in order to improve recommendation | ||
accuracy further without increasing the computational cost. In | ||
our experiments on three well-known data-sets, we first compare | ||
to the original Slim-implementation and find that not only ADMM | ||
reduces training time considerably, but also achieves an improvement | ||
in recommendation accuracy due to better optimization. We | ||
then compare to various state-of-the-art approaches and observe | ||
up to 25% improvement in recommendation accuracy in our experiments. | ||
Finally, we evaluate the importance of sparsity and the | ||
non-negativity constraint in the original Slim-objective with subsampling | ||
experiments that simulate scenarios of cold-starting and | ||
large catalog sizes compared to relatively small user base, which | ||
often occur in practice. | ||
|
||
Running with RecBole | ||
------------------------- | ||
|
||
**Model Hyper-Parameters:** | ||
|
||
- ``lambda1 (float)`` : L1-norm regularization parameter. Defaults to ``3``. | ||
|
||
- ``lambda2 (float)`` : L2-norm regularization parameter. Defaults to ``200``. | ||
|
||
- ``alpha (float)`` : The exponents to control the power-law in the regularization terms. Defaults to ``0.5``. | ||
|
||
- ``rho (float)`` : The penalty parameter that applies to the squared difference between primal variables. Defaults to ``4000``. | ||
|
||
- ``k (int)`` : The number of running iterations. Defaults to ``100``. | ||
|
||
- ``positive_only (bool)`` : Whether only preserves all positive values. Defaults to ``True``. | ||
|
||
- ``center_columns (bool)`` : Whether to use additional item-bias terms.. Defaults to ``False``. | ||
|
||
|
||
**A Running Example:** | ||
|
||
Write the following code to a python file, such as `run.py` | ||
|
||
.. code:: python | ||
from recbole.quick_start import run_recbole | ||
run_recbole(model='ADMMSLIM', dataset='ml-100k') | ||
And then: | ||
|
||
.. code:: bash | ||
python run.py | ||
Tuning Hyper Parameters | ||
------------------------- | ||
|
||
If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``. | ||
|
||
.. code:: bash | ||
lambda1 choice [0.1 , 0.5 , 5 , 10] | ||
lambda2 choice [5 , 50 , 1000 , 5000] | ||
alpha choice [0.25 , 0.5 , 0.75 , 1] | ||
Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model. | ||
|
||
Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning: | ||
|
||
.. code:: bash | ||
python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test | ||
For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`. | ||
|
||
If you want to change parameters, dataset or evaluation settings, take a look at | ||
|
||
- :doc:`../../../user_guide/config_settings` | ||
- :doc:`../../../user_guide/data_intro` | ||
- :doc:`../../../user_guide/train_eval_intro` | ||
- :doc:`../../../user_guide/usage` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# @Time : 2021/01/09 | ||
# @Author : Deklan Webster | ||
|
||
r""" | ||
ADMMSLIM | ||
################################################ | ||
Reference: | ||
Steck et al. ADMM SLIM: Sparse Recommendations for Many Users. https://doi.org/10.1145/3336191.3371774 | ||
""" | ||
|
||
from recbole.utils.enum_type import ModelType | ||
import numpy as np | ||
import scipy.sparse as sp | ||
import torch | ||
|
||
from recbole.utils import InputType | ||
from recbole.model.abstract_recommender import GeneralRecommender | ||
|
||
|
||
def soft_threshold(x, threshold): | ||
return (np.abs(x) > threshold) * (np.abs(x) - threshold) * np.sign(x) | ||
|
||
|
||
def zero_mean_columns(a): | ||
return a - np.mean(a, axis=0) | ||
|
||
|
||
def add_noise(t, mag=1e-5): | ||
return t + mag * torch.rand(t.shape) | ||
|
||
|
||
class ADMMSLIM(GeneralRecommender): | ||
input_type = InputType.POINTWISE | ||
type = ModelType.TRADITIONAL | ||
|
||
def __init__(self, config, dataset): | ||
super().__init__(config, dataset) | ||
|
||
# need at least one param | ||
self.dummy_param = torch.nn.Parameter(torch.zeros(1)) | ||
|
||
X = dataset.inter_matrix(form='csr').astype(np.float32) | ||
|
||
num_users, num_items = X.shape | ||
|
||
lambda1 = config['lambda1'] | ||
lambda2 = config['lambda2'] | ||
alpha = config['alpha'] | ||
rho = config['rho'] | ||
k = config['k'] | ||
positive_only = config['positive_only'] | ||
self.center_columns = config['center_columns'] | ||
self.item_means = X.mean(axis=0).getA1() | ||
|
||
if self.center_columns: | ||
zero_mean_X = X.toarray() - self.item_means | ||
G = (zero_mean_X.T @ zero_mean_X) | ||
# large memory cost because we need to make X dense to subtract mean, delete asap | ||
del zero_mean_X | ||
else: | ||
G = (X.T @ X).toarray() | ||
|
||
diag = lambda2 * np.diag(np.power(self.item_means, alpha)) + \ | ||
rho * np.identity(num_items) | ||
|
||
P = np.linalg.inv(G + diag).astype(np.float32) | ||
B_aux = (P @ G).astype(np.float32) | ||
# initialize | ||
Gamma = np.zeros_like(G, dtype=np.float32) | ||
C = np.zeros_like(G, dtype=np.float32) | ||
|
||
del diag, G | ||
# fixed number of iterations | ||
for _ in range(k): | ||
B_tilde = B_aux + P @ (rho * C - Gamma) | ||
gamma = np.diag(B_tilde) / (np.diag(P) + 1e-7) | ||
B = B_tilde - P * gamma | ||
C = soft_threshold(B + Gamma / rho, lambda1 / rho) | ||
if positive_only: | ||
C = (C > 0) * C | ||
Gamma += rho * (B - C) | ||
# torch doesn't support sparse tensor slicing, so will do everything with np/scipy | ||
self.item_similarity = C | ||
self.interaction_matrix = X | ||
|
||
def forward(self): | ||
pass | ||
|
||
def calculate_loss(self, interaction): | ||
return torch.nn.Parameter(torch.zeros(1)) | ||
|
||
def predict(self, interaction): | ||
user = interaction[self.USER_ID].cpu().numpy() | ||
item = interaction[self.ITEM_ID].cpu().numpy() | ||
|
||
user_interactions = self.interaction_matrix[user, :].toarray() | ||
|
||
if self.center_columns: | ||
r = (((user_interactions - self.item_means) * | ||
self.item_similarity[:, item].T).sum(axis=1)).flatten() + self.item_means[item] | ||
else: | ||
r = (user_interactions * self.item_similarity[:, item].T).sum(axis=1).flatten() | ||
|
||
return add_noise(torch.from_numpy(r)) | ||
|
||
def full_sort_predict(self, interaction): | ||
user = interaction[self.USER_ID].cpu().numpy() | ||
|
||
user_interactions = self.interaction_matrix[user, :].toarray() | ||
|
||
if self.center_columns: | ||
r = ((user_interactions - self.item_means) @ self.item_similarity + self.item_means).flatten() | ||
else: | ||
r = (user_interactions @ self.item_similarity).flatten() | ||
|
||
return add_noise(torch.from_numpy(r)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
lambda1: 3 | ||
lambda2: 200 | ||
alpha: 0.50 | ||
rho: 4000 | ||
k: 100 | ||
positive_only: True | ||
center_columns: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters