Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the unsupervised bipartite GraphSAGE model on the Taobao dataset #6144

Merged
merged 37 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
45fac38
Add unsupervised bipartite graphsage & dataset taobao
HuxleyHu98 Dec 2, 2022
9e4e61c
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 2, 2022
e4baf1f
style
HuxleyHu98 Dec 2, 2022
e4df11b
minor
HuxleyHu98 Dec 4, 2022
b3b5ad4
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 4, 2022
23c84eb
Merge branch 'bpsage' of https://github.com/husimplicity/pytorch_geom…
HuxleyHu98 Dec 4, 2022
8768df9
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 5, 2022
0959d3a
minor
HuxleyHu98 Dec 5, 2022
57c5830
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 5, 2022
b7cb029
Merge branch 'bpsage' of https://github.com/husimplicity/pytorch_geom…
HuxleyHu98 Dec 5, 2022
eae6442
minor
HuxleyHu98 Dec 5, 2022
318a90e
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 5, 2022
96e3f9c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2022
82dfb41
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 6, 2022
37c68c7
format
HuxleyHu98 Dec 6, 2022
d62518c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2022
dada43b
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 7, 2022
c59109e
fix:limit test sampling data within split test data
HuxleyHu98 Dec 7, 2022
e6d4fca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2022
d4a374c
Merge branch 'master' into bpsage
husimplicity Dec 14, 2022
754e3f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 14, 2022
e752eab
Apply triplet loss
HuxleyHu98 Dec 14, 2022
a1a5b03
Merge branch 'bpsage' of https://github.com/husimplicity/pytorch_geom…
HuxleyHu98 Dec 14, 2022
1fa0b68
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 14, 2022
ac35e61
Merge branch 'master' into bpsage
husimplicity Dec 20, 2022
6ce5fbe
Merge branch 'master' into bpsage
husimplicity Dec 21, 2022
3d6c20c
format
HuxleyHu98 Dec 21, 2022
be4e7ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2022
3924c29
Merge branch 'pyg-team:master' into bpsage
husimplicity Dec 22, 2022
ad90da5
Merge branch 'master' into bpsage
husimplicity Dec 27, 2022
559a5b1
Merge branch 'pyg-team:master' into bpsage
husimplicity Jan 16, 2023
6b762ad
changelog
rusty1s Jan 16, 2023
c6a7972
Merge branch 'master' into bpsage
rusty1s Jan 16, 2023
143d6c3
update
rusty1s Jan 16, 2023
f0c2a7a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2023
4b856d3
update
rusty1s Jan 16, 2023
6b9cf19
typo
rusty1s Jan 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .style.yapf
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[style]
based_on_style=pep8
split_before_named_assigns=False
blank_line_before_nested_class_or_def=False
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `Taobao` dataset and a corresponding example for it ([#6144](https://github.com/pyg-team/pytorch_geometric/pull/6144))
- Added `pyproject.toml` ([#6431](https://github.com/pyg-team/pytorch_geometric/pull/6431))
- Added the `torch_geometric.contrib` sub-package ([#6422](https://github.com/pyg-team/pytorch_geometric/pull/6422))
- Warn on using latest documentation ([#6418](https://github.com/pyg-team/pytorch_geometric/pull/6418))
Expand Down
14 changes: 7 additions & 7 deletions docs/source/tutorial/heterogeneous.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ As a consequence of the different data structure, the message passing formulatio
Example Graph
-------------

As a guiding example, we take a look at the heterogenous `ogbn-mag <https://ogb.stanford.edu/docs/nodeprop>`__ network from the `OGB datasets <https://ogb.stanford.edu>`_:
As a guiding example, we take a look at the heterogeneous `ogbn-mag <https://ogb.stanford.edu/docs/nodeprop>`__ network from the `OGB datasets <https://ogb.stanford.edu>`_:

.. image:: ../_figures/hg_example.svg
:align: center
Expand Down Expand Up @@ -192,7 +192,7 @@ The transform :meth:`~torch_geometric.transforms.NormalizeFeatures` works like i
Creating Heterogeneous GNNs
---------------------------

Standard Message Passing GNNs (MP-GNNs) can not trivially be applied to heterogenous graph data, as node and edge features from different types can not be processed by the same functions due to differences in feature type.
Standard Message Passing GNNs (MP-GNNs) can not trivially be applied to heterogeneous graph data, as node and edge features from different types can not be processed by the same functions due to differences in feature type.
A natural way to circumvent this is to implement message and update functions individually for each edge type.
During runtime, the MP-GNN algorithm would need to iterate over edge type dictionaries during message computation and over node type dictionaries during node updates.

Expand Down Expand Up @@ -298,10 +298,10 @@ Afterwards, the created model can be trained as usual:
optimizer.step()
return float(loss)

Using the Heterogenous Convolution Wrapper
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Using the Heterogeneous Convolution Wrapper
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The heterogenous convolution wrapper :class:`torch_geometric.nn.conv.HeteroConv` allows to define custom heterogenous message and update functions to build arbitrary MP-GNNs for heterogeneous graphs from scratch.
The heterogeneous convolution wrapper :class:`torch_geometric.nn.conv.HeteroConv` allows to define custom heterogeneous message and update functions to build arbitrary MP-GNNs for heterogeneous graphs from scratch.
While the automatic converter :meth:`~torch_geometric.nn.to_hetero` uses the same operator for all edge types, the wrapper allows to define different operators for different edge types.
Here, :class:`~torch_geometric.nn.conv.HeteroConv` takes a dictionary of submodules as input, one for each edge type in the graph data.
The following `example <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hetero_conv_dblp.py>`__ shows how to apply it.
Expand Down Expand Up @@ -349,8 +349,8 @@ We can initialize the model by calling it once (see :ref:`here<lazyinit>` for mo

and run the standard training procedure as outlined :ref:`here<trainfunc>`.

Deploy Existing Heterogenous Operators
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Deploy Existing Heterogeneous Operators
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

:pyg:`PyG` provides operators (*e.g.*, :class:`torch_geometric.nn.conv.HGTConv`), which are specifically designed for heterogeneous graphs.
These operators can be directly used to build heterogeneous GNN models as can be seen in the following `example <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hgt_dblp.py>`__:
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorial/load_csv.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Loading Graphs from CSV
=======================

In this example, we will show how to load a set of :obj:`*.csv` files as input and construct a **heterogeneous graph** from it, which can be used as input to a `heterogenous graph model <heterogeneous.html>`__.
In this example, we will show how to load a set of :obj:`*.csv` files as input and construct a **heterogeneous graph** from it, which can be used as input to a `heterogeneous graph model <heterogeneous.html>`__.
This tutorial is also available as an executable `example script <https://github.com/pyg-team/pytorch_geometric/tree/master/examples/hetero/load_csv.py>`_ in the :obj:`examples/hetero` directory.

We are going to use the `MovieLens dataset <https://grouplens.org/datasets/movielens/>`_ collected by the GroupLens research group.
Expand Down Expand Up @@ -251,7 +251,7 @@ With this, we are ready to finalize our :class:`~torch_geometric.data.HeteroData
}
)

This :class:`~torch_geometric.data.HeteroData` object is the native format of heterogenous graphs in :pyg:`PyG` and can be used as input for `heterogenous graph models <heterogeneous.html>`__.
This :class:`~torch_geometric.data.HeteroData` object is the native format of heterogeneous graphs in :pyg:`PyG` and can be used as input for `heterogeneous graph models <heterogeneous.html>`__.

.. note::

Expand Down
2 changes: 1 addition & 1 deletion examples/hetero/bipartite_sage.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, num_users, hidden_channels, out_channels):
self.user_emb = Embedding(num_users, hidden_channels)
self.user_encoder = UserGNNEncoder(hidden_channels, out_channels)
self.movie_encoder = MovieGNNEncoder(hidden_channels, out_channels)
self.decoder = EdgeDecoder(hidden_channels)
self.decoder = EdgeDecoder(out_channels)

def forward(self, x_dict, edge_index_dict, edge_label_index):
z_dict = {}
Expand Down
256 changes: 256 additions & 0 deletions examples/hetero/bipartite_sage_unsup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# An implementation of unsupervised bipartite GraphSAGE using the Alibaba
# Taobao dataset.
import os.path as osp

import torch
import torch.nn.functional as F
import tqdm
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_score,
recall_score,
)
from torch.nn import Embedding, Linear

import torch_geometric.transforms as T
from torch_geometric.datasets import Taobao
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils.convert import to_scipy_sparse_matrix

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/Taobao')

dataset = Taobao(path)
data = dataset[0]

data['user'].x = torch.arange(0, data['user'].num_nodes)
data['item'].x = torch.arange(0, data['item'].num_nodes)

# Only consider user<>item relationships for simplicity:
del data['category']
del data['item', 'category']
del data['user', 'item'].time
del data['user', 'item'].behavior

# Add a reverse ('item', 'rev_to', 'user') relation for message passing:
data = T.ToUndirected()(data)

# Perform a link-level split into training, validation, and test edges:
print('Computing data splits...')
train_data, val_data, test_data = T.RandomLinkSplit(
num_val=0.1,
num_test=0.1,
neg_sampling_ratio=1.0,
add_negative_train_samples=False,
edge_types=[('user', 'to', 'item')],
rev_edge_types=[('item', 'rev_to', 'user')],
)(data)
print('Done!')

# Compute sparsified item<>item relationships through users:
print('Computing item<>item relationships...')
mat = to_scipy_sparse_matrix(data['user', 'item'].edge_index).tocsr()
mat = mat[:data['user'].num_nodes, :data['item'].num_nodes]
comat = mat.T @ mat
comat.setdiag(0)
comat = comat >= 3.
comat = comat.tocoo()
row = torch.from_numpy(comat.row).to(torch.long)
col = torch.from_numpy(comat.col).to(torch.long)
item_to_item_edge_index = torch.stack([row, col], dim=0)

# Add the generated item<>item relationships for high-order information:
train_data['item', 'item'].edge_index = item_to_item_edge_index
val_data['item', 'item'].edge_index = item_to_item_edge_index
test_data['item', 'item'].edge_index = item_to_item_edge_index
print('Done!')

train_loader = LinkNeighborLoader(
data=train_data,
num_neighbors=[8, 4],
edge_label_index=('user', 'to', 'item'),
neg_sampling='binary',
batch_size=2048,
shuffle=True,
num_workers=16,
drop_last=True,
)

val_loader = LinkNeighborLoader(
data=val_data,
num_neighbors=[8, 4],
edge_label_index=(
('user', 'to', 'item'),
val_data[('user', 'to', 'item')].edge_label_index,
),
edge_label=val_data[('user', 'to', 'item')].edge_label,
batch_size=2048,
shuffle=True,
num_workers=16,
)

test_loader = LinkNeighborLoader(
data=test_data,
num_neighbors=[8, 4],
edge_label_index=(
('user', 'to', 'item'),
test_data[('user', 'to', 'item')].edge_label_index,
),
edge_label=test_data[('user', 'to', 'item')].edge_label,
batch_size=2048,
shuffle=True,
num_workers=16,
)


class ItemGNNEncoder(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(-1, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, out_channels)

def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index).relu()
return self.lin(x)


class UserGNNEncoder(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), hidden_channels)
self.conv3 = SAGEConv((-1, -1), hidden_channels)
self.lin = Linear(hidden_channels, out_channels)

def forward(self, x_dict, edge_index_dict):
item_x = self.conv1(
x_dict['item'],
edge_index_dict[('item', 'to', 'item')],
).relu()

user_x = self.conv2(
(x_dict['item'], x_dict['user']),
edge_index_dict[('item', 'rev_to', 'user')],
).relu()

user_x = self.conv3(
(item_x, user_x),
edge_index_dict[('item', 'to', 'user')],
).relu()

return self.lin(user_x)


class EdgeDecoder(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.lin1 = Linear(2 * hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, 1)

def forward(self, z_src, z_dst, edge_label_index):
row, col = edge_label_index
z = torch.cat([z_src[row], z_dst[col]], dim=-1)

z = self.lin1(z).relu()
z = self.lin2(z)
return z.view(-1)


class Model(torch.nn.Module):
def __init__(self, num_users, num_items, hidden_channels, out_channels):
super().__init__()
self.user_emb = Embedding(num_users, hidden_channels, device=device)
self.item_emb = Embedding(num_items, hidden_channels, device=device)
self.item_encoder = ItemGNNEncoder(hidden_channels, out_channels)
self.user_encoder = UserGNNEncoder(hidden_channels, out_channels)
self.decoder = EdgeDecoder(out_channels)

def forward(self, x_dict, edge_index_dict, edge_label_index):
z_dict = {}
x_dict['user'] = self.user_emb(x_dict['user'])
x_dict['item'] = self.item_emb(x_dict['item'])
z_dict['item'] = self.item_encoder(
x_dict['item'],
edge_index_dict[('item', 'to', 'item')],
)
z_dict['user'] = self.user_encoder(x_dict, edge_index_dict)

return self.decoder(z_dict['user'], z_dict['item'], edge_label_index)


model = Model(
num_users=data['user'].num_nodes,
num_items=data['item'].num_nodes,
hidden_channels=64,
out_channels=64,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train():
model.train()

total_loss = total_examples = 0
for batch in tqdm.tqdm(train_loader):
batch = batch.to(device)
optimizer.zero_grad()

pred = model(
batch.x_dict,
batch.edge_index_dict,
batch['user', 'item'].edge_label_index,
)
loss = F.binary_cross_entropy_with_logits(
pred, batch['user', 'item'].edge_label)

loss.backward()
optimizer.step()
total_loss += float(loss)
total_examples += pred.numel()

return total_loss / total_examples


@torch.no_grad()
def test(loader):
model.eval()

preds, targets = [], []
for batch in tqdm.tqdm(loader):
batch = batch.to(device)

pred = model(
batch.x_dict,
batch.edge_index_dict,
batch['user', 'item'].edge_label_index,
).sigmoid().view(-1).cpu()
target = batch['user', 'item'].edge_label.long().cpu()

preds.append(pred)
targets.append(pred)

pred = torch.cat(preds, dim=0).numpy()
target = torch.cat(target, dim=0).numpy()

acc = accuracy_score(target, pred)
prec = precision_score(target, pred)
rec = recall_score(target, pred)
f1 = f1_score(target, pred)

return acc, prec, rec, f1


for epoch in range(1, 21):
loss = train()
val_acc, val_prec, val_rec, val_f1 = test(val_loader)
test_acc, test_prec, test_rec, test_f1 = test(test_loader)

print(f'Epoch: {epoch:03d}, Loss: {loss:4f}')
print(f'Val Acc: {val_acc:.4f}, Val Precision {val_prec:.4f}, '
f'Val Recall {val_rec:.4f}, Val F1 {val_f1:.4f}')
print(f'Test Acc: {test_acc:.4f}, Test Precision {test_prec:.4f}, '
f'Test Recall {test_rec:.4f}, Test F1 {test_f1:.4f}')
2 changes: 1 addition & 1 deletion test/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_data():
assert clone.edge_index.data_ptr() != data.edge_index.data_ptr()
assert clone.edge_index.tolist() == data.edge_index.tolist()

# Test `data.to_heterogenous()`:
# Test `data.to_heterogeneous()`:
out = data.to_heterogeneous()
assert torch.allclose(data.x, out['0'].x)
assert torch.allclose(data.edge_index, out['0', '0'].edge_index)
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from .infection_dataset import InfectionDataset
from .ba2motif_dataset import BA2MotifDataset
from .airfrans import AirfRANS
from .taobao import Taobao

import torch_geometric.datasets.utils # noqa

Expand Down Expand Up @@ -173,6 +174,7 @@
'InfectionDataset',
'BA2MotifDataset',
'AirfRANS',
'Taobao',
]

classes = __all__
Loading