Skip to content

Commit

Permalink
Changed GatedGCNPyg to cast to float32
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo committed Aug 2, 2023
1 parent fed52cb commit c7966f0
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions graphium/nn/pyg_layers/gated_gcn_pyg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
eps: float = 1e-3,
eps: float = 1e-5,
**kwargs,
):
r"""
Expand Down Expand Up @@ -169,15 +169,24 @@ def aggregate(
Returns:
out: dim [n_nodes, out_dim]
"""
dim_size = Bx.shape[0] # or None ?? <--- Double check this
dim_size = Bx.shape[0]

sum_sigma_x = sigma_ij * Bx_j
numerator_eta_xj = scatter(sum_sigma_x, index, 0, None, dim_size, reduce="sum")
# Sum the messages, weighted by the gates. Sum the gates.
numerator_eta_xj = scatter(sigma_ij * Bx_j, index, 0, None, dim_size, reduce="sum")
denominator_eta_xj = scatter(sigma_ij, index, 0, None, dim_size, reduce="sum")

sum_sigma = sigma_ij
denominator_eta_xj = scatter(sum_sigma, index, 0, None, dim_size, reduce="sum")
# Cast to float32 if needed
dtype = denominator_eta_xj.dtype
if dtype == torch.float16:
numerator_eta_xj = numerator_eta_xj.to(dtype=torch.float32)
denominator_eta_xj = denominator_eta_xj.to(dtype=torch.float32)

# Normalize the messages by the sum of the gates
out = numerator_eta_xj / (denominator_eta_xj + self.eps)

# Cast back to float16 if needed
if dtype == torch.float16:
out = out.to(dtype=dtype)
return out

def update(self, aggr_out: torch.Tensor, Ax: torch.Tensor):
Expand Down

0 comments on commit c7966f0

Please sign in to comment.