Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed stage 3 and mixed precision cause an error #10510

Open
ktrapeznikov opened this issue Nov 12, 2021 · 19 comments · Fixed by #10655
Open

DeepSpeed stage 3 and mixed precision cause an error #10510

ktrapeznikov opened this issue Nov 12, 2021 · 19 comments · Fixed by #10655
Labels
3rd party Related to a 3rd-party bug Something isn't working strategy: deepspeed

Comments

@ktrapeznikov
Copy link

ktrapeznikov commented Nov 12, 2021

🐛 Bug

Using strategy="deepspeed_stage_3" and precision=16 causes an error

To Reproduce

import os
import torch
from torch.utils.data import DataLoader, Dataset
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return FusedAdam(self.layer.parameters(), lr=0.1)
        # return torch.optim.Adam(self.parameters(),lr = .1)

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        logger=False,
        enable_checkpointing=False,
        gpus = 4,
        precision=16,
        strategy = "deepspeed_stage_3"
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

I get the following error:

Traceback (most recent call last):
  File "bug.py", line 69, in <module>
    run()
  File "bug.py", line 64, in run
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
    self._call_and_handle_interrupt(
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1182, in _run
    self._pre_dispatch()
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1217, in _pre_dispatch
    self.accelerator.pre_dispatch(self)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 136, in pre_dispatch
    self.training_type_plugin.pre_dispatch()
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 397, in pre_dispatch
    self.init_deepspeed()
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 474, in init_deepspeed
    self._initialize_deepspeed_train(model)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 507, in _initialize_deepspeed_train
    model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 431, in _setup_model_and_optimizer
    deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/__init__.py", line 131, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 223, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 905, in _configure_optimizer
    self.optimizer = self._configure_zero_optimizer(basic_optimizer)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1152, in _configure_zero_optimizer
    optimizer = FP16_DeepSpeedZeroOptimizer_Stage3(
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 905, in __init__
    self.create_reduce_and_remove_grad_hooks()
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 1885, in create_reduce_and_remove_grad_hooks
    param.all_gather()
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 590, in all_gather
    return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 699, in _all_gather
    ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 947, in _allgather_params_coalesced
    h = dist._all_gather_base(allgather_params[param_idx],
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2070, in _all_gather_base
    work = group._allgather_base(output_tensor, input_tensor)
RuntimeError: output tensor must have the same type as input tensor

Expected behavior

it should work, right?

Environment

* CUDA:
        - GPU:
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
        - available:         True
        - version:           11.3
* Packages:
        - numpy:             1.21.1
        - pyTorch_debug:     False
        - pyTorch_version:   1.10.0+cu113
        - pytorch-lightning: 1.5.1
        - tqdm:              4.62.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.11
        - version:           #1 SMP Wed Feb 3 15:06:38 UTC 2021
  
  • Any other relevant information: deepspeed 0.5.6

Additional context

cc @SeanNaren @awaelchli @rohitgr7

@ktrapeznikov ktrapeznikov added bug Something isn't working help wanted Open to be worked on labels Nov 12, 2021
@ktrapeznikov
Copy link
Author

if I change strategy="deepspeed_stage_3_offload" and use DeepSpeedCPUAdam instead then I don't get an error.

@tchaton tchaton added the priority: 0 High priority task label Nov 15, 2021
@tchaton tchaton added this to the 1.5.x milestone Nov 15, 2021
@SeanNaren
Copy link
Contributor

pinging @tjruwase again if he has any insight why deepspeed stage 3 isn't working anymore with precision fp16

@ktrapeznikov
Copy link
Author

stage 2 with and without offload works with fp16

@tjruwase
Copy link

@SeanNaren and @ktrapeznikov, I am taking a look.

@tjruwase
Copy link

Quick update. The problem is that model parameters remained in fp32 despite precision=fp16 in Trainer constructor. @ktrapeznikov, can you please confirm that Runtime error is avoided by

    model = BoringModel().half()

For some reason zero.Init() is not using dtype to change parameter dtypes as promised here. Please give me some time to sync internally to understand this behavior. Thanks!

@tjruwase
Copy link

@SeanNaren, the problem seems to be due to Lightning calling zero.Init() here on an already constructed model. It particularly, zero.Init() is meant for constructing massive models that are too large for a single device. For already constructed models, the stage 3 optimizer will automatically setup the required partitioning as seen here. Hope that helps.

@SeanNaren
Copy link
Contributor

ahhh huge thanks @tjruwase! I recall having this in place because there were some internal deepspeed assertions that were raised if the model was partially partitioned! I've removed the code and have confirmed all tests are passing.

@ktrapeznikov #10655 should fix this issue :)

@SeanNaren
Copy link
Contributor

SeanNaren commented Nov 22, 2021

hey @tjruwase this breaks a case we had in our CI where only some of the parameters have been defined in the deepspeed.Init context, but there are still some parameters stray. Should these be detected/partitioned afterwards by deepspeed?

Here is the reproduce:

import os

import torch
from deepspeed.ops.adam import FusedAdam
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def configure_sharded_model(self) -> None:
        self.layer_2 = torch.nn.Linear(32, 2)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return FusedAdam(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        gpus=1,
        fast_dev_run=True,
        precision=16,
        strategy="deepspeed_stage_3"
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

The stacktrace:

Traceback (most recent call last):
  File "reprod.py", line 68, in <module>
    run()
  File "reprod.py", line 63, in run
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 720, in fit
    self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 671, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 754, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 1145, in _run
    self._pre_dispatch()
  File "/home/sean/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 1180, in _pre_dispatch
    self.accelerator.pre_dispatch(self)
  File "/home/sean/pytorch-lightning/pytorch_lightning/accelerators/accelerator.py", line 92, in pre_dispatch
    self.training_type_plugin.pre_dispatch()
  File "/home/sean/pytorch-lightning/pytorch_lightning/plugins/training_type/deepspeed.py", line 388, in pre_dispatch
    self.init_deepspeed()
  File "/home/sean/pytorch-lightning/pytorch_lightning/plugins/training_type/deepspeed.py", line 457, in init_deepspeed
    self._initialize_deepspeed_train(model)
  File "/home/sean/pytorch-lightning/pytorch_lightning/plugins/training_type/deepspeed.py", line 490, in _initialize_deepspeed_train
    model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
  File "/home/sean/pytorch-lightning/pytorch_lightning/plugins/training_type/deepspeed.py", line 429, in _setup_model_and_optimizer
    dist_init_required=False,
  File "/home/sean/anaconda3/lib/python3.7/site-packages/deepspeed/__init__.py", line 129, in initialize
    config_params=config_params)
  File "/home/sean/anaconda3/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 260, in __init__
    self._configure_distributed_model(model)
  File "/home/sean/anaconda3/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 977, in _configure_distributed_model
    f"fp16 is enabled but the following parameters have dtype that is not fp16: {', '.join(names)}"
ValueError: fp16 is enabled but the following parameters have dtype that is not fp16: module.layer.weight, module.layer.bias

Removing the suggested code which wraps the entire module with deepspeed.Init means that some of the parameters are not being transferred and causing this assertion to be raised.

@SeanNaren
Copy link
Contributor

There is a delay here primarily because of the above (cc @tjruwase)

In the meantime we can bypass this issue by passing the plugin manually and adding the partition_module=False flag:

from pytorch_lightning.plugins import DeepSpeedPlugin
trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        logger=False,
        enable_checkpointing=False,
        gpus = 4,
        precision=16,
        strategy = DeepSpeedPlugin(stage=3, offload_parameters=True, offload_optimizers=True, partition_module=False)
    )

@tjruwase
Copy link

tjruwase commented Dec 8, 2021

@SeanNaren, apologies for the delay with this investigation. Can you please test whether microsoft/DeepSpeed#1606 could help?

@jeffra, FYI

@SeanNaren
Copy link
Contributor

SeanNaren commented Dec 20, 2021

I tested microsoft/DeepSpeed#1606 and it seems the error still persists @tjruwase

As mentioned in the merged PR, I've fixed the case where the user doesn't use partition_parameters OR if the user defines all parameters in the context manager.

However the above case where the user defines a piece of the model within the context manager (let's say the transformer blocks or feedforwards) then it crashes as shown above!

@carmocca carmocca added 3rd party Related to a 3rd-party strategy: deepspeed and removed help wanted Open to be worked on priority: 0 High priority task labels Mar 1, 2022
@carmocca carmocca removed this from the 1.5.x milestone Mar 1, 2022
@chenzhekl
Copy link

chenzhekl commented Jul 22, 2022

I have tried all stages of DeepSpeed, and they all failed with "RuntimeError: expected scalar type Half but found Float"

@tjruwase
Copy link

@chenzhekl, this is likely due to the fact that DeepSpeed does not automatically cast the inputs. Can you please share your full stack trace?

@chenzhekl
Copy link

Hi @tjruwase, Thanks for your reply! Here is the stack trace. Hope this can be helpful.

Traceback (most recent call last):                                                                                                                   
  File "/home/zchen/workspace/demo/src/train.py", line 343, in <module>                                                             
    main(args)                                                                                                                                       
  File "/home/zchen/workspace/demo/src/train.py", line 261, in main                                                                 
    trainer.fit(                                                                                                                                     
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit                                               
    self._call_and_handle_interrupt(                                                                                                                 
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt                        
    return trainer_fn(*args, **kwargs)                                                                                                               
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl                                         
    results = self._run(model, ckpt_path=self.ckpt_path)                                                                                             
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run                                             
    results = self._run_stage()                                                                                                                      
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage                                       
    return self._run_train()                                                                                                                         
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1345, in _run_train                                       
    self._run_sanity_check()
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1413, in _run_sanity_check                                
    val_loop.run()                                                                                                                                   
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run                                                    
    self.advance(*args, **kwargs)                                                                                                                    
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 155, in advance                          
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)                                                                     
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run                                                    
    self.advance(*args, **kwargs)                                                                                                                    
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 128, in advance                         
    output = self._evaluation_step(**kwargs)                                                                                                         
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 226, in _evaluation_step
    output = self.trainer._call_strategy_hook("validation_step", *kwargs.values())
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1765, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/strategies/deepspeed.py", line 906, in validation_step
    return self.model(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 11, in wrapped_fn
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1601, in forward
    loss = self.module(*inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/strategies/deepspeed.py", line 80, in forward
    return super().forward(*inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 93, in forward
    return self.module.validation_step(*inputs, **kwargs)
  File "/home/zchen/workspace/demo/src/model/model.py", line 64, in validation_step
    y_ = self(batch["x"], batch["batch"], batch["pos"], batch["edge_index"])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zchen/workspace/demo/src/model/model.py", line 28, in forward
    features = F.elu(self.features(x, edge_index, pos))
  File "/home/zchen/workspace/demo/src/model/model.py", line 35, in features
    h1 = F.gelu(self.conv1(torch.cat([x, pos], dim=1), edge_index))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_geometric/nn/conv/gatv2_conv.py", line 200, in forward
    x_l = self.lin_l(x).view(-1, H, C)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_geometric/nn/dense/linear.py", line 118, in forward
    return F.linear(x, self.weight, self.bias)
RuntimeError: expected scalar type Half but found Float

@tjruwase
Copy link

@chenzhekl, thanks for sharing the stack trace. It does look like what I suspected. The problem is that DeepSpeed does not do automatic casting of input tensors to match parameters type. In this case, I suspect that self.weight and self.bias are fp16 while x is fp32, or vice-versa. One solution would be for the client to manually cast the input tensors, such as x to match the parameters type. Could you please try that change? Also, could you share the client script to reproduce this? Thanks!

@rohitgr7
Copy link
Contributor

it'd be great if someone can share a reproducible script to replicate the issue.

@rohitgr7 rohitgr7 assigned rohitgr7 and unassigned SeanNaren Jul 27, 2022
@chenzhekl
Copy link

@tjruwase Thanks for your suggestions. You are right. After manually casting the input to half, the error message disappears. I guess this might have something to do with PyG`s way of packing data. Here is a small script to reproduce the error.

import os
import torch
import torch.nn.functional as F
from torch.optim import Adam
from pytorch_lightning import LightningModule, Trainer
from torch_geometric.nn import EdgeConv, MLP, knn_graph
from torch_geometric.data import Data, Dataset, DataLoader


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = Data(
            x=torch.randn(size, 3),
            pos=torch.randn(size, 3),
            y=torch.randn(size, 5),
        )

    def __getitem__(self, index):
        return self.data

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.conv1 = EdgeConv(MLP(([2 * 6, 64, 64, 64])))
        self.conv2 = EdgeConv(MLP(([2 * 64, 128])))
        self.lin1 = torch.nn.Linear(64 + 128, 1024)
        self.mlp = MLP([1024, 512, 256, 5], dropout=0.0, batch_norm=False)

    def forward(self, x, batch, pos, edge_index):
        features = F.elu(self.features(x, edge_index, pos))
        pooled_features = nn.global_max_pool(features, batch)
        y = self.mlp(pooled_features)

        return y.sigmoid()

    def features(self, x, edge_index, pos):
        h1 = self.conv1(torch.cat([x, pos], dim=1), edge_index)
        h2 = self.conv2(h1, edge_index)
        h3 = self.lin1(torch.cat([h1, h2], dim=1))

        return h3

    def training_step(self, batch, batch_idx):
        edge_index = knn_graph(batch.pos, k=5, batch=batch.batch)
        y_ = self(batch.x, batch.batch, batch.pos, edge_index)
        loss = F.l1_loss(y_, batch.y)
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        edge_index = knn_graph(batch.pos, k=5, batch=batch.batch)
        y_ = self(batch.x, batch.batch, batch.pos, edge_index)
        loss = F.l1_loss(y_, batch.y)
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        logger=False,
        enable_checkpointing=False,
        gpus=1,
        precision=16,
        strategy = "deepspeed_stage_3"
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


if __name__ == "__main__":
    run()

@SeanNaren
Copy link
Contributor

I recall running into this before with geometric. I think this is down to the Data object, Lightning cannot automatically traverse the object to move tensors to the correct device. A manual casting may be the cleanest solution currently!

@tjruwase
Copy link

@chenzhekl, thanks for confirming and sharing a repro.

@SeanNaren, thanks for the sharing your experience.

@Borda Borda self-assigned this Nov 7, 2022
@awaelchli awaelchli assigned awaelchli and unassigned Borda and rohitgr7 Mar 18, 2023
@awaelchli awaelchli removed their assignment Nov 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working strategy: deepspeed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants