Skip to content

Commit

Permalink
Merge branch 'master' into CPUAdam_on_PowerPC
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Feb 28, 2023
2 parents dad8edb + dc01cee commit d42b8e8
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 21 deletions.
4 changes: 2 additions & 2 deletions deepspeed/module_inject/containers/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
transformer_param_names[i],
prefix + param_names[i],
qkv=True,
megatron_v2=self.is_megatron_v2,
split_qkv=self.split_qkv)
megatron_v2=self.policy.is_megatron_v2,
split_qkv=self.policy.split_qkv)
for i in range(2, 4):
maybe_copy(module.attention,
sd,
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/containers/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
[prefix + param_names[0],
prefix + param_names[1],
prefix + param_names[2]],
split_qkv=self.split_qkv)
split_qkv=self.policy.split_qkv)
for i in range(3, 4):
maybe_copy(module.attention,
sd,
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/containers/gptneo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
[prefix + param_names[0],
prefix + param_names[1],
prefix + param_names[2]],
split_qkv=self.split_qkv)
split_qkv=self.policy.split_qkv)
for i in range(3, 5):
maybe_copy(module.attention,
sd,
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/module_inject/containers/gptneox.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
transformer_param_names[i],
prefix + param_names[i],
qkv=True,
megatron_v2=self.is_megatron_v2,
split_qkv=self.split_qkv,
heads=self.client_module.attention.num_attention_heads)
megatron_v2=self.policy.is_megatron_v2,
split_qkv=self.policy.split_qkv,
heads=self.policy.client_module.attention.num_attention_heads)
for i in range(2, 4):
maybe_copy(module.attention,
sd,
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/containers/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
prefix + param_names[i + 1],
prefix + param_names[i + 2]
],
split_qkv=self.split_qkv)
split_qkv=self.policy.split_qkv)
for i in range(6, 8):
maybe_copy(module.attention,
sd,
Expand Down
10 changes: 3 additions & 7 deletions deepspeed/module_inject/load_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def load_model_with_checkpoint(r_module,
ckpt_mp_size,
weight_quantizer=None,
rank=0,
replace_policy=None):
container=None):
error_msgs = []

def transpose(data):
Expand Down Expand Up @@ -199,11 +199,7 @@ def load_parameters(module, prefix):
for n, child in module.named_children():
load_parameters(child, prefix + n + '.')
else:
replace_policy.load_params(module,
sd[0],
weight_quantizer,
mp_replace,
prefix)
container.load_params(module, sd[0], weight_quantizer, mp_replace, prefix)

try:
import transformers
Expand Down Expand Up @@ -274,7 +270,7 @@ def load_module_recursive(module, prefix='', level=0):
else:
load_module_recursive(
child,
prefix if (level == 0 and ckpt_type == 'pp') and replace_policy.use_load_prefix else \
prefix if (level == 0 and ckpt_type == 'pp') and container.policy.use_load_prefix else \
prefix + name + '.',
level + 1)

Expand Down
6 changes: 3 additions & 3 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def replace_fn(child, _policy, layer_id=0):
ckpt_type,
ckpt_mp_size,
quantizer,
replace_policy=container_g.policy)
container=container_g)
pbar.update(1)
else:
import gc
Expand Down Expand Up @@ -597,7 +597,7 @@ def replace_fn(child, _policy, layer_id=0):
ckpt_mp_size,
quantizer,
int(rank % tp_split_size),
replace_policy=container_g.policy)
container=container_g)
sds = [None for _ in sds]
gc.collect()

Expand All @@ -619,7 +619,7 @@ def replace_fn(child, _policy, layer_id=0):
ckpt_mp_size,
quantizer,
int(rank % tp_split_size),
replace_policy=container_g.policy)
container=container_g)
sds = [None for _ in sds]
gc.collect()
print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")
Expand Down
11 changes: 8 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,9 +1673,6 @@ def deepspeed_io(self,
or self.is_iterable_style_dataset(dataset)):
raise ValueError("Training data must be a torch Dataset")

if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL):
data_sampler = torch.utils.data.SequentialSampler(dataset)

if batch_size is None:
batch_size = self.train_micro_batch_size_per_gpu()

Expand All @@ -1694,6 +1691,14 @@ def deepspeed_io(self,
data_parallel_world_size = self.mpu.get_data_parallel_world_size()
data_parallel_rank = self.mpu.get_data_parallel_rank()

if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL):
data_sampler = torch.utils.data.DistributedSampler(
dataset,
num_replicas=data_parallel_world_size,
rank=data_parallel_rank,
shuffle=False,
)

deepspeed_dataloader_config = {}
if self.curriculum_learning_enabled():
deepspeed_dataloader_config = {
Expand Down

0 comments on commit d42b8e8

Please sign in to comment.