Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the unsupervised bipartite GraphSAGE model on the Taobao dataset #6144

Merged
merged 37 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
45fac38
Add unsupervised bipartite graphsage & dataset taobao
HuxleyHu98 Dec 2, 2022
9e4e61c
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 2, 2022
e4baf1f
style
HuxleyHu98 Dec 2, 2022
e4df11b
minor
HuxleyHu98 Dec 4, 2022
b3b5ad4
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 4, 2022
23c84eb
Merge branch 'bpsage' of https://github.com/husimplicity/pytorch_geom…
HuxleyHu98 Dec 4, 2022
8768df9
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 5, 2022
0959d3a
minor
HuxleyHu98 Dec 5, 2022
57c5830
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 5, 2022
b7cb029
Merge branch 'bpsage' of https://github.com/husimplicity/pytorch_geom…
HuxleyHu98 Dec 5, 2022
eae6442
minor
HuxleyHu98 Dec 5, 2022
318a90e
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 5, 2022
96e3f9c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2022
82dfb41
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 6, 2022
37c68c7
format
HuxleyHu98 Dec 6, 2022
d62518c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2022
dada43b
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 7, 2022
c59109e
fix:limit test sampling data within split test data
HuxleyHu98 Dec 7, 2022
e6d4fca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2022
d4a374c
Merge branch 'master' into bpsage
husimplicity Dec 14, 2022
754e3f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 14, 2022
e752eab
Apply triplet loss
HuxleyHu98 Dec 14, 2022
a1a5b03
Merge branch 'bpsage' of https://github.com/husimplicity/pytorch_geom…
HuxleyHu98 Dec 14, 2022
1fa0b68
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 14, 2022
ac35e61
Merge branch 'master' into bpsage
husimplicity Dec 20, 2022
6ce5fbe
Merge branch 'master' into bpsage
husimplicity Dec 21, 2022
3d6c20c
format
HuxleyHu98 Dec 21, 2022
be4e7ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2022
3924c29
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 22, 2022
ad90da5
Merge branch 'master' into bpsage
husimplicity Dec 27, 2022
559a5b1
Merge branch 'pyg-team:master' into bpsage
husimplicity Jan 16, 2023
6b762ad
changelog
rusty1s Jan 16, 2023
c6a7972
Merge branch 'master' into bpsage
rusty1s Jan 16, 2023
143d6c3
update
rusty1s Jan 16, 2023
f0c2a7a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2023
4b856d3
update
rusty1s Jan 16, 2023
6b9cf19
typo
rusty1s Jan 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
format
  • Loading branch information
HuxleyHu98 committed Dec 21, 2022
commit 3d6c20cf9f561b08b4376d5b2bd3659d0539fee3
30 changes: 15 additions & 15 deletions examples/hetero/bipartite_sage_unsup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# An implementation of unsupervised bipartite GraphSAGE using the Alibaba Taobao
# dataset.
# An implementation of unsupervised bipartite GraphSAGE using the Alibaba
# Taobao dataset.
import os.path as osp

import torch
Expand Down Expand Up @@ -221,7 +221,7 @@ def train():
@torch.no_grad()
def test(loader):
model.eval()
accs, precisions, recalls, f1s = [], [], [], []
accs, precs, recs, f1s = [], [], [], []

for batch in tqdm.tqdm(loader):
batch = batch.to(device)
Expand All @@ -232,34 +232,34 @@ def test(loader):
target = batch['user', 'item'].edge_label.round().cpu()

acc = accuracy_score(target, out)
precision = precision_score(target, out)
recall = recall_score(target, out)
prec = precision_score(target, out)
rec = recall_score(target, out)
f1 = f1_score(target, out)
accs.append(acc)
precisions.append(precision)
recalls.append(recall)
precs.append(prec)
recs.append(rec)
f1s.append(f1)

import numpy as np

total_acc = float(np.mean(accs))
total_precision = float(np.mean(precisions))
total_recall = float(np.mean(recalls))
total_prec = float(np.mean(precs))
total_rec = float(np.mean(recs))
total_f1 = float(np.mean(f1s))

return total_acc, total_precision, total_recall, total_f1
return total_acc, total_prec, total_rec, total_f1


for epoch in range(1, 21):
loss = train()
val_acc, val_precision, val_recall, val_f1 = test(val_loader)
test_acc, test_precision, test_recall, test_f1 = test(test_loader)
val_acc, val_prec, val_rec, val_f1 = test(val_loader)
tst_acc, tst_prec, tst_rec, tst_f1 = test(test_loader)

print(f'Epoch: {epoch:03d} | Loss: {loss:4f}')
print(f'Eval: Accuracy | Precision | Recall | F1 score')
print('Eval: Accuracy | Precision | Recall | F1 score')
print(
f'Val: {val_acc:.4f} | {val_precision:.4f} | {val_recall:.4f} | {val_f1:.4f}'
f'Val: {val_acc:.4f} | {val_prec:.4f} | {val_rec:.4f} | {val_f1:.4f}'
)
print(
f'Test: {test_acc:.4f} | {test_precision:.4f} | {test_recall:.4f} | {test_f1:.4f}'
f'Test: {tst_acc:.4f} | {tst_prec:.4f} | {tst_rec:.4f} | {tst_f1:.4f}'
)
7 changes: 4 additions & 3 deletions torch_geometric/datasets/taobao.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class Taobao(InMemoryDataset):
r"""Taobao User Behavior is a dataset of user behaviors from Taobao offered
by Alibaba. The dataset is from the platform Tianchi Alicloud.
https://tianchi.aliyun.com/dataset/649.
Taobao is a heterogeous graph for recommendation. Nodes represent users with
User IDs and items with Item IDs and Category IDs, and edges represent
Taobao is a heterogeous graph for recommendation. Nodes represent users
with User IDs and items with Item IDs and Category IDs, and edges represent
different types of user behaviours towards items with timestamps.

Args:
Expand All @@ -32,7 +32,8 @@ class Taobao(InMemoryDataset):
being saved to disk. (default: :obj:`None`)

"""
url = 'https://alicloud-dev.oss-cn-hangzhou.aliyuncs.com/UserBehavior.csv.zip'
dataset = 'UserBehavior.csv.zip'
url = 'https://alicloud-dev.oss-cn-hangzhou.aliyuncs.com/' + dataset

def __init__(
self,
Expand Down