-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c919de1
commit 526f7e2
Showing
15 changed files
with
1,112 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,98 @@ | ||
# LightHGNN | ||
Source code for ICLR 2024 "LightHGNN: Distilling Hypergraph Neural Networks into MLPs for 100x Faster Inference" | ||
This repository contains the source code for the paper "LightHGNN: Distilling Hypergraph Neural Networks into MLPs for 100x Faster Inference" published in IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI) 2024 by [Yifan Feng](https://fengyifan.site/), Yihe Luo, Shihui Ying, Yue Gao*. This paper is available at [here](https://openreview.net/forum?id=lHasEfGsXL). | ||
|
||
 | ||
|
||
## Introduction | ||
In this repository, we provide the implementation of our LightHGNNs, including LightHGNN and LightHGNN+, which is based on the following environments: | ||
* [python 3.9](https://www.python.org/): basic programming language. | ||
* [dhg 0.9.4](https://github.com/iMoonLab/DeepHypergraph): for hypergraph representation and learning. | ||
* [torch 1.12.1](https://pytorch.org/): for computation. | ||
* [hydra-core 1.3.2](https://hydra.cc/docs/intro/): for configuration and multi-run management. | ||
|
||
|
||
## Installation | ||
1. Clone this repository. | ||
2. Install the required libraries. | ||
``` bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Usage | ||
For transtive setting, you can run the following command: | ||
```bash | ||
python trans_train.py | ||
``` | ||
For multi-run and obtain the average results, you can run the following command: | ||
```bash | ||
python trans_multi_exp.py | ||
``` | ||
For production setting, you can run the following command: | ||
```bash | ||
python prod_train.py | ||
``` | ||
For multi-run and obtain the average results, you can run the following command: | ||
```bash | ||
python prod_multi_exp.py | ||
``` | ||
|
||
**Change Models** | ||
You can change the teacher by modifying the `teacher` in `trans_config.yaml` and `prod_config.yaml` as following: | ||
```yaml | ||
model: | ||
teacher: hgnn # hgnn, hgnnp, hnhn, unigcn | ||
``` | ||
Also, you can change the student by modifying the `student` in `trans_config.yaml` and `prod_config.yaml` as following: | ||
```yaml | ||
model: | ||
student: light_hgnn # light_hgnn, light_hgnnp | ||
``` | ||
|
||
**Change Datasets** | ||
In our paper, 13 grpah/hypergraph datasets are adopted for evaluation. | ||
- Graph datasets: `cora`, `pubmed`, `citeseer` | ||
- Hypergraph datasets: `news20`, `ca_cora`, `cc_cora`, `cc_citeseer`, `dblp4k_conf`, `dblp4k_paper`, `dblp4k_term`, `imdb_aw`, `recipe_100k`, `recipe_200k` | ||
|
||
You can change the dataset by modifying the `dataset` in `trans_config.yaml` and `prod_config.yaml` as following: | ||
```yaml | ||
data: | ||
name: dblp4k_paper # cora, pubmed, news20, ca_cora, dblp4k_term, imdb_aw, ... | ||
``` | ||
|
||
**Important Note** | ||
Since the `recipe_100k` and `recipe_200k` datasets are too large and contains more than 10k vertices, the two dataset can only be used under the production setting. **Please do not use the two datasets for the transitive setting.** | ||
|
||
## Citation | ||
If you find this repository useful in your research, please cite the following papers: | ||
``` | ||
@inproceedings{feng2024lighthgnn, | ||
title={Light{HGNN}: Distilling Hypergraph Neural Networks into {MLP}s for 100x Faster Inference}, | ||
author={Feng, Yifan and Luo, Yihe and Ying, Shihui and Gao, Yue}, | ||
booktitle={The Twelfth International Conference on Learning Representations}, | ||
year={2024}, | ||
} | ||
@article{gao2022hgnn+, | ||
title={HGNN+: General hypergraph neural networks}, | ||
author={Gao, Yue and Feng, Yifan and Ji, Shuyi and Ji, Rongrong}, | ||
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, | ||
volume={45}, | ||
number={3}, | ||
pages={3181--3199}, | ||
year={2022}, | ||
publisher={IEEE} | ||
} | ||
@inproceedings{feng2019hypergraph, | ||
title={Hypergraph neural networks}, | ||
author={Feng, Yifan and You, Haoxuan and Zhang, Zizhao and Ji, Rongrong and Gao, Yue}, | ||
booktitle={Proceedings of the AAAI conference on artificial intelligence}, | ||
volume={33}, | ||
number={01}, | ||
pages={3558--3565}, | ||
year={2019} | ||
} | ||
``` | ||
|
||
|
||
|
19 changes: 19 additions & 0 deletions
19
...0-1.0__hgnnp-light_hgnnp__hid-128__lamb-0__tau-1.0/2024-01-30_17-05-05/.hydra/config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
data: | ||
name: recipe_100k | ||
num_train: 20 | ||
num_val: 100 | ||
test_ind_ratio: 0.9 | ||
ft_noise_level: 0.0 | ||
hc_noise_level: 1.0 | ||
model: | ||
teacher: hgnnp | ||
student: light_hgnnp | ||
hid: 128 | ||
loss: | ||
lamb: 0 | ||
tau: 1.0 | ||
data_marker: ${data.name}__${data.num_train}-${data.num_val}-${data.test_ind_ratio}__noise-${data.ft_noise_level}-${data.hc_noise_level} | ||
model_marker: ${model.teacher}-${model.student}__hid-${model.hid} | ||
loss_marker: lamb-${loss.lamb}__tau-${loss.tau} | ||
task: ${data_marker}__${model_marker}__${loss_marker} | ||
res_path: cache/ind/${task} |
154 changes: 154 additions & 0 deletions
154
....0-1.0__hgnnp-light_hgnnp__hid-128__lamb-0__tau-1.0/2024-01-30_17-05-05/.hydra/hydra.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
hydra: | ||
run: | ||
dir: ${res_path}/${now:%Y-%m-%d}_${now:%H-%M-%S} | ||
sweep: | ||
dir: ${res_path}/${now:%Y-%m-%d}_${now:%H-%M-%S} | ||
subdir: ${hydra.job.num} | ||
launcher: | ||
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher | ||
sweeper: | ||
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper | ||
max_batch_size: null | ||
params: null | ||
help: | ||
app_name: ${hydra.job.name} | ||
header: '${hydra.help.app_name} is powered by Hydra. | ||
' | ||
footer: 'Powered by Hydra (https://hydra.cc) | ||
Use --hydra-help to view Hydra specific help | ||
' | ||
template: '${hydra.help.header} | ||
== Configuration groups == | ||
Compose your configuration from those groups (group=option) | ||
$APP_CONFIG_GROUPS | ||
== Config == | ||
Override anything in the config (foo.bar=value) | ||
$CONFIG | ||
${hydra.help.footer} | ||
' | ||
hydra_help: | ||
template: 'Hydra (${hydra.runtime.version}) | ||
See https://hydra.cc for more info. | ||
== Flags == | ||
$FLAGS_HELP | ||
== Configuration groups == | ||
Compose your configuration from those groups (For example, append hydra/job_logging=disabled | ||
to command line) | ||
$HYDRA_CONFIG_GROUPS | ||
Use ''--cfg hydra'' to Show the Hydra config. | ||
' | ||
hydra_help: ??? | ||
hydra_logging: | ||
version: 1 | ||
formatters: | ||
simple: | ||
format: '[%(asctime)s][HYDRA] %(message)s' | ||
handlers: | ||
console: | ||
class: logging.StreamHandler | ||
formatter: simple | ||
stream: ext://sys.stdout | ||
root: | ||
level: INFO | ||
handlers: | ||
- console | ||
loggers: | ||
logging_example: | ||
level: DEBUG | ||
disable_existing_loggers: false | ||
job_logging: | ||
version: 1 | ||
formatters: | ||
simple: | ||
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' | ||
handlers: | ||
console: | ||
class: logging.StreamHandler | ||
formatter: simple | ||
stream: ext://sys.stdout | ||
file: | ||
class: logging.FileHandler | ||
formatter: simple | ||
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log | ||
root: | ||
level: INFO | ||
handlers: | ||
- console | ||
- file | ||
disable_existing_loggers: false | ||
env: {} | ||
mode: RUN | ||
searchpath: [] | ||
callbacks: {} | ||
output_subdir: .hydra | ||
overrides: | ||
hydra: | ||
- hydra.mode=RUN | ||
task: [] | ||
job: | ||
name: prod_train | ||
chdir: null | ||
override_dirname: '' | ||
id: ??? | ||
num: ??? | ||
config_name: prod_config | ||
env_set: {} | ||
env_copy: [] | ||
config: | ||
override_dirname: | ||
kv_sep: '=' | ||
item_sep: ',' | ||
exclude_keys: [] | ||
runtime: | ||
version: 1.2.0 | ||
version_base: '1.1' | ||
cwd: /home/fengyifan/OS3D/LightHGNN/LightHGNN | ||
config_sources: | ||
- path: hydra.conf | ||
schema: pkg | ||
provider: hydra | ||
- path: /home/fengyifan/OS3D/LightHGNN/LightHGNN | ||
schema: file | ||
provider: main | ||
- path: '' | ||
schema: structured | ||
provider: schema | ||
output_dir: /home/fengyifan/OS3D/LightHGNN/LightHGNN/cache/ind/recipe_100k__20-100-0.9__noise-0.0-1.0__hgnnp-light_hgnnp__hid-128__lamb-0__tau-1.0/2024-01-30_17-05-05 | ||
choices: | ||
hydra/env: default | ||
hydra/callbacks: null | ||
hydra/job_logging: default | ||
hydra/hydra_logging: default | ||
hydra/hydra_help: default | ||
hydra/help: default | ||
hydra/sweeper: basic | ||
hydra/launcher: basic | ||
hydra/output: default | ||
verbose: false |
1 change: 1 addition & 0 deletions
1
....0__hgnnp-light_hgnnp__hid-128__lamb-0__tau-1.0/2024-01-30_17-05-05/.hydra/overrides.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[] |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
import dhg | ||
from dhg.nn import MLP | ||
from dhg.nn import GCNConv | ||
from dhg.nn import HGNNConv | ||
|
||
|
||
class MyGCN(nn.Module): | ||
r"""The GCN model proposed in `Semi-Supervised Classification with Graph Convolutional Networks <https://arxiv.org/pdf/1609.02907>`_ paper (ICLR 2017). | ||
Args: | ||
``in_channels`` (``int``): :math:`C_{in}` is the number of input channels. | ||
``hid_channels`` (``int``): :math:`C_{hid}` is the number of hidden channels. | ||
``num_classes`` (``int``): The Number of class of the classification task. | ||
``use_bn`` (``bool``): If set to ``True``, use batch normalization. Defaults to ``False``. | ||
``drop_rate`` (``float``): Dropout ratio. Defaults to ``0.5``. | ||
""" | ||
def __init__(self, in_channels: int, | ||
hid_channels: int, | ||
num_classes: int, | ||
use_bn: bool = False, | ||
drop_rate: float = 0.5) -> None: | ||
super().__init__() | ||
self.layers0 = GCNConv(in_channels, hid_channels, use_bn=use_bn, drop_rate=drop_rate) | ||
self.layers1 = GCNConv(hid_channels, num_classes, use_bn=use_bn, is_last=True) | ||
|
||
def forward(self, X: torch.Tensor, g: "dhg.Graph", get_emb=False) -> torch.Tensor: | ||
r"""The forward function. | ||
Args: | ||
``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`. | ||
``g`` (``dhg.Graph``): The graph structure that contains :math:`N` vertices. | ||
""" | ||
emb = self.layers0(X, g) | ||
X = self.layers1(emb, g) | ||
if get_emb: | ||
return emb | ||
else: | ||
return X | ||
|
||
|
||
class MyHGNN(nn.Module): | ||
r"""The HGNN model proposed in `Hypergraph Neural Networks <https://arxiv.org/pdf/1809.09401>`_ paper (AAAI 2019). | ||
Args: | ||
``in_channels`` (``int``): :math:`C_{in}` is the number of input channels. | ||
``hid_channels`` (``int``): :math:`C_{hid}` is the number of hidden channels. | ||
``num_classes`` (``int``): The Number of class of the classification task. | ||
``use_bn`` (``bool``): If set to ``True``, use batch normalization. Defaults to ``False``. | ||
``drop_rate`` (``float``, optional): Dropout ratio. Defaults to 0.5. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_channels: int, | ||
hid_channels: int, | ||
num_classes: int, | ||
use_bn: bool = False, | ||
drop_rate: float = 0.5, | ||
) -> None: | ||
super().__init__() | ||
self.layers0 = HGNNConv(in_channels, hid_channels, use_bn=use_bn, drop_rate=drop_rate) | ||
self.layers1 = HGNNConv(hid_channels, num_classes, use_bn=use_bn, is_last=True) | ||
|
||
def forward(self, X: torch.Tensor, hg: "dhg.Hypergraph", get_emb=False) -> torch.Tensor: | ||
r"""The forward function. | ||
Args: | ||
``X`` (``torch.Tensor``): Input vertex feature matrix. Size :math:`(N, C_{in})`. | ||
``hg`` (``dhg.Hypergraph``): The hypergraph structure that contains :math:`N` vertices. | ||
""" | ||
emb = self.layers0(X, hg) | ||
X = self.layers1(emb, hg) | ||
if get_emb: | ||
return emb | ||
else: | ||
return X | ||
|
||
|
||
class MyMLPs(nn.Module): | ||
def __init__(self, dim_in, dim_hid, n_classes) -> None: | ||
super().__init__() | ||
self.layer0 = MLP([dim_in, dim_hid]) | ||
self.layer1 = nn.Linear(dim_hid, n_classes) | ||
|
||
def forward(self, X, get_emb=False): | ||
emb = self.layer0(X) | ||
X = self.layer1(emb) | ||
if get_emb: | ||
return emb | ||
else: | ||
return X |
Oops, something went wrong.