Skip to content

Commit

Permalink
Add unsupervised bipartite graphsage & dataset taobao
Browse files Browse the repository at this point in the history
  • Loading branch information
HuxleyHu98 committed Dec 2, 2022
1 parent 07ba384 commit 45fac38
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 0 deletions.
243 changes: 243 additions & 0 deletions examples/hetero/bipartite_sage_unsup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import os.path as osp

This comment has been minimized.

Copy link
@Seventeen17

Seventeen17 Dec 2, 2022

Add description like:

# An implementation of GraphSAGE in Alibaba recommendation system on user-item bipartite graph
import tqdm

import torch
import torch.nn.functional as F
from torch.nn import Embedding, Linear

from torch_geometric.utils.convert import to_scipy_sparse_matrix
from torch_geometric.loader import LinkNeighborLoader
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Taobao

This comment has been minimized.

Copy link
@Seventeen17

Seventeen17 Dec 2, 2022

Change the import order:

import torch_geometric.transforms as T
from torch_geometric.datasets import Taobao
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils.convert import to_scipy_sparse_matrix

from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/Taobao')

dataset = Taobao(path)
data = dataset[0]

data['user'].x = torch.LongTensor(torch.arange(
0, data['user'].num_nodes))
data['item'].x = torch.LongTensor(torch.arange(
0, data['item'].num_nodes))

# Add a reverse ('item', 'rev_2', 'user') relation for message passing:
data = T.ToUndirected()(data)
del data['item', 'rev_2', 'user'].edge_label # Remove "reverse" label.
del data[('user', '2', 'item')].edge_attr

# Perform a link-level split into training, validation, and test edges:
train_data, val_data, test_data = T.RandomLinkSplit(
num_val=0.1,
num_test=0.1,
neg_sampling_ratio=0.0,
edge_types=[('user', '2', 'item')],
rev_edge_types=[('item', 'rev_2', 'user')],
)(data)

This comment has been minimized.

Copy link
@Seventeen17

Seventeen17 Dec 2, 2022

TODO: add a ChronologicalLinkSplit to split train, val, test data with timestamp.


def to_u2i_mat(edge_index, u_num, i_num):
# Convert bipartite edge_index format to matrix format
u2imat = to_scipy_sparse_matrix(edge_index).tocsr()

return u2imat[:u_num, :i_num]


def get_coocur_mat(train_mat, threshold):
# Generate the co-occurrence matrix and top-k filtering
A = train_mat.T @ train_mat
A.setdiag(0)
A = (A >= threshold).nonzero()
A = torch.stack((torch.from_numpy(A[0]), torch.from_numpy(A[1])), dim=0)

return A


u2i_mat = to_u2i_mat(train_data.edge_index_dict[('user', '2', 'item')],
train_data['user'].num_nodes,
train_data['item'].num_nodes)
i2i_edge_index = get_coocur_mat(u2i_mat, 3)

# Add the generated i2i graph for high-order information
train_data[('item', 'sims', 'item')].edge_index = i2i_edge_index
val_data[('item', 'sims', 'item')].edge_index = i2i_edge_index
test_data[('item', 'sims', 'item')].edge_index = i2i_edge_index


train_loader = LinkNeighborLoader(data=train_data,
num_neighbors=[8, 4],
edge_label_index=['user', '2', 'item'],
neg_sampling_ratio=1.,
batch_size=2048,
num_workers=32,
pin_memory=True,
)

This comment has been minimized.

Copy link
@Seventeen17

Seventeen17 Dec 2, 2022

这个右括号移动到77行行尾


val_loader = LinkNeighborLoader(data=val_data,
num_neighbors=[8, 4],
edge_label_index=['user', '2', 'item'],
neg_sampling_ratio=1.,
batch_size=2048,
num_workers=32,
pin_memory=True,
)

test_loader = LinkNeighborLoader(data=test_data,
num_neighbors=[8, 4],
edge_label_index=['user', '2', 'item'],
neg_sampling_ratio=1.,
batch_size=2048,
num_workers=32,
pin_memory=True,
)


class ItemGNNEncoder(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()

self.conv1 = SAGEConv(-1, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, out_channels)

def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index).relu()
return self.lin(x)



class UserGNNEncoder(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)

This comment has been minimized.

Copy link
@Seventeen17

Seventeen17 Dec 2, 2022

117行前面是否有空行和103行之前保持一致

self.conv2 = SAGEConv((-1, -1), hidden_channels)
self.conv3 = SAGEConv((-1, -1), hidden_channels)
self.lin = Linear(hidden_channels, out_channels)

def forward(self, x_dict, edge_index_dict):
item_x = self.conv1(
x_dict['item'],
edge_index_dict[('item', 'sims', 'item')],
).relu()

user_x = self.conv2(
(x_dict['item'], x_dict['user']),
edge_index_dict[('item', 'rev_2', 'user')],
).relu()

user_x = self.conv3(
(item_x, user_x),
edge_index_dict[('item', 'rev_2', 'user')],
).relu()

return self.lin(user_x)


class EdgeDecoder(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.lin1 = Linear(2 * hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, 1)

def forward(self, z_src, z_dst, edge_label_index):
row, col = edge_label_index
z = torch.cat([z_src[row], z_dst[col]], dim=-1)

z = self.lin1(z).relu()
z = self.lin2(z)
return z.view(-1)


class Model(torch.nn.Module):
def __init__(self, user_input_size, item_input_size, hidden_channels, out_channels):
super().__init__()
self.user_emb = Embedding(
user_input_size, hidden_channels, device=device)
self.item_emb = Embedding(
item_input_size, hidden_channels, device=device)
self.item_encoder = ItemGNNEncoder(hidden_channels, out_channels)
self.user_encoder = UserGNNEncoder(hidden_channels, out_channels)
self.decoder = EdgeDecoder(hidden_channels)

def forward(self, x_dict, edge_index_dict, edge_label_index):
z_dict = {}
x_dict['user'] = self.user_emb(x_dict['user'])
x_dict['item'] = self.item_emb(x_dict['item'])
z_dict['item'] = self.item_encoder(
x_dict['item'], edge_index_dict[('item', 'sims', 'item')])
z_dict['user'] = self.user_encoder(x_dict, edge_index_dict)

return self.decoder(z_dict['user'], z_dict['item'], edge_label_index)


model = Model(user_input_size=data['user'].num_nodes,
item_input_size=data['item'].num_nodes,
hidden_channels=64,
out_channels=64)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)


def train():
model.train()
total_loss = 0
for batch in tqdm.tqdm(train_loader):
batch = batch.to(device)
optimizer.zero_grad()
pred = model(batch.x_dict,
batch.edge_index_dict,
batch['user', 'item'].edge_label_index,
)
target = batch['user', 'item'].edge_label
loss = F.binary_cross_entropy_with_logits(pred, target)
loss.backward()
optimizer.step()
total_loss += loss

return float(total_loss)


@torch.no_grad()
def test(loader):
model.eval()
total_acc, total_precision, total_recall, total_f1 = 0., 0., 0., 0.

for batch in tqdm.tqdm(loader):
batch = batch.to(device)
out = model(batch.x_dict,
batch.edge_index_dict,
batch['user', 'item'].edge_label_index
).clamp(min=0, max=1).round().cpu()

target = batch['user', 'item'].edge_label.round().cpu()
acc = accuracy_score(target, out)
precision = precision_score(target, out)
recall = recall_score(target, out)
f1 = f1_score(target, out)
total_acc += acc
total_precision += precision
total_recall += recall
total_f1 += f1

return float(total_acc), float(total_precision), float(total_recall), float(total_f1)


for epoch in range(1, 51):
loss = train()
train_acc, train_precision, train_recall, train_f1 = test(train_loader)
val_acc, val_precision, val_recall, val_f1 = test(val_loader)
test_acc, test_precision, test_recall, test_f1 = test(test_loader)

print(f'Epoch: {epoch:03d} | Loss: {loss:4f}')
print(f'Eval Index: | Accuracy | Precision | Recall | F1_score')
print(f'Train: {train_acc:.4f} | {train_precision:.4f} | {train_recall:.4f} \
| {train_f1:.4f}')
print(f'Val: {val_acc:.4f} | {val_precision:.4f} | {val_recall:.4f} \
| {val_f1:.4f}')
print(f'Test: {test_acc:.4f} | {test_precision:.4f} | {test_recall:.4f} \
| {test_f1:.4f}')
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from .elliptic import EllipticBitcoinDataset
from .dgraph import DGraphFin
from .hydro_net import HydroNet
from .taobao import Taobao

import torch_geometric.datasets.utils # noqa

Expand Down Expand Up @@ -165,6 +166,7 @@
'EllipticBitcoinDataset',
'DGraphFin',
'HydroNet',
'Taobao',
]

classes = __all__
89 changes: 89 additions & 0 deletions torch_geometric/datasets/taobao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
from typing import Callable, List, Optional

import numpy as np
import torch
from torch_geometric.data import (
HeteroData,
InMemoryDataset,
download_url,
extract_zip,
)

class Taobao(InMemoryDataset):
r"""Taobao User Behavior is a dataset of user behaviors from Taobao offered
by Alibaba. The dataset is from the platform Tianchi Alicloud.
https://tianchi.aliyun.com/dataset/649.
Taobao is a heterogeous graph for recommendation. Nodes represent users with
User IDs and items with Item IDs and Category IDs, and edges represent
different types of user behaviours towards items with timestamps.
Args:
root (string): Root directory where the dataset should be saved.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
"""
url = 'https://alicloud-dev.oss-cn-hangzhou.aliyuncs.com/UserBehavior.csv.zip'

def __init__(self, root, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,):

super().__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self) -> List[str]:
return ['UserBehavior.csv']

@property
def processed_file_names(self) -> str:
return 'data.pt'

def download(self):
print(self.raw_dir)
path = download_url(self.url, self.raw_dir)
extract_zip(path, self.raw_dir)
os.remove(path)

def process(self):
import pandas as pd

data = HeteroData()

df = pd.read_csv(self.raw_paths[0])
df.columns = ['userId','itemId','categoryId','behaviorType','timestamp']

# Time representation (YYYY.MM.DD-HH:MM:SS -> Integer)
# 1511539200 = 2017.11.25-00:00:00 1512316799 = 2017.12.03-23:59:59
df = df[(df["timestamp"]>=1511539200) & (df["timestamp"]<=1512316799)]

df = df.drop_duplicates(
subset=['userId','itemId','categoryId','behaviorType','timestamp'],
keep='first')

behavior_dict = {'pv': 0, 'cart': 1, 'buy': 2, 'fav': 3}
df['behaviorType'] = df['behaviorType'].map(behavior_dict).values
_, df['userId'] = np.unique(df[['userId']].values, return_inverse=True)
_, df['itemId'] = np.unique(df[['itemId']].values, return_inverse=True)
_, df['categoryId'] = np.unique(df[['categoryId']].values, return_inverse=True)

data['user'].num_nodes = df['userId'].nunique()
data['item'].num_nodes = df['itemId'].nunique()

edge_feat, _ = np.unique(
df[['userId', 'itemId', 'behaviorType', 'timestamp']].values,
return_index=True, axis=0)
edge_feat = pd.DataFrame(edge_feat).drop_duplicates(subset=[0, 1], keep='last')
data['user', '2', 'item'].edge_index = torch.from_numpy(edge_feat[[0, 1]].values).T
data['user', '2', 'item'].edge_attr = torch.from_numpy(edge_feat[[2, 3]].values).T

data = data if self.pre_transform is None else self.pre_transform(data)

torch.save(self.collate([data]), self.processed_paths[0])

1 comment on commit 45fac38

@Seventeen17
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Nice :)

Please sign in to comment.