Skip to content

jackd/grax

Repository files navigation

Grax: Graph Neural Networks in Jax

Code style: black

This project aims to provide re-implementations of neural networks in jax, as close to the original author's implementations as practical. Apart from different default initializations,known deviations from original implementation logic are documented.

Implementations are in their own projects. Current implementations include:

  • APPNP: Approximate Personalized Propagation of Neural Predictions
  • DAGNN: Deep Adaptive Graph Neural Networks
  • DEQ_GCN: (Stalled WIP) Deep Equilibrium Graph Convolution Networks
  • GAT: Graph Attention Networks
  • GCN: Graph Convolution Networks
  • GCN2: Graph Convolution Networks 2
  • IGAT: (Stalled WIP) Inverse Graph Attention Networks
  • igcn: Inverse Graph Convolution Networks
  • pigcn: Pseudo-inverse Graph Convolution Networks
  • sgc: Simple Graph Convolution Networks

See the relevant subdirectory README.md for more details and example usage.

Dependencies

This project uses the following large open-source projects:

  • jax for performant low-level operations;
  • haiku for parameter / state management;
  • optax for optimizer implementations; and
  • gin for configuration.

Additional functionality is provided in smaller repositories:

  • spax for sparse jax classes and operations; and
  • huf: minimal framework built on top of haiku and optax.

This library is in early rapid development - things will break frequently.

Installation

After installing jax,

git clone https://github.com/jackd/grax
cd grax
pip install -r requirements.txt
pip install -e .  # local install

Datasets

DGL: Citations, Amazon and Coauthor

Citations datasets use dgl which will be installed with the above. You can customize where to download/extract relevant files with:

export DGL_DATA=/path/to/dgl_data_dir  # otherwise uses ~/.dgl
pip install ogb
export OGB_DATA=/path/to/ogb_data_dir  # otherwise uses ~/ogb

Quick Start

After installing:

# run a single GCN model on pubmed dataset
python -m grax grax_config/single/fit.gin gcn/config/pubmed.gin
# customize configuration
python -m grax grax_config/single/fit.gin gcn/config/pubmed.gin --bindings='
dropout_rate=0.6
seed=1
'
# perform multiple runs
python -m grax grax_config/single/fit_many.gin gcn/config/pubmed.gin
# perform multiple runs with ray
python -m grax grax_config/single/ray/fit_many.gin gcn/config/pubmed.gin

Pre-commit

This package uses pre-commit to ensure commits meet minimum criteria. To Install, use

pip install pre-commit
pre-commit install

This will ensure git hooks are run before each commit. While it is not advised to do so, you can skip these hooks with

git commit --no-verify -m "commit message"

About

Graph Networks with Jax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages