Skip to content

Commit

Permalink
Merge pull request #429 from datamol-io/gpu_fp16
Browse files Browse the repository at this point in the history
Gpu mixed-precision overflow fix
  • Loading branch information
DomInvivo authored Aug 9, 2023
2 parents d7c910e + d6e00fd commit 92ac9da
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 29 deletions.
12 changes: 8 additions & 4 deletions graphium/trainer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from graphium.trainer.predictor_summaries import TaskSummaries
from graphium.data.datamodule import BaseDataModule
from graphium.utils.moving_average_tracker import MovingAverageTracker
from graphium.utils.tensor import dict_tensor_fp16_to_fp32

GRAPHIUM_PRETRAINED_MODELS = {
"graphium-zinc-micro-dummy-test": "gcs://graphium-public/pretrained-models/graphium-zinc-micro-dummy-test/model.ckpt"
Expand Down Expand Up @@ -492,7 +493,7 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None:
self.task_epoch_summary.update_predictor_state(
step_name="train",
targets=outputs["targets"],
predictions=outputs["preds"],
preds=outputs["preds"],
loss=outputs["loss"], # This is the weighted loss for now, but change to task-specific loss
task_losses=outputs["task_losses"],
n_epochs=self.current_epoch,
Expand Down Expand Up @@ -554,9 +555,12 @@ def _general_epoch_end(self, outputs: Dict[str, Any], step_name: str) -> None:
weights = torch.cat([out["weights"] for out in outputs], dim=0)
else:
weights = None

# NOTE: Computing the loss over the entire split may cause
# overflow issues when using fp16
loss, task_losses = self.compute_loss(
preds=preds,
targets=targets,
preds=dict_tensor_fp16_to_fp32(preds),
targets=dict_tensor_fp16_to_fp32(targets),
weights=weights,
target_nan_mask=self.target_nan_mask,
multitask_handling=self.multitask_handling,
Expand All @@ -565,7 +569,7 @@ def _general_epoch_end(self, outputs: Dict[str, Any], step_name: str) -> None:

self.task_epoch_summary.update_predictor_state(
step_name=step_name,
predictions=preds,
preds=preds,
targets=targets,
loss=loss,
task_losses=task_losses,
Expand Down
43 changes: 21 additions & 22 deletions graphium/trainer/predictor_summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch import Tensor

from graphium.utils.tensor import nan_mean, nan_std, nan_median
from graphium.utils.tensor import nan_mean, nan_std, nan_median, tensor_fp16_to_fp32


class SummaryInterface(object):
Expand Down Expand Up @@ -80,15 +80,15 @@ def __init__(
# self.predictor_outputs = None
self.step_name: str = None
self.targets: Tensor = None
self.predictions: Tensor = None
self.preds: Tensor = None
self.loss = None # What type?
self.n_epochs: int = None

self.task_name = task_name
self.logged_metrics_exceptions = [] # Track which metric exceptions have been logged

def update_predictor_state(
self, step_name: str, targets: Tensor, predictions: Tensor, loss: Tensor, n_epochs: int
self, step_name: str, targets: Tensor, preds: Tensor, loss: Tensor, n_epochs: int
):
r"""
update the state of the predictor
Expand All @@ -101,7 +101,7 @@ def update_predictor_state(
"""
self.step_name = step_name
self.targets = targets
self.predictions = predictions
self.preds = preds
self.loss = loss
self.n_epochs = n_epochs

Expand All @@ -120,7 +120,7 @@ def set_results(
metrics[self.metric_log_name(self.task_name, "loss", self.step_name)] = self.loss
self.summaries[self.step_name] = Summary.Results(
targets=self.targets,
predictions=self.predictions,
preds=self.preds,
loss=self.loss,
metrics=metrics, # Should include task name from get_metrics_logs()
monitored_metric=f"{self.monitor}/{self.step_name}", # Include task name?
Expand Down Expand Up @@ -232,18 +232,17 @@ def get_metrics_logs(self) -> Dict[str, Any]:
Returns:
A dictionary of metrics to log.
"""
targets = self.targets.to(dtype=self.predictions.dtype, device=self.predictions.device)

targets = tensor_fp16_to_fp32(self.targets)
preds = tensor_fp16_to_fp32(self.preds)

targets = targets.to(dtype=preds.dtype, device=preds.device)

# Compute the metrics always used in regression tasks
metric_logs = {}
metric_logs[self.metric_log_name(self.task_name, "mean_pred", self.step_name)] = nan_mean(
self.predictions
)
metric_logs[self.metric_log_name(self.task_name, "std_pred", self.step_name)] = nan_std(
self.predictions
)
metric_logs[self.metric_log_name(self.task_name, "median_pred", self.step_name)] = nan_median(
self.predictions
)
metric_logs[self.metric_log_name(self.task_name, "mean_pred", self.step_name)] = nan_mean(preds)
metric_logs[self.metric_log_name(self.task_name, "std_pred", self.step_name)] = nan_std(preds)
metric_logs[self.metric_log_name(self.task_name, "median_pred", self.step_name)] = nan_median(preds)
metric_logs[self.metric_log_name(self.task_name, "mean_target", self.step_name)] = nan_mean(targets)
metric_logs[self.metric_log_name(self.task_name, "std_target", self.step_name)] = nan_std(targets)
metric_logs[self.metric_log_name(self.task_name, "median_target", self.step_name)] = nan_median(
Expand All @@ -264,7 +263,7 @@ def get_metrics_logs(self) -> Dict[str, Any]:
self.task_name, key, self.step_name
) # f"{key}/{self.step_name}"
try:
metric_logs[metric_name] = metric(self.predictions, targets)
metric_logs[metric_name] = metric(preds, targets)
except Exception as e:
metric_logs[metric_name] = torch.as_tensor(float("nan"))
# Warn only if it's the first warning for that metric
Expand Down Expand Up @@ -292,7 +291,7 @@ class Results:
def __init__(
self,
targets: Tensor = None,
predictions: Tensor = None,
preds: Tensor = None,
loss: float = None, # Is this supposed to be a Tensor or float?
metrics: dict = None,
monitored_metric: str = None,
Expand All @@ -302,14 +301,14 @@ def __init__(
This inner class is used as a container for storing the results of the summary.
Parameters:
targets: the targets
predictions: the prediction tensor
preds: the prediction tensor
loss: the loss, float or tensor
metrics: the metrics
monitored_metric: the monitored metric
n_epochs: the number of epochs
"""
self.targets = targets.detach().cpu()
self.predictions = predictions.detach().cpu()
self.preds = preds.detach().cpu()
self.loss = loss.item() if isinstance(loss, Tensor) else loss
self.monitored_metric = monitored_metric
if monitored_metric in metrics.keys():
Expand Down Expand Up @@ -371,7 +370,7 @@ def update_predictor_state(
self,
step_name: str,
targets: Dict[str, Tensor],
predictions: Dict[str, Tensor],
preds: Dict[str, Tensor],
loss: Tensor,
task_losses: Dict[str, Tensor],
n_epochs: int,
Expand All @@ -381,7 +380,7 @@ def update_predictor_state(
Parameters:
step_name: the name of the step
targets: the target tensors
predictions: the prediction tensors
preds: the prediction tensors
loss: the loss tensor
task_losses: the task losses
n_epochs: the number of epochs
Expand All @@ -392,7 +391,7 @@ def update_predictor_state(
self.task_summaries[task].update_predictor_state(
step_name,
targets[task],
predictions[task].detach(),
preds[task].detach(),
task_losses[task].detach(),
n_epochs,
)
Expand Down
36 changes: 35 additions & 1 deletion graphium/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from typing import Iterable, List, Union, Any, Callable
from typing import Iterable, List, Union, Any, Callable, Dict
from inspect import getfullargspec
from copy import copy, deepcopy
from loguru import logger
Expand Down Expand Up @@ -384,3 +384,37 @@ def arg_in_func(fn, arg):
"""
fn_args = getfullargspec(fn)
return (fn_args.varkw is not None) or (arg in fn_args[0])


def tensor_fp16_to_fp32(tensor: Tensor) -> Tensor:
r"""Cast a tensor from fp16 to fp32 if it is in fp16
Parameters:
tensor: A tensor. If it is in fp16, it will be casted to fp32
Returns:
tensor: The tensor casted to fp32 if it was in fp16
"""
if tensor.dtype == torch.float16:
return tensor.to(dtype=torch.float32)
return tensor


def dict_tensor_fp16_to_fp32(
dict_tensor: Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]
) -> Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]:
r"""Recursively Cast a dictionary of tensors from fp16 to fp32 if it is in fp16
Parameters:
dict_tensor: A recursive dictionary of tensors. To be casted to fp32 if it was in fp16.
Returns:
dict_tensor: The recursive dictionary of tensors casted to fp32 if it was in fp16
"""
if isinstance(dict_tensor, dict):
for key, value in dict_tensor.items():
dict_tensor[key] = dict_tensor_fp16_to_fp32(value)
else:
dict_tensor = tensor_fp16_to_fp32(dict_tensor)

return dict_tensor
68 changes: 66 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,23 @@
Unit tests for the metrics and wrappers of graphium/utils/...
"""

from graphium.utils.tensor import nan_mad, nan_mean, nan_std, nan_var, nan_median
from graphium.utils.safe_run import SafeRun
import torch
import numpy as np
import scipy as sp
import unittest as ut
import gzip

from graphium.utils.read_file import file_opener
from graphium.utils.tensor import (
nan_mad,
nan_mean,
nan_std,
nan_var,
nan_median,
dict_tensor_fp16_to_fp32,
tensor_fp16_to_fp32,
)
from graphium.utils.safe_run import SafeRun


class test_nan_statistics(ut.TestCase):
Expand Down Expand Up @@ -182,5 +191,60 @@ def test_safe_run(self):
print("This is not an error")


class TestTensorFp16ToFp32(ut.TestCase):
def test_tensor_fp16_to_fp32(self):
# Create a tensor
tensor = torch.randn(10, 10).half()

# Convert the tensor to fp32
tensor_fp32 = tensor_fp16_to_fp32(tensor)
self.assertTrue(tensor_fp32.dtype == torch.float32)

# Don't convert the tensor to fp32
tensor = torch.randn(10, 10).int()
tensor_fp32 = tensor_fp16_to_fp32(tensor)
self.assertFalse(tensor_fp32.dtype == torch.float32)

# Don't convert the tensor to fp32
tensor = torch.randn(10, 10).double()
tensor_fp32 = tensor_fp16_to_fp32(tensor)
self.assertFalse(tensor_fp32.dtype == torch.float32)

def test_dict_tensor_fp16_to_fp32(self):
# Create a dictionary of tensors
tensor_dict = {
"a": torch.randn(10, 10).half(),
"b": torch.randn(10, 10).half(),
"c": torch.randn(10, 10).double(),
"d": torch.randn(10, 10).half(),
"e": torch.randn(10, 10).float(),
"f": torch.randn(10, 10).half(),
"g": torch.randn(10, 10).int(),
"h": {
"h1": torch.randn(10, 10).double(),
"h2": torch.randn(10, 10).half(),
"h3": torch.randn(10, 10).float(),
"h4": torch.randn(10, 10).half(),
"h5": torch.randn(10, 10).int(),
},
}

# Convert the dictionary to fp32
tensor_dict_fp32 = dict_tensor_fp16_to_fp32(tensor_dict)

# Check that the dictionary is correctly converted
for key, tensor in tensor_dict_fp32.items():
if key in ["a", "b", "d", "e", "f"]:
self.assertEqual(tensor.dtype, torch.float32)
elif key in ["h"]:
for key2, tensor2 in tensor.items():
if key2 in ["h2", "h3", "h4"]:
self.assertEqual(tensor2.dtype, torch.float32)
else:
self.assertNotEqual(tensor2.dtype, torch.float32)
else:
self.assertNotEqual(tensor.dtype, torch.float32)


if __name__ == "__main__":
ut.main()

0 comments on commit 92ac9da

Please sign in to comment.