Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow TiedLayerSpec to have multiple tied weights #4216

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import re as regex

from functools import partial
from typing import List

import torch
import torch.nn as nn
Expand Down Expand Up @@ -76,11 +77,25 @@ def build(self, log=False):

class TiedLayerSpec(LayerSpec):

def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_attr='weight', **module_kwargs):
def __init__(self,
key,
typename,
*module_args,
forward_fn=None,
tied_weight_attr='weight', # Deprecated
tied_weight_attrs: List[str] = None,
**module_kwargs):
super().__init__(typename, *module_args, **module_kwargs)
self.key = key
self.forward_fn = forward_fn
self.tied_weight_attr = tied_weight_attr
if tied_weight_attr is not None:
logger.warning(
"`tied_weight_attr` in TiedLayerSpec is deprecated, please use `tied_weight_attrs` instead."
)
self.tied_weight_attrs = [tied_weight_attr]
else:
self.tied_weight_attrs = tied_weight_attrs



class PipelineModule(nn.Module):
Expand Down Expand Up @@ -190,7 +205,7 @@ def __init__(self,
self.forward_funcs = []
self.fwd_map = {}
self.tied_modules = nn.ModuleDict()
self.tied_weight_attrs = {}
self.tied_weight_attrss = {}

# Offset the random seed by the stage ID.
#newseed = get_accelerator().initial_seed() + self._grid.get_stage_id()
Expand Down Expand Up @@ -235,7 +250,7 @@ def _build(self):
# Build and register the module if we haven't seen it before.
if layer.key not in self.tied_modules:
self.tied_modules[layer.key] = layer.build()
self.tied_weight_attrs[layer.key] = layer.tied_weight_attr
self.tied_weight_attrss[layer.key] = layer.tied_weight_attrs

if layer.forward_fn is None:
# Just use forward()
Expand Down Expand Up @@ -423,23 +438,26 @@ def _partition_layers(self, method='uniform'):
def allreduce_tied_weight_gradients(self):
'''All reduce the gradients of the tied weights between tied stages'''
for key, comm in self.tied_comms.items():
weight = getattr(self.tied_modules[key], comm['weight_attr'])
dist.all_reduce(weight.grad, group=comm['group'])
for weight_attr in comm['weight_attrs']:
weight = getattr(self.tied_modules[key], weight_attr)
dist.all_reduce(weight.grad, group=comm['group'])

def get_tied_weights_and_groups(self):
weight_group_list = []
for key, comm in self.tied_comms.items():
weight = getattr(self.tied_modules[key], comm['weight_attr'])
weight_group_list.append((weight, comm['group']))
for weight_attr in comm['weight_attrs']:
weight = getattr(self.tied_modules[key], weight_attr)
weight_group_list.append((weight, comm['group']))
return weight_group_list

def _synchronize_tied_weights(self):
for key, comm in self.tied_comms.items():
dist.broadcast(
getattr(comm['module'], comm['weight_attr']),
src=min(comm['ranks']),
group=comm['group'],
)
for weight_attr in comm['weight_attrs']:
dist.broadcast(
getattr(comm['module'], weight_attr),
src=min(comm['ranks']),
group=comm['group'],
)

def _index_tied_modules(self):
''' Build communication structures for tied modules. '''
Expand Down Expand Up @@ -478,7 +496,7 @@ def _index_tied_modules(self):
tied_comms[key] = {
'ranks': tied_ranks,
'group': group,
'weight_attr': self.tied_weight_attrs[key],
'weight_attrs': self.tied_weight_attrss[key],
'module': self.tied_modules[key],
}
# Only count the tied module once in the eyes of the FP16 optimizer
Expand Down