Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancing GraphGym Documentation #7885

Merged
merged 12 commits into from
Sep 1, 2023
135 changes: 114 additions & 21 deletions torch_geometric/graphgym/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,27 @@ def GNNLayer(dim_in, dim_out, has_act=True):
dim_out (int): Output dimension
has_act (bool): Whether has activation function after the layer

Returns:
GeneralLayer: A GNN layer configured according to the provided parameters.

This function creates a GNN layer based on the specified input and output dimensions,
with an optional activation function. The layer configuration is determined using the
`new_layer_config` function, considering the provided settings and configurations.

Example:
To create a GNN layer with input dimension 16 and output dimension 32:
>>> layer = GNNLayer(dim_in=16, dim_out=32)

Note:
Make sure the `new_layer_config` function is properly configured before using this
function.

"""
return GeneralLayer(
cfg.gnn.layer_type,
layer_config=new_layer_config(dim_in, dim_out, 1, has_act=has_act,
has_bias=False, cfg=cfg))
has_bias=False, cfg=cfg),
)


def GNNPreMP(dim_in, dim_out, num_layers):
Expand All @@ -39,16 +55,33 @@ def GNNPreMP(dim_in, dim_out, num_layers):
dim_out (int): Output dimension
num_layers (int): Number of layers

Returns:
GeneralMultiLayer: A stack of neural network layers for preprocessing before GNN
message passing.

This function creates a sequence of neural network layers intended to preprocess input
features before GNN message passing. The number of layers, input dimension, and output
dimension are specified, and the layer configuration is determined using the
`new_layer_config` function.

Example:
To create a stack of 3 linear layers with input dimension 16 and output dimension 32:
>>> pre_mp_layers = GNNPreMP(dim_in=16, dim_out=32, num_layers=3)

Note:
Make sure the `new_layer_config` function is properly configured before using this
function.
"""
return GeneralMultiLayer(
'linear',
"linear",
layer_config=new_layer_config(dim_in, dim_out, num_layers,
has_act=False, has_bias=False, cfg=cfg))
has_act=False, has_bias=False, cfg=cfg),
)


@register_stage('stack')
@register_stage('skipsum')
@register_stage('skipconcat')
@register_stage("stack")
@register_stage("skipsum")
@register_stage("skipconcat")
class GNNStackStage(nn.Module):
"""
Simple Stage that stack GNN layers
Expand All @@ -62,21 +95,20 @@ def __init__(self, dim_in, dim_out, num_layers):
super().__init__()
self.num_layers = num_layers
for i in range(num_layers):
if cfg.gnn.stage_type == 'skipconcat':
if cfg.gnn.stage_type == "skipconcat":
d_in = dim_in if i == 0 else dim_in + i * dim_out
else:
d_in = dim_in if i == 0 else dim_out
layer = GNNLayer(d_in, dim_out)
self.add_module('layer{}'.format(i), layer)
self.add_module("layer{}".format(i), layer)

def forward(self, batch):
for i, layer in enumerate(self.children()):
x = batch.x
batch = layer(batch)
if cfg.gnn.stage_type == 'skipsum':
if cfg.gnn.stage_type == "skipsum":
batch.x = x + batch.x
elif cfg.gnn.stage_type == 'skipconcat' and \
i < self.num_layers - 1:
elif cfg.gnn.stage_type == "skipconcat" and i < self.num_layers - 1:
batch.x = torch.cat([x, batch.x], dim=1)
if cfg.gnn.l2norm:
batch.x = F.normalize(batch.x, p=2, dim=-1)
Expand All @@ -85,10 +117,38 @@ def forward(self, batch):

class FeatureEncoder(nn.Module):
"""
Encoding node and edge features
Encodes node and edge features based on the provided configurations.

Args:
dim_in (int): Input feature dimension
dim_in (int): Input feature dimension.

Attributes:
dim_in (int): The current input feature dimension after applying encoders.
node_encoder (nn.Module): Node feature encoder module, if enabled.
edge_encoder (nn.Module): Edge feature encoder module, if enabled.
node_encoder_bn (BatchNorm1dNode): Batch normalization for node encoder output,
if enabled.
edge_encoder_bn (BatchNorm1dNode): Batch normalization for edge encoder output,
if enabled.

The FeatureEncoder module encodes node and edge features based on the configurations
specified in the provided `cfg` object. It supports encoding integer node and edge
features using embeddings, optionally followed by batch normalization. The output
dimension of the encoded features is determined by the `cfg.gnn.dim_inner` parameter.

If `cfg.dataset.node_encoder` or `cfg.dataset.edge_encoder` is enabled, the respective
encoder modules are created based on the provided encoder names. If batch
normalization is enabled for either encoder, the corresponding batch normalization
layer is added after the encoder.

Example:
Given an instance of FeatureEncoder:
>>> encoder = FeatureEncoder(dim_in=16)
>>> encoded_features = encoder(batch)

Note:
Make sure to set up the configuration (`cfg`) appropriately before creating an
instance of FeatureEncoder.
"""
def __init__(self, dim_in):
super().__init__()
Expand All @@ -100,8 +160,14 @@ def __init__(self, dim_in):
self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
if cfg.dataset.node_encoder_bn:
self.node_encoder_bn = BatchNorm1dNode(
new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False,
has_bias=False, cfg=cfg))
new_layer_config(
cfg.gnn.dim_inner,
-1,
-1,
has_act=False,
has_bias=False,
cfg=cfg,
))
# Update dim_in to reflect the new dimension fo the node features
self.dim_in = cfg.gnn.dim_inner
if cfg.dataset.edge_encoder:
Expand All @@ -111,8 +177,14 @@ def __init__(self, dim_in):
self.edge_encoder = EdgeEncoder(cfg.gnn.dim_inner)
if cfg.dataset.edge_encoder_bn:
self.edge_encoder_bn = BatchNorm1dNode(
new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False,
has_bias=False, cfg=cfg))
new_layer_config(
cfg.gnn.dim_inner,
-1,
-1,
has_act=False,
has_bias=False,
cfg=cfg,
))

def forward(self, batch):
for module in self.children():
Expand All @@ -122,12 +194,33 @@ def forward(self, batch):

class GNN(nn.Module):
"""
General GNN model: encoder + stage + head
General Graph Neural Network (GNN) model composed of an encoder, processing stage,
and head.

Args:
dim_in (int): Input dimension
dim_out (int): Output dimension
**kwargs (optional): Optional additional args
dim_in (int): Input feature dimension.
dim_out (int): Output feature dimension.
**kwargs (optional): Optional additional keyword arguments.

Attributes:
encoder (FeatureEncoder): Node and edge feature encoder.
pre_mp (GNNPreMP, optional): Pre-message-passing processing layers, if any.
mp (GNNStage, optional): Message-passing stage, if any.
post_mp (GNNHead): Post-message-passing processing layers.

The GNN model consists of three main components: an encoder to transform input
features, a message-passing stage for information exchange, and a head to produce
final output features. The processing layers in each component are determined by
the provided configurations.

Example:
Given an instance of GNN:
>>> gnn = GNN(dim_in=16, dim_out=32)
>>> output = gnn(batch)

Note:
Make sure to set up the configuration (`cfg`) and any required module registrations
(`register`) before creating an instance of GNN.
"""
def __init__(self, dim_in, dim_out, **kwargs):
super().__init__()
Expand Down
41 changes: 41 additions & 0 deletions torch_geometric/graphgym/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,25 @@


class GraphGymDataModule(LightningDataModule):
"""
LightningDataModule for handling data loading in GraphGym.

This class provides data loaders for training, validation, and testing, which are
created using the `create_loader` function. The data loaders can be accessed through
the `train_dataloader`, `val_dataloader`, and `test_dataloader` methods, respectively.

Note:
Make sure to call the constructor of this class with appropriate configurations
before using it.

Example:
>>> data_module = GraphGymDataModule()
>>> model = GraphGymModule()
>>> train(model, datamodule=data_module)

Attributes:
loaders: List of DataLoader instances for train, validation, and test.
"""
def __init__(self):
self.loaders = create_loader()
super().__init__(has_val=True, has_test=True)
Expand All @@ -32,6 +51,28 @@ def test_dataloader(self) -> DataLoader:

def train(model: GraphGymModule, datamodule, logger: bool = True,
trainer_config: Optional[dict] = None):
"""
Train a GraphGym model using PyTorch Lightning.

Args:
model (GraphGymModule): The GraphGym model to be trained.
datamodule (GraphGymDataModule): The data module containing data loaders.
logger (bool): Whether to enable logging during training.
trainer_config (dict, optional): Additional trainer configurations.

This function trains the provided GraphGym model using PyTorch Lightning. It sets up
the trainer with given configurations, including callbacks for logging and checkpointing.
After training, the function also tests the model using the provided data module.

Example:
>>> data_module = GraphGymDataModule()
>>> model = GraphGymModule()
>>> train(model, datamodule=data_module)

Note:
Make sure the appropriate configurations (`cfg`) are set before calling this
function.
"""
warnings.filterwarnings('ignore', '.*use `CSVLogger` as the default.*')

callbacks = []
Expand Down