Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
994b40c
run the base
Mar 16, 2023
0390f6e
working on dist ppo
Mar 16, 2023
c1df61b
sync
Mar 20, 2023
32837c3
Merge remote-tracking branch 'upstream/main' into chatgpt_dist_ppo
Mar 20, 2023
518f837
detached trainer
Mar 20, 2023
b707ba2
update detached trainer. no maker update function
Mar 20, 2023
1311924
facing init problem
Mar 21, 2023
29976fa
1 maker 1 trainer detached run. but no model update
Mar 21, 2023
ea4761a
Merge branch 'hpcaitech:main' into detached_ppo
CsRic Mar 21, 2023
523e209
facing cuda problem
Mar 22, 2023
45361c2
fix save functions
Mar 22, 2023
886cc98
Merge branch 'detached_ppo' of https://github.com/CsRic/ColossalAI in…
Mar 22, 2023
42aa4c7
verified maker update
Mar 22, 2023
26d82b5
nothing
Mar 23, 2023
517ff22
add ignore
Mar 23, 2023
dc91d58
Merge remote-tracking branch 'upstream/main' into detached_ppo
Mar 23, 2023
b91348d
analyize loss issue
Mar 24, 2023
b882fdd
fix detached ppo loss issue
Mar 24, 2023
ebd2be9
remove some debug codes
Mar 24, 2023
650ec5b
facing 2m1t stuck issue
Mar 27, 2023
b40974b
Merge remote-tracking branch 'upstream/main' into detached_ppo
Mar 27, 2023
f791fb7
2m1t verified
Mar 27, 2023
f468724
do not use torchrun
Mar 27, 2023
12b94f7
working on 2m2t
Mar 27, 2023
05df7d7
working on 2m2t
Mar 27, 2023
0773697
initialize strategy in ray actor env
Mar 28, 2023
9451a54
facing actor's init order issue
Mar 28, 2023
9626518
facing ddp model update issue (need unwarp ddp)
Mar 28, 2023
d637032
unwrap ddp actor
Mar 29, 2023
7dadc80
checking 1m2t stuck problem
Mar 29, 2023
09f611d
nothing
Mar 29, 2023
459639c
set timeout for trainer choosing. It solves the stuck problem!
Mar 29, 2023
65363e1
delete some debug output
Mar 29, 2023
2f8036b
rename to sync with upstream
Mar 30, 2023
7e5c8f2
rename to sync with upstream
Mar 30, 2023
c36d58a
merge upstream
Mar 30, 2023
c0649c3
coati rename
Mar 30, 2023
db29760
Merge branch 'hpcaitech:main' into detached_ppo
CsRic Mar 31, 2023
3c6f68c
nothing
Mar 31, 2023
334956a
merge
Mar 31, 2023
35e4602
I am going to detach the replaybuffer from trainer and make it a Ray …
Apr 3, 2023
04069cd
experience_maker_holder performs target-revolving _send_experience() …
Apr 4, 2023
117f08c
Merge 'upstream/main'
Apr 4, 2023
b0a002e
Merge remote-tracking branch 'upstream/main' into detached_ppo
Apr 12, 2023
9051002
Merge remote-tracking branch 'upstream/main' into detached_ppo
Apr 13, 2023
3a4d0e7
move code to ray subfolder
Apr 13, 2023
95a6c72
working on pipeline inference
Apr 13, 2023
19fab46
apply comments
Apr 14, 2023
9937c1f
Merge remote-tracking branch 'upstream/main' into detached_ppo
Apr 14, 2023
0980832
working on pipeline strategy. in progress.
Apr 14, 2023
8ac2775
remove pipeline code. clean this branch
Apr 17, 2023
48792f4
update remote parameters by state_dict. no test
Apr 17, 2023
342fd59
nothing
Apr 18, 2023
1a2786c
Merge remote-tracking branch
Apr 18, 2023
ebc6275
state_dict sharding transfer
Apr 18, 2023
1ce0182
Merge remote-tracking branch 'upstream/main' into detached_ppo
Apr 18, 2023
0573521
Merge remote-tracking branch 'upstream/main' into detached_ppo_debug
Apr 19, 2023
59dc0db
merge debug branch
Apr 19, 2023
df451b2
gemini _unwrap_model fix
Apr 19, 2023
744a011
simplify code
Apr 19, 2023
3a2eb83
simplify code & fix LoRALinear AttributeError
Apr 20, 2023
056a887
Merge remote-tracking branch 'upstream/main' into detached_ppo
Apr 20, 2023
a252e9b
critic unwrapped state_dict
Apr 20, 2023
1c5ac0d
Merge remote-tracking branch 'ver217/dev/chat-ray' into detached_ppo
Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion applications/Chat/coati/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,13 @@ def T(w):
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
# csric: temporary fix
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
self.reset_parameters()
else:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False

def eval(self):
Expand Down
147 changes: 147 additions & 0 deletions applications/Chat/coati/ray/example/1m1t_quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import argparse
import pandas as pd
import torch
import ray
import os
import socket

from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer

from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]


def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]


def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainer = {'local_rank' : '0',
'rank' : '0',
'world_size' : '1',
'master_port' : trainer_port,
'master_addr' : master_addr}

# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {'local_rank' : '0',
'rank' : '0',
'world_size' : '1',
'master_port' : maker_port,
'master_addr' : master_addr}

# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')

# configure Trainer
trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1"],
strategy=args.trainer_strategy,
model=args.model,
env_info = env_info_trainer,
pretrained=args.pretrain,
lora_rank=args.lora_rank,
train_batch_size=args.train_batch_size,
buffer_limit=16,
experience_batch_size=args.experience_batch_size,
max_epochs=args.max_epochs,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)

# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=["trainer1"],
strategy=args.maker_strategy,
env_info = env_info_maker,
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)

# a 'jump wire' to set quantized initial_model and reward_model


# trainer send its actor and critic to experience holders.
# ray.get(trainer_ref.initialize_remote_makers.remote())

# configure sampler
dataset = pd.read_csv(args.prompt_path)['prompt']

def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
return {k: v.cuda() for k, v in batch.items()}

trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance
maker_done_ref = experience_holder_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)

ray.get([trainer_done_ref, maker_done_ref])

# save model checkpoint after fitting
trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('prompt_path')
parser.add_argument('--trainer_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--maker_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")

parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"])
main(args)
17 changes: 11 additions & 6 deletions applications/Chat/coati/ray/src/detached_trainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class DetachedTrainer(ABC):
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating

'''

def __init__(self,
Expand All @@ -45,6 +46,11 @@ def __init__(self,
self.generate_kwargs = generate_kwargs
self.target_holder_name_list = experience_maker_holder_name_list
self.target_holder_list = []

if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
self._debug = True
else:
self._debug = False

def update_target_holder_list(self, experience_maker_holder_name_list):
self.target_holder_name_list = experience_maker_holder_name_list
Expand All @@ -63,13 +69,13 @@ def training_step(self, experience: Experience) -> Dict[str, Any]:
def _learn(self):
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for _ in pbar:
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
if self._debug:
print("[trainer] sampling exp")
experience = self._buffer_sample()
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
if self._debug:
print("[trainer] training step")
metrics = self.training_step(experience)
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
if self._debug:
print("[trainer] step over")
pbar.set_postfix(metrics)

Expand All @@ -88,15 +94,14 @@ def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timest
@ray.method(concurrency_group="buffer_length")
def buffer_get_length(self):
# called by ExperienceMakerHolder
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
if self._debug:
print("[trainer] telling length")
return self.detached_replay_buffer.get_length()

@ray.method(concurrency_group="buffer_append")
def buffer_append(self, experience: Experience):
# called by ExperienceMakerHolder
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
# print(f"[trainer] receiving exp. Current buffer length: {self.detached_replay_buffer.get_length()}")
if self._debug:
print(f"[trainer] receiving exp.")
self.detached_replay_buffer.append(experience)

Expand Down
110 changes: 87 additions & 23 deletions applications/Chat/coati/ray/src/detached_trainer_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
import ray


from .utils import is_rank_0, get_cuda_actor_critic_from_args, get_strategy_from_args, set_dist_env
from .utils import is_rank_0, get_actor_from_args, get_critic_from_args, get_strategy_from_args, set_dist_env, \
state_dict_to

from .detached_trainer_base import DetachedTrainer


@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append":1, "buffer_sample":1,"model_io": 1, "compute": 1})
@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1})
class DetachedPPOTrainer(DetachedTrainer):
'''
Detached Trainer for PPO algorithm
Expand All @@ -44,9 +46,12 @@ def __init__(self,
experience_maker_holder_name_list: List[str],
strategy: str,
model: str,
env_info: Dict[str, str] = None,
pretrained: str = None,
lora_rank: int = 0,
cr_model: str = None, # if not None, use below cr settings for critic
cr_pretrained: str = None,
cr_lora_rank: int = 0,
env_info: Dict[str, str] = None,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
Expand All @@ -63,24 +68,32 @@ def __init__(self,
# configure strategy
self.strategy = get_strategy_from_args(strategy)
# configure models, loss and optimizers
if cr_model is None:
cr_model = model
cr_pretrained = pretrained
cr_lora_rank = lora_rank

with self.strategy.model_init_context():
self.actor, self.critic = get_cuda_actor_critic_from_args(model, pretrained, lora_rank)
self.actor = get_actor_from_args(model, pretrained, lora_rank)
self.critic = get_critic_from_args(cr_model, cr_pretrained, cr_lora_rank)

if strategy != 'colossalai_gemini':
self.actor.to(torch.float16).to(torch.cuda.current_device())
self.critic.to(torch.float16).to(torch.cuda.current_device())
self.actor.to(torch.cuda.current_device()) #.to(torch.float16)
self.critic.to(torch.cuda.current_device()) #.to(torch.float16)


if strategy.startswith('colossalai'):
self.actor_optim = HybridAdam(self.actor.parameters(), lr=5e-6)
self.critic_optim = HybridAdam(self.critic.parameters(), lr=5e-6)
self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
else:
self.actor_optim = Adam(self.actor.parameters(), lr=5e-6)
self.critic_optim = Adam(self.critic.parameters(), lr=5e-6)
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)

(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor)

# configure trainer
generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor)
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)

Expand All @@ -94,25 +107,69 @@ def __init__(self,
callbacks=callbacks,
**generate_kwargs)

# for remote maker initialization
self._model_str = model
self._cr_model_str = cr_model
self._pretrained = pretrained
self._cr_pretrained = cr_pretrained

@ray.method(concurrency_group="model_io")
def _update_remote_makers(self):
def _update_remote_makers(self, **config):
# TODO: balance duties
if is_rank_0():
self.update_target_holder_list(self.target_holder_name_list)
for target_holder in self.target_holder_list:
# TODO: reduce malloc
with torch.no_grad():
ray.get(target_holder.update_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))

with torch.no_grad():
# actor:
# mark start
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_start=True)
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_actor(self.actor), **config):
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(new_actor_state_dict = state_dict_shard)
# mark end
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_end=True)
# critic
# mark start
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_start=True)
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config):
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(new_critic_state_dict = state_dict_shard)
# mark end
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_end=True)

@ray.method(concurrency_group="model_io")
def initialize_remote_makers(self):
def initialize_remote_makers(self, **config):
# TODO: balance duties
if is_rank_0():
self.update_target_holder_list(self.target_holder_name_list)
for target_holder in self.target_holder_list:
# TODO: reduce malloc
with torch.no_grad():
ray.get(target_holder.initialize_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))
with torch.no_grad():
# actor / initial_model:
# mark start
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(actor_model=self._model_str,actor_pretrained=self._pretrained,chunk_start=True)
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_actor(self.actor), **config):
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(actor_state_dict=state_dict_shard)
# mark end
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(actor_model=self._model_str, chunk_end=True)
# critic / reward_model:
# mark start
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str,critic_pretrained=self._cr_pretrained,chunk_start=True)
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config):
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(critic_state_dict=state_dict_shard)
# mark end
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str, chunk_end=True)

@ray.method(concurrency_group="compute")
def training_step(self, experience: Experience) -> Dict[str, float]:
Expand Down Expand Up @@ -177,6 +234,14 @@ def _get_unwrapped_critic(self):
elif isinstance(self.strategy, NaiveStrategy):
return self.critic

def _get_model_state_dict_shard(self, model: torch.nn.Module, **config):
try:
self.strategy.merge_lora_weight(model)
except AttributeError:
pass
for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
yield state_dict_to(state_dict)


def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor)
Expand All @@ -189,4 +254,3 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn

return new_kwargs

Loading