diff --git a/client/README.md b/client/README.md index a7537040d..9403e47a6 100644 --- a/client/README.md +++ b/client/README.md @@ -2,12 +2,15 @@ client负责管理本进程的显存和内存使用。 系统中多个进程Client管理的显存和内存互相隔离的,这样避免了client之间进程进程间通信。 -client可以register module,也可以register paramter。 -通过param作为key来索引每个Parameter对应的data和grad。 +Client的成员变量有ChunkList,ChunkTensorIndex。 +Chunk是一段连续的内存空间,可以在内存或者显存中。 +Chunk可以在内存和显存间按需移动。 +ChunkList就是一串Chunk的链表。 Parameter的data和grad存储在以Chunk方式管理的tensor中(chunked tensor)。 -Chunk是一段连续的内存空间,可以在内存或者显存中。Chunk可以在内存和显存间按需移动。 +Client通过param作为key来索引每个Parameter对应的data和grad。 Parameter是client管理的最小单位。 +client具备改造`torch.nn.Parameter`的能力,给Paramter添加一个ps_attr属性,架空它原本的data和grad。 client只能注册一个nn.Parameter,而无法单独注册一个tensor。 因为,一旦改变的tensor.data的位置,那么这个tensor也就不是原来的tensor了。 如果想让一个tensor被弄成chunked tensor形式,需要将tensor包装成nn.Parameter给PS注册。 diff --git a/client/chunk_list.py b/client/chunk_list.py index 968afcc2c..34739c5c4 100644 --- a/client/chunk_list.py +++ b/client/chunk_list.py @@ -52,6 +52,12 @@ def size(self) -> int: """ return len(self.chunk_id_to_chunk_dict) + def max_chunk_size(self): + max_size = 0 + for chunk_id, chunk in self.chunk_id_to_chunk_dict.items(): + max_size = max(chunk.capacity, max_size) + return max_size + def access_chunk(self, chunk_id: int, compute_device: torch.device): """ 访问chunk_id,将chunk的内存准备到compute device上。 diff --git a/client/chunk_tensor_index.py b/client/chunk_tensor_index.py index c20fbf465..fd0f737ad 100644 --- a/client/chunk_tensor_index.py +++ b/client/chunk_tensor_index.py @@ -104,12 +104,14 @@ def generate_grad_tensor_param(self): """ 按chunk内部排列顺序生成所有当前没有被free的grad tensor所在的param """ + res_list = [] for chunk_id, tensor_id_list in self.dict_chunk_id_tensor_id.items(): for tensor_id in tensor_id_list: info = self.dict_tensor_id_info[tensor_id] if info.access_type == AccessType.GRAD and info.status( ) != PSTensorStatus.FREE: - yield info.param + res_list.append(info.param) + return res_list def generate_all_tensor_info(self): """ @@ -232,7 +234,8 @@ def tensor_id_to_chunk_id(self, tensor_id) -> int: else: return info.chunk_id - def get_chunk_id(self, param: PSParameter, access_type: AccessType) -> int: + def get_chunk_id(self, param: torch.nn.Parameter, + access_type: AccessType) -> int: tensor_id = param.ps_attr.get_tensor_id(access_type) info = self.dict_tensor_id_info.get(tensor_id) if info is None: diff --git a/client/client.py b/client/client.py index 46c35ec34..b5f0ef272 100644 --- a/client/client.py +++ b/client/client.py @@ -32,6 +32,63 @@ from utils.memory_monitor import get_memory_used +class CachedFP32Buff(object): + # TODO release max_chunk_size + def __init__(self, default_chunk_size: int): + self.cached_chunk_id = None + self.max_chunk_size = default_chunk_size + self.cpu_cached_fp32_payload = torch.zeros(self.max_chunk_size, + dtype=torch.float, + pin_memory=True) + self.cuda_cached_fp32_payload = torch.zeros( + self.max_chunk_size, + dtype=torch.float, + device=torch.device('cuda:0')) + + def reset(self): + self.cached_chunk_id = None + self.cpu_cached_fp32_payload.zero_() + self.cuda_cached_fp32_payload.zero_() + + def update_chunk(self, chunk: Chunk, time_profile=True): + """ + 如果chunk id被cache住,则直接cached_buff上索引 + chunk在cuda上,返回结果再cpu上 + cuda fp16 -> cpu fp16 -> cpu fp32 + cuda fp16 -> cpu fp32 慢! + """ + chunk_id = chunk.chunk_id + if self.cached_chunk_id is None or self.cached_chunk_id != chunk_id: + if time_profile: + start_time = time.time() + + self.cached_chunk_id = chunk_id + chunk_size = chunk.capacity + if chunk_size > self.max_chunk_size: + self.max_chunk_size = chunk_size + self.cpu_cached_fp32_payload = torch.zeros(self.max_chunk_size, + dtype=torch.float, + pin_memory=True) + self.cuda_cached_fp32_payload = torch.zeros( + self.max_chunk_size, + dtype=torch.float, + device=torch.device('cuda:0')) + + cuda_buff = self.cuda_cached_fp32_payload.narrow(0, 0, chunk_size) + cuda_buff.copy_(chunk.payload) + cpu_buff = self.cpu_cached_fp32_payload.narrow(0, 0, chunk_size) + cpu_buff.copy_(cuda_buff) + # self.cpu_cached_fp32_payload.copy_(chunk.payload) + + if time_profile: + global_timer.gpu_cpu_move_elapse += time.time() - start_time + global_timer.gpu_cpu_move_times += 1 + global_timer.gpu_cpu_move_data_amount += chunk.capacity + + def access_chunk(self, start_offset, numel): + return self.cpu_cached_fp32_payload.narrow(0, start_offset, numel) + + class HybridPSClient(object): def __init__(self, gpu_index: int = 0, @@ -67,6 +124,8 @@ def __init__(self, self._chunk_id = -1 + self._cached_fp32_buff = CachedFP32Buff(default_chunk_size) + def _generate_chunk_id(self): self._chunk_id += 1 return self._chunk_id @@ -143,6 +202,13 @@ def generate_grad_params(self): """ return self.chunk_tensor_index.generate_grad_tensor_param() + def fp16_to_fp32_copy(self, param, access_type): + tensor_id = param.ps_attr.get_tensor_id(access_type) + info = self.chunk_tensor_index.get_tensor_info(tensor_id) + self._cached_fp32_buff.update_chunk(self.chunk_list[info.chunk_id]) + return self._cached_fp32_buff.access_chunk(info.start_offset, + info.numel) + def _assign_chunk_for_tensor(self, param, access_type): """ 为param分配一个chunk,如果已经存在的chunk有空隙则插在空隙中 @@ -292,8 +358,8 @@ def release(self, start_time = time.time() assert isinstance(reset_to_status, PSTensorStatus) - if param.ps_attr.get_status(access_type) != PSTensorStatus.COMPUTE: - return + # if param.ps_attr.get_status(access_type) != PSTensorStatus.COMPUTE: + # return chunk_id = self.chunk_tensor_index.get_chunk_id(param, access_type) logging.debug( diff --git a/ops/fp16_cpu_adam.py b/ops/fp16_cpu_adam.py index e53571b28..69a2f5fe0 100644 --- a/ops/fp16_cpu_adam.py +++ b/ops/fp16_cpu_adam.py @@ -25,7 +25,7 @@ def FP16_f_adam(client, fp32_params: List[torch.nn.Parameter], - fp16_param_with_grad, + fp16_param_with_grad_list, exp_avgs: List[torch.nn.Parameter], exp_avg_sqs: List[torch.nn.Parameter], max_exp_avg_sqs: List[Tensor], @@ -41,7 +41,8 @@ def FP16_f_adam(client, prefer_device, time_profile=True): r"""Functional API that performs Adam algorithm computation. - See :class:`~torch.optim.Adam` for details. + 按照在chunk内的存储顺序连续访问fp16_param_with_grad_list的参数,获取fp16 grad, + 以chunk为单位拷贝到一个tmp buff之中 """ timer = global_timer.IterationTimer() if time_profile: @@ -50,14 +51,13 @@ def FP16_f_adam(client, for i, param in enumerate(fp32_params): if time_profile: adam_iter_access_start = time.time() - compute_device = prefer_device client.access_data(param, compute_device) param_data = param.ps_attr.access_tensor(AccessType.DATA) param_grad = param_grad_buff.narrow(0, 0, param_data.numel()).view( param_data.shape) - fp16_param = fp16_param_with_grad[i] + fp16_param = fp16_param_with_grad_list[i] client.access_grad(fp16_param, torch.device('cuda:0')) fp16_param_grad = fp16_param.ps_attr.access_tensor(AccessType.GRAD) @@ -72,7 +72,7 @@ def FP16_f_adam(client, global_timer.cpu_gpu_move_data_amount += param_grad.numel() #TODO(jiaruifang) HOLD->FREE - fp16_param_grad.zero_() + # fp16_param_grad.zero_() client.release_grad(fp16_param, PSTensorStatus.FREE) exp_avg_param = exp_avgs[i] @@ -122,7 +122,7 @@ def FP16_f_adam(client, ) - f_adam_compute_start_time adam_iter_release_start = time.time() - fp16_param = fp16_param_with_grad[i] + fp16_param = fp16_param_with_grad_list[i] client.access_data(fp16_param, torch.device('cuda:0')) fp16_data = fp16_param.ps_attr.access_tensor(AccessType.DATA) if time_profile: @@ -132,8 +132,147 @@ def FP16_f_adam(client, global_timer.gpu_cpu_move_elapse += time.time() - start_time global_timer.gpu_cpu_move_data_amount += fp16_data.numel() global_timer.gpu_cpu_move_times += 1 + client.release_data(fp16_param, PSTensorStatus.HOLD) + client.release_data(param) + client.release_data(exp_avg_param) + client.release_data(exp_avg_sq_param) + + if time_profile: + global_timer.cpu_adam_release_elapse += time.time( + ) - adam_iter_release_start + + timer.tik(device_type='all') + global_timer.cpu_adam_elapse += time.time() - adam_start_time + + +def FP16_f_adamv2(client, + fp32_params: List[torch.nn.Parameter], + fp16_param_with_grad_list, + exp_avgs: List[torch.nn.Parameter], + exp_avg_sqs: List[torch.nn.Parameter], + max_exp_avg_sqs: List[Tensor], + state_steps: List[int], + amsgrad: bool, + beta1_list: List[float], + beta2_list: List[float], + lr_list: List[float], + weight_decay_list: List[float], + eps_list: List[float], + prefer_device, + param_grad_buff, + time_profile=True): + r"""Functional API that performs Adam algorithm computation. + 按照在chunk内的存储顺序连续访问fp16_param_with_grad_list的参数,获取fp16 grad, + 以chunk为单位拷贝到一个tmp buff之中 + """ + assert prefer_device.type == 'cpu' + timer = global_timer.IterationTimer() + if time_profile: + adam_start_time = time.time() + # TODO(jiaruifang)计算粒度为什么是tensor,而不是chunk + for i, param in enumerate(fp32_params): + if time_profile: + adam_iter_access_start = time.time() + compute_device = prefer_device + client.access_data(param, compute_device) + param_data = param.ps_attr.access_tensor(AccessType.DATA) + + fp16_param = fp16_param_with_grad_list[i] + + # 把fp16_param所在的chunk拷贝到tmp_buff中,并返回对应的tensor + if True: + # client.access_grad(fp16_param, torch.device('cuda:0')) + param_grad = client.fp16_to_fp32_copy( + fp16_param, AccessType.GRAD).view(param_data.shape) + # necessary to reset grads + client.release_grad(fp16_param, PSTensorStatus.FREE) + else: + client.access_grad(fp16_param, torch.device('cuda:0')) + fp16_param_grad = fp16_param.ps_attr.access_tensor(AccessType.GRAD) + + if time_profile: + start_time = time.time() + param_grad = param_grad_buff.narrow(0, 0, param_data.numel()).view( + param_data.shape) + # torch.cuda.synchronize() + param_grad.copy_(fp16_param_grad, non_blocking=False) + # torch.cuda.synchronize() + if time_profile: + global_timer.gpu_cpu_move_elapse += time.time() - start_time + global_timer.gpu_cpu_move_times += 1 + global_timer.gpu_cpu_move_data_amount += param_grad.numel() + + client.release_grad(fp16_param, PSTensorStatus.FREE) + # print(i, 'param_grad', param_grad) + + exp_avg_param = exp_avgs[i] + exp_avg_sq_param = exp_avg_sqs[i] + + client.access_data(exp_avg_param, compute_device) + client.access_data(exp_avg_sq_param, compute_device) + + exp_avg = exp_avg_param.ps_attr.access_tensor(AccessType.DATA) + + exp_avg_sq = exp_avg_sq_param.ps_attr.access_tensor(AccessType.DATA) + + if time_profile: + global_timer.cpu_adam_access_elapse += time.time( + ) - adam_iter_access_start + f_adam_compute_start_time = time.time() + + step = state_steps[i] + + beta1 = beta1_list[i] + beta2 = beta2_list[i] + eps = eps_list[i] + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + weight_decay = weight_decay_list[i] + + if weight_decay != 0: + param_grad = param_grad.add(param_data, alpha=weight_decay) + + exp_avg.mul_(beta1).add_(param_grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(param_grad, + param_grad, + value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum(max_exp_avg_sqs[i], + exp_avg_sq, + out=max_exp_avg_sqs[i]) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sqs[i].sqrt() / + math.sqrt(bias_correction2)).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + lr = lr_list[i] + step_size = lr / bias_correction1 + + param_data.addcdiv_(exp_avg, denom, value=-step_size) + + if time_profile: + global_timer.cpu_adam_f_elapse += time.time( + ) - f_adam_compute_start_time + adam_iter_release_start = time.time() + + fp16_param = fp16_param_with_grad_list[i] + client.access_data(fp16_param, torch.device('cuda:0')) + fp16_data = fp16_param.ps_attr.access_tensor(AccessType.DATA) + if time_profile: + start_time = time.time() + # TODO 直接拷贝一块 + fp16_data.copy_(param_data, non_blocking=False) + if time_profile: + global_timer.cpu_gpu_move_elapse += time.time() - start_time + global_timer.cpu_gpu_move_data_amount += fp16_data.numel() + global_timer.cpu_gpu_move_times += 1 + + client.release_data(fp16_param, PSTensorStatus.HOLD) client.release_data(param) client.release_data(exp_avg_param) client.release_data(exp_avg_sq_param) @@ -186,31 +325,32 @@ def __init__(self, max_param_size = 0 data_type = None + + # 将group参数放置到每个param内部 for group in self.param_groups: for p in group['params']: max_param_size = max(max_param_size, p.numel()) data_type = p.dtype + self.state[p]['betas'] = group['betas'] + self.state[p]['lr'] = group['lr'] + self.state[p]['weight_decay'] = group['weight_decay'] + self.state[p]['eps'] = group['eps'] self.max_param_size = max_param_size assert data_type == torch.half - if self.prefer_device.type == 'cpu': - self.param_grad_buff = torch.zeros(max_param_size, - dtype=torch.float, - device=self.prefer_device, - pin_memory=True) - else: - self.param_grad_buff = torch.zeros(max_param_size, - dtype=torch.float, - device=self.prefer_device) + # TODO(jiaruifang) buff应该是最大chunk的size rather than default chunk size. + # move to first init # 存储fp32 param的data - # 在初始化先把fp16的数据拷贝到fp32的参数内 - # 按照初始化顺序来拷贝 + # 在初始化时,先把fp16的数据拷贝到fp32的参数内 + # 按照初始化顺序来拷贝,这不好 + + # 可以做一个p -> group的映射,获取正确的group['betas'] self.fp32_params_list = [] for i, group in enumerate(self.param_groups): for j, p in enumerate(group['params']): state = self.state[p] - state['step'] = 0 + # state['step'] = 0 fp32_param = torch.nn.Parameter(torch.zeros_like( p, dtype=torch.float, device=torch.device('cpu:0')), @@ -222,23 +362,23 @@ def __init__(self, state['fp32_param_data'] = fp32_param self.client.release_data(fp32_param) - state['exp_avg'] = torch.nn.Parameter(torch.zeros( - p.shape, dtype=torch.float, device=self.prefer_device), - requires_grad=False) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.nn.Parameter(torch.zeros( - p.shape, dtype=torch.float, device=self.prefer_device), - requires_grad=False) + # state['exp_avg'] = torch.nn.Parameter(torch.zeros( + # p.shape, dtype=torch.float, device=self.prefer_device), + # requires_grad=False) + # # Exponential moving average of squared gradient values + # state['exp_avg_sq'] = torch.nn.Parameter(torch.zeros( + # p.shape, dtype=torch.float, device=self.prefer_device), + # requires_grad=False) - self.client.access_data(state['exp_avg'], self.prefer_device) - state['exp_avg'].ps_attr.access_tensor(AccessType.DATA).zero_() - self.client.release_data(state['exp_avg']) + # self.client.access_data(state['exp_avg'], self.prefer_device) + # state['exp_avg'].ps_attr.access_tensor(AccessType.DATA).zero_() + # self.client.release_data(state['exp_avg']) - self.client.access_data(state['exp_avg_sq'], - self.prefer_device) - state['exp_avg_sq'].ps_attr.access_tensor( - AccessType.DATA).zero_() - self.client.release_data(state['exp_avg_sq']) + # self.client.access_data(state['exp_avg_sq'], + # self.prefer_device) + # state['exp_avg_sq'].ps_attr.access_tensor( + # AccessType.DATA).zero_() + # self.client.release_data(state['exp_avg_sq']) def __setstate__(self, state): super(CPUAdam, self).__setstate__(state) @@ -265,53 +405,176 @@ def step(self, closure=None): with torch.enable_grad(): loss = closure() - for i, group in enumerate(self.param_groups): - fp16_param_with_grad = [] - fp32_param = [] - exp_avgs = [] - exp_avg_sqs = [] - state_sums = [] - max_exp_avg_sqs = [] - state_steps = [] + fp16_param_with_grad_list = [] + fp32_param_list = [] + exp_avgs = [] + exp_avg_sqs = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1_list = [] + beta2_list = [] + weight_decay_list = [] + eps_list = [] + lr_list = [] + + self.client._cached_fp32_buff.reset() + for p in self.client.generate_grad_params(): + if p.requires_grad: + fp16_param_with_grad_list.append(p) + state = self.state[p] - for j, p in enumerate(group['params']): - if p.requires_grad: - fp16_param_with_grad.append(p) - state = self.state[p] - - # 初始化M,V,FP32 data - # 把原来的FP32 data释放掉 - if len(state) == 0: - # 第一次预热时候,拷贝FP32 data数据 - state['step'] = 0 - if group['amsgrad']: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - fp32_param.append(state['fp32_param_data']) - - if group['amsgrad']: - max_exp_avg_sqs.append(state['max_exp_avg_sq']) - - # update the steps for each param group update - state['step'] += 1 - # record the step after step update - state_steps.append(state['step']) - else: - raise RuntimeError(f"tensor id {p.ps_attr.grad_id()}") - - beta1, beta2 = group['betas'] - - # self.client.chunk_tensor_index.visit_chunks(self.client.chunk_list) - # input('wait') - FP16_f_adam( - self.client, fp32_param, fp16_param_with_grad, exp_avgs, - exp_avg_sqs, max_exp_avg_sqs, state_steps, group['amsgrad'], - beta1, beta2, group['lr'], group['weight_decay'], group['eps'], - self.max_param_size, self.param_grad_buff if hasattr( - self, 'param_grad_buff') else None, self.prefer_device) + # if len(state) == 0: + if 'exp_avg' not in state: + # 第一次预热时候,拷贝FP32 data数据 + state['step'] = 0 + + if self.prefer_device.type == 'cpu': + self.param_grad_buff = torch.zeros( + self.client.chunk_list.max_chunk_size(), + dtype=torch.float, + device=self.prefer_device, + pin_memory=True) + else: + self.param_grad_buff = torch.zeros( + self.client.chunk_list.max_chunk_size(), + dtype=torch.float, + device=self.prefer_device) + + state['exp_avg'] = torch.nn.Parameter(torch.zeros( + p.ps_attr.ps_shape, + dtype=torch.float, + device=self.prefer_device), + requires_grad=False) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.nn.Parameter( + torch.zeros(p.ps_attr.ps_shape, + dtype=torch.float, + device=self.prefer_device), + requires_grad=False) + + self.client.access_data(state['exp_avg'], + self.prefer_device) + state['exp_avg'].ps_attr.access_tensor( + AccessType.DATA).zero_() + self.client.release_data(state['exp_avg']) + + self.client.access_data(state['exp_avg_sq'], + self.prefer_device) + state['exp_avg_sq'].ps_attr.access_tensor( + AccessType.DATA).zero_() + self.client.release_data(state['exp_avg_sq']) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + fp32_param_list.append(state['fp32_param_data']) + beta1, beta2 = state['betas'] + + beta1_list.append(beta1) + beta2_list.append(beta2) + lr_list.append(state['lr']) + weight_decay_list.append(state['weight_decay']) + eps_list.append(state['eps']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + else: + raise RuntimeError(f"tensor id {p.ps_attr.grad_id()}") + + FP16_f_adamv2(self.client, fp32_param_list, fp16_param_with_grad_list, + exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, + False, beta1_list, beta2_list, lr_list, + weight_decay_list, eps_list, self.prefer_device, + self.param_grad_buff) + + # for i, group in enumerate(self.param_groups): + # fp16_param_with_grad_list = [] + # fp32_param_list = [] + # exp_avgs = [] + # exp_avg_sqs = [] + # state_sums = [] + # max_exp_avg_sqs = [] + # state_steps = [] + + # for j, p in enumerate(group['params']): + # if p.requires_grad: + # fp16_param_with_grad_list.append(p) + # state = self.state[p] + + # # if len(state) == 0: + # if 'exp_avg' not in state: + # # 第一次预热时候,拷贝FP32 data数据 + # state['step'] = 0 + + # # fp32_param = torch.nn.Parameter(torch.zeros( + # # p.ps_attr.ps_shape, + # # dtype=torch.float, + # # device=torch.device('cpu:0')), + # # requires_grad=False) + # # self.client.access_data(fp32_param, self.prefer_device) + # # self.client.access_data(p, self.prefer_device) + + # # fp32_param_data = fp32_param.ps_attr.access_tensor( + # # AccessType.DATA) + # # fp16_param_data = p.ps_attr.access_tensor(AccessType.DATA) + + # # fp32_param_data.copy_(fp16_param_data.float()) + # # state['fp32_param_data'] = fp32_param + + # # self.client.release_data(fp32_param) + # # self.client.release_data(p) + + # state['exp_avg'] = torch.nn.Parameter(torch.zeros( + # p.ps_attr.ps_shape, + # dtype=torch.float, device=self.prefer_device), + # requires_grad=False) + # # Exponential moving average of squared gradient values + # state['exp_avg_sq'] = torch.nn.Parameter(torch.zeros( + # p.ps_attr.ps_shape, + # dtype=torch.float, device=self.prefer_device), + # requires_grad=False) + + # self.client.access_data(state['exp_avg'], self.prefer_device) + # state['exp_avg'].ps_attr.access_tensor(AccessType.DATA).zero_() + # self.client.release_data(state['exp_avg']) + + # self.client.access_data(state['exp_avg_sq'], + # self.prefer_device) + # state['exp_avg_sq'].ps_attr.access_tensor( + # AccessType.DATA).zero_() + # self.client.release_data(state['exp_avg_sq']) + + # if group['amsgrad']: + # raise NotImplementedError + # # Maintains max of all exp. moving avg. of sq. grad. values + # state['max_exp_avg_sq'] = torch.zeros_like( + # p, memory_format=torch.preserve_format) + + # exp_avgs.append(state['exp_avg']) + # exp_avg_sqs.append(state['exp_avg_sq']) + # fp32_param_list.append(state['fp32_param_data']) + + # if group['amsgrad']: + # max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # # update the steps for each param group update + # state['step'] += 1 + # # record the step after step update + # state_steps.append(state['step']) + # else: + # raise RuntimeError(f"tensor id {p.ps_attr.grad_id()}") + + # beta1, beta2 = group['betas'] + + # # self.client.chunk_tensor_index.visit_chunks(self.client.chunk_list) + # # input('wait') + # FP16_f_adam( + # self.client, fp32_param_list, fp16_param_with_grad_list, exp_avgs, + # exp_avg_sqs, max_exp_avg_sqs, state_steps, group['amsgrad'], + # beta1, beta2, group['lr'], group['weight_decay'], group['eps'], + # self.max_param_size, self.param_grad_buff if hasattr( + # self, 'param_grad_buff') else None, self.prefer_device) return loss diff --git a/tests/test_bert.py b/tests/test_bert.py index 5c5eba021..27714a6af 100644 --- a/tests/test_bert.py +++ b/tests/test_bert.py @@ -160,7 +160,7 @@ def test_bert_model(is_ckp: bool = False, if is_ps: # chunk 512 MB, good for CPU-GPU bandwidth client = HybridPSClient(gpu_index=0, - default_chunk_size=1024 * 1024 * 512, + default_chunk_size=1024 * 1024 * 8, warmup=True, is_fp16=is_fp16)