From d271bcdea234d1fdf3e94cdd1a30ed8eb48983bd Mon Sep 17 00:00:00 2001 From: kerstink-GC Date: Wed, 25 Oct 2023 11:37:04 +0000 Subject: [PATCH 1/3] fixed output scaling bug and removed double residual edges --- graphium/nn/pyg_layers/gps_pyg.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/graphium/nn/pyg_layers/gps_pyg.py b/graphium/nn/pyg_layers/gps_pyg.py index e706bd5b3..53939b24a 100644 --- a/graphium/nn/pyg_layers/gps_pyg.py +++ b/graphium/nn/pyg_layers/gps_pyg.py @@ -47,7 +47,6 @@ def __init__( activation: Union[Callable, str] = "relu", dropout: float = 0.0, node_residual: Optional[bool] = True, - edge_residual: Optional[bool] = True, normalization: Union[str, Callable] = "none", mpnn_type: str = "pyg:gine", mpnn_kwargs: Optional[dict] = None, @@ -101,9 +100,6 @@ def __init__( node_residual: If node residual is used after on the gnn layer output - edge_residual: - If edge residual is used after on the gnn layer output - normalization: Normalization to use. Choices: @@ -175,7 +171,6 @@ def __init__( # Residual connections self.node_residual = node_residual - self.edge_residual = edge_residual self.precision = precision @@ -229,7 +224,7 @@ def scale_activations(self, feature: Tensor, scale_factor: Tensor) -> Tensor: Tensor: The scaled features """ scale_factor = torch.tensor(scale_factor).to(feature.device) - feature *= scale_factor.to(dtype=feature.dtype) + feature = feature/scale_factor.to(dtype=feature.dtype) return feature def forward(self, batch: Batch) -> Batch: @@ -252,25 +247,16 @@ def forward(self, batch: Batch) -> Batch: if self.mpnn is not None: batch_out = self.mpnn(batch_out) h_local = batch_out.feat - e_local = batch_out.edge_feat if self.dropout_local is not None: h_local = self.dropout_local(h_local) # Apply the residual connection for the node features and scale the activations by some value to help reduce activation growth if self.node_residual: if self.layer_depth < 1: h_local = self.residual_add(h_local, feat_in) - h_local *= 1 / self.scale_activations(h_local, self.output_scale) + h_local *= self.scale_activations(h_local, self.output_scale) else: - h_local *= 1 / self.scale_activations(h_local, self.output_scale) + h_local *= self.scale_activations(h_local, self.output_scale) h_local = self.residual_add(h_local, feat_in) - # Apply the residual connection for the edge features and scale the activations by some value to help reduce activation growth - if self.edge_residual and self.use_edges: - if self.layer_depth < 1: - e_local = self.residual_add(e_local, edges_feat_in) - e_local *= 1 / self.scale_activations(e_local, self.output_scale) - else: - e_local *= 1 / self.scale_activations(e_local, self.output_scale) - e_local = self.residual_add(e_local, edges_feat_in) if self.norm_layer_local is not None: h_local = self.norm_layer_local(h_local) From 8ceb167b764a31e2f5c676f62919a6815b80bb4a Mon Sep 17 00:00:00 2001 From: kerstink-GC Date: Wed, 25 Oct 2023 13:00:31 +0000 Subject: [PATCH 2/3] lint --- graphium/nn/pyg_layers/gps_pyg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphium/nn/pyg_layers/gps_pyg.py b/graphium/nn/pyg_layers/gps_pyg.py index 53939b24a..6195cb270 100644 --- a/graphium/nn/pyg_layers/gps_pyg.py +++ b/graphium/nn/pyg_layers/gps_pyg.py @@ -224,7 +224,7 @@ def scale_activations(self, feature: Tensor, scale_factor: Tensor) -> Tensor: Tensor: The scaled features """ scale_factor = torch.tensor(scale_factor).to(feature.device) - feature = feature/scale_factor.to(dtype=feature.dtype) + feature = feature / scale_factor.to(dtype=feature.dtype) return feature def forward(self, batch: Batch) -> Batch: From ee56497e857b080a7d73ff6e59b5d3ae4a479f06 Mon Sep 17 00:00:00 2001 From: kerstink-GC Date: Wed, 25 Oct 2023 13:23:26 +0000 Subject: [PATCH 3/3] implemented feedback --- graphium/nn/pyg_layers/gps_pyg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphium/nn/pyg_layers/gps_pyg.py b/graphium/nn/pyg_layers/gps_pyg.py index 6195cb270..bc04c8288 100644 --- a/graphium/nn/pyg_layers/gps_pyg.py +++ b/graphium/nn/pyg_layers/gps_pyg.py @@ -253,9 +253,9 @@ def forward(self, batch: Batch) -> Batch: if self.node_residual: if self.layer_depth < 1: h_local = self.residual_add(h_local, feat_in) - h_local *= self.scale_activations(h_local, self.output_scale) + h_local = self.scale_activations(h_local, self.output_scale) else: - h_local *= self.scale_activations(h_local, self.output_scale) + h_local = self.scale_activations(h_local, self.output_scale) h_local = self.residual_add(h_local, feat_in) if self.norm_layer_local is not None: