Skip to content

Commit

Permalink
upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Re-bin committed Apr 26, 2023
0 parents commit dcbde52
Show file tree
Hide file tree
Showing 14 changed files with 411,346 additions and 0 deletions.
Binary file added DCCF.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
152 changes: 152 additions & 0 deletions DCCF_PyTorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import torch.optim as optim
import random
import logging
import datetime
import os
from utility.parser import parse_args
from utility.batch_test import *
from utility.load_data import *
from model import *
from tqdm import tqdm
from time import time
from copy import deepcopy

args = parse_args()
seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def load_adjacency_list_data(adj_mat):
tmp = adj_mat.tocoo()
all_h_list = list(tmp.row)
all_t_list = list(tmp.col)
all_v_list = list(tmp.data)

return all_h_list, all_t_list, all_v_list

if __name__ == '__main__':

"""
*********************************************************
Prepare the log file
"""
curr_time = datetime.datetime.now()
if not os.path.exists('log'):
os.mkdir('log')
logger = logging.getLogger('train_logger')
logger.setLevel(logging.INFO)
logfile = logging.FileHandler('log/{}.log'.format(args.dataset), 'a', encoding='utf-8')
logfile.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(message)s')
logfile.setFormatter(formatter)
logger.addHandler(logfile)

"""
*********************************************************
Prepare the dataset
"""
data_generator = Data(args)
logger.info(data_generator.get_statistics())

print("************************* Run with following settings 🏃 ***************************")
print(args)
logger.info(args)
print("************************************************************************************")

config = dict()
config['n_users'] = data_generator.n_users
config['n_items'] = data_generator.n_items

"""
*********************************************************
Generate the adj matrix
"""
plain_adj = data_generator.get_adj_mat()
all_h_list, all_t_list, all_v_list = load_adjacency_list_data(plain_adj)
config['plain_adj'] = plain_adj
config['all_h_list'] = all_h_list
config['all_t_list'] = all_t_list

"""
*********************************************************
Prepare the model and start training
"""
_model = DCCF(config, args).cuda()
optimizer = optim.Adam(_model.parameters(), lr=args.lr)

print("Start Training")
stopping_step = 0
last_state_dict = None
for epoch in range(args.epoch):

## train
t1 = time()

n_samples = data_generator.uniform_sample()
n_batch = int(np.ceil(n_samples / args.batch_size))

_model.train()
loss, mf_loss, emb_loss, cen_loss, cl_loss = 0., 0., 0., 0., 0.
for idx in tqdm(range(n_batch)):

optimizer.zero_grad()

users, pos_items, neg_items = data_generator.mini_batch(idx)
batch_mf_loss, batch_emb_loss, batch_cen_loss, batch_cl_loss = _model(users, pos_items, neg_items)
batch_loss = batch_mf_loss + batch_emb_loss + batch_cen_loss + batch_cl_loss

loss += float(batch_loss) / n_batch
mf_loss += float(batch_mf_loss) / n_batch
emb_loss += float(batch_emb_loss) / n_batch
cen_loss += float(batch_cen_loss) / n_batch
cl_loss += float(batch_cl_loss) / n_batch

batch_loss.backward()
optimizer.step()

## update the saved model parameters after each epoch
last_state_dict = deepcopy(_model.state_dict())
torch.cuda.empty_cache()

if epoch % args.show_step != 0 and epoch != args.epoch - 1:
perf_str = 'Epoch %2d [%.1fs]: train==[%.5f=%.5f + %.5f + %.5f + %.5f]' % (epoch, time() - t1, loss, mf_loss, emb_loss, cen_loss, cl_loss)
print(perf_str)
logger.info(perf_str)
continue

t2 = time()

## test the model on test set for observation
with torch.no_grad():
_model.eval()
_model.inference()
test_ret = eval_PyTorch(_model, data_generator, eval(args.Ks))
torch.cuda.empty_cache()

t3 = time()

perf_str = 'Epoch %2d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f + %.5f + %.5f], test-recall=[%.4f, %.4f], test-ndcg=[%.4f, %.4f]' % \
(epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, cen_loss, cl_loss, test_ret['recall'][0], test_ret['recall'][1], test_ret['ndcg'][0], test_ret['ndcg'][1])
print(perf_str)

logger.info(perf_str)

## final test and report it in the paper
if not os.path.exists('saved'):
os.mkdir('saved')
if args.save_model:
torch.save(last_state_dict, 'saved/{}.pth'.format(args.dataset))
_model.load_state_dict(last_state_dict)
with torch.no_grad():
_model.eval()
_model.inference()
final_test_ret = eval_PyTorch(_model, data_generator, eval(args.Ks))

pref_str = 'Final Test Set Result: test-recall=[%.4f, %.4f], test-ndcg=[%.4f, %.4f]' % (final_test_ret['recall'][0], final_test_ret['recall'][1], final_test_ret['ndcg'][0], final_test_ret['ndcg'][1])
print(pref_str)
logger.info(pref_str)
51 changes: 51 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Disentangled Contrastive Collaborative Filtering

This is the PyTorch implementation by <a href='https://github.com/Re-bin'>@Re-bin</a> for DCCF model proposed in this paper:

>**Disentangled Contrastive Collaborative Filtering**
> Xubin Ren, Chao Huang, Lianghao Xia, Jiashu Zhao and Dawei Yin\
>*SIGIR 2023*
<p align="center">
<img src="DCCF.png" alt="DCCF" />
</p>

In this paper, we propose a disentangled contrastive learning method for recommendation, which explores latent factors underlying implicit intents for interactions. In particular, a graph structure learning layer is devised to enable the adaptive interaction augmentation, based on the learned disentangle user (item) intent-aware dependencies. Along the augmented intent-aware graph structures, we propose a intent-aware contrastive learning scheme to bring the benefits of disentangled self-supervision signals.

## Environment

The codes are written in Python 3.8.13 with the following dependencies.

- numpy == 1.22.3
- pytorch == 1.11.0 (GPU version)
- torch-scatter == 2.0.9
- torch-sparse == 0.6.14
- scipy == 1.7.3

## Dataset

We utilized three public datasets to evaluate DCCF: *Gowalla, Amazon-book,* and *Tmall*.

Note that the validation set is only used for tuning hyperparameters, and for *Gowalla* / *Tmall*, the validation set is merged into the training set for training.

## Examples to run the codes

The command to train DCCF on the Gowalla / Amazon-book / Tmall dataset is as follows.

We train DCCF with a fixed number of epochs and save the parameters obtained after the final epoch for testing.

- Gowalla

```python DCCF_PyTorch.py --dataset gowalla --epoch 150```

- Amazon-book:

```python DCCF_PyTorch.py --dataset amazon --epoch 100```

- Tmall:

```python DCCF_PyTorch.py --dataset tmall --epoch 100```

**For advanced usage of arguments, run the code with --help argument.**

**Thanks for your interest in our work**
Loading

0 comments on commit dcbde52

Please sign in to comment.