-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit dcbde52
Showing
14 changed files
with
411,346 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import torch.optim as optim | ||
import random | ||
import logging | ||
import datetime | ||
import os | ||
from utility.parser import parse_args | ||
from utility.batch_test import * | ||
from utility.load_data import * | ||
from model import * | ||
from tqdm import tqdm | ||
from time import time | ||
from copy import deepcopy | ||
|
||
args = parse_args() | ||
seed = args.seed | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
torch.backends.cudnn.benchmark = False | ||
torch.backends.cudnn.deterministic = True | ||
|
||
def load_adjacency_list_data(adj_mat): | ||
tmp = adj_mat.tocoo() | ||
all_h_list = list(tmp.row) | ||
all_t_list = list(tmp.col) | ||
all_v_list = list(tmp.data) | ||
|
||
return all_h_list, all_t_list, all_v_list | ||
|
||
if __name__ == '__main__': | ||
|
||
""" | ||
********************************************************* | ||
Prepare the log file | ||
""" | ||
curr_time = datetime.datetime.now() | ||
if not os.path.exists('log'): | ||
os.mkdir('log') | ||
logger = logging.getLogger('train_logger') | ||
logger.setLevel(logging.INFO) | ||
logfile = logging.FileHandler('log/{}.log'.format(args.dataset), 'a', encoding='utf-8') | ||
logfile.setLevel(logging.INFO) | ||
formatter = logging.Formatter('%(asctime)s - %(message)s') | ||
logfile.setFormatter(formatter) | ||
logger.addHandler(logfile) | ||
|
||
""" | ||
********************************************************* | ||
Prepare the dataset | ||
""" | ||
data_generator = Data(args) | ||
logger.info(data_generator.get_statistics()) | ||
|
||
print("************************* Run with following settings 🏃 ***************************") | ||
print(args) | ||
logger.info(args) | ||
print("************************************************************************************") | ||
|
||
config = dict() | ||
config['n_users'] = data_generator.n_users | ||
config['n_items'] = data_generator.n_items | ||
|
||
""" | ||
********************************************************* | ||
Generate the adj matrix | ||
""" | ||
plain_adj = data_generator.get_adj_mat() | ||
all_h_list, all_t_list, all_v_list = load_adjacency_list_data(plain_adj) | ||
config['plain_adj'] = plain_adj | ||
config['all_h_list'] = all_h_list | ||
config['all_t_list'] = all_t_list | ||
|
||
""" | ||
********************************************************* | ||
Prepare the model and start training | ||
""" | ||
_model = DCCF(config, args).cuda() | ||
optimizer = optim.Adam(_model.parameters(), lr=args.lr) | ||
|
||
print("Start Training") | ||
stopping_step = 0 | ||
last_state_dict = None | ||
for epoch in range(args.epoch): | ||
|
||
## train | ||
t1 = time() | ||
|
||
n_samples = data_generator.uniform_sample() | ||
n_batch = int(np.ceil(n_samples / args.batch_size)) | ||
|
||
_model.train() | ||
loss, mf_loss, emb_loss, cen_loss, cl_loss = 0., 0., 0., 0., 0. | ||
for idx in tqdm(range(n_batch)): | ||
|
||
optimizer.zero_grad() | ||
|
||
users, pos_items, neg_items = data_generator.mini_batch(idx) | ||
batch_mf_loss, batch_emb_loss, batch_cen_loss, batch_cl_loss = _model(users, pos_items, neg_items) | ||
batch_loss = batch_mf_loss + batch_emb_loss + batch_cen_loss + batch_cl_loss | ||
|
||
loss += float(batch_loss) / n_batch | ||
mf_loss += float(batch_mf_loss) / n_batch | ||
emb_loss += float(batch_emb_loss) / n_batch | ||
cen_loss += float(batch_cen_loss) / n_batch | ||
cl_loss += float(batch_cl_loss) / n_batch | ||
|
||
batch_loss.backward() | ||
optimizer.step() | ||
|
||
## update the saved model parameters after each epoch | ||
last_state_dict = deepcopy(_model.state_dict()) | ||
torch.cuda.empty_cache() | ||
|
||
if epoch % args.show_step != 0 and epoch != args.epoch - 1: | ||
perf_str = 'Epoch %2d [%.1fs]: train==[%.5f=%.5f + %.5f + %.5f + %.5f]' % (epoch, time() - t1, loss, mf_loss, emb_loss, cen_loss, cl_loss) | ||
print(perf_str) | ||
logger.info(perf_str) | ||
continue | ||
|
||
t2 = time() | ||
|
||
## test the model on test set for observation | ||
with torch.no_grad(): | ||
_model.eval() | ||
_model.inference() | ||
test_ret = eval_PyTorch(_model, data_generator, eval(args.Ks)) | ||
torch.cuda.empty_cache() | ||
|
||
t3 = time() | ||
|
||
perf_str = 'Epoch %2d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f + %.5f + %.5f], test-recall=[%.4f, %.4f], test-ndcg=[%.4f, %.4f]' % \ | ||
(epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, cen_loss, cl_loss, test_ret['recall'][0], test_ret['recall'][1], test_ret['ndcg'][0], test_ret['ndcg'][1]) | ||
print(perf_str) | ||
|
||
logger.info(perf_str) | ||
|
||
## final test and report it in the paper | ||
if not os.path.exists('saved'): | ||
os.mkdir('saved') | ||
if args.save_model: | ||
torch.save(last_state_dict, 'saved/{}.pth'.format(args.dataset)) | ||
_model.load_state_dict(last_state_dict) | ||
with torch.no_grad(): | ||
_model.eval() | ||
_model.inference() | ||
final_test_ret = eval_PyTorch(_model, data_generator, eval(args.Ks)) | ||
|
||
pref_str = 'Final Test Set Result: test-recall=[%.4f, %.4f], test-ndcg=[%.4f, %.4f]' % (final_test_ret['recall'][0], final_test_ret['recall'][1], final_test_ret['ndcg'][0], final_test_ret['ndcg'][1]) | ||
print(pref_str) | ||
logger.info(pref_str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Disentangled Contrastive Collaborative Filtering | ||
|
||
This is the PyTorch implementation by <a href='https://github.com/Re-bin'>@Re-bin</a> for DCCF model proposed in this paper: | ||
|
||
>**Disentangled Contrastive Collaborative Filtering** | ||
> Xubin Ren, Chao Huang, Lianghao Xia, Jiashu Zhao and Dawei Yin\ | ||
>*SIGIR 2023* | ||
<p align="center"> | ||
<img src="DCCF.png" alt="DCCF" /> | ||
</p> | ||
|
||
In this paper, we propose a disentangled contrastive learning method for recommendation, which explores latent factors underlying implicit intents for interactions. In particular, a graph structure learning layer is devised to enable the adaptive interaction augmentation, based on the learned disentangle user (item) intent-aware dependencies. Along the augmented intent-aware graph structures, we propose a intent-aware contrastive learning scheme to bring the benefits of disentangled self-supervision signals. | ||
|
||
## Environment | ||
|
||
The codes are written in Python 3.8.13 with the following dependencies. | ||
|
||
- numpy == 1.22.3 | ||
- pytorch == 1.11.0 (GPU version) | ||
- torch-scatter == 2.0.9 | ||
- torch-sparse == 0.6.14 | ||
- scipy == 1.7.3 | ||
|
||
## Dataset | ||
|
||
We utilized three public datasets to evaluate DCCF: *Gowalla, Amazon-book,* and *Tmall*. | ||
|
||
Note that the validation set is only used for tuning hyperparameters, and for *Gowalla* / *Tmall*, the validation set is merged into the training set for training. | ||
|
||
## Examples to run the codes | ||
|
||
The command to train DCCF on the Gowalla / Amazon-book / Tmall dataset is as follows. | ||
|
||
We train DCCF with a fixed number of epochs and save the parameters obtained after the final epoch for testing. | ||
|
||
- Gowalla | ||
|
||
```python DCCF_PyTorch.py --dataset gowalla --epoch 150``` | ||
|
||
- Amazon-book: | ||
|
||
```python DCCF_PyTorch.py --dataset amazon --epoch 100``` | ||
|
||
- Tmall: | ||
|
||
```python DCCF_PyTorch.py --dataset tmall --epoch 100``` | ||
|
||
**For advanced usage of arguments, run the code with --help argument.** | ||
|
||
**Thanks for your interest in our work** |
Oops, something went wrong.