Skip to content

Commit

Permalink
fit cachedfp32buff to situation where chunk size larger than default …
Browse files Browse the repository at this point in the history
…chunk size.
  • Loading branch information
feifeibear committed May 21, 2021
1 parent 6849adc commit 6d98bd4
Show file tree
Hide file tree
Showing 6 changed files with 429 additions and 88 deletions.
9 changes: 6 additions & 3 deletions client/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
client负责管理本进程的显存和内存使用。
系统中多个进程Client管理的显存和内存互相隔离的,这样避免了client之间进程进程间通信。

client可以register module,也可以register paramter。
通过param作为key来索引每个Parameter对应的data和grad。
Client的成员变量有ChunkList,ChunkTensorIndex。
Chunk是一段连续的内存空间,可以在内存或者显存中。
Chunk可以在内存和显存间按需移动。
ChunkList就是一串Chunk的链表。
Parameter的data和grad存储在以Chunk方式管理的tensor中(chunked tensor)。
Chunk是一段连续的内存空间,可以在内存或者显存中。Chunk可以在内存和显存间按需移动
Client通过param作为key来索引每个Parameter对应的data和grad

Parameter是client管理的最小单位。
client具备改造`torch.nn.Parameter`的能力,给Paramter添加一个ps_attr属性,架空它原本的data和grad。
client只能注册一个nn.Parameter,而无法单独注册一个tensor。
因为,一旦改变的tensor.data的位置,那么这个tensor也就不是原来的tensor了。
如果想让一个tensor被弄成chunked tensor形式,需要将tensor包装成nn.Parameter给PS注册。
Expand Down
6 changes: 6 additions & 0 deletions client/chunk_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ def size(self) -> int:
"""
return len(self.chunk_id_to_chunk_dict)

def max_chunk_size(self):
max_size = 0
for chunk_id, chunk in self.chunk_id_to_chunk_dict.items():
max_size = max(chunk.capacity, max_size)
return max_size

def access_chunk(self, chunk_id: int, compute_device: torch.device):
"""
访问chunk_id,将chunk的内存准备到compute device上。
Expand Down
7 changes: 5 additions & 2 deletions client/chunk_tensor_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ def generate_grad_tensor_param(self):
"""
按chunk内部排列顺序生成所有当前没有被free的grad tensor所在的param
"""
res_list = []
for chunk_id, tensor_id_list in self.dict_chunk_id_tensor_id.items():
for tensor_id in tensor_id_list:
info = self.dict_tensor_id_info[tensor_id]
if info.access_type == AccessType.GRAD and info.status(
) != PSTensorStatus.FREE:
yield info.param
res_list.append(info.param)
return res_list

def generate_all_tensor_info(self):
"""
Expand Down Expand Up @@ -232,7 +234,8 @@ def tensor_id_to_chunk_id(self, tensor_id) -> int:
else:
return info.chunk_id

def get_chunk_id(self, param: PSParameter, access_type: AccessType) -> int:
def get_chunk_id(self, param: torch.nn.Parameter,
access_type: AccessType) -> int:
tensor_id = param.ps_attr.get_tensor_id(access_type)
info = self.dict_tensor_id_info.get(tensor_id)
if info is None:
Expand Down
70 changes: 68 additions & 2 deletions client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,63 @@
from utils.memory_monitor import get_memory_used


class CachedFP32Buff(object):
# TODO release max_chunk_size
def __init__(self, default_chunk_size: int):
self.cached_chunk_id = None
self.max_chunk_size = default_chunk_size
self.cpu_cached_fp32_payload = torch.zeros(self.max_chunk_size,
dtype=torch.float,
pin_memory=True)
self.cuda_cached_fp32_payload = torch.zeros(
self.max_chunk_size,
dtype=torch.float,
device=torch.device('cuda:0'))

def reset(self):
self.cached_chunk_id = None
self.cpu_cached_fp32_payload.zero_()
self.cuda_cached_fp32_payload.zero_()

def update_chunk(self, chunk: Chunk, time_profile=True):
"""
如果chunk id被cache住,则直接cached_buff上索引
chunk在cuda上,返回结果再cpu上
cuda fp16 -> cpu fp16 -> cpu fp32
cuda fp16 -> cpu fp32 慢!
"""
chunk_id = chunk.chunk_id
if self.cached_chunk_id is None or self.cached_chunk_id != chunk_id:
if time_profile:
start_time = time.time()

self.cached_chunk_id = chunk_id
chunk_size = chunk.capacity
if chunk_size > self.max_chunk_size:
self.max_chunk_size = chunk_size
self.cpu_cached_fp32_payload = torch.zeros(self.max_chunk_size,
dtype=torch.float,
pin_memory=True)
self.cuda_cached_fp32_payload = torch.zeros(
self.max_chunk_size,
dtype=torch.float,
device=torch.device('cuda:0'))

cuda_buff = self.cuda_cached_fp32_payload.narrow(0, 0, chunk_size)
cuda_buff.copy_(chunk.payload)
cpu_buff = self.cpu_cached_fp32_payload.narrow(0, 0, chunk_size)
cpu_buff.copy_(cuda_buff)
# self.cpu_cached_fp32_payload.copy_(chunk.payload)

if time_profile:
global_timer.gpu_cpu_move_elapse += time.time() - start_time
global_timer.gpu_cpu_move_times += 1
global_timer.gpu_cpu_move_data_amount += chunk.capacity

def access_chunk(self, start_offset, numel):
return self.cpu_cached_fp32_payload.narrow(0, start_offset, numel)


class HybridPSClient(object):
def __init__(self,
gpu_index: int = 0,
Expand Down Expand Up @@ -67,6 +124,8 @@ def __init__(self,

self._chunk_id = -1

self._cached_fp32_buff = CachedFP32Buff(default_chunk_size)

def _generate_chunk_id(self):
self._chunk_id += 1
return self._chunk_id
Expand Down Expand Up @@ -143,6 +202,13 @@ def generate_grad_params(self):
"""
return self.chunk_tensor_index.generate_grad_tensor_param()

def fp16_to_fp32_copy(self, param, access_type):
tensor_id = param.ps_attr.get_tensor_id(access_type)
info = self.chunk_tensor_index.get_tensor_info(tensor_id)
self._cached_fp32_buff.update_chunk(self.chunk_list[info.chunk_id])
return self._cached_fp32_buff.access_chunk(info.start_offset,
info.numel)

def _assign_chunk_for_tensor(self, param, access_type):
"""
为param分配一个chunk,如果已经存在的chunk有空隙则插在空隙中
Expand Down Expand Up @@ -292,8 +358,8 @@ def release(self,
start_time = time.time()

assert isinstance(reset_to_status, PSTensorStatus)
if param.ps_attr.get_status(access_type) != PSTensorStatus.COMPUTE:
return
# if param.ps_attr.get_status(access_type) != PSTensorStatus.COMPUTE:
# return

chunk_id = self.chunk_tensor_index.get_chunk_id(param, access_type)
logging.debug(
Expand Down
Loading

0 comments on commit 6d98bd4

Please sign in to comment.