forked from deepinsight/insightface
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6de9652
commit a87d25b
Showing
1 changed file
with
257 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
#!/usr/bin/env python | ||
# -*- encoding: utf-8 -*- | ||
""" | ||
@Author : Qingping Zheng | ||
@Contact : qingpingzheng2014@gmail.com | ||
@File : encoding.py | ||
@Time : 10/01/21 00:00 PM | ||
@Desc : | ||
@License : Licensed under the Apache License, Version 2.0 (the "License"); | ||
@Copyright : Copyright 2022 The Authors. All Rights Reserved. | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import functools | ||
import threading | ||
import torch | ||
import torch.cuda.comm as comm | ||
|
||
from torch.autograd import Variable, Function | ||
from torch.nn.parallel.data_parallel import DataParallel | ||
from torch.nn.parallel.parallel_apply import get_a_var | ||
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast | ||
|
||
torch_ver = torch.__version__[:3] | ||
|
||
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', | ||
'patch_replication_callback'] | ||
|
||
def allreduce(*inputs): | ||
"""Cross GPU all reduce autograd operation for calculate mean and | ||
variance in SyncBN. | ||
""" | ||
return AllReduce.apply(*inputs) | ||
|
||
class AllReduce(Function): | ||
@staticmethod | ||
def forward(ctx, num_inputs, *inputs): | ||
ctx.num_inputs = num_inputs | ||
ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] | ||
inputs = [inputs[i:i + num_inputs] | ||
for i in range(0, len(inputs), num_inputs)] | ||
# sort before reduce sum | ||
inputs = sorted(inputs, key=lambda i: i[0].get_device()) | ||
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) | ||
outputs = comm.broadcast_coalesced(results, ctx.target_gpus) | ||
return tuple([t for tensors in outputs for t in tensors]) | ||
|
||
@staticmethod | ||
def backward(ctx, *inputs): | ||
inputs = [i.data for i in inputs] | ||
inputs = [inputs[i:i + ctx.num_inputs] | ||
for i in range(0, len(inputs), ctx.num_inputs)] | ||
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) | ||
outputs = comm.broadcast_coalesced(results, ctx.target_gpus) | ||
return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) | ||
|
||
|
||
class Reduce(Function): | ||
@staticmethod | ||
def forward(ctx, *inputs): | ||
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] | ||
inputs = sorted(inputs, key=lambda i: i.get_device()) | ||
return comm.reduce_add(inputs) | ||
|
||
@staticmethod | ||
def backward(ctx, gradOutput): | ||
return Broadcast.apply(ctx.target_gpus, gradOutput) | ||
|
||
|
||
class DataParallelModel(DataParallel): | ||
"""Implements data parallelism at the module level. | ||
This container parallelizes the application of the given module by | ||
splitting the input across the specified devices by chunking in the | ||
batch dimension. | ||
In the forward pass, the module is replicated on each device, | ||
and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. | ||
Note that the outputs are not gathered, please use compatible | ||
:class:`encoding.parallel.DataParallelCriterion`. | ||
The batch size should be larger than the number of GPUs used. It should | ||
also be an integer multiple of the number of GPUs so that each chunk is | ||
the same size (so that each GPU processes the same number of samples). | ||
Args: | ||
module: module to be parallelized | ||
device_ids: CUDA devices (default: all devices) | ||
Reference: | ||
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, | ||
Amit Agrawal. “Context Encoding for Semantic Segmentation. | ||
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* | ||
Example:: | ||
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) | ||
>>> y = net(x) | ||
""" | ||
def gather(self, outputs, output_device): | ||
return outputs | ||
|
||
def replicate(self, module, device_ids): | ||
modules = super(DataParallelModel, self).replicate(module, device_ids) | ||
execute_replication_callbacks(modules) | ||
return modules | ||
|
||
|
||
class DataParallelCriterion(DataParallel): | ||
""" | ||
Calculate loss in multiple-GPUs, which balance the memory usage for | ||
Semantic Segmentation. | ||
The targets are splitted across the specified devices by chunking in | ||
the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. | ||
Reference: | ||
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, | ||
Amit Agrawal. “Context Encoding for Semantic Segmentation. | ||
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* | ||
Example:: | ||
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) | ||
>>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) | ||
>>> y = net(x) | ||
>>> loss = criterion(y, target) | ||
""" | ||
def forward(self, inputs, *targets, **kwargs): | ||
# input should be already scatterd | ||
# scattering the targets instead | ||
if not self.device_ids: | ||
return self.module(inputs, *targets, **kwargs) | ||
targets, kwargs = self.scatter(targets, kwargs, self.device_ids) | ||
if len(self.device_ids) == 1: | ||
return self.module(inputs, *targets[0], **kwargs[0]) | ||
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) | ||
outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) | ||
return Reduce.apply(*outputs) / len(outputs) | ||
#return self.gather(outputs, self.output_device).mean() | ||
|
||
|
||
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): | ||
assert len(modules) == len(inputs) | ||
assert len(targets) == len(inputs) | ||
if kwargs_tup: | ||
assert len(modules) == len(kwargs_tup) | ||
else: | ||
kwargs_tup = ({},) * len(modules) | ||
if devices is not None: | ||
assert len(modules) == len(devices) | ||
else: | ||
devices = [None] * len(modules) | ||
|
||
lock = threading.Lock() | ||
results = {} | ||
if torch_ver != "0.3": | ||
grad_enabled = torch.is_grad_enabled() | ||
|
||
def _worker(i, module, input, target, kwargs, device=None): | ||
if torch_ver != "0.3": | ||
torch.set_grad_enabled(grad_enabled) | ||
if device is None: | ||
device = get_a_var(input).get_device() | ||
try: | ||
if not isinstance(input, tuple): | ||
input = (input,) | ||
with torch.cuda.device(device): | ||
output = module(*(input + target), **kwargs) | ||
with lock: | ||
results[i] = output | ||
except Exception as e: | ||
with lock: | ||
results[i] = e | ||
|
||
if len(modules) > 1: | ||
threads = [threading.Thread(target=_worker, | ||
args=(i, module, input, target, | ||
kwargs, device),) | ||
for i, (module, input, target, kwargs, device) in | ||
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] | ||
|
||
for thread in threads: | ||
thread.start() | ||
for thread in threads: | ||
thread.join() | ||
else: | ||
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) | ||
|
||
outputs = [] | ||
for i in range(len(inputs)): | ||
output = results[i] | ||
if isinstance(output, Exception): | ||
raise output | ||
outputs.append(output) | ||
return outputs | ||
|
||
|
||
########################################################################### | ||
# Adapted from Synchronized-BatchNorm-PyTorch. | ||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch | ||
# | ||
class CallbackContext(object): | ||
pass | ||
|
||
|
||
def execute_replication_callbacks(modules): | ||
""" | ||
Execute an replication callback `__data_parallel_replicate__` on each module created | ||
by original replication. | ||
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` | ||
Note that, as all modules are isomorphism, we assign each sub-module with a context | ||
(shared among multiple copies of this module on different devices). | ||
Through this context, different copies can share some information. | ||
We guarantee that the callback on the master copy (the first copy) will be called ahead | ||
of calling the callback of any slave copies. | ||
""" | ||
master_copy = modules[0] | ||
nr_modules = len(list(master_copy.modules())) | ||
ctxs = [CallbackContext() for _ in range(nr_modules)] | ||
|
||
for i, module in enumerate(modules): | ||
for j, m in enumerate(module.modules()): | ||
if hasattr(m, '__data_parallel_replicate__'): | ||
m.__data_parallel_replicate__(ctxs[j], i) | ||
|
||
|
||
def patch_replication_callback(data_parallel): | ||
""" | ||
Monkey-patch an existing `DataParallel` object. Add the replication callback. | ||
Useful when you have customized `DataParallel` implementation. | ||
Examples: | ||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) | ||
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) | ||
> patch_replication_callback(sync_bn) | ||
# this is equivalent to | ||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) | ||
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) | ||
""" | ||
|
||
assert isinstance(data_parallel, DataParallel) | ||
|
||
old_replicate = data_parallel.replicate | ||
|
||
@functools.wraps(old_replicate) | ||
def new_replicate(module, device_ids): | ||
modules = old_replicate(module, device_ids) | ||
execute_replication_callbacks(modules) | ||
return modules | ||
|
||
data_parallel.replicate = new_replicate | ||
|