Skip to content

Commit 5af6463

Browse files
authored
Add InternalStorage and add ShardingOptimizerStage2 (#37489)
1 parent 8bb1038 commit 5af6463

File tree

8 files changed

+785
-11
lines changed

8 files changed

+785
-11
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
from .hybrid_parallel_optimizer import HybridParallelOptimizer
1414
from .hybrid_parallel_gradscaler import HybridParallelGradScaler
15-
from .dygraph_sharding_optimizer import DygraphShardingOptimizer
15+
# from .dygraph_sharding_optimizer import DygraphShardingOptimizer
1616

1717
__all__ = []

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

Lines changed: 266 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -11,16 +11,35 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
#Taken and modified for fairscale from:
15+
# https://github.com/facebookresearch/fairscale/blob/main/fairscale/optim/oss.py
16+
#Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e
1417

15-
######
18+
import numpy as np
19+
from itertools import chain
1620
from functools import reduce
21+
from collections import OrderedDict
1722

1823
import paddle
1924
from paddle import framework
25+
import paddle.distributed as dist
26+
from paddle.optimizer import Optimizer
27+
2028
from ...utils.log_util import logger
29+
from ...utils.internal_storage import ParamStorage
30+
from ...meta_parallel.sharding.sharding_utils import Type
31+
32+
# CUDA alignment 256 bytes
33+
alignment = {"gpu": 256, }
34+
align = {
35+
Type.fp16.value: 2,
36+
Type.fp32.value: 4,
37+
}
38+
39+
__all__ = ["ShardingOptimizerStage2"]
2140

2241

23-
def _is_trainable(param: paddle.Tensor) -> bool:
42+
def _is_trainable(param):
2443
return not param.stop_gradient
2544

2645

@@ -41,13 +60,8 @@ class DygraphShardingOptimizer(object):
4160
# 3. dynamic trainable params, which is the case bewteen pretraining and finetuning
4261
# 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm
4362

44-
def __init__(
45-
self,
46-
hcg,
47-
user_defined_strategy,
48-
params,
49-
inner_optimizer_class,
50-
**inner_optimizer_kargs, ):
63+
def __init__(self, hcg, user_defined_strategy, params,
64+
inner_optimizer_class, **inner_optimizer_kargs):
5165

5266
if not isinstance(params, list):
5367
raise TypeError(
@@ -196,3 +210,245 @@ def _grad_clip(self):
196210

197211
def __getattr__(self, item):
198212
return getattr(self._inner_optimizer, item)
213+
214+
215+
class ShardingOptimizerStage2(Optimizer):
216+
"""
217+
A wrapper for Sharding Stage2 Optimizer in Dygraph.
218+
219+
.. warning: ShardingOptimizer encapsulates the optimization strategy and integrates it into the optimizer.
220+
221+
.. ZeRO: 1.https://arxiv.org/pdf/1910.02054.pdf 2.https://arxiv.org/pdf/1910.02054.pdf.
222+
223+
"""
224+
225+
# TODO (Baibaifan)
226+
# Feature Notes:
227+
# 1. Unified memory for parameters and parameters.grad to InternalStorage.
228+
# 2. Support the segmentation of optimizer parameters and partial updating of parameters.
229+
# 3. Dynamically adjust training parameters and models。
230+
# 4. Support offload function.
231+
# 5. Support the establishment of independent communication groups.
232+
# 6. Broadcast_fp16 is not supported now.
233+
def __init__(self,
234+
params,
235+
optim,
236+
group,
237+
broadcast_fp16=False,
238+
offload=False,
239+
device="gpu",
240+
accumulation_steps=None,
241+
**kw):
242+
243+
super().__init__(optim._learning_rate, params, kw)
244+
245+
# Segmentation information
246+
self._dtype_rank_params = OrderedDict(
247+
) # {dtype:[param1,param2]} device, rank, params
248+
self._param2rank = {}
249+
self._segment_params = []
250+
self._rank_buffer_size = {} # {dtype: {rank: numel+alignment}}
251+
self._param2align = {} # {param.name: align}
252+
253+
# Default information
254+
self._optim_defaults = kw
255+
self._optim = optim
256+
self._local_params = params
257+
self._default_device = device
258+
self._accumulation_steps = accumulation_steps
259+
260+
assert group is not None, "Distributed communication group is must be gived"
261+
self.group = group
262+
self.world_size = group.nranks
263+
self.rank = group.rank
264+
265+
self.broadcast_fp16 = broadcast_fp16
266+
self.param_storages = {} # {dtype: {rank: InternalStorage}}
267+
self.offload = offload # Using for offload
268+
269+
# Update optimizer parameters and adjust parameter storage and use according to rank.
270+
self.update_opt_status()
271+
272+
def update_opt_status(self):
273+
"""Update optimizer status and parameter storage information, and special functions to be developed.
274+
"""
275+
# func 1
276+
self._integration_params()
277+
278+
# fun 2 TODO
279+
280+
# Segement helpers
281+
282+
def segment_params(self):
283+
"""
284+
Divide all optimizer parameters equally into rank.
285+
"""
286+
if len(self._segment_params) == 0:
287+
self._segment_params, param_lists = [
288+
[] for _ in range(self.world_size)
289+
], [[] for _ in range(self.world_size)]
290+
sizes = [0] * self.world_size
291+
for param in self._local_params:
292+
# Add this param to rank with smallest size.
293+
rank = sizes.index(min(sizes))
294+
param_lists[rank].append(param)
295+
296+
# Statistical real numels
297+
sizes[rank] += np.prod(param.shape) if param.trainable else 0
298+
299+
for rank, params in enumerate(param_lists):
300+
# param_group_rank = copy.copy(params)
301+
self._segment_params[rank].extend(params)
302+
return self._segment_params
303+
304+
@property
305+
def local_params(self):
306+
return self._local_params
307+
308+
@property
309+
def accumulation_steps(self):
310+
return self._accumulation_steps
311+
312+
@property
313+
def param2rank(self):
314+
"""Map the params to the rank which owns them"""
315+
if len(self._param2rank) == 0:
316+
for rank, params in enumerate(self.segment_params()):
317+
for param in params:
318+
self._param2rank[param.name] = rank
319+
return self._param2rank
320+
321+
@property
322+
def dtype_rank_params(self):
323+
"""
324+
Divide the parameters into groups according to rank and dtype.
325+
"""
326+
if len(self._dtype_rank_params) == 0:
327+
# Assign the parameters of each rank according to the type
328+
for param in self._local_params:
329+
if param.dtype not in self._dtype_rank_params.keys():
330+
self._dtype_rank_params[
331+
param.dtype] = [[] for _ in range(self.world_size)]
332+
self._dtype_rank_params[param.dtype][self.param2rank[
333+
param.name]].append(param)
334+
335+
# Sort per rank params by size
336+
for dtype in self._dtype_rank_params.keys():
337+
for rank_params in self._dtype_rank_params[dtype]:
338+
rank_params.sort(key=lambda x: np.prod(x.shape))
339+
340+
return self._dtype_rank_params
341+
342+
@property
343+
def rank_buffer_size(self):
344+
"""
345+
Count the memory size of the parameters corresponding to rank under the corresponding dtype.
346+
"""
347+
# CUDA alignment 256 bytes
348+
if len(self._rank_buffer_size) == 0:
349+
for dtype in self.dtype_rank_params.keys():
350+
if dtype not in self._rank_buffer_size.keys():
351+
self._rank_buffer_size[dtype] = {}
352+
for dst_rank, per_rank_params in enumerate(
353+
self.dtype_rank_params[dtype]):
354+
if dst_rank not in self._rank_buffer_size[dtype].keys():
355+
self._rank_buffer_size[dtype][dst_rank] = 0
356+
for param in per_rank_params:
357+
if not param.trainable:
358+
continue
359+
size = np.prod(param.shape) * align[dtype]
360+
remaining = size % alignment[self._default_device]
361+
ali = 0 if remaining == 0 else alignment[
362+
self._default_device] - remaining
363+
align_ = ali // align[dtype]
364+
self._rank_buffer_size[dtype][dst_rank] += np.prod(
365+
param.shape) + align_
366+
self._param2align[param.name] = align_
367+
368+
return self._rank_buffer_size
369+
370+
def _integration_params(self):
371+
"""
372+
Integrate the parameters into a continuous memory according to rank, and support the update of training parameters.
373+
"""
374+
375+
for dtype, per_rank_params in self.dtype_rank_params.items():
376+
if dtype not in self.param_storages.keys():
377+
self.param_storages[dtype] = {}
378+
379+
for dst_rank, params in enumerate(per_rank_params):
380+
if len(params) > 0:
381+
382+
# Merge all the trainable params in a single InternalStorage
383+
trainable_params = list(
384+
filter(lambda x: x.trainable, params))
385+
if trainable_params:
386+
param_storage = ParamStorage(
387+
size=self.rank_buffer_size[dtype][dst_rank],
388+
dtype=dtype,
389+
device=self._default_device)
390+
391+
param_storage.add_rank_params(trainable_params,
392+
self._param2align)
393+
self.param_storages[dtype][dst_rank] = param_storage
394+
395+
# Clear the InternalStorage keys which are not in use anymore
396+
dtype_in_use = list(self.dtype_rank_params.keys())
397+
dtype_to_pop = list(
398+
filter(lambda x: x not in dtype_in_use, self.param_storages.keys()))
399+
for d in dtype_to_pop:
400+
self.param_storages.pop(d)
401+
402+
def step(self):
403+
"""
404+
A wrapper for Optimizer's step function to finish the update operation of the optimizer.
405+
"""
406+
407+
# Synchronize optimizer parameters for the current rank
408+
if len(self.dtype_rank_params.keys(
409+
)) == 1 and Type.fp32.value in self.dtype_rank_params.keys():
410+
self._optim._parameter_list = self.dtype_rank_params[
411+
Type.fp32.value][self.rank]
412+
elif len(self.dtype_rank_params.keys(
413+
)) == 1 and Type.fp16.value in self.dtype_rank_params.keys():
414+
self._optim._parameter_list = self.dtype_rank_params[
415+
Type.fp16.value][self.rank]
416+
else:
417+
self._optim._parameter_list = self.dtype_rank_params[
418+
Type.fp16.value][self.rank] + self.dtype_rank_params[
419+
Type.fp32.value][self.rank]
420+
421+
# Run the optimizer of the current rank step
422+
self._optim.step()
423+
424+
# Synchronize all the updated shards in between the ranks
425+
self._broadcast_params()
426+
427+
# Return full parameters to optimizer parameters
428+
self._optim._parameter_list = self._local_params
429+
430+
def clear_cache(self):
431+
self._segment_params.clear()
432+
self._dtype_rank_params.clear()
433+
self._param2rank.clear()
434+
435+
@paddle.no_grad()
436+
def _broadcast_params(self):
437+
"""Broadcast the parameters of the current rank to each rank"""
438+
439+
assert self._default_device == "gpu", "Only supported gpu"
440+
441+
# Exchange all the shards with the other ranks
442+
for dtype_per_rank in self.param_storages.values():
443+
for dst_rank, internal_storage in dtype_per_rank.items():
444+
dist.broadcast(
445+
tensor=internal_storage.buffer,
446+
src=dst_rank,
447+
group=self.group,
448+
use_calc_stream=True)
449+
450+
# Multi stream operation will be supported later
451+
dist.wait(
452+
tensor=internal_storage.buffer,
453+
group=self.group,
454+
use_calc_stream=True)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .sharding_utils import GpuInfo

0 commit comments

Comments
 (0)