Description
Bug description
Hi, i have a problem integrating DeepSpeed and PyG.
In particular Setting 32 precision on Lightning Trainer on single GPU Quadro RTX 6000 everything works fine. Something similar to the issue in #8426, i guess.
But, switching to 16 precision i have the following Traceback calling Trainer.fit()
(even calling torch.Tensor.half()
on model, or on input, or both).
How to reproduce the bug
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset
from pytorch_lightning.strategies import DeepSpeedStrategy
from torch_geometric.nn import global_mean_pool
from torch_geometric_temporal import GConvGRU
import torch.nn.functional as F
from torch.nn import Linear
import pytorch_lightning as pl
import time
from typing import Union, Dict, List
import math
import torchmetrics as tm
class GraphNN(pl.LightningModule):
def __init__(self,
num_features,
num_classes,
k,
dropout: Union[int, float] = 0.2,
learning_rate: float = 1e-4,
weight_decay: float = 1e-4,
num_layers: int = 2,
batch_size: int = 32,
):
super(GraphNN, self).__init__()
self.dropout = dropout
self.learning_rate = learning_rate
self.criterion = torch.nn.CrossEntropyLoss()
self.weight_decay = weight_decay
self.k = k
self.num_classes = num_classes+1
self.layers = torch.nn.ModuleList()
self.batch_size = batch_size
h_channels = round(num_features/4)
for layer in range(num_layers):
if layer == 0:
self.layers.append(GConvGRU(num_features, h_channels, k))
else:
output = round(math.sqrt(h_channels * self.num_classes))
self.layers.append(GConvGRU(h_channels, output, k))
h_channels = output
output = round(math.sqrt(h_channels * self.num_classes))
self.fc = Linear(h_channels, output)
self.output = Linear(output, self.num_classes)
self.save_hyperparameters()
def forward(self, x, edge_index, edge_attr, batch):
for f in self.layers:
x = f(x, edge_index, edge_attr)
x = x.relu()
x = self.fc(x)
x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.output(x)
return x
def training_step(self, batch, batch_idx):
return self.step(batch)
def validation_step(self, batch, batch_idx):
return self.step(batch)
def step(self, batch):
phase: str = "train" if self.training is True else "val"
x: torch.Tensor = batch.x
x.requires_grad = True
y = batch.y
starting_time = time.time()
net_outputs = self(x, batch.edge_index, batch.edge_attr, batch.batch)
loss = self.criterion(net_outputs, y)
results = {
"loss": loss,
"acc": torch.as_tensor(
tm.functional.accuracy(net_outputs, y, average="micro")),
"time": time.time() - starting_time,
}
if phase == "val":
results["val_preds"] = net_outputs
results["val_target"] = y
elif phase == "train":
results["train_preds"] = net_outputs
results["train_target"] = y
return results
def training_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
self.log_stats(outputs)
def validation_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
self.log_stats(outputs)
def log_stats(self, outputs: List[Dict[str, torch.Tensor]]):
phase: str = "train" if self.training is True else "val"
self.log(f"time_{phase}", sum([e["time"] for e in outputs]) / len(outputs),
prog_bar=False, sync_dist=True)
self.log(f"loss_{phase}", torch.stack([e["loss"] for e in outputs]).mean(),
prog_bar=True if phase == "val" else False, sync_dist=True)
for metric in ["acc"]: #, "f1", "precision", "recall", "mcc"
metric_data = torch.stack([e[metric] for e in outputs]).float()
self.log(f"{metric}_mean_{phase}", metric_data.mean(),
prog_bar=True if metric == "acc" else False, sync_dist=True)
del metric_data
del outputs
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
for param in self.parameters():
param.grad = None
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
return optimizer
def run():
dataset = TUDataset(root='data/TUDataset', name='MUTAG')
channels = 64
classes = 2
max_epochs = 100
batch_size = 64
k = 1
in_channels = 2
num_features = dataset.num_features
lr = 0.001
dropout = 0.1
torch.manual_seed(12345)
dataset = dataset.shuffle()
train_dataset = dataset[:150]
test_dataset = dataset[150:]
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
model = GraphNN(
dropout=dropout,
num_classes=classes,
num_features=num_features,
k=k,
learning_rate=lr,
weight_decay=0,
num_layers=in_channels
)
trainer = pl.Trainer(
accelerator="cuda" if torch.cuda.is_available() else "cpu",
devices=1,
precision=16,
max_epochs=max_epochs,
gradient_clip_val=1, # if disable_gradient_clipping is False else 0
strategy=DeepSpeedStrategy(
stage=3,
offload_optimizer=True,
offload_parameters=True,
allgather_partitions=False,
),
)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=test_loader)
if __name__ == "__main__":
run()
Error messages and logs
Traceback (most recent call last):
File "/projects/pyg/user/project/main_git.py", line 73, in <module>
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=test_loader)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
call._call_and_handle_interrupt(
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 88, in launch
return function(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
self._run(model, ckpt_path=self.ckpt_path)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1103, in _run
results = self._run_stage()
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1182, in _run_stage
self._run_train()
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run_train
self._run_sanity_check()
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1267, in _run_sanity_check
val_loop.run()
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance
dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 137, in advance
output = self._evaluation_step(**kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 234, in _evaluation_step
output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1485, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/strategies/deepspeed.py", line 917, in validation_step
return self.model(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/deepspeed/utils/nvtx.py", line 11, in wrapped_fn
return func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 1836, in forward
loss = self.module(*inputs, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/overrides/base.py", line 110, in forward
return self._forward_module.validation_step(*inputs, **kwargs)
File "/projects/pyg/user/project/GraphNN.py", line 125, in validation_step
return self.step(batch)
File "/projects/pyg/user/project/GraphNN.py", line 139, in step
net_outputs = self(x, batch.edge_index, batch.edge_attr, batch.batch)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/projects/pyg/user/project/GraphNN.py", line 103, in forward
x = f(x, edge_index, edge_attr)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_geometric_temporal/nn/recurrent/gconv_gru.py", line 166, in forward
Z = self._calculate_update_gate(X, edge_index, edge_weight, H, lambda_max)
File "/usr/local/lib/python3.8/dist-packages/torch_geometric_temporal/nn/recurrent/gconv_gru.py", line 120, in _calculate_update_gate
Z = self.conv_x_z(X, edge_index, edge_weight, lambda_max=lambda_max)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/conv/cheb_conv.py", line 170, in forward
out = self.lins[0](Tx_0)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1148, in _call_impl
result = forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/dense/linear.py", line 136, in forward
return F.linear(x, self.weight, self.bias)
File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/linear.py", line 116, in zero3_linear_wrap
return LinearFunctionForZeroStage3.apply(input, weight)
File "/usr/local/lib/python3.8/dist-packages/torch/cuda/amp/autocast_mode.py", line 110, in decorate_fwd
return fwd(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/linear.py", line 61, in forward
output = input.matmul(weight.t())
RuntimeError: expected scalar type Float but found Half
Environment
Environment
System info:
OS: Ubuntu 20.04.4 LTS
GPU count and types: 1 x Quadro RTX 6000
Python version: Python 3.8.10
Pip installed libraries:
torch==1.12.1+cu116
torch-cluster==1.6.0+pt112cu116
torch-geometric==2.2.0
torch-geometric-temporal==0.54.0
torch-scatter==2.1.0+pt112cu116
torch-sparse==0.6.16+pt112cu116
torch-spline-conv==1.2.1+pt112cu116
torchaudio==0.12.1+cu116
torchfile==0.1.0
torchmetrics==0.9.3
torchvision==0.13.1+cu116
DeepSpeed 0.8.0
pytorch-lightning==1.9.0
More info
No response
cc @awaelchli
Activity