Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DataParallel and make Block support DataParallel #87

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

gngdb
Copy link

@gngdb gngdb commented Sep 10, 2022

There's a lambda function in Block that causes errors when trying to train with DataParallel enabled. I'm not sure if you want to add DataParallel to minGPT because maybe that's not minimal enough. Could reduce this to just make the change to mingpt/model.py to eliminate the error if people use DataParallel with this model.

Example error:

Traceback (most recent call last):                                                                                                                                                                                                  [404/9341]
  File "projects/chargpt/chargpt.py", line 134, in <module>
    trainer.run()
  File "/h/gngdb/repos/minGNS/mingpt/trainer.py", line 97, in run
    logits, self.loss = model(x, y)
  File "/nobackup/gngdb/envs/humor_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/nobackup/gngdb/envs/humor_env/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/nobackup/gngdb/envs/humor_env/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/nobackup/gngdb/envs/humor_env/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/nobackup/gngdb/envs/humor_env/lib/python3.7/site-packages/torch/_utils.py", line 461, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/nobackup/gngdb/envs/humor_env/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/nobackup/gngdb/envs/humor_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/h/gngdb/repos/minGNS/mingpt/model.py", line 271, in forward
    x = block(x)                       
  File "/nobackup/gngdb/envs/humor_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/h/gngdb/repos/minGNS/mingpt/model.py", line 92, in forward                                           
    x = x + self.mlpf(self.ln_2(x))                                             
  File "/h/gngdb/repos/minGNS/mingpt/model.py", line 88, in <lambda>           
    self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward      
  File "/nobackup/gngdb/envs/humor_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)                                       
  File "/nobackup/gngdb/envs/humor_env/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)         
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_addmm)

Removing the lambda function resolves it, although I don't fully understand why.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant