Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make MosaicGPT a HuggingFace PreTrainedModel #243

Merged
merged 103 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from 101 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
6d6477f
squash
abhi-mosaic Feb 28, 2023
c2573ad
Merge branch 'main' into hero-3-3
abhi-mosaic Feb 28, 2023
1dd1c59
update composer commit
abhi-mosaic Feb 28, 2023
e8ac879
Add Lion optimizer (#200)
abhi-mosaic Mar 1, 2023
46dbc41
Merge branch 'main' into hero-3-3
dakinggg Mar 1, 2023
d64db2a
Torchmetrics upgrade (#203)
dakinggg Mar 2, 2023
ce5d36c
Low precision layer norm option (#205)
dakinggg Mar 2, 2023
18387de
replace speed monitor with composer one (#207)
dakinggg Mar 2, 2023
b5b57cb
add health checker callback (#206)
dakinggg Mar 2, 2023
80d1b22
Merge branch 'main' into hero-3-3
abhi-mosaic Mar 2, 2023
0ea0e2e
upgrade to release 0.13 branch
abhi-mosaic Mar 3, 2023
4677fb4
upgrade base mosaicml
mvpatel2000 Mar 3, 2023
e44f91f
add pynvml for health checker
abhi-mosaic Mar 3, 2023
b798cd0
add slack sdk
mvpatel2000 Mar 3, 2023
d09e3f5
bump composer
codestar12 Mar 6, 2023
01b1a83
add lion
codestar12 Mar 6, 2023
1b4e38c
Merge branch 'init_tests' into hero-3-3
codestar12 Mar 6, 2023
7109fb6
raise timeout
mvpatel2000 Mar 7, 2023
afc5dcf
add load_ignore_keys
abhi-mosaic Mar 7, 2023
ff7b867
remove reference to speedmonitormfu
abhi-mosaic Mar 7, 2023
e07ba98
use lpln_bias fix branch
abhi-mosaic Mar 8, 2023
5a54d0f
need to fix super().__init in StreamingTextDataset
growlix Mar 9, 2023
fc9fc4c
set keep_zip manually to deal with bug
codestar12 Mar 9, 2023
9d553d1
pin torchmetrics
codestar12 Mar 11, 2023
44c1322
requirements.txt is pip freeze from previous hero run
growlix Mar 13, 2023
a2092c9
updated requirements.txt composer version to lp-ln fix
growlix Mar 14, 2023
d21152e
attempted requirements.txt fix
growlix Mar 14, 2023
20a4584
requirements.txt fix attempt. mlnx-tools version
growlix Mar 14, 2023
b7adb8d
requirements.txt fix attempt. mlnx-tools version
growlix Mar 14, 2023
0e9ddf9
requirements.txt fix attempt. commented out mlnx-tools
growlix Mar 14, 2023
5b0ad9c
requirements.txt fix attempt. commented out pillow
growlix Mar 14, 2023
38bd96e
requirements.txt fix attempt. commented out mosaicml-examples
growlix Mar 14, 2023
0b3b685
requirements.txt fix attempt. commented out packaging
growlix Mar 14, 2023
9ef3a22
commented out mcli from requirements.txt
growlix Mar 14, 2023
684da29
commented out toml and tomli
growlix Mar 14, 2023
aa73dc8
commented out most stuff
growlix Mar 14, 2023
8cade14
merge with main
growlix Mar 14, 2023
22f6521
Fixed SpeedMonitor import
growlix Mar 14, 2023
619c372
import LionW
growlix Mar 14, 2023
c5e3e18
starter config file
dakinggg Mar 15, 2023
544d0b9
full config file
dakinggg Mar 15, 2023
2ac4d7a
move config validation into config class
dakinggg Mar 15, 2023
0c7cae5
updated mosaicml pin
growlix Mar 16, 2023
de71ea6
wip
abhi-mosaic Mar 16, 2023
e547527
Merge branch 'main' into merge-hero-changes
abhi-mosaic Mar 16, 2023
d90baaf
wip
abhi-mosaic Mar 16, 2023
4320b04
wip
abhi-mosaic Mar 16, 2023
9cf99b5
wip
abhi-mosaic Mar 16, 2023
622cae7
drop cfg.gets that dont work with a pretrained config
dakinggg Mar 16, 2023
328fe0d
address comments
abhi-mosaic Mar 16, 2023
0b93321
Apply suggestions from code review
abhi-mosaic Mar 16, 2023
b27a2d1
revert 4.25 transformers
abhi-mosaic Mar 16, 2023
885b95b
increase timeout 1800, lint
abhi-mosaic Mar 17, 2023
def4aca
propagate the config everywhere
dakinggg Mar 17, 2023
95641b8
fix required arg passing
dakinggg Mar 17, 2023
216fc20
print
dakinggg Mar 17, 2023
8f784cc
resolve true
dakinggg Mar 17, 2023
77b624a
fix MDSWriter args for streamitn 0.0.3
abhi-mosaic Mar 17, 2023
b567a7e
add test for config parsing
dakinggg Mar 17, 2023
d39268a
merge
dakinggg Mar 17, 2023
d0610fc
fix typo and wip generation
dakinggg Mar 17, 2023
6115e99
rename attention mask
dakinggg Mar 17, 2023
2fee085
wip
dakinggg Mar 17, 2023
d9d9814
merge
dakinggg Mar 17, 2023
1200b29
update
dakinggg Mar 17, 2023
5c3ba48
wip
dakinggg Mar 17, 2023
f66593f
precision
dakinggg Mar 17, 2023
e86143f
finish test
dakinggg Mar 17, 2023
a96a5eb
fix bad merge
dakinggg Mar 17, 2023
e1f12e2
pyright
dakinggg Mar 17, 2023
5e78669
pyright
dakinggg Mar 17, 2023
a3473d7
fix gpu test
dakinggg Mar 17, 2023
0132ed3
docs to the config
dakinggg Mar 17, 2023
1ade7c5
attempted merge
dakinggg Mar 18, 2023
7e0f467
pyright
dakinggg Mar 18, 2023
cac7ce3
fix duplicate arg
dakinggg Mar 18, 2023
cc22bd9
fix
dakinggg Mar 18, 2023
e68c3bd
remove type ignore
dakinggg Mar 18, 2023
864c6fa
wip attempted fixes for left padding
dakinggg Mar 18, 2023
203b8e6
gate position changing based on attnetion mask
dakinggg Mar 18, 2023
f3a30b2
pyright
dakinggg Mar 20, 2023
9e6761c
add save/from pretrained test
dakinggg Mar 20, 2023
66791d9
add comment back
dakinggg Mar 20, 2023
146ae9b
remove extra check
dakinggg Mar 20, 2023
f4380d9
fix kwargs
dakinggg Mar 21, 2023
ee26de1
pr comments
dakinggg Mar 21, 2023
ad0de73
type ignore
dakinggg Mar 21, 2023
bbd3cc4
merge
dakinggg Mar 21, 2023
375579b
fix merge
dakinggg Mar 21, 2023
9aed28b
add cache to config and drop clip check
dakinggg Mar 21, 2023
8a4ce97
add cache test
dakinggg Mar 21, 2023
dd97e33
adjustments to mgpt
dakinggg Mar 21, 2023
822615b
reset the causal masks
dakinggg Mar 21, 2023
29a6d83
check cuda for gpu test
dakinggg Mar 21, 2023
d597f86
fix bad copy paste
dakinggg Mar 21, 2023
d72a534
Merge branch 'main' into mgpt_to_hf
dakinggg Mar 21, 2023
43c1268
pyright
dakinggg Mar 21, 2023
5ab09a1
type ignore
dakinggg Mar 22, 2023
8da6cd2
comments to tests
dakinggg Mar 22, 2023
b44eef2
list -> tuple
dakinggg Mar 22, 2023
1ae6cf1
remove outdated config checks
dakinggg Mar 22, 2023
e5286bd
Update examples/llm/src/models/layers/attention.py
dakinggg Mar 22, 2023
1ef9c57
add comment
dakinggg Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions examples/llm/src/models/configuration_mosaic_gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0

"""A HuggingFace-style model configuration."""

from typing import Optional, Tuple, Union

from transformers import PretrainedConfig


class MosaicGPTConfig(PretrainedConfig):
model_type = 'mosaic_gpt'

def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
mlp_ratio: int = 4,
max_seq_len: int = 2048,
vocab_size: int = 50257,
init_std: float = 0.02,
attn_pdrop: float = 0.0,
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
attn_impl: str = 'triton',
attn_qk_ln: bool = False,
attn_clip_qkv: Optional[float] = None,
softmax_scale: Optional[float] = None,
alibi: bool = False,
alibi_bias_max: int = 8,
init_device: str = 'cpu',
logit_scale: Optional[Union[float, str]] = None,
no_bias: bool = False,
verbose: int = 0,
param_init_fn: str = 'baseline_',
init_div_is_residual: Union[int, float, str, bool] = True,
emb_init_std: Optional[float] = None,
emb_init_uniform_lim: Optional[Union[Tuple[float, float],
float]] = None,
init_gain: float = 0,
fan_mode: str = 'fan_in',
init_nonlinearity: str = 'leaky_relu',
embedding_fraction: float = 1.0,
low_precision_layernorm: bool = False,
use_cache: bool = True,
**kwargs,
):
"""The MosaicGPT configuration class.

Args:
d_model (int): The size of the embedding dimension of the model.
n_heads (int): The number of attention heads.
n_layers (int): The number of layers in the model.
mlp_ratio (int): The ratio of the up/down scale in the MLP.
max_seq_len (int): The maximum sequence length of the model.
vocab_size (int): The size of the vocabulary.
init_std (float): The standard deviation of the normal distribution used to initialize the model,
if using the normal parameter initialization scheme.
attn_pdrop (float): The dropout probability for the attention layers.
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
emb_pdrop (float): The dropout probability for the embedding layer.
attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
attn_qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
attn_clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
this value.
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
use the default scale of ``1/sqrt(d_keys)``.
alibi (bool): Whether to use the alibi bias instead of position embeddings.
alibi_bias_max (int): The maximum value of the alibi bias.
init_device (str): The device to use for parameter initialization.
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
no_bias (bool): Whether to use bias in all layers.
verbose (int): The verbosity level. 0 is silent.
param_init_fn (str): The parameter initialization scheme to use. One of 'default_', 'baseline_', 'kaiming_uniform_',
'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 'xavier_normal_'.
init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
low_precision_layernorm (bool): Whether to use low precision layer normalization.
use_cache (bool): Whether or not the model should return the last key/values attentions
"""
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.mlp_ratio = mlp_ratio
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.init_std = init_std
self.attn_pdrop = attn_pdrop
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.attn_impl = attn_impl
self.attn_qk_ln = attn_qk_ln
self.attn_clip_qkv = attn_clip_qkv
self.softmax_scale = softmax_scale
self.alibi = alibi
self.alibi_bias_max = alibi_bias_max
self.init_device = init_device
self.logit_scale = logit_scale
self.no_bias = no_bias
self.verbose = verbose
self.param_init_fn = param_init_fn
self.init_div_is_residual = init_div_is_residual
self.emb_init_std = emb_init_std
self.emb_init_uniform_lim = emb_init_uniform_lim
self.init_std = init_std
self.init_gain = init_gain
self.fan_mode = fan_mode
self.init_nonlinearity = init_nonlinearity
self.embedding_fraction = embedding_fraction
self.low_precision_layernorm = low_precision_layernorm
self.use_cache = use_cache
if 'name' in kwargs:
del kwargs['name']
super().__init__(**kwargs)

self._validate_config()

def _validate_config(self):
if self.d_model % self.n_heads != 0:
raise ValueError('d_model must be divisible by n_heads')
if any(prob < 0 or prob > 1
for prob in [self.attn_pdrop, self.resid_pdrop, self.emb_pdrop]):
raise ValueError(
'attn_pdrop, resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1'
)
if self.attn_impl not in ['torch', 'flash', 'triton']:
raise ValueError(f'Unknown attn_impl={self.attn_impl}')
if self.alibi and self.attn_impl not in ['torch', 'triton']:
raise NotImplementedError(
'alibi only implemented with torch and triton attention.')
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
)
if isinstance(self.logit_scale,
str) and self.logit_scale != 'inv_sqrt_d_model':
raise ValueError(
f"{self.logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
)
58 changes: 42 additions & 16 deletions examples/llm/src/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@
from composer.algorithms.low_precision_layernorm.low_precision_layernorm import \
LPLayerNorm
from einops import rearrange
from omegaconf import DictConfig
from torch import nn


def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
original_is_causal: bool):
if num_query_tokens != num_key_tokens:
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
if num_query_tokens != 1:
raise NotImplementedError(
'MosaicGPT does not support query and key with different number of tokens, unless number of query tokens is 1.'
)
else:
return False
return original_is_causal


def scaled_multihead_dot_product_attention(
query,
key,
Expand Down Expand Up @@ -148,6 +159,9 @@ def flash_attn_fn(
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

dropout_p = dropout_p if training else 0.0

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)

output_unpad = flash_attn_interface.flash_attn_unpadded_func(
query_unpad,
key_unpad,
Expand All @@ -158,7 +172,7 @@ def flash_attn_fn(
max_seqlen_k,
dropout_p,
softmax_scale=softmax_scale,
causal=is_causal,
causal=reset_is_causal,
return_attn_probs=needs_weights)

output = bert_padding.pad_input(
Expand Down Expand Up @@ -226,8 +240,9 @@ def triton_flash_attn_fn(
key = rearrange(key, 'b s (h d) -> b s h d', h=n_heads)
value = rearrange(value, 'b s (h d) -> b s h d', h=n_heads)

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = flash_attn_triton.flash_attn_func(query, key, value,
attn_bias, is_causal,
attn_bias, reset_is_causal,
softmax_scale)

output = attn_output.view(*attn_output.shape[:2], -1)
Expand All @@ -244,28 +259,38 @@ class MultiheadAttention(nn.Module):
additive bias.
"""

def __init__(self, cfg: DictConfig, device: Optional[str] = None):
def __init__(
self,
d_model: int,
n_heads: int,
attn_impl: str = 'triton',
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
attn_clip_qkv: Optional[float] = None,
attn_qk_ln: bool = False,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
low_precision_layernorm: bool = False,
device: Optional[str] = None,
):
super().__init__()
self.attn_impl = cfg.get('attn_impl')

self.clip_qkv = cfg.get('attn_clip_qkv')
self.attn_qk_ln = cfg.get('attn_qk_ln')
self.attn_impl = attn_impl
self.clip_qkv = attn_clip_qkv
self.attn_qk_ln = attn_qk_ln

self.d_model = cfg.d_model
self.n_heads = cfg.n_heads
self.softmax_scale = cfg.get('softmax_scale')
self.d_model = d_model
self.n_heads = n_heads
self.softmax_scale = softmax_scale
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = cfg.get('attn_pdrop')
self.attn_dropout_p = attn_pdrop

self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
# for param init fn; enables shape based init of fused layers
fuse_splits = (cfg.d_model, 2 * cfg.d_model)
fuse_splits = (d_model, 2 * d_model)
self.Wqkv._fused = (0, fuse_splits) # type: ignore

if self.attn_qk_ln:
layernorm_class = nn.LayerNorm if not cfg.get(
'low_precision_layernorm', False) else LPLayerNorm
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
self.q_ln = layernorm_class(self.d_model, device=device)
self.k_ln = layernorm_class(self.d_model, device=device)

Expand All @@ -284,7 +309,7 @@ def __init__(self, cfg: DictConfig, device: Optional[str] = None):
'Using `attn_impl: torch`; recommened using `attn_impl: flash`.'
)
else:
raise ValueError(f"{cfg.get('attn_impl')=} is an invalid setting.")
raise ValueError(f'{attn_impl=} is an invalid setting.')
dakinggg marked this conversation as resolved.
Show resolved Hide resolved

self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
self.out_proj._is_residual = True # type: ignore
Expand All @@ -297,6 +322,7 @@ def forward(self,
is_causal=True,
needs_weights=False):
qkv = self.Wqkv(x)

if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)

Expand All @@ -313,7 +339,7 @@ def forward(self,
key = self.k_ln(key).to(dtype)

if past_key_value is not None:
if len(past_key_value) == 0:
if len(past_key_value) != 0:
key = torch.cat([past_key_value[0], key], dim=1)
value = torch.cat([past_key_value[1], value], dim=1)

Expand Down
59 changes: 42 additions & 17 deletions examples/llm/src/models/layers/gpt_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,20 @@
import torch.nn as nn
from composer.algorithms.low_precision_layernorm.low_precision_layernorm import \
LPLayerNorm
from omegaconf import DictConfig

from examples.llm.src.models.layers.attention import MultiheadAttention


class GPTMLP(nn.Module):

def __init__(self, cfg: DictConfig, device: Optional[str] = None):
def __init__(self,
d_model: int,
mlp_ratio: int,
device: Optional[str] = None):
super().__init__()
self.mlp_up = nn.Linear(cfg.d_model,
cfg.mlp_ratio * cfg.d_model,
device=device)
self.mlp_up = nn.Linear(d_model, mlp_ratio * d_model, device=device)
self.mlp_act = nn.GELU(approximate='none')
self.mlp_down = nn.Linear(cfg.mlp_ratio * cfg.d_model,
cfg.d_model,
device=device)
self.mlp_down = nn.Linear(mlp_ratio * d_model, d_model, device=device)
self.mlp_down._is_residual = True # type: ignore

def forward(self, x):
Expand All @@ -33,17 +31,44 @@ def forward(self, x):

class GPTBlock(nn.Module):

def __init__(self, cfg: DictConfig, device: Optional[str] = None):
def __init__(self,
attn_impl: str,
d_model: int,
n_heads: int,
mlp_ratio: int,
attn_clip_qkv: Optional[float] = None,
attn_qk_ln: bool = False,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
alibi: bool = False,
resid_pdrop: float = 0.0,
low_precision_layernorm: bool = False,
device: Optional[str] = None,
**kwargs):
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
del kwargs # unused, just to capture any extra args from the config
super().__init__()
layernorm_class = LPLayerNorm if cfg.get('low_precision_layernorm',
False) else nn.LayerNorm

self.ln_1 = layernorm_class(cfg.d_model, device=device)
self.attn = MultiheadAttention(cfg, device)
self.ln_2 = layernorm_class(cfg.d_model, device=device)
self.mlp = GPTMLP(cfg, device=device)
self.resid_attn_dropout = nn.Dropout(cfg.resid_pdrop)
self.resid_mlp_dropout = nn.Dropout(cfg.resid_pdrop)
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm

self.ln_1 = layernorm_class(d_model, device=device)
self.attn = MultiheadAttention(
attn_impl=attn_impl,
attn_clip_qkv=attn_clip_qkv,
attn_qk_ln=attn_qk_ln,
softmax_scale=softmax_scale,
attn_pdrop=attn_pdrop,
d_model=d_model,
n_heads=n_heads,
device=device,
)
self.ln_2 = layernorm_class(d_model, device=device)
self.mlp = GPTMLP(
d_model=d_model,
mlp_ratio=mlp_ratio,
device=device,
)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_mlp_dropout = nn.Dropout(resid_pdrop)

def forward(
self,
Expand Down
Loading