Skip to content

Commit 2763b5d

Browse files
920232796marscrazy
andauthored
Opt 30b (#16)
* clean codes Co-authored-by: Zac Liu <liuguang@baai.ac.cn>
1 parent bcf24a7 commit 2763b5d

File tree

5 files changed

+18
-65
lines changed

5 files changed

+18
-65
lines changed

examples/opt/generate_opt_30b.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from flagai.model.predictor.predictor import Predictor
22
from flagai.auto_model.auto_loader import AutoLoader
33
import torch
4-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54

65
loader = AutoLoader(task_name="lm",
76
model_name="opt-30b-en")
87

98
model = loader.get_model()
109
tokenizer = loader.get_tokenizer()
1110
model.eval()
12-
model.to(device)
1311

1412
text = "The trophy doesn’t fit in the suitcase because "
1513
predictor = Predictor(model, tokenizer)

examples/opt/opt_30b_en_mutigpu.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,2"
1+
22
import torch
33
import os
44
import argparse
@@ -7,8 +7,11 @@
77
import random
88
import numpy as np
99
from flagai.model.predictor.predictor import Predictor
10+
import glob
11+
import time
12+
13+
# run script : python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 opt_30b_en_mutigpu.py
1014

11-
# run script : python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 glm_blank_filling_QA_ch_mutigpu.py
1215
os.environ["ENV_TYPE"] = "deepspeed+mpu"
1316
model_parallel_size = 4
1417
world_size = 4
@@ -58,11 +61,15 @@ def initialize_distributed():
5861

5962
set_random_seed(123)
6063

61-
loader = AutoLoader("lm", model_name="opt-350m-en")
64+
65+
print(f"building model...")
66+
loader = AutoLoader("lm", model_name="opt-30b-en")
6267
model = loader.get_model()
63-
model.half()
6468
tokenizer = loader.get_tokenizer()
65-
# model.parallel_output = False
69+
model.half()
70+
71+
model.parallel_output = False
72+
6673
model.eval()
6774
model.to(device)
6875

@@ -75,4 +82,3 @@ def initialize_distributed():
7582
if mpu.get_model_parallel_rank() == 0:
7683
print(f"pred is {out}")
7784

78-

flagai/model/gpt2_model.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111
import torch.nn.functional as F
1212

1313
if os.getenv('ENV_TYPE') == 'deepspeed+mpu':
14-
from flagai.mpu import get_model_parallel_world_size
15-
from flagai.mpu import get_cuda_rng_tracker
1614
from flagai.mpu.utils import divide
17-
if os.getenv('ENV_TYPE') == 'deepspeed+mpu':
1815
from flagai.mpu.random import checkpoint
19-
from flagai.mpu import copy_to_model_parallel_region, gather_from_model_parallel_region
16+
from flagai.mpu import copy_to_model_parallel_region, gather_from_model_parallel_region, get_model_parallel_world_size, get_cuda_rng_tracker
2017
from flagai.mpu.cross_entropy import vocab_parallel_cross_entropy
2118

2219
elif os.getenv('ENV_TYPE') == 'deepspeed':
@@ -321,19 +318,6 @@ def forward(
321318
None,
322319
}
323320

324-
# lm_logits = self.lm_head(hidden_states)
325-
# return_data = {"logits": lm_logits}
326-
# if labels is not None:
327-
# # Shift so that tokens < n predict n
328-
# shift_logits = lm_logits[..., :-1, :].contiguous()
329-
# shift_labels = labels[..., 1:].contiguous()
330-
# loss_fct = nn.CrossEntropyLoss()
331-
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
332-
# shift_labels.view(-1))
333-
# return_data["loss"] = loss
334-
335-
# return return_data
336-
337321
def load_weights(self, checkpoint_path):
338322
checkpoint = torch.load(checkpoint_path,
339323
map_location=torch.device("cpu"))

flagai/model/opt_model.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -265,43 +265,6 @@ def __init__(self, config, **kwargs):
265265
# self.config = config
266266
self.transformer = OPTStack(self.config)
267267

268-
# def forward(
269-
# self,
270-
# **data,
271-
# ):
272-
# input_ids = data.get("input_ids", None)
273-
# # attention_mask = data.get("attention_mask", None)
274-
# # position_ids = data.get("position_ids", None)
275-
# labels = data.get("labels", None)
276-
# use_cache = data.get("use_cache", None)
277-
# output_attentions = data.get("output_attentions", None)
278-
# output_hidden_states = data.get("output_hidden_states", True)
279-
#
280-
# transformer_outputs = self.transformer(
281-
# input_ids,
282-
# attention_mask=None,
283-
# position_ids=None,
284-
# use_cache=use_cache,
285-
# output_attentions=output_attentions,
286-
# output_hidden_states=output_hidden_states,
287-
# )
288-
# hidden_states = transformer_outputs
289-
#
290-
# lm_logits = self.lm_head(hidden_states)
291-
#
292-
# return_data = {"logits": lm_logits}
293-
# if labels is not None:
294-
# # Shift so that tokens < n predict n
295-
# shift_logits = lm_logits[..., :-1, :].contiguous()
296-
# shift_labels = labels[..., 1:].contiguous()
297-
# loss_fct = nn.CrossEntropyLoss()
298-
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
299-
# shift_labels.view(-1))
300-
# return_data["loss"] = loss
301-
#
302-
# return return_data
303-
304-
305268
def load_weights(self, checkpoint_path):
306269
checkpoint = torch.load(checkpoint_path,
307270
map_location=torch.device("cpu"))

flagai/mp_tools.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,6 @@ def change_pytorch_model_mp_from_1_to_n_new(model_name_brief, checkpoint: str, t
219219
start = ratio * i
220220
end = ratio * (i + 1)
221221
d = torch.load(filenames[i], map_location='cpu')
222-
if d.get("module", None) is None:
223-
d["module"] = d
224222

225223
for j in range(start, end):
226224
d_new = {}
@@ -235,7 +233,11 @@ def change_pytorch_model_mp_from_1_to_n_new(model_name_brief, checkpoint: str, t
235233
d_new[k] = None
236234
d_new['module'] = {}
237235
with torch.no_grad():
238-
for k, v in d['module'].items():
236+
237+
if "module" in d:
238+
d = d["module"]
239+
240+
for k, v in d.items():
239241
assert len(v.shape) < 3
240242
flag = 0
241243
for keys in trans_keys:

0 commit comments

Comments
 (0)