Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding version check #51

Merged
merged 4 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# TGB
<!-- # TGB -->
![TGB logo](imgs/logo.png)

<h4>
<a href="https://arxiv.org/abs/2307.01026"><img src="https://img.shields.io/badge/arXiv-pdf-yellowgreen"></a>
<a href="https://pypi.org/project/py-tgb/"><img src="https://img.shields.io/pypi/v/py-tgb.svg?color=brightgreen"></a>
Expand All @@ -7,14 +9,21 @@
</h4>
Temporal Graph Benchmark for Machine Learning on Temporal Graphs

![TGB dataloading and evaluation pipeline](imgs/pipeline.png)

Overview of the Temporal Graph Benchmark (TGB) pipeline:
- TGB includes large-scale and realistic datasets from five different domains with both dynamic link prediction and node property prediction tasks
- TGB automatically downloads datasets and processes them into `numpy`, `PyTorch` and `PyG compatible TemporalData` formats.
- Novel TG models can be easily evaluated on TGB datasets via reproducible and realistic evaluation protocols.
- TGB provides public and online leaderboards to track recent developments in temporal graph learning domain

![TGB dataloading and evaluation pipeline](imgs/pipeline.png)

### Annoucements

**Please update to version `0.7.5`**

the negative samples for the `tgbl-wiki` and `tgbl-review` dataset has been updated and redownload of the dataset would be needed (will be prompted automatically in this version when you use the dataloader)


### Pip Install

Expand Down
1 change: 1 addition & 0 deletions docs/about.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Temporal Graph Benchmark (TGB)
![TGB logo](assets/logo.png)

## Overview

Expand Down
Binary file added docs/assets/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 20 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
# Welcome to Temporal Graph Benchmark
![TGB logo](assets/logo.png)



### Pip Install

You can install TGB via [pip](https://pypi.org/project/py-tgb/)
```
pip install py-tgb
```

### Links and Datasets

The project website can be found [here](https://tgb.complexdatalab.com/).

The API documentations can be found [here](https://shenyanghuang.github.io/TGB/).

all dataset download links can be found at [info.py](https://github.com/shenyangHuang/TGB/blob/main/tgb/utils/info.py)

TGB dataloader will also automatically download the dataset as well as the negative samples for the link property prediction datasets.

# TGB
Temporal Graph Benchmark project repo

### Install dependency
Our implementation works with python >= 3.9 and can be installed as follows
Expand Down
Binary file added imgs/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 4 additions & 3 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
site_name: TGB
site_name: Temporal Graph Benchmark

nav:
- Overview: index.md
Expand All @@ -12,6 +12,7 @@ nav:
- Access Edge Data in Numpy: tutorials/Edge_data_numpy.ipynb

theme:
logo: assets/logo.png
name: material
features:
- navigation.tabs
Expand All @@ -29,8 +30,8 @@ theme:
toggle:
icon: material/toggle-switch-off-outline
name: Switch to dark mode
primary: orange
accent: purple
primary: purple
accent: orange
- scheme: slate
toggle:
icon: material/toggle-switch
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "py-tgb"
version = "0.7.0"
version = "0.7.5"
description = "Temporal Graph Benchmark project repo"
authors = ["shenyang Huang <shenyang.huang@mail.mcgill.ca>", "Farimah Poursafaei", "Emanuele Rossi <emanuele.rossi1909@gmail.com>", "Jacob Danovitch <jacob.danovitch@mila.quebec>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from setuptools import setup, find_packages

setup(name="py-tgb", version="0.7.0", packages=find_packages())
setup(name="py-tgb", version="0.7.5", packages=find_packages())
182 changes: 182 additions & 0 deletions tgb/datasets/token_network/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import csv




def store_token_address(token_dict, outname, topk=1000):
"""
Parameters:
outname: name of the output csv file
Output:
output csv file with topk token addresses
"""
sorted_tokens = {k: v for k, v in sorted(token_dict.items(), key=lambda item: item[1], reverse=True)}
ctr = 0
with open(outname, "w") as csv_file:
csv_writer = csv.writer(csv_file, delimiter=",")
csv_writer.writerow(["token_address", "frequency"])
for key, value in sorted_tokens.items():
if (ctr <= topk):
csv_writer.writerow([key, value])
else:
break
ctr += 1

def analyze_token_frequency(fname):
# ['token_address', 'from_address', 'to_address', 'value', 'block_timestamp']
token_dict = {}
node_dict = {}
time_dict = {}
max_w = 0
min_w = 100000

with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
ctr = 0
for row in csv_reader:
if ctr == 0:
ctr += 1
continue
else:
token_type = row[0]
if (token_type not in token_dict):
token_dict[token_type] = 1
else:
token_dict[token_type] += 1
src = row[1]
if (src not in node_dict):
node_dict[src] = 1
else:
node_dict[src] += 1
dst = row[2]
if (dst not in node_dict):
node_dict[dst] = 1
else:
node_dict[dst] += 1

w = float(row[3])
if (w > max_w):
max_w = w
elif (w < min_w):
min_w = w
timestamp = row[4]
if (timestamp not in time_dict):
time_dict[timestamp] = 1
ctr += 1

print (" number of unique tokens are ", len(token_dict))
print (" number of unique nodes are ", len(node_dict))
print (" number of unique timestamps are ", len(time_dict))
print (" max weight is ", max_w)
print (" min weight is ", min_w)

topk = 1000
store_token_address(token_dict, "token_list.csv", topk=topk)






def print_csv(fname):
# ['token_address', 'from_address', 'to_address', 'value', 'block_timestamp']
with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
ctr = 0
for row in csv_reader:
ctr += 1
print ("there are ", ctr, " rows in the csv file")





def analyze_csv(fname):
node_dict = {}
edge_dict = {}
num_edges = 0
num_time = 0
prev_t = "none"
min_w = 100000
max_w = 0

with open(fname, "r") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
for row in csv_reader:
if line_count == 0:
line_count += 1
else:
# t,u,v,w
t = row[0]
u = row[1]
v = row[2]
w = float(row[3].strip())

# min & max edge weights
if w > max_w:
max_w = w

if w < min_w:
min_w = w

# count unique time
if t != prev_t:
num_time += 1
prev_t = t

# unique nodes
if u not in node_dict:
node_dict[u] = 1
else:
node_dict[u] += 1

if v not in node_dict:
node_dict[v] = 1
else:
node_dict[v] += 1

# unique edges
num_edges += 1
if (u, v) not in edge_dict:
edge_dict[(u, v)] = 1
else:
edge_dict[(u, v)] += 1

print("----------------------high level statistics-------------------------")
print("number of total edges are ", num_edges)
print("number of nodes are ", len(node_dict))
print("number of unique edges are ", len(edge_dict))
print("number of unique timestamps are ", num_time)
print("maximum edge weight is ", max_w)
print("minimum edge weight is ", min_w)

num_10 = 0
num_100 = 0
num_1000 = 0

for node in node_dict:
if node_dict[node] >= 10:
num_10 += 1
if node_dict[node] >= 100:
num_100 += 1
if node_dict[node] >= 1000:
num_1000 += 1
print("number of nodes with # edges >= 10 is ", num_10)
print("number of nodes with # edges >= 100 is ", num_100)
print("number of nodes with # edges >= 1000 is ", num_1000)
print("----------------------high level statistics-------------------------")





def main():
fname = "ERC20_token_network.csv"
analyze_token_frequency(fname)
#print_csv(fname)
#analyze_csv(fname)


if __name__ == "__main__":
main()
51 changes: 46 additions & 5 deletions tgb/linkproppred/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from clint.textui import progress

from tgb.linkproppred.negative_sampler import NegativeEdgeSampler
from tgb.utils.info import PROJ_DIR, DATA_URL_DICT, DATA_EVAL_METRIC_DICT, BColors
from tgb.utils.info import (
PROJ_DIR,
DATA_URL_DICT,
DATA_VERSION_DICT,
DATA_EVAL_METRIC_DICT,
BColors
)
from tgb.utils.pre_process import (
csv_to_pd_data,
process_node_feat,
Expand Down Expand Up @@ -67,9 +73,15 @@ def __init__(
self.meta_dict["fname"] = self.root + "/" + self.name + "_edgelist.csv"
self.meta_dict["nodefile"] = None

# TODO update the logic here to load the filenames from info.py
if name == "tgbl-flight":
self.meta_dict["nodefile"] = self.root + "/" + "airport_node_feat.csv"

self.meta_dict["val_ns"] = self.root + "/" + self.name + "_val_ns.pkl"
self.meta_dict["test_ns"] = self.root + "/" + self.name + "_test_ns.pkl"

#! version check
self.version_passed = True
self._version_check()

# initialize
self._node_feat = None
Expand All @@ -94,6 +106,34 @@ def __init__(
dataset_name=self.name, strategy="hist_rnd"
)

def _version_check(self) -> None:
r"""Implement Version checks for dataset files
updates the file names based on the current version number
prompt the user to download the new version via self.version_passed variable
"""
if (self.name in DATA_VERSION_DICT):
version = DATA_VERSION_DICT[self.name]
else:
print(f"Dataset {self.name} version number not found.")
self.version_passed = False
return None

if (version > 1):
#* check if current version is outdated
self.meta_dict["fname"] = self.root + "/" + self.name + "_edgelist_v" + str(int(version)) + ".csv"
self.meta_dict["nodefile"] = None
if self.name == "tgbl-flight":
self.meta_dict["nodefile"] = self.root + "/" + "airport_node_feat_v" + str(int(version)) + ".csv"
self.meta_dict["val_ns"] = self.root + "/" + self.name + "_val_ns_v" + str(int(version)) + ".pkl"
self.meta_dict["test_ns"] = self.root + "/" + self.name + "_test_ns_v" + str(int(version)) + ".pkl"

if (not osp.exists(self.meta_dict["fname"])):
print(f"Dataset {self.name} version {int(version)} not found.")
print(f"Please download the latest version of the dataset.")
self.version_passed = False
return None


def download(self):
"""
downloads this dataset from url
Expand Down Expand Up @@ -138,6 +178,7 @@ def download(self):
with zipfile.ZipFile(path_download, "r") as zip_ref:
zip_ref.extractall(self.root)
print(f"{BColors.OKGREEN}Download completed {BColors.ENDC}")
self.version_passed = True
else:
raise Exception(
BColors.FAIL + "Data not found error, download " + self.name + " failed"
Expand All @@ -163,7 +204,7 @@ def generate_processed_files(self) -> pd.DataFrame:
if self.meta_dict["nodefile"] is not None:
OUT_NODE_FEAT = self.root + "/" + "ml_{}.pkl".format(self.name + "_node")

if osp.exists(OUT_DF):
if (osp.exists(OUT_DF)) and (self.version_passed is True):
print("loading processed file")
df = pd.read_pickle(OUT_DF)
edge_feat = load_pkl(OUT_EDGE_FEAT)
Expand Down Expand Up @@ -281,15 +322,15 @@ def load_val_ns(self) -> None:
load the negative samples for the validation set
"""
self.ns_sampler.load_eval_set(
fname=self.root + "/" + self.name + "_val_ns.pkl", split_mode="val"
fname=self.meta_dict["val_ns"], split_mode="val"
)

def load_test_ns(self) -> None:
r"""
load the negative samples for the test set
"""
self.ns_sampler.load_eval_set(
fname=self.root + "/" + self.name + "_test_ns.pkl", split_mode="test"
fname=self.meta_dict["test_ns"], split_mode="test"
)

@property
Expand Down
Loading