Skip to content
Merged
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,7 @@ dist
*.npy
.coverage
dev-stgraph/
htmlconv/
htmlconv/
*.txt
egl_kernel.cu
egl_kernel.ptx
21 changes: 13 additions & 8 deletions stgraph/benchmark_tools/table.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from __future__ import annotations

from rich.console import Console
from rich.table import Table

console = Console()

class BenchmarkTable:
def __init__(self, title: str, col_name_list: list[str]):
self.title = '\n' + title + '\n'
self.title = "\n" + title + "\n"
self.col_name_list = col_name_list
self._table = Table(title=self.title, show_edge=False, style="black bold")
self._num_cols = len(col_name_list)
self._num_rows = 0

self._table_add_columns()

def _table_add_columns(self):
for col_name in self.col_name_list:
self._table.add_column(col_name, justify="left")

def add_row(self, values: list):
values_str = tuple([str(val) for val in values])
self._table.add_row(*values_str)

def display(self):
console.print(self._table)

def display(self, output_file=None):
if not output_file:
console = Console()
else:
console = Console(file=output_file)
console.print(self._table)
78 changes: 78 additions & 0 deletions tests/scripts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# STGraph Script Testing

Within this directory are scripts intended for manual execution by users or developers outside of the automated testing suite for STGraph. These scripts are primarily designed to assess the functionality of the following modules, whose correctness cannot be directly unit tested using PyTest. A table containing such modules can be found in the [List of Modules Tested](#list-of-modules-tested) section.

Additional scripts may be added as the project evolves.

## Usage

To execute the script tests, utilize the following command:

```
python3 stgraph_script.py [-h] [-v | --version VERSION] [-t | --testpack-names TESTPACK_NAMES]
```

For instance, to evaluate the GCN Dataloaders for STGraph version v1.1.0, execute:

```
python3 stgraph_script.py -v 1_1_0 -t gcn_dataloaders
```

Please ensure that the exact version of STGraph is installed within your virtual environment prior to conducting these tests.

### Command Line Arguments

| Argument | Description | Possible Values |
| -------------------- | ------------------------------------------------------ | ---------------------------------------------- |
| -h, --help | Obtain a brief description of the command | - |
| -v, --version | Specify the version of the STGraph testpack to execute | `1_1_0` |
| -t, --testpack-names | Provide a list of testpack names | `gcn_dataloaders`, `temporal_tgcn_dataloaders` |

## A Note to the Developers

The rest of the document outlines essential procedures that developers, contributors, and maintainers must follow for the STGraph project. As the project progresses, it is imperative to keep the testing script updated with new functionalities to prevent any potential issues for end-users.

### Post-Release Protocol

Following the release of a new version of STGraph, code owners are tasked with maintaining the integrity of the testpack folders corresponding to each version. These folders are located within the directory `tests/scripts`.

For instance, upon the release of STGraph version `v1.1.0`, and with the subsequent planning of version `v1.2.0`, the following steps are to be taken:

1. **Creation of New Test Pack:** A copy of the current version's testpack folder `v1_1_0` should be created and renamed to reflect the upcoming version `v1_2_0`.
2. **Development Phase Updates:** Any further enhancements or additions to the test scripts during the development phase of `v1.2.0` must be implemented within the designated folder.

```
tests/
└── scripts
├── v1_1_0
└── v1_2_0
```

By adhering to this protocol, the project maintains a structured and reliable testing framework, ensuring correctness and stability across successive releases.

### Test Script Creation Protocol

When preparing to write test scripts for newly implemented functionalities, ensure adherence to the following protocol

1. **Choose a Descriptive Name:** Select a meaningful and self-documenting name for your testpack, following the lower snake-case convention.
2. **Create Testpack Folder**: Within the current development version of STGraph, create a folder bearing the chosen testpack name.
3. **Script Creation**: Craft a Python script named `<testpack_name>.py`, containing the necessary testing logic.
4. **Supplementary Files**: Include any additional folders and files deemed necessary for testing within the testpack folder.
5. **Status Reporting**: Ensure that the Python script incorporates a mechanism to display the status of test cases effectively on the screen.
6. **Integration with STGraph Script**: Incorporate your newly created testpack as a selectable option for the `--testpack-names` command-line argument within `stgraph_script.py`

To get a better idea on how to develop your scripts, it is advised to refer to previously maintained scripts. This practice ensures uniformity and easier maintenance across all scripts.

## List of Modules Tested

| Module | Test Pack Name | Initial Version Release |
| ------------------------ | ------------------------- | ----------------------- |
| GraphConv | gcn_dataloaders | `v1.1.0` |
| CoraDataLoader | gcn_dataloaders | `v1.1.0` |
| TGCN | temporal_tgcn_dataloaders | `v1.1.0` |
| HungaryCPDataLoader | temporal_tgcn_dataloaders | `v1.1.0` |
| METRLADataLoader | temporal_tgcn_dataloaders | `v1.1.0` |
| MontevideoBusDataLoader | temporal_tgcn_dataloaders | `v1.1.0` |
| PedalMeDataLoader | temporal_tgcn_dataloaders | `v1.1.0` |
| WikiMathDataLoader | temporal_tgcn_dataloaders | `v1.1.0` |
| WindmillOutputDataLoader | temporal_tgcn_dataloaders | `v1.1.0` |
42 changes: 42 additions & 0 deletions tests/scripts/stgraph_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import argparse
import os
import subprocess


def main(args):
version_number = args.version
testpack_names = args.testpack_names

for testpack in testpack_names:
script_path = "v" + version_number + "/" + testpack + "/" + testpack + ".py"
output_folder_path = "v" + version_number + "/" + testpack + "/outputs"
if os.path.exists(script_path):
subprocess.run(["python3", script_path, "-o", output_folder_path])
else:
print(f"Script {script_path} doesn't exists")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="STGraph Test Scripts")

parser.add_argument(
"-v",
"--version",
type=str,
default="1_1_0",
choices=["1_1_0"],
help="Version of STGraph",
)

parser.add_argument(
"-t",
"--testpack-names",
nargs="*",
default=["temporal_tgcn_dataloaders", "gcn_dataloaders"],
choices=["temporal_tgcn_dataloaders", "gcn_dataloaders"],
help="Names of the testpacks to be executed",
)

args = parser.parse_args()

main(args=args)
Empty file.
27 changes: 27 additions & 0 deletions tests/scripts/v1_1_0/gcn_dataloaders/gcn/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch.nn as nn
from stgraph.nn.pytorch.graph_conv import GraphConv

class GCN(nn.Module):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation):
super(GCN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
# input layer
self.layers.append(GraphConv(in_feats, n_hidden, activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation))
# output layer
self.layers.append(GraphConv(n_hidden, n_classes, None))

def forward(self, g, features):
h = features
for layer in self.layers:
h = layer(g, h)
return h
139 changes: 139 additions & 0 deletions tests/scripts/v1_1_0/gcn_dataloaders/gcn/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import argparse
import time

import numpy as np
import pynvml
import snoop
import torch
import torch.nn as nn
import torch.nn.functional as F
import traceback

from .model import GCN
from .utils import accuracy, generate_test_mask, generate_train_mask, to_default_device

from stgraph.dataset import CoraDataLoader
from stgraph.graph.static.StaticGraph import StaticGraph
from stgraph.benchmark_tools.table import BenchmarkTable


def train(
dataset: str,
lr: float,
num_epochs: int,
num_hidden: int,
num_layers: int,
weight_decay: float,
self_loops: bool,
output_file_path: str,
) -> int:
with open(output_file_path, "w") as f:
if torch.cuda.is_available():
print("🎉 CUDA is available", file=f)
else:
print("😔 CUDA is not available", file=f)
return 1

tmp = StaticGraph([(0, 0)], [1], 1)

if dataset == "Cora":
dataloader = CoraDataLoader()
else:
print("😔 Unrecognized dataset", file=f)
return 1

features = torch.FloatTensor(dataloader.get_all_features())
labels = torch.LongTensor(dataloader.get_all_targets())

train_mask = generate_train_mask(len(features), 0.6)
test_mask = generate_test_mask(len(features), 0.6)

train_mask = torch.BoolTensor(train_mask)
test_mask = torch.BoolTensor(test_mask)

cuda = True
torch.cuda.set_device(0)
features = features.cuda()
labels = labels.cuda()
train_mask = train_mask.cuda()
test_mask = test_mask.cuda()
edge_weight = [1 for _ in range(len(dataloader.get_edges()))]

pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
initial_used_gpu_mem = pynvml.nvmlDeviceGetMemoryInfo(handle).used
g = StaticGraph(dataloader.get_edges(), edge_weight, features.shape[0])
graph_mem = pynvml.nvmlDeviceGetMemoryInfo(handle).used - initial_used_gpu_mem

degs = torch.from_numpy(g.weighted_in_degrees()).type(torch.int32)
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
norm = to_default_device(norm)
g.set_ndata("norm", norm.unsqueeze(1))

num_feats = features.shape[1]
n_classes = int(max(labels) - min(labels) + 1)

model = GCN(g, num_feats, num_hidden, n_classes, num_layers, F.relu)
model.cuda()

loss_fcn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
model.parameters(), lr=lr, weight_decay=weight_decay
)

dur = []
Used_memory = 0
table = BenchmarkTable(
f"STGraph GCN on {dataloader.name} dataset",
["Epoch", "Time(s)", "Train Accuracy", "Used GPU Memory (Max MB)"],
)

try:
for epoch in range(num_epochs):
torch.cuda.reset_peak_memory_stats(0)
model.train()
if cuda:
torch.cuda.synchronize()
t0 = time.time()

# forward
logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()

now_mem = torch.cuda.max_memory_allocated(0) + graph_mem
Used_memory = max(now_mem, Used_memory)

if cuda:
torch.cuda.synchronize()

run_time_this_epoch = time.time() - t0

if epoch >= 3:
dur.append(run_time_this_epoch)

train_acc = accuracy(logits[train_mask], labels[train_mask])
table.add_row(
[epoch, run_time_this_epoch, train_acc, (now_mem * 1.0 / (1024**2))]
)

table.display(output_file=f)
print("Average Time taken: {:6f}".format(np.mean(dur)), file=f)
return 0

except Exception as e:
print("---------------- Error ----------------\n", file=f)
print(e, file=f)
print("\n", file=f)

traceback.print_exc(file=f)
print("\n", file=f)

if "out of memory" in str(e):
table.add_row(["OOM", "OOM", "OOM", "OOM"])
table.display(output_file=f)

return 1
32 changes: 32 additions & 0 deletions tests/scripts/v1_1_0/gcn_dataloaders/gcn/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch


def accuracy(logits, labels):
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)


# GPU | CPU
def get_default_device():
if torch.cuda.is_available():
return torch.device("cuda:0")
else:
return torch.device("cpu")


def to_default_device(data):
if isinstance(data, (list, tuple)):
return [to_default_device(x, get_default_device()) for x in data]

return data.to(get_default_device(), non_blocking=True)


def generate_train_mask(size: int, train_test_split: int) -> list:
cutoff = size * train_test_split
return [1 if i < cutoff else 0 for i in range(size)]


def generate_test_mask(size: int, train_test_split: int) -> list:
cutoff = size * train_test_split
return [0 if i < cutoff else 1 for i in range(size)]
Loading