-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeepSpeed stage 3 and mixed precision cause an error #10510
Comments
if I change |
pinging @tjruwase again if he has any insight why deepspeed stage 3 isn't working anymore with precision fp16 |
stage 2 with and without offload works with fp16 |
@SeanNaren and @ktrapeznikov, I am taking a look. |
Quick update. The problem is that model parameters remained in fp32 despite model = BoringModel().half() For some reason zero.Init() is not using |
@SeanNaren, the problem seems to be due to Lightning calling zero.Init() here on an already constructed model. It particularly, zero.Init() is meant for constructing massive models that are too large for a single device. For already constructed models, the stage 3 optimizer will automatically setup the required partitioning as seen here. Hope that helps. |
ahhh huge thanks @tjruwase! I recall having this in place because there were some internal deepspeed assertions that were raised if the model was partially partitioned! I've removed the code and have confirmed all tests are passing. @ktrapeznikov #10655 should fix this issue :) |
hey @tjruwase this breaks a case we had in our CI where only some of the parameters have been defined in the Here is the reproduce: import os
import torch
from deepspeed.ops.adam import FusedAdam
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def configure_sharded_model(self) -> None:
self.layer_2 = torch.nn.Linear(32, 2)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return FusedAdam(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
gpus=1,
fast_dev_run=True,
precision=16,
strategy="deepspeed_stage_3"
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
if __name__ == "__main__":
run() The stacktrace:
Removing the suggested code which wraps the entire module with |
There is a delay here primarily because of the above (cc @tjruwase) In the meantime we can bypass this issue by passing the plugin manually and adding the from pytorch_lightning.plugins import DeepSpeedPlugin
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
logger=False,
enable_checkpointing=False,
gpus = 4,
precision=16,
strategy = DeepSpeedPlugin(stage=3, offload_parameters=True, offload_optimizers=True, partition_module=False)
) |
@SeanNaren, apologies for the delay with this investigation. Can you please test whether microsoft/DeepSpeed#1606 could help? @jeffra, FYI |
I tested microsoft/DeepSpeed#1606 and it seems the error still persists @tjruwase As mentioned in the merged PR, I've fixed the case where the user doesn't use partition_parameters OR if the user defines all parameters in the context manager. However the above case where the user defines a piece of the model within the context manager (let's say the transformer blocks or feedforwards) then it crashes as shown above! |
I have tried all stages of DeepSpeed, and they all failed with "RuntimeError: expected scalar type Half but found Float" |
@chenzhekl, this is likely due to the fact that DeepSpeed does not automatically cast the inputs. Can you please share your full stack trace? |
Hi @tjruwase, Thanks for your reply! Here is the stack trace. Hope this can be helpful.
|
@chenzhekl, thanks for sharing the stack trace. It does look like what I suspected. The problem is that DeepSpeed does not do automatic casting of input tensors to match parameters type. In this case, I suspect that |
it'd be great if someone can share a reproducible script to replicate the issue. |
@tjruwase Thanks for your suggestions. You are right. After manually casting the input to half, the error message disappears. I guess this might have something to do with PyG`s way of packing data. Here is a small script to reproduce the error. import os
import torch
import torch.nn.functional as F
from torch.optim import Adam
from pytorch_lightning import LightningModule, Trainer
from torch_geometric.nn import EdgeConv, MLP, knn_graph
from torch_geometric.data import Data, Dataset, DataLoader
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = Data(
x=torch.randn(size, 3),
pos=torch.randn(size, 3),
y=torch.randn(size, 5),
)
def __getitem__(self, index):
return self.data
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.conv1 = EdgeConv(MLP(([2 * 6, 64, 64, 64])))
self.conv2 = EdgeConv(MLP(([2 * 64, 128])))
self.lin1 = torch.nn.Linear(64 + 128, 1024)
self.mlp = MLP([1024, 512, 256, 5], dropout=0.0, batch_norm=False)
def forward(self, x, batch, pos, edge_index):
features = F.elu(self.features(x, edge_index, pos))
pooled_features = nn.global_max_pool(features, batch)
y = self.mlp(pooled_features)
return y.sigmoid()
def features(self, x, edge_index, pos):
h1 = self.conv1(torch.cat([x, pos], dim=1), edge_index)
h2 = self.conv2(h1, edge_index)
h3 = self.lin1(torch.cat([h1, h2], dim=1))
return h3
def training_step(self, batch, batch_idx):
edge_index = knn_graph(batch.pos, k=5, batch=batch.batch)
y_ = self(batch.x, batch.batch, batch.pos, edge_index)
loss = F.l1_loss(y_, batch.y)
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
edge_index = knn_graph(batch.pos, k=5, batch=batch.batch)
y_ = self(batch.x, batch.batch, batch.pos, edge_index)
loss = F.l1_loss(y_, batch.y)
self.log("valid_loss", loss)
def configure_optimizers(self):
return Adam(self.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
logger=False,
enable_checkpointing=False,
gpus=1,
precision=16,
strategy = "deepspeed_stage_3"
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
if __name__ == "__main__":
run() |
I recall running into this before with geometric. I think this is down to the Data object, Lightning cannot automatically traverse the object to move tensors to the correct device. A manual casting may be the cleanest solution currently! |
@chenzhekl, thanks for confirming and sharing a repro. @SeanNaren, thanks for the sharing your experience. |
🐛 Bug
Using
strategy="deepspeed_stage_3"
andprecision=16
causes an errorTo Reproduce
I get the following error:
Expected behavior
it should work, right?
Environment
deepspeed 0.5.6
Additional context
cc @SeanNaren @awaelchli @rohitgr7
The text was updated successfully, but these errors were encountered: