Skip to content

Commit

Permalink
Added bitsandbytes Adam optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Oct 24, 2021
1 parent bf70cdf commit 9fa5d46
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 18 deletions.
25 changes: 17 additions & 8 deletions megatron/model/word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,22 @@ def __init__(self,
)
self._word_embeddings_key = 'word_embeddings'

if neox_args.use_bnb_optimizer:
try:
import bitsandbytes as bnb
self.embedding_module = bnb.nn.StableEmbedding
except ModuleNotFoundError:
print("Please install bitsandbytes following https://github.com/facebookresearch/bitsandbytes.")
raise Exception
else:
self.embedding_module = torch.nn.Embedding

# Position embedding (serial).
self.use_pos_emb = use_pos_emb
if self.use_pos_emb:
self.embedding_type = neox_args.pos_emb
if self.embedding_type == "learned":
self.position_embeddings = torch.nn.Embedding(
self.position_embeddings = self.embedding_module(
max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
Expand All @@ -62,7 +72,7 @@ def __init__(self,
# token types and add them as needed.
self._tokentype_embeddings_key = 'tokentype_embeddings'
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
self.tokentype_embeddings = self.embedding_module(self.num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
Expand All @@ -83,7 +93,7 @@ def add_tokentype_embeddings(self, num_tokentypes):
print('adding embedding for {} tokentypes'.format(num_tokentypes),
flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.tokentype_embeddings = self.embedding_module(num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
Expand Down Expand Up @@ -150,11 +160,10 @@ def __init__(self,
self.init_string = init_string
self.soft_embedding_weight = torch.nn.parameter.Parameter(self.initialize_embedding(wte))

def initialize_embedding(self,
wte: torch.nn.Embedding):
def initialize_embedding(self):
if self.init_string:
embeds = torch.LongTensor(self.neox_args.tokenizer.tokenize(self.init_string)).to(wte.weight.device)
embeds = wte(embeds)
embeds = torch.LongTensor(self.neox_args.tokenizer.tokenize(self.init_string)).to(self.embedding_module.weight.device)
embeds = self.embedding_module(embeds)
if embeds.shape[0] >= self.n_tokens:
embeds = embeds[:self.n_tokens, :] # slice
else:
Expand All @@ -181,4 +190,4 @@ def forward(self, args: tuple):
embedding = torch.cat((soft_embedding, embedding), dim=1)
embedding = embedding[:, :self.neox_args.seq_length, ...]
# otherwise, we're in incremental mode, and just want to forward the single embedding (since the soft prompt has already been cached)
return embedding, layer_past, attention_mask
return embedding, layer_past, attention_mask
7 changes: 6 additions & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,11 @@ class NeoXArgsOptimizer(NeoXArgsTemplate):
Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd]
"""

use_bnb_optimizer: bool = False
"""
Whether to enable the bitsandbytes optimizers
"""

zero_stage: int = None
"""
Zero Optimizer stage
Expand Down Expand Up @@ -951,4 +956,4 @@ class NeoXArgsTextgen(NeoXArgsTemplate):
eval_tasks: list = None
"""
Tasks to evaluate on using lm_eval_harness
"""
"""
25 changes: 17 additions & 8 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,23 @@ def get_optimizer(model, neox_args):
)
elif neox_args.optimizer_type.lower() == "adam":
# Use Adam
try:
# default to apex as it's slightly faster
from apex.optimizers import FusedAdam as Adam
except ImportError:
# if apex isn't installed, use deepspeed's FusedAdam
print("WARNING: APEX not installed - defaulting to deepspeed's fused adam")
from deepspeed.ops.adam import FusedAdam as Adam
optimizer = Adam(
if neox_args.use_bnb_optimizer:
try:
import bitsandbytes as bnb
adam_optimizer = bnb.optim.Adam8bit
except ModuleNotFoundError:
print("Please install bitsandbytes following https://github.com/facebookresearch/bitsandbytes.")
raise Exception
else:
try:
# default to apex as it's slightly faster
from apex.optimizers import FusedAdam as Adam
except ImportError:
# if apex isn't installed, use deepspeed's FusedAdam
print("WARNING: APEX not installed - defaulting to deepspeed's fused adam")
from deepspeed.ops.adam import FusedAdam as Adam
adam_optimizer = Adam
optimizer = adam_optimizer(
param_groups,
weight_decay=neox_args.weight_decay,
**neox_args.optimizer["params"],
Expand Down
5 changes: 4 additions & 1 deletion megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def local_rank():
local_rank = 0
return int(local_rank)

def is_bnb_available():
""" True if bitsandbytes optimizers are available """
return importlib.util.find_spec("bitsandbytes") is not None

def is_local_main():
""" True if is the local main process """
Expand Down Expand Up @@ -442,4 +445,4 @@ def __next__(self):
self.batch_count += 1
end = time.time()
self.total_time += end - start
return batch
return batch

0 comments on commit 9fa5d46

Please sign in to comment.