Skip to content

Commit 23cd5e2

Browse files
authored
[chatgpt]update ci (#3087)
* [chatgpt]update ci * Update test_ci.sh * Update test_ci.sh * Update test_ci.sh * test * Update train_prompts.py * Update train_dummy.py * add save_path * polish * add save path * polish * add save path * polish * delete bloom-560m test delete bloom-560m test because of oom * add ddp test
1 parent 169ed4d commit 23cd5e2

File tree

3 files changed

+56
-8
lines changed

3 files changed

+56
-8
lines changed

applications/ChatGPT/examples/test_ci.sh

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,57 @@ export OMP_NUM_THREADS=8
1515
pip install -r ${BASE}/requirements.txt
1616

1717
# train dummy
18-
for strategy in ddp colossalai_gemini colossalai_zero2; do
19-
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py --strategy ${strategy} --num_episodes 2 --max_timesteps 3 --update_timesteps 3 --max_epochs 3 --experience_batch_size 4 --train_batch_size 4
20-
done
18+
python ${BASE}/train_dummy.py --strategy naive --num_episodes 1 \
19+
--max_timesteps 2 --update_timesteps 2 \
20+
--max_epochs 1 --train_batch_size 2 --lora_rank 4
21+
22+
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
23+
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
24+
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
25+
--pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
26+
--save_path ${BASE}/actor_checkpoint_dummy.pt
27+
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'facebook/opt-350m' --model opt
28+
29+
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
30+
--strategy ddp --num_episodes 1 --max_timesteps 2 \
31+
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
32+
--pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
33+
--save_path ${BASE}/actor_checkpoint_dummy.pt
34+
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'facebook/opt-350m' --model opt
35+
36+
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
37+
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
38+
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
39+
--pretrain 'gpt2' --model gpt2 --lora_rank 4\
40+
--save_path ${BASE}/actor_checkpoint_dummy.pt
41+
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'gpt2' --model gpt2
42+
43+
rm -rf ${BASE}/actor_checkpoint_dummy.pt
2144

2245
# train prompts
23-
for strategy in ddp colossalai_gemini colossalai_zero2; do
24-
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH --strategy ${strategy} --num_episodes 2 --max_timesteps 3 --update_timesteps 3 --max_epochs 3
25-
done
46+
python ${BASE}/train_prompts.py $PROMPT_PATH --strategy naive --num_episodes 1 \
47+
--max_timesteps 2 --update_timesteps 2 \
48+
--max_epochs 1 --train_batch_size 2 --lora_rank 4
49+
50+
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
51+
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
52+
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
53+
--pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
54+
--save_path ${BASE}/actor_checkpoint_prompts.pt
55+
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'facebook/opt-350m' --model opt
56+
57+
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
58+
--strategy ddp --num_episodes 1 --max_timesteps 2 \
59+
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
60+
--pretrain 'gpt2' --model gpt2 --lora_rank 4\
61+
--save_path ${BASE}/actor_checkpoint_prompts.pt
62+
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
63+
64+
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
65+
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
66+
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
67+
--pretrain 'gpt2' --model gpt2 --lora_rank 4\
68+
--save_path ${BASE}/actor_checkpoint_prompts.pt
69+
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
70+
71+
rm -rf ${BASE}/actor_checkpoint_prompts.pt

applications/ChatGPT/examples/train_dummy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def main(args):
115115
update_timesteps=args.update_timesteps)
116116

117117
# save model checkpoint after fitting
118-
strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True)
118+
strategy.save_model(actor, args.save_path, only_rank0=True)
119119
# save optimizer checkpoint on all ranks
120120
if args.need_optim_ckpt:
121121
strategy.save_optimizer(actor_optim,
@@ -130,6 +130,7 @@ def main(args):
130130
default='naive')
131131
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt'])
132132
parser.add_argument('--pretrain', type=str, default=None)
133+
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt')
133134
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
134135
parser.add_argument('--num_episodes', type=int, default=50)
135136
parser.add_argument('--max_timesteps', type=int, default=10)

applications/ChatGPT/examples/train_prompts.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def tokenize_fn(texts):
102102
max_timesteps=args.max_timesteps,
103103
update_timesteps=args.update_timesteps)
104104
# save model checkpoint after fitting
105-
strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True)
105+
strategy.save_model(actor, args.save_path, only_rank0=True)
106106
# save optimizer checkpoint on all ranks
107107
if args.need_optim_ckpt:
108108
strategy.save_optimizer(actor_optim,
@@ -118,6 +118,7 @@ def tokenize_fn(texts):
118118
default='naive')
119119
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
120120
parser.add_argument('--pretrain', type=str, default=None)
121+
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
121122
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
122123
parser.add_argument('--num_episodes', type=int, default=10)
123124
parser.add_argument('--max_timesteps', type=int, default=10)

0 commit comments

Comments
 (0)