Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d561380
Create a custom sliced class for phi2
LianaMikael Apr 11, 2024
e017d97
Verify model saving and loading
LianaMikael Apr 11, 2024
d071a92
Start adding slicing scheduler
LianaMikael Apr 11, 2024
c27e3d5
Correct model architecture and slice
LianaMikael Apr 12, 2024
94a5720
Add additional intermediate hidden size to config
LianaMikael Apr 12, 2024
2b6b7f9
Clean up and update intermediate hidden layer dim
LianaMikael Apr 12, 2024
a9a3488
Small fixes and perplexity computation
LianaMikael Apr 16, 2024
7f4adda
Fix from_pretrained function
LianaMikael Apr 17, 2024
dd58719
Add tests and fix config
LianaMikael Apr 19, 2024
4cb10a7
Style
LianaMikael Apr 19, 2024
85b3782
ClAdd LLama class and clean up
LianaMikael Apr 21, 2024
1f6c019
Add tests and update loading in run_slicegpt
LianaMikael Apr 21, 2024
c6ff900
Clean up tests and unnecessary script
LianaMikael Apr 21, 2024
1a68d00
Remove unnecessary script
LianaMikael Apr 21, 2024
e6e63f5
Move sliced models to new files
LianaMikael Apr 21, 2024
125cd58
Apply style
LianaMikael Apr 21, 2024
e9bb417
Add configs and move model saving to hf_utils
LianaMikael Apr 21, 2024
be7b767
Formatting
pashminacameron Apr 23, 2024
40a7661
Add module imports
pashminacameron Apr 23, 2024
3719480
Formatting module imports
pashminacameron Apr 23, 2024
f63a038
src/slicegpt/hf_utils.py
pashminacameron Apr 23, 2024
6210fc1
Add intermediate_size to OPT adapter to fix tests
LianaMikael Apr 23, 2024
fcd5fa7
Make sparsity and new_hidden_size mandatory, fix intermediate_size in…
LianaMikael Apr 23, 2024
20d1ef7
Fix config inputs
LianaMikael Apr 23, 2024
6c1f0df
Fix slicing tests
LianaMikael Apr 23, 2024
9821f08
Fi model saving
LianaMikael Apr 23, 2024
322016c
Use ffn_dim in OPT. Don't set intermediate_size in OPTCOnfig.
pashminacameron Apr 24, 2024
d1e5899
Remove unnecessary params and fix scheduler when model loading
LianaMikael Apr 24, 2024
16cb082
Fix tests
LianaMikael Apr 24, 2024
a83a523
Merge branch 'liana/make_model_HF_compatible' of https://github.com/m…
LianaMikael Apr 24, 2024
d2f4d2e
Update model loading
LianaMikael Apr 24, 2024
981615b
Update scheduler params when loading sliced model
LianaMikael Apr 24, 2024
84d775a
Update test with model loading
LianaMikael Apr 24, 2024
e678717
Update lm eval and fix rotate
LianaMikael Apr 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions experiments/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import os

import lm_eval
from slicegpt.model_adapter import ModelAdapter
import torch
import wandb
from lm_eval import tasks
from lm_eval import utils as lm_eval_utils
from lm_eval.api.registry import ALL_TASKS
from lm_eval.models.huggingface import HFLM
from lm_eval.tasks import initialize_tasks

import wandb
from slicegpt import gpu_utils, hf_utils, utils
from slicegpt.config import config

Expand Down Expand Up @@ -120,29 +121,34 @@ def eval_main(args: argparse.Namespace) -> None:
if args.sliced_model_path:
# load the sliced model
logging.info(f"Loading sliced {args.model} model from {args.sliced_model_path} with sparsity {args.sparsity}")
model_adapter, tokenizer = hf_utils.load_sliced_model(
model, tokenizer = hf_utils.load_sliced_model(
args.model,
args.sliced_model_path,
sparsity=args.sparsity,
token=args.hf_token,
round_interval=args.round_interval,
)
if isinstance(model, ModelAdapter):
model = model.model
else:
model = model.to(config.dtype)
else:
# load the original model
logging.info(f"Loading {args.model} model")
model_adapter, tokenizer = hf_utils.get_model_and_tokenizer(args.model, args.model_path, token=args.hf_token)
model = model_adapter.model

# the lm eval harness ties the weights, but this should not be done for sliced models unless the lm_head was sliced
model_adapter.model.tie_weights = lambda: None
model.model.tie_weights = lambda: None

if args.distribute_model:
# distribute model across available GPUs
gpu_utils.distribute_model(model_adapter)
else:
model_adapter.model.to(config.device)
model.to(config.device)

### LM Eval Harness ###
hflm = HFLM(pretrained=model_adapter.model, tokenizer=tokenizer, batch_size=args.batch_size)
hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=args.batch_size)

if args.tasks is None:
task_names = tasks.ALL_TASKS
Expand Down
27 changes: 15 additions & 12 deletions experiments/run_slicegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import shutil

import torch
import wandb

import wandb
from slicegpt import data_utils, gpu_utils, hf_utils, layernorm_fusion, rotate, utils
from slicegpt.config import config
from slicegpt.slicing_scheduler import ConstSlicingScheduler
Expand Down Expand Up @@ -137,20 +137,20 @@ def slicing_main(args: argparse.Namespace) -> None:

if args.sliced_model_path:
# load the model from sliced_model_path to compute perplexity and skip rotation and slicing
model_adapter, tokenizer = hf_utils.load_sliced_model(
model, tokenizer = hf_utils.load_sliced_model(
args.model,
args.sliced_model_path,
sparsity=args.sparsity,
round_interval=args.round_interval,
token=args.hf_token,
)
model = model.to(config.dtype)
else:
# load one of the pre-trained models
model_adapter, tokenizer = hf_utils.get_model_and_tokenizer(
args.model, args.model_path, token=args.hf_token, dtype=config.dtype
)

model = model_adapter.model
model = model_adapter.model

def reset_model_device() -> None:
if args.distribute_model:
Expand Down Expand Up @@ -228,14 +228,17 @@ def reset_model_device() -> None:
sliced_model_dir = pathlib.Path(args.save_dir)
sliced_model_dir.mkdir(parents=True, exist_ok=True)

sliced_model_name = sliced_model_dir / f'{pathlib.Path(args.model).name}_{args.sparsity}.pt'

# Save the sliced model
torch.save(model.state_dict(), sliced_model_name)

# Save the slicing config
config_path = sliced_model_name.with_suffix('.json')
config_path.write_text(model_adapter.slicing_conf.to_json_string())
# Save the sliced model in HF format for Phi and Llama
hf_utils.save_sliced_model(
args.model,
config.dtype,
model,
scheduler,
sliced_model_dir,
args.sparsity,
new_embedding_dimension,
model_adapter.slicing_conf,
)

# If slicing a local model, also save HF config files in sliced model dir
if args.model_path:
Expand Down
2 changes: 2 additions & 0 deletions src/slicegpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .adapters.llama_adapter import LlamaModelAdapter
from .adapters.opt_adapter import OPTModelAdapter
from .adapters.phi2_adapter import Phi2ModelAdapter
from .adapters.sliced_llama import SlicedLlama, SlicedLlamaConfig, SlicedLlamaForCausalLM
from .adapters.sliced_phi import SlicedPhi, SlicedPhi2Config, SlicedPhiForCausalLM
from .data_utils import get_dataset, prepare_dataloader
from .gpu_utils import benchmark, distribute_model, evaluate_ppl
from .hf_utils import get_model_and_tokenizer, load_sliced_model
Expand Down
Empty file.
6 changes: 6 additions & 0 deletions src/slicegpt/adapters/llama_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaDecoderLayer, LlamaForCausalLM, LlamaRMSNorm

from slicegpt.model_adapter import LayerAdapter, ModelAdapter
from slicegpt.modules import RMSN


class CompressedLlamaDecoderLayer(LlamaDecoderLayer):
Expand All @@ -23,6 +24,11 @@ class CompressedLlamaDecoderLayer(LlamaDecoderLayer):
but with the addition of a shortcut_Q attribute. This attribute is used to rotate the residual tensors.
"""

def __init__(self, config: LlamaConfig, layer_idx: int, replace_layernorm: bool = False):
super().__init__(config, layer_idx)
if replace_layernorm:
self.input_layernorm = RMSN(config.hidden_size)

def forward(
self,
hidden_states: Tensor,
Expand Down
6 changes: 6 additions & 0 deletions src/slicegpt/adapters/opt_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers.models.opt.modeling_opt import OPTConfig, OPTDecoderLayer, OPTForCausalLM

from slicegpt.model_adapter import LayerAdapter, ModelAdapter
from slicegpt.modules import RMSN


class CompressedOPTDecoderLayer(OPTDecoderLayer):
Expand All @@ -23,6 +24,11 @@ class CompressedOPTDecoderLayer(OPTDecoderLayer):
We also support the input rotation and mean subtraction in this class (if needed).
"""

def __init__(self, config: OPTConfig, replace_layernorm: bool = False):
super().__init__(config)
if replace_layernorm:
self.input_layernorm = RMSN(config.hidden_size)

def forward(
self,
hidden_states: Tensor,
Expand Down
6 changes: 6 additions & 0 deletions src/slicegpt/adapters/phi2_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from transformers.models.phi.modeling_phi import PhiConfig, PhiDecoderLayer, PhiForCausalLM

from slicegpt.model_adapter import LayerAdapter, ModelAdapter
from slicegpt.modules import RMSN


class CompressedPhiDecoderLayer(PhiDecoderLayer):
Expand All @@ -25,6 +26,11 @@ class CompressedPhiDecoderLayer(PhiDecoderLayer):
but with the addition of a shortcut_Q attribute. This attribute is used to rotate the residual tensors.
"""

def __init__(self, config: PhiConfig, layer_idx: int, replace_layernorm: bool = False):
super().__init__(config, layer_idx)
if replace_layernorm:
self.input_layernorm = RMSN(config.hidden_size)

def forward(
self,
hidden_states: Tensor,
Expand Down
86 changes: 86 additions & 0 deletions src/slicegpt/adapters/sliced_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from transformers.configuration_utils import PretrainedConfig
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM, LlamaModel

from slicegpt.adapters.llama_adapter import CompressedLlamaDecoderLayer, LlamaModelAdapter
from slicegpt.modules import RMSN
from slicegpt.rotate import slice_rotated_model
from slicegpt.slicing_scheduler import SlicingScheduler


class SlicedLlamaConfig(LlamaConfig):
model_type = "sliced_llama"
is_composition = True

def __init__(self, **kwargs) -> None:
self.sparsity = kwargs.pop("sparsity", None)
self.new_hidden_size = kwargs.pop("new_hidden_size", None)
super().__init__(**kwargs)

@classmethod
def from_pretrained(cls, config_path: str, sparsity: float, new_hidden_size: int) -> PretrainedConfig:
kwargs = {"sparsity": sparsity, "new_hidden_size": new_hidden_size}
return super().from_pretrained(config_path, **kwargs)


class SlicedLlama(LlamaModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.layers = nn.ModuleList(
[
CompressedLlamaDecoderLayer(config, layer_idx, replace_layernorm=True)
for layer_idx in range(config.num_hidden_layers)
]
)
self.final_layernorm = RMSN(config.hidden_size)


class SlicedLlamaForCausalLM(LlamaForCausalLM):
def __init__(
self,
config,
scheduler: SlicingScheduler | None = None,
*model_args,
**kwargs,
):
super().__init__(config)
self.model = SlicedLlama(config)
self.model_adapter = LlamaModelAdapter(self)

if scheduler:
self.update_dims(scheduler)

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
scheduler: SlicingScheduler | None,
sparsity: float,
new_hidden_size: int,
config_path: str,
*model_args,
**kwargs,
):
"""Overrides the from_pretrained method to accept the scheduler and returns the sliced model"""
config = SlicedLlamaConfig.from_pretrained(config_path, sparsity, new_hidden_size)
kwargs = {"scheduler": scheduler}
model = super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs)
model.load_state_dict(model.state_dict())
return model

def update_dims(self, scheduler: SlicingScheduler) -> None:
layers = self.model_adapter.get_layers()

hidden_size = self.model_adapter.hidden_size
for layer_adapter in layers:
if not self.model_adapter.parallel_blocks:
layer_adapter.layer.mlp_shortcut_Q = torch.nn.Parameter(
torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous()
)
layer_adapter.layer.attn_shortcut_Q = torch.nn.Parameter(
torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous()
)

slice_rotated_model(self.model_adapter, scheduler)
86 changes: 86 additions & 0 deletions src/slicegpt/adapters/sliced_phi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from transformers.configuration_utils import PretrainedConfig
from transformers.models.phi.modeling_phi import PhiConfig, PhiForCausalLM, PhiModel

from slicegpt.adapters.phi2_adapter import CompressedPhiDecoderLayer, Phi2ModelAdapter
from slicegpt.modules import RMSN
from slicegpt.rotate import slice_rotated_model
from slicegpt.slicing_scheduler import SlicingScheduler


class SlicedPhi2Config(PhiConfig):
model_type = "sliced_phi2"
is_composition = True

def __init__(self, **kwargs) -> None:
self.sparsity = kwargs.pop("sparsity", None)
self.new_hidden_size = kwargs.pop("new_hidden_size", None)
super().__init__(**kwargs)

@classmethod
def from_pretrained(cls, config_path: str, sparsity: float, new_hidden_size: int) -> PretrainedConfig:
kwargs = {"sparsity": sparsity, "new_hidden_size": new_hidden_size}
return super().from_pretrained(config_path, local_files_only=True, **kwargs)


class SlicedPhi(PhiModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.layers = nn.ModuleList(
[
CompressedPhiDecoderLayer(config, layer_idx, replace_layernorm=True)
for layer_idx in range(config.num_hidden_layers)
]
)
self.final_layernorm = RMSN(config.hidden_size)


class SlicedPhiForCausalLM(PhiForCausalLM):
def __init__(
self,
config,
scheduler: SlicingScheduler | None = None,
*model_args,
**kwargs,
):
super().__init__(config)
self.model = SlicedPhi(config)
self.model_adapter = Phi2ModelAdapter(self)

if scheduler:
self.update_dims(scheduler)

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
scheduler: SlicingScheduler | None,
sparsity: float,
new_hidden_size: int,
config_path: str,
*model_args,
**kwargs,
):
"""Overrides the from_pretrained method to accept the scheduler and returns the sliced model"""
config = SlicedPhi2Config.from_pretrained(config_path, sparsity, new_hidden_size)
kwargs = {"scheduler": scheduler}
model = super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs)
model.load_state_dict(model.state_dict())
return model

def update_dims(self, scheduler: SlicingScheduler) -> None:
layers = self.model_adapter.get_layers()

hidden_size = self.model_adapter.hidden_size
for layer_adapter in layers:
if not self.model_adapter.parallel_blocks:
layer_adapter.layer.mlp_shortcut_Q = torch.nn.Parameter(
torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous()
)
layer_adapter.layer.attn_shortcut_Q = torch.nn.Parameter(
torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous()
)

slice_rotated_model(self.model_adapter, scheduler)
Loading