Skip to content

Integration with DeepSpeed and PyG: expected scalar type Float but found Half #16793

Open
@adm995

Description

@adm995

Bug description

Hi, i have a problem integrating DeepSpeed and PyG.
In particular Setting 32 precision on Lightning Trainer on single GPU Quadro RTX 6000 everything works fine. Something similar to the issue in #8426, i guess.
But, switching to 16 precision i have the following Traceback calling Trainer.fit() (even calling torch.Tensor.half() on model, or on input, or both).

How to reproduce the bug

import torch
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from pytorch_lightning.strategies import DeepSpeedStrategy
from torch_geometric.nn import global_mean_pool
from torch_geometric_temporal import GConvGRU
import torch.nn.functional as F
from torch.nn import Linear
import pytorch_lightning as pl
import time
from typing import Union, Dict, List
import math
import torchmetrics as tm


class GraphNN(pl.LightningModule):
    def __init__(self,
                 num_features,
                 num_classes,
                 k,
                 dropout: Union[int, float] = 0.2,
                 learning_rate: float = 1e-4,
                 weight_decay: float = 1e-4,
                 num_layers: int = 2,
                 batch_size: int = 32,
                 ):
        super(GraphNN, self).__init__()
        self.dropout = dropout
        self.learning_rate = learning_rate
        self.criterion = torch.nn.CrossEntropyLoss()
        self.weight_decay = weight_decay
        self.k = k
        self.num_classes = num_classes+1
        self.layers = torch.nn.ModuleList()
        self.batch_size = batch_size
        h_channels = round(num_features/4)
        for layer in range(num_layers):
            if layer == 0:
                self.layers.append(GConvGRU(num_features, h_channels, k))

            else:
                output = round(math.sqrt(h_channels * self.num_classes))
                self.layers.append(GConvGRU(h_channels, output, k))
                h_channels = output

        output = round(math.sqrt(h_channels * self.num_classes))
        self.fc = Linear(h_channels, output)
        self.output = Linear(output, self.num_classes)
        self.save_hyperparameters()

    def forward(self, x, edge_index, edge_attr, batch):
        for f in self.layers:
            x = f(x, edge_index, edge_attr)
            x = x.relu()
        x = self.fc(x)
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.output(x)
        return x

    def training_step(self, batch, batch_idx):
        return self.step(batch)

    def validation_step(self, batch, batch_idx):
        return self.step(batch)

    def step(self, batch):
        phase: str = "train" if self.training is True else "val"
        x: torch.Tensor = batch.x
        x.requires_grad = True
        y = batch.y
        starting_time = time.time()
        net_outputs = self(x, batch.edge_index, batch.edge_attr, batch.batch)
        loss = self.criterion(net_outputs, y)

        results = {
            "loss": loss,
            "acc": torch.as_tensor(
                tm.functional.accuracy(net_outputs, y, average="micro")),
            "time": time.time() - starting_time,
        }
        if phase == "val":
            results["val_preds"] = net_outputs
            results["val_target"] = y

        elif phase == "train":
            results["train_preds"] = net_outputs
            results["train_target"] = y
        return results

    def training_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):

        self.log_stats(outputs)

    def validation_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
        self.log_stats(outputs)

    def log_stats(self, outputs: List[Dict[str, torch.Tensor]]):

        phase: str = "train" if self.training is True else "val"

        self.log(f"time_{phase}", sum([e["time"] for e in outputs]) / len(outputs),
                 prog_bar=False, sync_dist=True)

        self.log(f"loss_{phase}", torch.stack([e["loss"] for e in outputs]).mean(),
                 prog_bar=True if phase == "val" else False, sync_dist=True)

        for metric in ["acc"]: #, "f1", "precision", "recall", "mcc"
            metric_data = torch.stack([e[metric] for e in outputs]).float()
            self.log(f"{metric}_mean_{phase}", metric_data.mean(),
                     prog_bar=True if metric == "acc" else False, sync_dist=True)
            del metric_data
        del outputs

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
        for param in self.parameters():
            param.grad = None

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        return optimizer


def run():
    dataset = TUDataset(root='data/TUDataset', name='MUTAG')
    channels = 64
    classes = 2
    max_epochs = 100
    batch_size = 64
    k = 1
    in_channels = 2
    num_features = dataset.num_features
    lr = 0.001

    dropout = 0.1

    torch.manual_seed(12345)
    dataset = dataset.shuffle()
    train_dataset = dataset[:150]
    test_dataset = dataset[150:]

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


    model = GraphNN(
        dropout=dropout,
        num_classes=classes,
        num_features=num_features,
        k=k,
        learning_rate=lr,
        weight_decay=0,
        num_layers=in_channels
    )

    trainer = pl.Trainer(
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        devices=1,
        precision=16,
        max_epochs=max_epochs,
        gradient_clip_val=1,  # if disable_gradient_clipping is False else 0
        strategy=DeepSpeedStrategy(
            stage=3,
            offload_optimizer=True,
            offload_parameters=True,
            allgather_partitions=False,
        ),
    )
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=test_loader)


if __name__ == "__main__":
    run()

Error messages and logs

Traceback (most recent call last):
  File "/projects/pyg/user/project/main_git.py", line 73, in <module>
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=test_loader)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 88, in launch
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1103, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1182, in _run_stage
    self._run_train()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run_train
    self._run_sanity_check()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1267, in _run_sanity_check
    val_loop.run()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 137, in advance
    output = self._evaluation_step(**kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 234, in _evaluation_step
    output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1485, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/strategies/deepspeed.py", line 917, in validation_step
    return self.model(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/utils/nvtx.py", line 11, in wrapped_fn
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 1836, in forward
    loss = self.module(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/overrides/base.py", line 110, in forward
    return self._forward_module.validation_step(*inputs, **kwargs)
  File "/projects/pyg/user/project/GraphNN.py", line 125, in validation_step
    return self.step(batch)
  File "/projects/pyg/user/project/GraphNN.py", line 139, in step
    net_outputs = self(x, batch.edge_index, batch.edge_attr, batch.batch)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/projects/pyg/user/project/GraphNN.py", line 103, in forward
    x = f(x, edge_index, edge_attr)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric_temporal/nn/recurrent/gconv_gru.py", line 166, in forward
    Z = self._calculate_update_gate(X, edge_index, edge_weight, H, lambda_max)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric_temporal/nn/recurrent/gconv_gru.py", line 120, in _calculate_update_gate
    Z = self.conv_x_z(X, edge_index, edge_weight, lambda_max=lambda_max)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/conv/cheb_conv.py", line 170, in forward
    out = self.lins[0](Tx_0)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/dense/linear.py", line 136, in forward
    return F.linear(x, self.weight, self.bias)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/linear.py", line 116, in zero3_linear_wrap
    return LinearFunctionForZeroStage3.apply(input, weight)
  File "/usr/local/lib/python3.8/dist-packages/torch/cuda/amp/autocast_mode.py", line 110, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/linear.py", line 61, in forward
    output = input.matmul(weight.t())
RuntimeError: expected scalar type Float but found Half

Environment

Environment

System info:

OS: Ubuntu 20.04.4 LTS
GPU count and types: 1 x Quadro RTX 6000
Python version: Python 3.8.10

Pip installed libraries:

torch==1.12.1+cu116
torch-cluster==1.6.0+pt112cu116
torch-geometric==2.2.0
torch-geometric-temporal==0.54.0
torch-scatter==2.1.0+pt112cu116
torch-sparse==0.6.16+pt112cu116
torch-spline-conv==1.2.1+pt112cu116
torchaudio==0.12.1+cu116
torchfile==0.1.0
torchmetrics==0.9.3
torchvision==0.13.1+cu116
DeepSpeed 0.8.0
pytorch-lightning==1.9.0

More info

No response

cc @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions