Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

use all_gather to gather results from all gpus #383

Merged
merged 1 commit into from
Jan 25, 2019

Conversation

wat3rBro
Copy link
Contributor

from @ppwwyyxx

@facebook-github-bot facebook-github-bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Jan 24, 2019
@qianyizhang
Copy link
Contributor

whats the advantage using one over the other?

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much better, thanks a lot @wat3rBro and @ppwwyyxx !

I've left a few nits, but I'm merging this as this is already much better than before.

return data_list


def reduce_dict(input_dict, average=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I believe this is not used anywhere (it probably is due to a refactor in the engine/trainer?)

for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this could generate NaN because the data is unnitialized. This per se doesn't affect the overall results because we remove the padded value, but I'm not sure if it could cause problems with dist.all_gather.

# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it would be good to use the new API for this:

device=torch.device("cuda")
...
torch.empty((max_size,), dtype=torch.uint8, device=device

tensor = torch.ByteTensor(storage).to("cuda")

# obtain Tensor size of each rank
local_size = torch.IntTensor([tensor.numel()]).to("cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be good to replace the usages of IntTensor/ByteTensor with their new constructs. In this case (because it's initialized):

device = torch.device("cuda")
...
local_size = torch.tensor([tensor.numel()], dtype=torch.int32, device=device)
# or, if 0d tensors work with dist.all_gather
local_size = torch.tensor(tensor.numel(), dtype=torch.int32, device=device)

@fmassa
Copy link
Contributor

fmassa commented Jan 25, 2019

@qianyizhang the advantage is that it makes it possible to do multi-machine testing, which was not possible before.

@fmassa fmassa merged commit 5f2a826 into facebookresearch:master Jan 25, 2019
@qianyizhang
Copy link
Contributor

@fmassa
thanks. Also I have another somewhat related question
Is it possible to run multiple sessions of maskrcnn_benchmark (or even torch.distributed processes) on the same node?
for example i have a server with 8 gpus and i want to run 2 sets of experiments using 4 cards each.

I got the rendezvous complaining "RuntimeError: Address already in use", how to make it work?

@fmassa
Copy link
Contributor

fmassa commented Jan 25, 2019

@qianyizhang yes, it's possible, but you need to change the master_addr and the master_port in torch.distributed.launch, see https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py#L164-L169

# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, @ppwwyyxx this will probably be problematic for large datasets, as we will run out of memory on the GPU when trying to perform this communication.

The idea that I had was to use shared memory on the CPU and communicate the address of the shared memory, but this doesn't work on the multiple-machine case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. But is there any way to do all-gather on CPUs (given that the dist backend was initialized with "nccl")?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With c10d, it is now possible to have more than one dist backend at a time. So one could potentially have one nccl backend and one mpi backend?

Copy link
Contributor

@yelantf yelantf Feb 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am using these communication codes in other task. Because of the large size of each data, I get OOM error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yelantingfeng I'd recommend either:

  • reverting this change locally for now
  • try creating a new process group with c10d which is on the CPU, and communicate this data on the CPU instead

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am using these communication codes in other task. Because of the large size of each data, I get OOM error.
I got the OOM error. Have you implement the second method recommended by @fmassa ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yelantingfeng I'd recommend either:

  • reverting this change locally for now
  • try creating a new process group with c10d which is on the CPU, and communicate this data on the CPU instead

I think we could just set a memory limit for this all_gather function. I implement this by splitting those ByteTensors into chunks. After some tests, I found my implementation can limit the total usage of memory to MiB level. Though not 100% precise limit, but I think it should be useful enough. I would be glad to send a PR if you think this is a good improvement for this repository.

size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
Copy link

@pietern pietern Jan 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wat3rBro You can use dist.all_reduce(local_size, op=dist.ReduceOp.MAX) here for a little less code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sizes of all ranks are needed later.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, because of pickle requiring the exact size and doesn't tolerate additional NULs. Thanks for clarifying.

@JoyHuYY1412
Copy link

JoyHuYY1412 commented Aug 27, 2019

This is much better, thanks a lot @wat3rBro and @ppwwyyxx !

I've left a few nits, but I'm merging this as this is already much better than before.

HI, Can I use 'all_gather' to gather the weights of all gpus?
Similar to the issue here https://discuss.pytorch.org/t/how-to-preserve-backward-grad-fn-after-distributed-operations/49343

I want to achieve in each batch, different gpu outputs different weights, and the loss will be calculated using all the weights. When I use 'all_gather', I found the output accumulated weights loss the grad_fn.

Lyears pushed a commit to Lyears/maskrcnn-benchmark that referenced this pull request Jun 28, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed Do not delete this pull request or issue due to inactivity.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants