Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CuGraph Example Fixes #9577

Merged
merged 34 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
64d6fc9
initial write
alexbarghi-nv Jun 1, 2022
32cd7ca
more hard work, demo to follow
alexbarghi-nv Jun 3, 2022
d578b15
add example
alexbarghi-nv Jun 3, 2022
cbdfc26
follow
alexbarghi-nv Jun 3, 2022
c1cb247
add newline
alexbarghi-nv Jun 3, 2022
fac4f9e
Update torch_geometric/data/cugraph/cugraph_data.py
alexbarghi-nv Jun 6, 2022
9198d97
Update torch_geometric/data/cugraph/cugraph_data.py
alexbarghi-nv Jun 6, 2022
23bbefe
Update torch_geometric/nn/conv/message_passing.py
alexbarghi-nv Jun 6, 2022
ba58ca2
Update torch_geometric/nn/conv/message_passing.py
alexbarghi-nv Jun 6, 2022
8830e82
Update torch_geometric/typing.py
alexbarghi-nv Jun 6, 2022
4221cf6
Update torch_geometric/data/cugraph/cugraph_data.py
alexbarghi-nv Jun 6, 2022
4063d13
Update torch_geometric/data/cugraph/cugraph_data.py
alexbarghi-nv Jun 6, 2022
1f2f78a
remove anytensor typing
alexbarghi-nv Jun 7, 2022
9eb546c
Merge branch 'initial-cugraph-storage' of https://github.com/alexbarg…
alexbarghi-nv Jun 7, 2022
a6b6432
fixes and changes
alexbarghi-nv Jun 7, 2022
172e953
small fixes
alexbarghi-nv Jun 7, 2022
d7cf2ba
Merge pull request #1 from alexbarghi-nv/initial-cugraph-storage
alexbarghi-nv Jun 14, 2022
7982e17
fix merge conflict
alexbarghi-nv Aug 3, 2022
5ef3ee8
fix merge conflict
alexbarghi-nv Sep 9, 2022
d476386
Merge branch 'master' of https://github.com/pyg-team/pytorch_geometric
alexbarghi-nv Sep 23, 2022
e81881e
resolve merge conflict
alexbarghi-nv Aug 6, 2024
06ce89a
minor fixes
alexbarghi-nv Aug 6, 2024
9aedc38
remove unwanted files
alexbarghi-nv Aug 6, 2024
23f1ed8
remove old unwanted file
alexbarghi-nv Aug 6, 2024
68d2db7
fix
alexbarghi-nv Aug 6, 2024
1035711
c
alexbarghi-nv Aug 6, 2024
e23035d
revert
alexbarghi-nv Aug 6, 2024
bd9960e
revert message passing
alexbarghi-nv Aug 6, 2024
cbcf690
add line
alexbarghi-nv Aug 6, 2024
ba98b6a
revert typing
alexbarghi-nv Aug 6, 2024
c2c4c7a
z
alexbarghi-nv Aug 6, 2024
009db20
remove log file
alexbarghi-nv Aug 6, 2024
4cf1a3e
update changelog
alexbarghi-nv Aug 6, 2024
316b30f
correct changelog
alexbarghi-nv Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more hard work, demo to follow
  • Loading branch information
alexbarghi-nv committed Jun 3, 2022
commit 32cd7cad8789b0536b5c6e8a2d545db02cd6030f
59 changes: 59 additions & 0 deletions test/data/test_cugraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import sys
sys.path += ['/work/pytorch_geometric', '/work/gaas/python']

from torch_geometric.data import Data
from torch_geometric.data.cugraph import CuGraphData
from torch_geometric.nn import GCNConv

import numpy as np

import torch
import torch.nn.functional as F

from gaas_client import GaasClient

client = GaasClient()
client.load_csv_as_edge_data('/work/cugraph/datasets/karate.csv', ['int32','int32','float32'], ['0','1'])

cd = CuGraphData(client)


class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(8, 16)
self.conv2 = GCNConv(16, 2)

def forward(self, data):
x, edge_index = data.x, data.edge_index

x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)

return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)

data = cd.to(device)
data.x = torch.Tensor(np.ones((34,8))).type(torch.FloatTensor).to(device)
data.y = torch.Tensor(np.zeros(34)).type(torch.LongTensor).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()

model.eval()
pred = model(data).argmax(dim=1)
correct = (pred == data.y).sum()
acc = int(correct) / 34
print(f'Accuracy: {acc:.4f}')
13 changes: 11 additions & 2 deletions torch_geometric/data/cugraph/cugraph_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from torch import device as TorchDevice
from torch_geometric.data import Data
from torch_geometric.data.cugraph.cugraph_storage import CuGraphStorage

from gaas_client.client import GaasClient
from gaas_client.defaults import graph_id as DEFAULT_GRAPH_ID

class CuGraphData(Data):
def __init__(self, gaas_client: GaasClient, graph_id: int=DEFAULT_GRAPH_ID):
def __init__(self, gaas_client: GaasClient, graph_id: int=DEFAULT_GRAPH_ID, device=TorchDevice('cpu')):
super().__init__()

self.__dict__['_store'] = CuGraphStorage(gaas_client, graph_id)
self.__dict__['_store'] = CuGraphStorage(gaas_client, graph_id, device=device)
self.__dict__['device'] = device

def to(self, to_device: TorchDevice) -> Data:
return CuGraphData(
self.__dict__['_store'].gaas_client,
self.__dict__['_store'].gaas_graph_id,
to_device
)
105 changes: 85 additions & 20 deletions torch_geometric/data/cugraph/cugraph_storage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import Any
from typing import List

import torch
from torch import device as TorchDevice
from torch_geometric.typing import ProxyTensor
from torch_geometric.data.storage import GlobalStorage

from gaas_client.client import GaasClient

import torch
EDGE_KEYS = ["_SRC_", "_DST_"]
VERTEX_KEYS = ["_VERTEX_ID_"]


class TorchTensorGaasGraphDataProxy(ProxyTensor):
"""
Expand All @@ -14,13 +21,29 @@ class TorchTensorGaasGraphDataProxy(ProxyTensor):
"""
_data_categories = ["vertex", "edge"]

def __init__(self, gaas_client: GaasClient, gaas_graph_id: int, data_category: str):
def __init__(self,
gaas_client: GaasClient,
gaas_graph_id: int,
data_category: str,
device:TorchDevice=TorchDevice('cpu'),
property_keys: List[str]=None,
transposed: bool=False):
if data_category not in self._data_categories:
raise ValueError("data_category must be one of "
f"{self._data_categories}, got {data_category}")

if property_keys is None:
if data_category == 'vertex':
property_keys = VERTEX_KEYS
else:
property_keys = EDGE_KEYS

self.__client = gaas_client
self.__graph_id = gaas_graph_id
self.__category = data_category
self.__device = device
self.__property_keys = property_keys
self.__transposed = transposed

def __getitem__(self, index: int):
"""
Expand All @@ -33,37 +56,60 @@ def __getitem__(self, index: int):
index = [int(i) for i in index]

if self.__category == "edge":
if index > 1:
raise IndexError(index)
# FIXME find a more efficient way to do this that doesn't transfer so much data
idx = -1 if self.__transposed else index
data = self.__client.get_graph_edge_dataframe_rows(
index_or_indices=-1, graph_id=self.__graph_id)
index_or_indices=idx, graph_id=self.__graph_id,
property_keys=self.__property_keys)

else:
# FIXME find a more efficient way to do this that doesn't transfer so much data
idx = -1 if self.__transposed else index
data = self.__client.get_graph_vertex_dataframe_rows(
index_or_indices=index, graph_id=self.__graph_id)
index_or_indices=idx, graph_id=self.__graph_id,
property_keys=self.__property_keys)

torch_data = torch.from_numpy(data.T)[index]
if self.__category == 'vertex':
return torch_data.to(torch.float32)
return torch_data.to(torch.long)
if self.__transposed:
torch_data = torch.from_numpy(data.T)[index].to(self.device)
if self.__property_keys[index] in EDGE_KEYS:
torch_data = torch_data.to(torch.long)
elif self.__property_keys[index] in VERTEX_KEYS:
torch_data = torch_data.to(torch.long)
else:
# FIXME handle non-numeric datatypes
torch_data = torch.from_numpy(data).to(torch.float32)

return torch_data.to(self.__device)

@property
def shape(self) -> torch.Size:
if self.__category == "edge":
shape = self.__client.get_graph_edge_dataframe_shape(
graph_id=self.__graph_id)
return torch.Size([shape[1] - 1, shape[0]])
num_edges = self.__client.get_num_edges(self.__graph_id)
return torch.Size([len(self.__property_keys), num_edges])
else:
shape = self.__client.get_graph_vertex_dataframe_shape(
graph_id=self.__graph_id)
return torch.Size(shape)
num_vertices = self.__client.get_num_vertices(self.__graph_id)
return torch.Size([len(self.__property_keys), num_vertices])

@property
def dtype(self) -> Any:
if self.__category == 'edge':
return torch.long
else:
return torch.float32

@property
def device(self) -> TorchDevice:
return self.__device

def to(self, to_device: TorchDevice):
return TorchTensorGaasGraphDataProxy(
self.__client,
self.__graph_id,
self.__category,
to_device,
property_keys=self.__property_keys,
transposed=self.transposed
)

def dim(self) -> int:
return self.shape[0]
Expand All @@ -76,19 +122,38 @@ def size(self, idx=None) -> Any:


class CuGraphStorage(GlobalStorage):
def __init__(self, gaas_client: GaasClient, gaas_graph_id: int):
def __init__(self, gaas_client: GaasClient, gaas_graph_id: int, device: TorchDevice=TorchDevice('cpu')):
super().__init__()
self.gaas_client = gaas_client
self.gaas_graph_id = gaas_graph_id
self.node_index = TorchTensorGaasGraphDataProxy(gaas_client, gaas_graph_id, 'vertex')
self.edge_index = TorchTensorGaasGraphDataProxy(gaas_client, gaas_graph_id, 'edge')
self.node_index = TorchTensorGaasGraphDataProxy(gaas_client, gaas_graph_id, 'vertex', device)
self.edge_index = TorchTensorGaasGraphDataProxy(gaas_client, gaas_graph_id, 'edge', device, transposed=True)

@property
def num_nodes(self) -> int:
return self.gaas_client.get_num_vertices(self.gaas_graph_id)

@property
def num_node_features(self) -> int:
return self.gaas_client.get_graph_vertex_dataframe_shape(self.gaas_graph_id)[1]

@property
def num_edge_features(self) -> int:
# includes the original src and dst columns w/ original names
return self.gaas_client.get_graph_edge_dataframe_shape(self.gaas_graph_id)[1]

@property
def num_edges(self) -> int:
return self.gaas_client.get_num_edges(self.gaas_graph_id)


def __getattr__(self, key: str) -> Any:
if key in self:
return self[key]
else:
return TorchTensorGaasGraphDataProxy(
self.gaas_client,
self.gaas_graph_id,
'vertex',
self.node_index.device,
[key]
)