Skip to content

Commit

Permalink
Merge pull request #426 from datamol-io/cleanup
Browse files Browse the repository at this point in the history
Fixed GatedGCN in float16. Minor cleanups.
  • Loading branch information
DomInvivo authored Aug 3, 2023
2 parents 2e3d0bf + c7966f0 commit d7c910e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
4 changes: 0 additions & 4 deletions graphium/nn/architectures/global_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,12 +1155,8 @@ def forward(self, g: Batch) -> Tensor:
# Apply the positional encoders
g = self.encoder_manager(g)

g["feat"] = g["feat"]
e = None

if "edge_feat" in get_keys(g):
g["edge_feat"] = g["edge_feat"]

# Run the pre-processing network on node features
if self.pre_nn is not None:
g["feat"] = self.pre_nn.forward(g["feat"])
Expand Down
27 changes: 21 additions & 6 deletions graphium/nn/pyg_layers/gated_gcn_pyg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
eps: float = 1e-5,
**kwargs,
):
r"""
Expand Down Expand Up @@ -63,6 +64,10 @@ def __init__(
- "layer_norm": Layer normalization
- `Callable`: Any callable function
eps:
Epsilon value for the normalization `sum(gate_weights * messages) / (sum(gate_weights) + eps)`,
where `gate_weights` are the weights of the gates and follow a sigmoid function.
"""
MessagePassing.__init__(self, aggr="add", flow="source_to_target", node_dim=-2)
BaseGraphStructure.__init__(
Expand Down Expand Up @@ -92,6 +97,7 @@ def __init__(
self.edge_out = FCLayer(
in_dim=out_dim, out_dim=out_dim_edges, activation=None, dropout=dropout, bias=True
)
self.eps = eps

def forward(
self,
Expand Down Expand Up @@ -163,15 +169,24 @@ def aggregate(
Returns:
out: dim [n_nodes, out_dim]
"""
dim_size = Bx.shape[0] # or None ?? <--- Double check this
dim_size = Bx.shape[0]

# Sum the messages, weighted by the gates. Sum the gates.
numerator_eta_xj = scatter(sigma_ij * Bx_j, index, 0, None, dim_size, reduce="sum")
denominator_eta_xj = scatter(sigma_ij, index, 0, None, dim_size, reduce="sum")

sum_sigma_x = sigma_ij * Bx_j
numerator_eta_xj = scatter(sum_sigma_x, index, 0, None, dim_size, reduce="sum")
# Cast to float32 if needed
dtype = denominator_eta_xj.dtype
if dtype == torch.float16:
numerator_eta_xj = numerator_eta_xj.to(dtype=torch.float32)
denominator_eta_xj = denominator_eta_xj.to(dtype=torch.float32)

sum_sigma = sigma_ij
denominator_eta_xj = scatter(sum_sigma, index, 0, None, dim_size, reduce="sum")
# Normalize the messages by the sum of the gates
out = numerator_eta_xj / (denominator_eta_xj + self.eps)

out = numerator_eta_xj / (denominator_eta_xj + 1e-6)
# Cast back to float16 if needed
if dtype == torch.float16:
out = out.to(dtype=dtype)
return out

def update(self, aggr_out: torch.Tensor, Ax: torch.Tensor):
Expand Down

0 comments on commit d7c910e

Please sign in to comment.