Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions applications/Chat/coati/ray/detached_trainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ def __init__(self,
self.callbacks = callbacks
self.target_holder_name_list = experience_maker_holder_name_list
self.target_holder_list = []

self._is_target_holder_initialized = False
self._debug = debug

def update_target_holder_list(self, experience_maker_holder_name_list):
self.target_holder_name_list = experience_maker_holder_name_list
self.target_holder_list = []
for name in self.target_holder_name_list:
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
def update_target_holder_list(self):
# as the length of target_holder_list may be zero, we need to check it by a bool flag
if not self._is_target_holder_initialized:
for name in self.target_holder_name_list:
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
self._is_target_holder_initialized = True

@abstractmethod
def _update_remote_makers(self, fully_update: bool = False, **kwargs):
Expand Down
41 changes: 20 additions & 21 deletions applications/Chat/coati/ray/detached_trainer_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_actor_from_args,
get_critic_from_args,
get_model_numel,
get_rank,
get_strategy_from_args,
is_rank_0,
set_dist_env,
Expand Down Expand Up @@ -102,38 +103,36 @@ def __init__(
dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks,
debug=debug)
if self._debug:
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')

@ray.method(concurrency_group="model_io")
@torch.no_grad()
def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties
if is_rank_0():
self.update_target_holder_list(self.target_holder_name_list)
# mark start, ensure order
tasks = []
for target_holder in self.target_holder_list:
tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
ray.get(tasks)
self.update_target_holder_list()
# mark start, ensure order
tasks = []
for target_holder in self.target_holder_list:
tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
ray.get(tasks)
# sending loop
tasks = []
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_model(self.actor), **config):
if is_rank_0():
for target_holder in self.target_holder_list:
tasks.append(
target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard,
fully_update=fully_update))
for target_holder in self.target_holder_list:
tasks.append(
target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard,
fully_update=fully_update))
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config):
if is_rank_0():
for target_holder in self.target_holder_list:
tasks.append(
target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard,
fully_update=fully_update))
ray.get(tasks)
if is_rank_0():
# mark end
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)
tasks.append(
target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard,
fully_update=fully_update))
ray.get(tasks)
# mark end
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)

@ray.method(concurrency_group="compute")
def training_step(self, experience: Experience) -> Dict[str, float]:
Expand Down
57 changes: 13 additions & 44 deletions applications/Chat/coati/ray/experience_maker_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import Tensor
from tqdm import tqdm

from .utils import get_model_numel, is_rank_0, set_dist_env
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env


@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
Expand Down Expand Up @@ -50,8 +50,8 @@ def __init__(
if env_info:
set_dist_env(env_info=env_info)
self.target_trainer_list = []
for name in detached_trainer_name_list:
self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
assert len(detached_trainer_name_list) > 0
self._detached_trainer_name_list = detached_trainer_name_list
self.strategy = strategy_fn()
self.buffer_cpu_offload = buffer_cpu_offload
self.kl_coef = kl_coef
Expand Down Expand Up @@ -81,8 +81,10 @@ def __init__(

self._target_idx = 0

if self._debug and not self._is_fully_initialized:
print('[maker] Waiting for INIT')
if self._debug:
print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}')
if not self._is_fully_initialized:
print(f'[maker{get_rank()}] Waiting for INIT')

def _get_ready(self):
while not self._fully_initialized():
Expand All @@ -91,10 +93,11 @@ def _get_ready(self):
def _fully_initialized(self):
return self._is_fully_initialized

def update_target_trainer_list(self, detached_trainer_name_list):
self.target_trainer_list = []
for name in detached_trainer_name_list:
self.target_trainer_list.append(ray.get_actor(name))
def _init_target_trainer_list(self):
if len(self.target_trainer_list) > 0:
return
for name in self._detached_trainer_name_list:
self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))

# copy from ../trainer/base.py
@ray.method(concurrency_group="compute")
Expand All @@ -106,43 +109,9 @@ def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experien
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')

# TODO(ver217): remove this method
@ray.method(concurrency_group="experience_io")
def _send_experience(self, experience):
if not self.target_auto_balance:
# choose the trainer in polling mannar
if not hasattr(self, "_target_idx"):
self._target_idx = 0
chosen_trainer = self.target_trainer_list[self._target_idx]
if self._debug:
print(f"[maker] sending exp to {chosen_trainer}")
chosen_trainer.buffer_append.remote(experience)
self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
else:
# choose a trainer that has the least experience batch in its detached_replay_buffer
chosen_trainer = None
min_length = None
if self._debug:
print("[maker] choosing tartget trainer")
while chosen_trainer is None:
for target_trainer in self.target_trainer_list:
try:
temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1)
if min_length is None:
min_length = temp_length
chosen_trainer = target_trainer
else:
if temp_length < min_length:
min_length = temp_length
chosen_trainer = target_trainer
except GetTimeoutError:
pass
if self._debug:
print(f"[maker] sending exp to {chosen_trainer}")
chosen_trainer.buffer_append.remote(experience)

@ray.method(concurrency_group="experience_io")
def _send_items(self, experience: Experience) -> None:
self._init_target_trainer_list()
items = split_experience_batch(experience)
items_per_trainer = [[] for _ in range(len(self.target_trainer_list))]
for item in items:
Expand Down
27 changes: 25 additions & 2 deletions applications/Chat/coati/ray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0


def get_rank() -> int:
return dist.get_rank() if dist.is_initialized() else 0


def get_world_size() -> int:
return dist.get_world_size() if dist.is_initialized() else 1


def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == 'gpt2':
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
Expand Down Expand Up @@ -76,9 +84,9 @@ def get_strategy_from_args(strategy: str):
elif strategy == 'colossalai_zero2':
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda')
elif strategy == 'colossalai_gemini_cpu':
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
elif strategy == 'colossalai_zero2_cpu':
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cpu')
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
return strategy_
Expand Down Expand Up @@ -126,3 +134,18 @@ def state_dict_to(state_dict: Dict[str, Any],
def get_model_numel(model: nn.Module) -> int:
numel = sum(p.numel() for p in model.parameters())
return numel


def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
target_receivers = []
if num_senders <= num_receivers or allow_idle_sender:
# a sender will send data to one or more than one receivers
# a receiver only has one sender
for i in range(num_receivers):
if i % num_senders == sender_idx:
target_receivers.append(i)
else:
# a sender will send data to one receiver
# a receiver may have more than one sender
target_receivers.append(sender_idx % num_receivers)
return target_receivers
2 changes: 1 addition & 1 deletion applications/Chat/coati/trainer/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,4 @@ def get_model_state_dict_shard(self, model: nn.Module, **config):
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()
yield from model.state_dict_shard(max_shard_size=1024)
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
31 changes: 11 additions & 20 deletions applications/Chat/coati/trainer/strategies/ddp.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Optional

import os
import random
from typing import Optional

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.models.base import LM, Actor, RewardModel, Critic
from coati.models.base import LM, Actor, Critic, RewardModel
from coati.models.lora import LoraLinear
from coati.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
Expand All @@ -30,19 +29,8 @@ def __init__(self, seed: int = 42) -> None:
super().__init__()

def setup_distributed(self) -> None:
try:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
self._try_init_dist(force=True)
self.set_seed(self.seed)
torch.cuda.set_device(local_rank)

def set_seed(self, seed: int) -> None:
random.seed(seed)
Expand Down Expand Up @@ -74,21 +62,25 @@ def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False
def _unwrap_actor(actor: Actor) -> nn.Module:
model: DDP = Strategy._unwrap_actor(actor)
return model.module

@staticmethod
def _unwrap_critic(critic: Critic) -> nn.Module:
model: DDP = Strategy._unwrap_critic(critic)
return model.module

def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return None

for module in model.modules():
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()

if isinstance(model, RewardModel):
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
Expand All @@ -114,4 +106,3 @@ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = Fal

def setup_sampler(self, dataset) -> DistributedSampler:
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())

36 changes: 30 additions & 6 deletions applications/Chat/coati/trainer/strategies/naive.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from typing import Any, Optional

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from coati.replay_buffer import ReplayBuffer
from coati.models.base import LM, RewardModel
from coati.models.lora import LoraLinear
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
Expand All @@ -25,7 +27,7 @@ def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.step()

def setup_distributed(self) -> None:
pass
self._try_init_dist(force=False)

def setup_model(self, model: nn.Module) -> nn.Module:
return model
Expand All @@ -41,12 +43,16 @@ def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)

def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
for module in model.modules():
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()

if isinstance(model, RewardModel):
state_dict = model.state_dict()
torch.save(state_dict, path)
Expand Down Expand Up @@ -77,10 +83,28 @@ def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
state_dict = model.state_dict()
yield state_dict

def merge_lora_weight(self, model: nn.Module):
unwrapped_model = self._unwrap_model(model)
for module in unwrapped_model.modules():
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()
module.eval()

def _try_init_dist(self, force: bool = False) -> None:
try:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
torch.cuda.set_device(local_rank)
except KeyError as e:
if force:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
except Exception as e:
if force:
raise e
Loading