Skip to content

Question about using multi-gpu #5

Open
@WeiyiLi

Description

Hello Ruby,

I am modifying your code to try to run your code on multi-gpu using DataParallel. For the Cifar10 example, I changed the code in main function as below:

    device_id = []
    device_id.append(0)
    device_id.append(1)
    model = nn.DataParallel(model, device_ids=device_id).cuda()

    # if isinstance(model, nn.DataParallel):
    #     model = model.module

But it reports error:

Traceback (most recent call last):            
   File "legr.py", line 302, in <module>                                                                                                                                                                                                             
       legr.prune(args.name, args.model, args.long_ft, (1-(prune_away)/100.))                                                                                                                                                                        
  File "legr.py", line 173, in prune                                                                                                                                                                                                                
     acc = test(self.model, self.test_loader, device=self.device)                                                                                                                                                                                  
  File "/home/workspace/test_gpu/utils/drivers.py", line 142, in test                                                                                                                                                                      
    output = model(batch)                                                                                                                                                                                                                         
  File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__                                                                                                                                     
    result = self.forward(*input, **kwargs)                                                                                                                                                                                                       
  File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward                                                                                                                              
    outputs = self.parallel_apply(replicas, inputs, kwargs)                                                                                                                                                                                       
  File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply                                                                                                                       
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])                                                                                                                                                              
  File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply                                                                                                                       
    raise output                                                                                                                                                                                                                                  
  File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker                                                                                                                              
    output = module(*input, **kwargs)                                                                                                                                                                                                             
  File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__                                                                                                                                     
    result = self.forward(*input, **kwargs)                                                                                                                                                                                                       
 File "/home/workspace/test_gpu/model/resnet_cifar10.py", line 87, in forward                                                                                                                                                             
   x = self.features(x)                                                                                                                                                                                                                          
 File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__                                                                                                                                     
   result = self.forward(*input, **kwargs)                                                                                                                                                                                                       
 File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward                                                                                                                                    
   input = module(input)                                                                                                                                                                                                                         
 File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__                                                                                                                                     
   result = self.forward(*input, **kwargs)                                                                                                                                                                                                       
 File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward                                                                                                                                    
    input = module(input)                                                                                                                                                                                                                         
 File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__                                                                                                                                     
    result = self.forward(*input, **kwargs)                                                                                                                                                                                                       
 File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 320, in forward                                                                                                                                        
   self.padding, self.dilation, self.groups)                                                                                                                                                                                                   
 RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)  

I tried to comment out legr.pruner.forward(torch.zeros((1,3,dummy_size, dummy_size), device=device)) in main function, then the test function in line 173 will not report error. So I wonder whether the forward function in fp_resnet.py modify the original model so that the original model cannot be tested using multi-gpu?

When I tried the ResNet50 on ImageNet using multi-gpu, it reports the similar error.

Could you also please advise how I can use multiple GPUs to run your code?

Thank you!

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions