Skip to content

Commit

Permalink
Merge pull request #364 from datamol-io/hybridCEloss_changes
Browse files Browse the repository at this point in the history
Hybrid ce loss changes
  • Loading branch information
DomInvivo committed Jun 23, 2023
2 parents 64a1fc3 + 2f28208 commit 892de3b
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
4 changes: 2 additions & 2 deletions expts/main_run_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
# CONFIG_FILE = "expts/configs/config_mpnn_10M_pcqm4m.yaml"
# CONFIG_FILE = "expts/neurips2023_configs/config_debug.yaml"
# CONFIG_FILE = "expts/neurips2023_configs/config_large_mpnn.yaml"
# CONFIG_FILE = "expts/neurips2023_configs/debug/config_large_gcn_debug.yaml"
CONFIG_FILE = "expts/neurips2023_configs/config_large_gin.yaml"
CONFIG_FILE = "expts/neurips2023_configs/debug/config_large_gcn_debug.yaml"
# CONFIG_FILE = "expts/neurips2023_configs/config_large_gin.yaml"
# CONFIG_FILE = "expts/neurips2023_configs/config_large_gcn.yaml"
# CONFIG_FILE = "expts/neurips2023_configs/config_large_gine.yaml"
# CONFIG_FILE = "expts/neurips2023_configs/config_small_gcn.yaml"
Expand Down
4 changes: 2 additions & 2 deletions expts/neurips2023_configs/debug/config_large_gcn_debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ architecture:
hidden_dims: 128
depth: 2
activation: none
last_activation: sigmoid
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
Expand All @@ -261,7 +261,7 @@ architecture:
hidden_dims: 128
depth: 2
activation: none
last_activation: sigmoid
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
Expand Down
9 changes: 7 additions & 2 deletions graphium/trainer/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(

self.brackets = Tensor(range(n_brackets))
self.alpha = alpha
self.softmax = torch.nn.Softmax(dim=1)

def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""
Expand All @@ -59,10 +60,14 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""

target = target.flatten()

regression_input = torch.inner(input, self.brackets.to(input.device))
# regression loss needs normalized logits to probability as input to do inner product with self.brackets
# we apply softmax on the raw logits first
softmax_input = self.softmax(input)
# [batch_size, n_classes] * [n_classes] ([0, 1, 2...n_brakets-1]) -> [batch_size]
regression_input = torch.inner(softmax_input, self.brackets.to(input.device))
regression_loss = self.regression_loss(regression_input, target.float(), reduction=self.reduction)

# cross_entropy loss needs raw logits as input
ce_loss = F.cross_entropy(input, target.long(), weight=self.weight, reduction=self.reduction)

return self.alpha * ce_loss + (1 - self.alpha) * regression_loss
12 changes: 8 additions & 4 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ class test_HybridCELoss(ut.TestCase):
input = torch.Tensor([[0.1, 0.1, 0.3, 0.5, 0.0], [0.1, 0.0, 0.7, 0.2, 0.0]])
target = torch.Tensor([3, 0]).long()
brackets = torch.Tensor([0, 1, 2, 3, 4])
regression_input = torch.Tensor([2.2, 2.0]) # inner product of input and brackets
regression_input = torch.Tensor([2.0537, 2.0017]) # inner product of input and brackets
regression_target = torch.Tensor([3, 0]).float()

def test_pure_ce_loss(self):
loss = HybridCELoss(n_brackets=len(self.brackets), alpha=1.0, reduction="none")

assert torch.allclose(
loss(self.input, self.target),
F.cross_entropy(self.input, self.target, reduction="none"),
Expand All @@ -38,10 +37,11 @@ def test_pure_mae_loss(self):
regression_loss="mae",
reduction="none",
)

assert torch.allclose(
loss(self.input, self.target),
F.l1_loss(self.regression_input, self.regression_target, reduction="none"),
rtol=1e-04,
atol=1e-07,
)
assert loss(self.input, self.target).shape == (2,)

Expand All @@ -56,6 +56,8 @@ def test_pure_mse_loss(self):
assert torch.allclose(
loss(self.input, self.target),
F.mse_loss(self.regression_input, self.regression_target, reduction="none"),
rtol=1e-04,
atol=1e-07,
)
assert loss(self.input, self.target).shape == (2,)

Expand All @@ -65,7 +67,9 @@ def test_hybrid_loss(self):
ce_loss = F.cross_entropy(self.input, self.target)
mse_loss = F.mse_loss(self.regression_input, self.regression_target)

assert torch.allclose(loss(self.input, self.target), 0.5 * ce_loss + 0.5 * mse_loss)
assert torch.allclose(
loss(self.input, self.target), 0.5 * ce_loss + 0.5 * mse_loss, rtol=1e-04, atol=1e-07
)
assert loss(self.input, self.target).shape == torch.Size([])

def test_loss_parser(self):
Expand Down

0 comments on commit 892de3b

Please sign in to comment.