Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ADMMSLIM #662

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix item centering. Add item-specific L2 regularization.
  • Loading branch information
deklanw committed Jan 8, 2021
commit 3013021ac8b00d817126e010d02dcdcb2bcedb8a
29 changes: 17 additions & 12 deletions recbole/model/general_recommender/admmslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,28 @@ def __init__(self, config, dataset):

lambda1 = config['lambda1']
lambda2 = config['lambda2']
alpha = config['alpha']
rho = config['rho']
k = config['k']
positive_only = config['positive_only']
self.center_columns = config['center_columns']
self.item_means = X.mean(axis=0).getA1()

if self.center_columns:
self.item_means = X.mean(axis=0).getA1()
zero_mean_X = X.toarray() - self.item_means
G = (zero_mean_X.T @ zero_mean_X)
# large memory cost because we need to make X dense to subtract mean, delete asap
del zero_mean_X
else:
G = (X.T @ X).toarray()

diag = (lambda2 + rho) * np.identity(num_items)

# alpha = 0 corresponds to this case
# diag = (lambda2 + rho) * np.identity(num_items)

# add 1 to means to prevent singular matrix in what follows
diag = (lambda2 + rho) * np.diag(np.power(self.item_means + 1, alpha / 2))

P = np.linalg.inv(G + diag)
B_aux = P @ G

Expand Down Expand Up @@ -93,22 +100,20 @@ def predict(self, interaction):
user_interactions = self.interaction_matrix[user, :].toarray()

if self.center_columns:
user_interactions -= self.item_means

r = torch.from_numpy(
(user_interactions * self.item_similarity[:, item].T).sum(axis=1).flatten())
r = (((user_interactions - self.item_means) * self.item_similarity[:, item].T).sum(axis=1)).flatten() + self.item_means[item]
else:
r = (user_interactions * self.item_similarity[:, item].T).sum(axis=1).flatten()

return add_noise(r)
return add_noise(torch.from_numpy(r))

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID].cpu().numpy()

user_interactions = self.interaction_matrix[user, :].toarray()

if self.center_columns:
user_interactions -= self.item_means

r = user_interactions @ self.item_similarity
r = torch.from_numpy(r.flatten())
r = ((user_interactions - self.item_means) @ self.item_similarity + self.item_means).flatten()
else:
r = (user_interactions @ self.item_similarity).flatten()

return add_noise(r)
return add_noise(torch.from_numpy(r))
7 changes: 4 additions & 3 deletions recbole/properties/model/ADMMSLIM.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
lambda1: 3
lambda2: 150
rho: 300
lambda2: 200
alpha: 0.25
rho: 4000
k: 100
positive_only: False
positive_only: True
center_columns: False