-
Notifications
You must be signed in to change notification settings - Fork 53
Make sliced models HuggingFace compatible #139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
LianaMikael
wants to merge
34
commits into
main
Choose a base branch
from
liana/make_model_HF_compatible
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
d561380
Create a custom sliced class for phi2
LianaMikael e017d97
Verify model saving and loading
LianaMikael d071a92
Start adding slicing scheduler
LianaMikael c27e3d5
Correct model architecture and slice
LianaMikael 94a5720
Add additional intermediate hidden size to config
LianaMikael 2b6b7f9
Clean up and update intermediate hidden layer dim
LianaMikael a9a3488
Small fixes and perplexity computation
LianaMikael 7f4adda
Fix from_pretrained function
LianaMikael dd58719
Add tests and fix config
LianaMikael 4cb10a7
Style
LianaMikael 85b3782
ClAdd LLama class and clean up
LianaMikael 1f6c019
Add tests and update loading in run_slicegpt
LianaMikael c6ff900
Clean up tests and unnecessary script
LianaMikael 1a68d00
Remove unnecessary script
LianaMikael e6e63f5
Move sliced models to new files
LianaMikael 125cd58
Apply style
LianaMikael e9bb417
Add configs and move model saving to hf_utils
LianaMikael be7b767
Formatting
pashminacameron 40a7661
Add module imports
pashminacameron 3719480
Formatting module imports
pashminacameron f63a038
src/slicegpt/hf_utils.py
pashminacameron 6210fc1
Add intermediate_size to OPT adapter to fix tests
LianaMikael fcd5fa7
Make sparsity and new_hidden_size mandatory, fix intermediate_size in…
LianaMikael 20d1ef7
Fix config inputs
LianaMikael 6c1f0df
Fix slicing tests
LianaMikael 9821f08
Fi model saving
LianaMikael 322016c
Use ffn_dim in OPT. Don't set intermediate_size in OPTCOnfig.
pashminacameron d1e5899
Remove unnecessary params and fix scheduler when model loading
LianaMikael 16cb082
Fix tests
LianaMikael a83a523
Merge branch 'liana/make_model_HF_compatible' of https://github.com/m…
LianaMikael d2f4d2e
Update model loading
LianaMikael 981615b
Update scheduler params when loading sliced model
LianaMikael 84d775a
Update test with model loading
LianaMikael e678717
Update lm eval and fix rotate
LianaMikael File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torch | ||
import torch.nn as nn | ||
from transformers.configuration_utils import PretrainedConfig | ||
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM, LlamaModel | ||
|
||
from slicegpt.adapters.llama_adapter import CompressedLlamaDecoderLayer, LlamaModelAdapter | ||
from slicegpt.modules import RMSN | ||
from slicegpt.rotate import slice_rotated_model | ||
from slicegpt.slicing_scheduler import SlicingScheduler | ||
|
||
|
||
class SlicedLlamaConfig(LlamaConfig): | ||
model_type = "sliced_llama" | ||
is_composition = True | ||
|
||
def __init__(self, **kwargs) -> None: | ||
self.sparsity = kwargs.pop("sparsity", None) | ||
self.new_hidden_size = kwargs.pop("new_hidden_size", None) | ||
super().__init__(**kwargs) | ||
|
||
@classmethod | ||
def from_pretrained(cls, config_path: str, sparsity: float, new_hidden_size: int) -> PretrainedConfig: | ||
kwargs = {"sparsity": sparsity, "new_hidden_size": new_hidden_size} | ||
return super().from_pretrained(config_path, **kwargs) | ||
|
||
|
||
class SlicedLlama(LlamaModel): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
self.config = config | ||
self.layers = nn.ModuleList( | ||
[ | ||
CompressedLlamaDecoderLayer(config, layer_idx, replace_layernorm=True) | ||
for layer_idx in range(config.num_hidden_layers) | ||
] | ||
) | ||
self.final_layernorm = RMSN(config.hidden_size) | ||
|
||
|
||
class SlicedLlamaForCausalLM(LlamaForCausalLM): | ||
def __init__( | ||
self, | ||
config, | ||
scheduler: SlicingScheduler | None = None, | ||
*model_args, | ||
**kwargs, | ||
): | ||
super().__init__(config) | ||
self.model = SlicedLlama(config) | ||
self.model_adapter = LlamaModelAdapter(self) | ||
|
||
if scheduler: | ||
self.update_dims(scheduler) | ||
|
||
@classmethod | ||
def from_pretrained( | ||
cls, | ||
pretrained_model_name_or_path, | ||
scheduler: SlicingScheduler | None, | ||
sparsity: float, | ||
new_hidden_size: int, | ||
config_path: str, | ||
*model_args, | ||
**kwargs, | ||
): | ||
"""Overrides the from_pretrained method to accept the scheduler and returns the sliced model""" | ||
config = SlicedLlamaConfig.from_pretrained(config_path, sparsity, new_hidden_size) | ||
kwargs = {"scheduler": scheduler} | ||
model = super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs) | ||
model.load_state_dict(model.state_dict()) | ||
return model | ||
|
||
def update_dims(self, scheduler: SlicingScheduler) -> None: | ||
layers = self.model_adapter.get_layers() | ||
|
||
hidden_size = self.model_adapter.hidden_size | ||
for layer_adapter in layers: | ||
if not self.model_adapter.parallel_blocks: | ||
layer_adapter.layer.mlp_shortcut_Q = torch.nn.Parameter( | ||
torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous() | ||
) | ||
layer_adapter.layer.attn_shortcut_Q = torch.nn.Parameter( | ||
torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous() | ||
) | ||
|
||
slice_rotated_model(self.model_adapter, scheduler) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torch | ||
import torch.nn as nn | ||
from transformers.configuration_utils import PretrainedConfig | ||
from transformers.models.phi.modeling_phi import PhiConfig, PhiForCausalLM, PhiModel | ||
|
||
from slicegpt.adapters.phi2_adapter import CompressedPhiDecoderLayer, Phi2ModelAdapter | ||
from slicegpt.modules import RMSN | ||
from slicegpt.rotate import slice_rotated_model | ||
from slicegpt.slicing_scheduler import SlicingScheduler | ||
|
||
|
||
class SlicedPhi2Config(PhiConfig): | ||
model_type = "sliced_phi2" | ||
is_composition = True | ||
|
||
def __init__(self, **kwargs) -> None: | ||
self.sparsity = kwargs.pop("sparsity", None) | ||
self.new_hidden_size = kwargs.pop("new_hidden_size", None) | ||
super().__init__(**kwargs) | ||
|
||
@classmethod | ||
def from_pretrained(cls, config_path: str, sparsity: float, new_hidden_size: int) -> PretrainedConfig: | ||
kwargs = {"sparsity": sparsity, "new_hidden_size": new_hidden_size} | ||
return super().from_pretrained(config_path, local_files_only=True, **kwargs) | ||
|
||
|
||
class SlicedPhi(PhiModel): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
self.config = config | ||
self.layers = nn.ModuleList( | ||
[ | ||
CompressedPhiDecoderLayer(config, layer_idx, replace_layernorm=True) | ||
for layer_idx in range(config.num_hidden_layers) | ||
] | ||
) | ||
self.final_layernorm = RMSN(config.hidden_size) | ||
|
||
|
||
class SlicedPhiForCausalLM(PhiForCausalLM): | ||
def __init__( | ||
self, | ||
config, | ||
scheduler: SlicingScheduler | None = None, | ||
*model_args, | ||
**kwargs, | ||
): | ||
super().__init__(config) | ||
self.model = SlicedPhi(config) | ||
self.model_adapter = Phi2ModelAdapter(self) | ||
|
||
if scheduler: | ||
self.update_dims(scheduler) | ||
|
||
@classmethod | ||
def from_pretrained( | ||
cls, | ||
pretrained_model_name_or_path, | ||
scheduler: SlicingScheduler | None, | ||
sparsity: float, | ||
new_hidden_size: int, | ||
config_path: str, | ||
*model_args, | ||
**kwargs, | ||
): | ||
"""Overrides the from_pretrained method to accept the scheduler and returns the sliced model""" | ||
config = SlicedPhi2Config.from_pretrained(config_path, sparsity, new_hidden_size) | ||
kwargs = {"scheduler": scheduler} | ||
model = super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs) | ||
model.load_state_dict(model.state_dict()) | ||
return model | ||
|
||
def update_dims(self, scheduler: SlicingScheduler) -> None: | ||
layers = self.model_adapter.get_layers() | ||
|
||
hidden_size = self.model_adapter.hidden_size | ||
for layer_adapter in layers: | ||
if not self.model_adapter.parallel_blocks: | ||
layer_adapter.layer.mlp_shortcut_Q = torch.nn.Parameter( | ||
torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous() | ||
) | ||
layer_adapter.layer.attn_shortcut_Q = torch.nn.Parameter( | ||
torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous() | ||
) | ||
|
||
slice_rotated_model(self.model_adapter, scheduler) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.