-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bert #158
Bert #158
Changes from all commits
f738bac
b494ba5
acfde41
b4e9627
c6fbea3
fce71f2
ef390bc
51847e1
3f904db
48e835d
8eaa8a5
37d78be
98361a5
e6005bf
ae86109
fe6a418
6684a5d
746377a
b3d9786
a70db8d
b672228
df3a2b2
565a35d
b21dde9
d9c089b
d3b3e57
ffda7ad
cb97a53
10faf97
d679833
2fad460
9db4cab
ba5aa39
01a1106
a61cc9a
c602432
072d81c
59c77e5
44715ca
e793bb6
714eaf5
7242cb5
b467d58
1aac59d
2bc4922
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,31 +3,32 @@ | |
|
||
import torch | ||
import torch.distributed as dist | ||
try: | ||
import amp_C | ||
import apex_C | ||
from apex import amp | ||
from apex.amp import _amp_state | ||
from apex.contrib.optimizers.distributed_fused_lamb import DistributedFusedLAMB | ||
from apex.optimizers import FusedLAMB | ||
from apex.parallel import DistributedDataParallel as APEX_DDP | ||
from apex.parallel.distributed import flat_dist_call | ||
except ImportError: | ||
print("import apex error") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kunlunxin的机器没有apex,import会报错,加入try catch |
||
|
||
import amp_C | ||
import apex_C | ||
from apex import amp | ||
from apex.amp import _amp_state | ||
from apex.contrib.optimizers.distributed_fused_lamb import DistributedFusedLAMB | ||
from apex.optimizers import FusedLAMB | ||
from apex.parallel import DistributedDataParallel as APEX_DDP | ||
from apex.parallel.distributed import flat_dist_call | ||
from torch.cuda.amp import GradScaler | ||
from torch.nn.parallel import DistributedDataParallel as NativeDDP | ||
from torch.optim import Optimizer | ||
|
||
import utils | ||
import config | ||
#from converter import convert_model | ||
from .distributed_fused_lamb import _pipeline_block_reductions_patched, _pipeline_step_patched | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
|
||
BERT_MODEL = torch.nn.Module | ||
|
||
BERT_MODEL = torch.nn.Module | ||
|
||
def convert_model(model: BERT_MODEL) -> BERT_MODEL: | ||
return model | ||
|
||
|
||
def create_optimizer(model: BERT_MODEL) -> Optimizer: | ||
param_optimizer = list(model.named_parameters()) | ||
|
||
|
@@ -46,6 +47,7 @@ def create_optimizer(model: BERT_MODEL) -> Optimizer: | |
}] | ||
|
||
if config.distributed_lamb: | ||
from .distributed_fused_lamb import _pipeline_block_reductions_patched, _pipeline_step_patched | ||
DistributedFusedLAMB._pipeline_block_reductions = _pipeline_block_reductions_patched | ||
DistributedFusedLAMB._pipeline_step = _pipeline_step_patched | ||
optimizer = DistributedFusedLAMB( | ||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,9 +5,9 @@ | |
import torch.distributed as dist | ||
|
||
from torch.cuda.amp import GradScaler | ||
from torch.nn.parallel import DistributedDataParallel as NativeDDP | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from torch.optim import Optimizer | ||
from torch_xmlir.optimizer import Lamb | ||
from torch_xmlir.optimizer import FusedLAMB | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 替换为fuse优化器 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请提供kulunxin机器上跑通的截图 |
||
|
||
import torch_xmlir.core.xpu_model as xm | ||
|
||
|
@@ -38,26 +38,7 @@ def create_optimizer(model: BERT_MODEL) -> Optimizer: | |
0.0 | ||
}] | ||
|
||
if config.distributed_lamb: | ||
print("distributed") | ||
#DistributedFusedLAMB._pipeline_block_reductions = _pipeline_block_reductions_patched | ||
#DistributedFusedLAMB._pipeline_step = _pipeline_step_patched | ||
#optimizer = DistributedFusedLAMB(optimizer_grouped_parameters, lr=config.learning_rate, | ||
# betas=(config.opt_lamb_beta_1, config.opt_lamb_beta_2), | ||
# eps=1e-6, | ||
# max_grad_norm=1.0, | ||
# overlap_reductions=config.dwu_overlap_reductions, | ||
# dwu_group_size=config.dwu_group_size, | ||
# dwu_num_blocks=config.dwu_num_blocks, | ||
# dwu_num_chunks=config.dwu_num_chunks, | ||
# dwu_num_rs_pg=config.dwu_num_rs_pg, | ||
# dwu_num_ar_pg=config.dwu_num_ar_pg, | ||
# dwu_num_ag_pg=config.dwu_num_ag_pg, | ||
# use_nvlamb=False, | ||
# e5m2_allgather=config.dwu_e5m2_allgather) | ||
#optimizer.set_global_scale(float(os.getenv("INIT_LOSS_SCALE", 2 ** 20))) | ||
else: | ||
optimizer = Lamb(optimizer_grouped_parameters, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删除多余代码 |
||
optimizer = FusedLAMB(optimizer_grouped_parameters, | ||
lr=config.learning_rate, | ||
betas=(config.opt_lamb_beta_1, | ||
config.opt_lamb_beta_2)) | ||
|
@@ -71,97 +52,24 @@ def model_to_fp16(model: BERT_MODEL, | |
|
||
|
||
def model_to_ddp(model: BERT_MODEL) -> BERT_MODEL: | ||
use_ddp = dist.is_initialized() | ||
if use_ddp and config.use_xpu: | ||
from torch_xmlir.distributed import DistributedDataParallel as DDP | ||
model = DDP(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 替换为torch原生的ddp,而不是自定的ddp |
||
if dist.is_available() and dist.is_initialized(): | ||
model = DDP(model, device_ids=[config.local_rank]) | ||
return model | ||
|
||
|
||
def backward(step: int, | ||
loss: torch.Tensor, | ||
optimizer: Optimizer, | ||
grad_scaler: GradScaler = None): | ||
if config.bypass_amp: | ||
loss.backward() | ||
elif config.distributed_lamb: | ||
optimizer._lazy_init_stage1() | ||
grad_scaler.scale(loss).backward() | ||
optimizer._lazy_init_stage2() | ||
else: | ||
with amp.scale_loss(loss, | ||
optimizer, | ||
delay_overflow_check=self.config. | ||
allreduce_post_accumulation) as scaled_loss: | ||
scaled_loss.backward() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删除冗余代码 |
||
loss.backward() | ||
|
||
update_step = step % config.gradient_accumulation_steps == 0 | ||
if update_step: | ||
update_model_params(loss, optimizer, grad_scaler) | ||
else: | ||
xm.mark_step() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. eager模式不用xm.mark_step |
||
|
||
|
||
def update_model_params(loss, | ||
optimizer: Optimizer, | ||
grad_scaler: GradScaler = None): | ||
if config.allreduce_post_accumulation and config.use_cuda_graph: | ||
assert False, "code path not tested with cuda graphs" | ||
if config.distributed_lamb: | ||
optimizer.set_global_scale(grad_scaler._get_scale_async()) | ||
optimizer.complete_reductions() | ||
grad_scaler.step(optimizer) | ||
grad_scaler.update() | ||
|
||
found_inf = optimizer._overflow_buf # GPU tensor | ||
|
||
elif config.allreduce_post_accumulation: | ||
# manually allreduce gradients after all accumulation steps | ||
# check for Inf/NaN | ||
# 1. allocate an uninitialized buffer for flattened gradient | ||
# torch.nn.utils.clip_grad_norm_(parameters=amp.master_params(optimizer), max_norm=1.0, norm_type=2.0) | ||
#scaler = _amp_state.loss_scalers[0] | ||
# master_grads = [p.grad for p in amp.master_params(optimizer) if p.grad is not None] | ||
# flat_grad_size = sum(p.numel() for p in master_grads) | ||
# allreduce_dtype = torch.float16 if config.allreduce_post_accumulation_fp16 else torch.float32 | ||
# flat_raw = torch.empty(flat_grad_size, device='cpu', dtype=allreduce_dtype) | ||
#flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype) | ||
# 2. combine unflattening and predivision of unscaled 'raw' gradient | ||
#allreduced_views = apex_C.unflatten(flat_raw, master_grads) | ||
#self.overflow_buf.zero_() | ||
#amp_C.multi_tensor_scale(65536, | ||
# self.overflow_buf, | ||
# [master_grads, allreduced_views], | ||
# scaler.loss_scale() / ( | ||
# torch.distributed.get_world_size() * config.gradient_accumulation_steps)) | ||
# 3. sum gradient across ranks. Because of the predivision, this averages the gradient | ||
#torch.distributed.all_reduce(flat_raw) | ||
# 4. combine unscaling and unflattening of allreduced gradient | ||
#self.overflow_buf.zero_() | ||
#amp_C.multi_tensor_scale(65536, | ||
# self.overflow_buf, | ||
# [allreduced_views, master_grads], | ||
# 1. / scaler.loss_scale()) | ||
# 5. update loss scale | ||
had_overflow = 0 | ||
#scaler = _amp_state.loss_scalers[0] | ||
#old_overflow_buf = scaler._overflow_buf | ||
#scaler._overflow_buf = self.overflow_buf | ||
#had_overflow = scaler.update_scale() | ||
#scaler._overflow_buf = old_overflow_buf | ||
# 6. call optimizer step function | ||
if had_overflow == 0: | ||
# optimizer.step() | ||
xm.optimizer_step(optimizer, barrier=True) | ||
else: | ||
# Overflow detected, print message and clear gradients | ||
if utils.is_main_process(): | ||
print("Overflow detected, reduced loss_scaler to %f" % | ||
(scaler.loss_scale())) | ||
if _amp_state.opt_properties.master_weights: | ||
for param in optimizer._amp_stash.all_fp32_from_fp16_params: | ||
param.grad = None | ||
else: | ||
xm.optimizer_step(optimizer, barrier=True) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删除多余代码 |
||
optimizer.step() | ||
optimizer.zero_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前跑图模式需要这个xm.mark_step,现在是eager模式,不需要了