diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 6d9316663..42478efa9 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -1155,12 +1155,8 @@ def forward(self, g: Batch) -> Tensor: # Apply the positional encoders g = self.encoder_manager(g) - g["feat"] = g["feat"] e = None - if "edge_feat" in get_keys(g): - g["edge_feat"] = g["edge_feat"] - # Run the pre-processing network on node features if self.pre_nn is not None: g["feat"] = self.pre_nn.forward(g["feat"]) diff --git a/graphium/nn/pyg_layers/gated_gcn_pyg.py b/graphium/nn/pyg_layers/gated_gcn_pyg.py index 97bbfbf5e..348b6ad7a 100644 --- a/graphium/nn/pyg_layers/gated_gcn_pyg.py +++ b/graphium/nn/pyg_layers/gated_gcn_pyg.py @@ -31,6 +31,7 @@ def __init__( activation: Union[Callable, str] = "relu", dropout: float = 0.0, normalization: Union[str, Callable] = "none", + eps: float = 1e-5, **kwargs, ): r""" @@ -63,6 +64,10 @@ def __init__( - "layer_norm": Layer normalization - `Callable`: Any callable function + eps: + Epsilon value for the normalization `sum(gate_weights * messages) / (sum(gate_weights) + eps)`, + where `gate_weights` are the weights of the gates and follow a sigmoid function. + """ MessagePassing.__init__(self, aggr="add", flow="source_to_target", node_dim=-2) BaseGraphStructure.__init__( @@ -92,6 +97,7 @@ def __init__( self.edge_out = FCLayer( in_dim=out_dim, out_dim=out_dim_edges, activation=None, dropout=dropout, bias=True ) + self.eps = eps def forward( self, @@ -163,15 +169,24 @@ def aggregate( Returns: out: dim [n_nodes, out_dim] """ - dim_size = Bx.shape[0] # or None ?? <--- Double check this + dim_size = Bx.shape[0] + + # Sum the messages, weighted by the gates. Sum the gates. + numerator_eta_xj = scatter(sigma_ij * Bx_j, index, 0, None, dim_size, reduce="sum") + denominator_eta_xj = scatter(sigma_ij, index, 0, None, dim_size, reduce="sum") - sum_sigma_x = sigma_ij * Bx_j - numerator_eta_xj = scatter(sum_sigma_x, index, 0, None, dim_size, reduce="sum") + # Cast to float32 if needed + dtype = denominator_eta_xj.dtype + if dtype == torch.float16: + numerator_eta_xj = numerator_eta_xj.to(dtype=torch.float32) + denominator_eta_xj = denominator_eta_xj.to(dtype=torch.float32) - sum_sigma = sigma_ij - denominator_eta_xj = scatter(sum_sigma, index, 0, None, dim_size, reduce="sum") + # Normalize the messages by the sum of the gates + out = numerator_eta_xj / (denominator_eta_xj + self.eps) - out = numerator_eta_xj / (denominator_eta_xj + 1e-6) + # Cast back to float16 if needed + if dtype == torch.float16: + out = out.to(dtype=dtype) return out def update(self, aggr_out: torch.Tensor, Ax: torch.Tensor):