Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration with DeepSpeed and PyG: expected scalar type Float but found Half #16793

Open
adm995 opened this issue Feb 17, 2023 · 0 comments
Open
Labels
3rd party Related to a 3rd-party bug Something isn't working strategy: deepspeed

Comments

@adm995
Copy link

adm995 commented Feb 17, 2023

Bug description

Hi, i have a problem integrating DeepSpeed and PyG.
In particular Setting 32 precision on Lightning Trainer on single GPU Quadro RTX 6000 everything works fine. Something similar to the issue in #8426, i guess.
But, switching to 16 precision i have the following Traceback calling Trainer.fit() (even calling torch.Tensor.half() on model, or on input, or both).

How to reproduce the bug

import torch
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from pytorch_lightning.strategies import DeepSpeedStrategy
from torch_geometric.nn import global_mean_pool
from torch_geometric_temporal import GConvGRU
import torch.nn.functional as F
from torch.nn import Linear
import pytorch_lightning as pl
import time
from typing import Union, Dict, List
import math
import torchmetrics as tm


class GraphNN(pl.LightningModule):
    def __init__(self,
                 num_features,
                 num_classes,
                 k,
                 dropout: Union[int, float] = 0.2,
                 learning_rate: float = 1e-4,
                 weight_decay: float = 1e-4,
                 num_layers: int = 2,
                 batch_size: int = 32,
                 ):
        super(GraphNN, self).__init__()
        self.dropout = dropout
        self.learning_rate = learning_rate
        self.criterion = torch.nn.CrossEntropyLoss()
        self.weight_decay = weight_decay
        self.k = k
        self.num_classes = num_classes+1
        self.layers = torch.nn.ModuleList()
        self.batch_size = batch_size
        h_channels = round(num_features/4)
        for layer in range(num_layers):
            if layer == 0:
                self.layers.append(GConvGRU(num_features, h_channels, k))

            else:
                output = round(math.sqrt(h_channels * self.num_classes))
                self.layers.append(GConvGRU(h_channels, output, k))
                h_channels = output

        output = round(math.sqrt(h_channels * self.num_classes))
        self.fc = Linear(h_channels, output)
        self.output = Linear(output, self.num_classes)
        self.save_hyperparameters()

    def forward(self, x, edge_index, edge_attr, batch):
        for f in self.layers:
            x = f(x, edge_index, edge_attr)
            x = x.relu()
        x = self.fc(x)
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.output(x)
        return x

    def training_step(self, batch, batch_idx):
        return self.step(batch)

    def validation_step(self, batch, batch_idx):
        return self.step(batch)

    def step(self, batch):
        phase: str = "train" if self.training is True else "val"
        x: torch.Tensor = batch.x
        x.requires_grad = True
        y = batch.y
        starting_time = time.time()
        net_outputs = self(x, batch.edge_index, batch.edge_attr, batch.batch)
        loss = self.criterion(net_outputs, y)

        results = {
            "loss": loss,
            "acc": torch.as_tensor(
                tm.functional.accuracy(net_outputs, y, average="micro")),
            "time": time.time() - starting_time,
        }
        if phase == "val":
            results["val_preds"] = net_outputs
            results["val_target"] = y

        elif phase == "train":
            results["train_preds"] = net_outputs
            results["train_target"] = y
        return results

    def training_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):

        self.log_stats(outputs)

    def validation_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
        self.log_stats(outputs)

    def log_stats(self, outputs: List[Dict[str, torch.Tensor]]):

        phase: str = "train" if self.training is True else "val"

        self.log(f"time_{phase}", sum([e["time"] for e in outputs]) / len(outputs),
                 prog_bar=False, sync_dist=True)

        self.log(f"loss_{phase}", torch.stack([e["loss"] for e in outputs]).mean(),
                 prog_bar=True if phase == "val" else False, sync_dist=True)

        for metric in ["acc"]: #, "f1", "precision", "recall", "mcc"
            metric_data = torch.stack([e[metric] for e in outputs]).float()
            self.log(f"{metric}_mean_{phase}", metric_data.mean(),
                     prog_bar=True if metric == "acc" else False, sync_dist=True)
            del metric_data
        del outputs

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
        for param in self.parameters():
            param.grad = None

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        return optimizer


def run():
    dataset = TUDataset(root='data/TUDataset', name='MUTAG')
    channels = 64
    classes = 2
    max_epochs = 100
    batch_size = 64
    k = 1
    in_channels = 2
    num_features = dataset.num_features
    lr = 0.001

    dropout = 0.1

    torch.manual_seed(12345)
    dataset = dataset.shuffle()
    train_dataset = dataset[:150]
    test_dataset = dataset[150:]

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


    model = GraphNN(
        dropout=dropout,
        num_classes=classes,
        num_features=num_features,
        k=k,
        learning_rate=lr,
        weight_decay=0,
        num_layers=in_channels
    )

    trainer = pl.Trainer(
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        devices=1,
        precision=16,
        max_epochs=max_epochs,
        gradient_clip_val=1,  # if disable_gradient_clipping is False else 0
        strategy=DeepSpeedStrategy(
            stage=3,
            offload_optimizer=True,
            offload_parameters=True,
            allgather_partitions=False,
        ),
    )
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=test_loader)


if __name__ == "__main__":
    run()

Error messages and logs

Traceback (most recent call last):
  File "/projects/pyg/user/project/main_git.py", line 73, in <module>
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=test_loader)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 88, in launch
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1103, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1182, in _run_stage
    self._run_train()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run_train
    self._run_sanity_check()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1267, in _run_sanity_check
    val_loop.run()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 137, in advance
    output = self._evaluation_step(**kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 234, in _evaluation_step
    output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1485, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/strategies/deepspeed.py", line 917, in validation_step
    return self.model(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/utils/nvtx.py", line 11, in wrapped_fn
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 1836, in forward
    loss = self.module(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/overrides/base.py", line 110, in forward
    return self._forward_module.validation_step(*inputs, **kwargs)
  File "/projects/pyg/user/project/GraphNN.py", line 125, in validation_step
    return self.step(batch)
  File "/projects/pyg/user/project/GraphNN.py", line 139, in step
    net_outputs = self(x, batch.edge_index, batch.edge_attr, batch.batch)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/projects/pyg/user/project/GraphNN.py", line 103, in forward
    x = f(x, edge_index, edge_attr)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric_temporal/nn/recurrent/gconv_gru.py", line 166, in forward
    Z = self._calculate_update_gate(X, edge_index, edge_weight, H, lambda_max)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric_temporal/nn/recurrent/gconv_gru.py", line 120, in _calculate_update_gate
    Z = self.conv_x_z(X, edge_index, edge_weight, lambda_max=lambda_max)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/conv/cheb_conv.py", line 170, in forward
    out = self.lins[0](Tx_0)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/dense/linear.py", line 136, in forward
    return F.linear(x, self.weight, self.bias)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/linear.py", line 116, in zero3_linear_wrap
    return LinearFunctionForZeroStage3.apply(input, weight)
  File "/usr/local/lib/python3.8/dist-packages/torch/cuda/amp/autocast_mode.py", line 110, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/linear.py", line 61, in forward
    output = input.matmul(weight.t())
RuntimeError: expected scalar type Float but found Half

Environment

Environment

System info:

OS: Ubuntu 20.04.4 LTS
GPU count and types: 1 x Quadro RTX 6000
Python version: Python 3.8.10

Pip installed libraries:

torch==1.12.1+cu116
torch-cluster==1.6.0+pt112cu116
torch-geometric==2.2.0
torch-geometric-temporal==0.54.0
torch-scatter==2.1.0+pt112cu116
torch-sparse==0.6.16+pt112cu116
torch-spline-conv==1.2.1+pt112cu116
torchaudio==0.12.1+cu116
torchfile==0.1.0
torchmetrics==0.9.3
torchvision==0.13.1+cu116
DeepSpeed 0.8.0
pytorch-lightning==1.9.0

More info

No response

cc @awaelchli

@adm995 adm995 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Feb 17, 2023
@awaelchli awaelchli added 3rd party Related to a 3rd-party strategy: deepspeed and removed needs triage Waiting to be triaged by maintainers labels Mar 18, 2023
@awaelchli awaelchli self-assigned this Mar 18, 2023
@awaelchli awaelchli removed their assignment Nov 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working strategy: deepspeed
Projects
None yet
Development

No branches or pull requests

2 participants