diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index b328f2c..43df77b 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -60,7 +60,7 @@ class Config(BaseConfig): per_device_train_batch_size: int = 32 warmup_steps: int = 1000 total_steps: int = 88_000 - sharding_strategy: str = "SHARD_GRAD_OP" + sharding_strategy: str = "FULL_SHARD" def get_dataloader(tokenizer, world_size, rank, local_rank, config: Config) -> StatefulDataLoader: @@ -82,6 +82,10 @@ def get_model(config: Config) -> LlamaForCausalLM: return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) +def get_offloaded_param(model: LlamaForCausalLM) -> list[torch.Tensor]: + return [param.data.detach().clone().to("cuda") for param in model.parameters()] + + def train(config: Config): sharding_strategy = get_sharding_strategy(config.sharding_strategy) local_rank = int(os.environ["LOCAL_RANK"]) @@ -123,7 +127,11 @@ def train(config: Config): # Setup optimizers inner_optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=0.1, betas=(0.9, 0.95)) - # outer_optimizer = torch.optim.SGD(model.parameters(), lr=config.diloco.outer_lr, momentum=0.9, nesterov=True) + + cpu_model = get_offloaded_param( + model + ) # todo: in case of sharded grap op we need to offload the cpu model only once per nodes + outer_optimizer = torch.optim.SGD(cpu_model, lr=config.diloco.outer_lr, momentum=0.9, nesterov=True) scheduler = get_cosine_schedule_with_warmup( inner_optimizer, @@ -152,6 +160,7 @@ def train(config: Config): with model.no_sync() if is_accumulating else nullcontext(): outputs = model(**batch) loss = outputs.loss / gradient_accumulation_steps + loss.backward() loss_batch += loss.detach() model.clip_grad_norm_(1.0) # gradient clipping @@ -166,8 +175,22 @@ def train(config: Config): loss_batch = 0 - for param in model.parameters(): # todo make this like hybrid shard is doing - dist.all_reduce(param, op=dist.ReduceOp.AVG, group=global_pg) + ### the whole sectione below is just a PoC. We need to benchmark and optimizer what is the most efficient: + ## do the all reduce on cpu or on gpu + ## do the outer optimizer step on cpu or on gpu + + for param_offloaded, param in zip( + cpu_model, model.parameters() + ): # There is only one big fat tensor in the param because of fsdp 1 bucket stuff + # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices + param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg) + + outer_optimizer.step() + outer_optimizer.zero_grad() + + for param_offloaded, param in zip(cpu_model, model.parameters()): + param.data = param_offloaded.data.to("cuda") outer_step += 1