-
Notifications
You must be signed in to change notification settings - Fork 24
Bump transformers and torch #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3a960bb
3d223a2
207f8b1
bc82841
35fc918
6a26464
4d68263
6a3e1d4
2b5fe7e
bb0089c
72802e3
9876c7e
19f4d21
99805f8
108ed17
59778eb
ff8a2a1
a3009ca
ae488b1
b7a2fa1
1e0a671
896f0da
abd641b
7f7f9c2
5f8a56f
92bc2ba
4abb2ec
ad9b639
b252038
e135310
671bc06
70338e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,12 +54,12 @@ def __init__( | |
|
||
# Create a list of CustomKVCache instances, one per layer | ||
self.kv_cache = torch.nn.ModuleList() | ||
for _ in range(config.num_hidden_layers): | ||
for layer in self.layers: | ||
layer_cache = CustomKVCache( | ||
max_batch_size=self.max_batch_size, | ||
max_context_length=self.max_cache_len, | ||
n_heads=self.num_key_value_heads, | ||
head_dim=self.head_dim, | ||
max_batch_size=layer.max_batch_size, | ||
max_context_length=layer.max_cache_len, | ||
n_heads=layer.num_heads, | ||
head_dim=layer.head_dim, | ||
dtype=dtype, | ||
) | ||
self.kv_cache.append(layer_cache) | ||
|
@@ -202,32 +202,29 @@ def __init__( | |
layer_device_map=layer_device_map, | ||
) | ||
|
||
# make sure layer_device_map is none | ||
assert layer_device_map is None | ||
assert device is None or device == "cpu", "Device must be None or 'cpu'" | ||
|
||
self.cache_position = None | ||
# Create a list of cache instances, one per layer | ||
# Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers | ||
# Create a list of cache instances, one per layer. | ||
# Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers. | ||
self.kv_cache = torch.nn.ModuleList() | ||
for layer_idx in range(config.num_hidden_layers): | ||
# newer version of transfomer has is_sliding defined | ||
# for HybridCache | ||
if self.is_sliding[layer_idx]: | ||
for layer in self.layers: | ||
if layer.is_sliding: | ||
# This is a sliding window layer | ||
layer_cache = CustomRingKVCache( | ||
max_batch_size=self.max_batch_size, | ||
max_context_length=self.sliding_window_len, | ||
n_heads=self.num_key_value_heads, | ||
head_dim=self.head_dim, | ||
max_batch_size=layer.max_batch_size, | ||
max_context_length=layer.max_cache_len, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait what is happening here? is this same as sliding_window_len There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah they removed https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py#L357 |
||
n_heads=layer.num_heads, | ||
head_dim=layer.head_dim, | ||
dtype=dtype, | ||
) | ||
else: | ||
layer_cache = CustomKVCache( | ||
max_batch_size=self.max_batch_size, | ||
max_context_length=self.max_cache_len, | ||
n_heads=self.num_key_value_heads, | ||
head_dim=self.head_dim, | ||
max_batch_size=layer.max_batch_size, | ||
max_context_length=layer.max_cache_len, | ||
n_heads=layer.num_heads, | ||
head_dim=layer.head_dim, | ||
dtype=dtype, | ||
) | ||
self.kv_cache.append(layer_cache) | ||
|
@@ -284,7 +281,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |
|
||
# For CustomRingKVCache, we need to handle the sequence length differently | ||
layer_cache = self.kv_cache[layer_idx] | ||
if self.is_sliding[layer_idx]: | ||
if self.layers[layer_idx].is_sliding: | ||
# CustomRingKVCache cache_position_manager which | ||
# maintains cache position for each slot in the kv cache | ||
# we return the max position + 1 to indicate max position | ||
|
@@ -308,7 +305,7 @@ def get_layer_cache(self, layer_idx: int): | |
|
||
def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): | ||
""" | ||
Replace all KV caches in the module with ETCustomStaticCache. | ||
Replace all KV caches in the module with ETCustomStaticCache or ETCustomHybridCache. | ||
This modifies the model in place. | ||
|
||
Args: | ||
|
@@ -342,18 +339,18 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt | |
if getattr(module, "replace_cache", None) is not None: | ||
static_cache = ETCustomStaticCache( | ||
config=config, | ||
max_batch_size=generation_config.cache_config.batch_size, | ||
max_cache_len=generation_config.cache_config.max_cache_len, | ||
device=generation_config.cache_config.device, | ||
max_batch_size=generation_config.cache_config.get("batch_size"), | ||
max_cache_len=generation_config.cache_config.get("max_cache_len"), | ||
device=generation_config.cache_config.get("device"), | ||
dtype=cache_dtype, | ||
) | ||
module.replace_cache(static_cache) | ||
else: | ||
module.static_cache = ETCustomStaticCache( | ||
config=config, | ||
max_batch_size=generation_config.cache_config.batch_size, | ||
max_cache_len=generation_config.cache_config.max_cache_len, | ||
device=generation_config.cache_config.device, | ||
max_batch_size=generation_config.cache_config.get("batch_size"), | ||
max_cache_len=generation_config.cache_config.get("max_cache_len"), | ||
device=generation_config.cache_config.get("device"), | ||
dtype=cache_dtype, | ||
) | ||
# Dont know why we need to this even though | ||
|
@@ -370,25 +367,25 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt | |
if getattr(module, "replace_cache", None) is not None: | ||
hybrid_cache = ETCustomHybridCache( | ||
config=config, | ||
max_batch_size=generation_config.cache_config.batch_size, | ||
max_cache_len=generation_config.cache_config.max_cache_len, | ||
device=generation_config.cache_config.device, | ||
max_batch_size=generation_config.cache_config.get("batch_size"), | ||
max_cache_len=generation_config.cache_config.get("max_cache_len"), | ||
device=generation_config.cache_config.get("device"), | ||
dtype=cache_dtype, | ||
) | ||
module.replace_cache(hybrid_cache) | ||
else: | ||
module.cache = ETCustomHybridCache( | ||
config=config, | ||
max_batch_size=generation_config.cache_config.batch_size, | ||
max_cache_len=generation_config.cache_config.max_cache_len, | ||
device=generation_config.cache_config.device, | ||
max_batch_size=generation_config.cache_config.get("batch_size"), | ||
max_cache_len=generation_config.cache_config.get("max_cache_len"), | ||
device=generation_config.cache_config.get("device"), | ||
dtype=cache_dtype, | ||
) | ||
# Register cache attributes for each layer | ||
for i in range(len(module.cache.kv_cache)): | ||
setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache) | ||
setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache) | ||
if module.cache.is_sliding[i]: | ||
if module.cache.layers[i].is_sliding: | ||
# Register cache_positions as buffer for sliding window layers | ||
# This prevents it from being traced as a constant | ||
module.register_buffer( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
import logging | ||
import os | ||
import shutil | ||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
|
@@ -24,6 +25,7 @@ | |
import torch | ||
from huggingface_hub import hf_hub_download | ||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE | ||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa | ||
from transformers import ( | ||
AutoModelForCausalLM, | ||
AutoModelForImageClassification, | ||
|
@@ -102,6 +104,34 @@ def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedCon | |
|
||
self.stats = Stats() | ||
|
||
# Initialize cleanup tracking | ||
self._temp_dir = None | ||
|
||
def __del__(self): | ||
"""Clean up temporary files when the model instance is destroyed.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldnt this already happen automatically? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah probably, but added just to be extra sure that it's cleaned up between tests |
||
self._cleanup_temp_resources() | ||
|
||
def _cleanup_temp_resources(self): | ||
"""Clean up temporary directory and files.""" | ||
if hasattr(self, "_temp_dir") and self._temp_dir is not None: | ||
try: | ||
if hasattr(self._temp_dir, "cleanup"): | ||
# It's a TemporaryDirectory object | ||
logging.info(f"Cleaning up temporary directory: {self._temp_dir.name}") | ||
self._temp_dir.cleanup() | ||
logging.info("Temporary directory cleanup completed") | ||
elif isinstance(self._temp_dir, (str, Path)): | ||
# It's a path | ||
logging.info(f"Cleaning up temporary path: {self._temp_dir}") | ||
shutil.rmtree(self._temp_dir, ignore_errors=True) | ||
logging.info("Temporary path cleanup completed") | ||
except Exception as e: | ||
# Log cleanup errors for debugging | ||
logging.warning(f"Error during temp directory cleanup: {e}") | ||
pass | ||
finally: | ||
self._temp_dir = None | ||
|
||
@abstractmethod | ||
def forward(self, *args, **kwargs): | ||
""" | ||
|
@@ -242,7 +272,7 @@ def _export( | |
inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) | ||
logging.info(f"Inferred task from model class: {inferred_task}") | ||
|
||
save_dir = TemporaryDirectory() | ||
save_dir = TemporaryDirectory(prefix="executorch_export_") | ||
save_dir_path = Path(save_dir.name) | ||
|
||
# Export to ExecuTorch and save the pte file to the temporary directory | ||
|
@@ -266,7 +296,7 @@ def _export( | |
for name, _ in executorch_progs.items(): | ||
models.update(cls._from_pretrained(save_dir_path, file_name=f"{name}.pte", config=config)) | ||
|
||
return models | ||
return models, save_dir | ||
|
||
def _save_pretrained(self, save_directory): | ||
""" | ||
|
@@ -298,6 +328,7 @@ def from_pretrained( | |
logger.info("Offline mode: setting `local_files_only=True`") | ||
local_files_only = True | ||
|
||
# See if model was already exported to ExecuTorch and uplaoded to the HuggingFace repo. | ||
_export = export | ||
try: | ||
if local_files_only and not os.path.isdir(model_id): | ||
|
@@ -324,21 +355,21 @@ def from_pretrained( | |
if export: | ||
logger.warning( | ||
f"The model {model_id} was already converted to the ExecuTorch IR but got `export=True`, the model will be converted to ExecuTorch once again. " | ||
# "Don't forget to save the resulting model with `.save_pretrained()`" | ||
) | ||
_export = True | ||
else: | ||
logger.warning( | ||
f"No ExecuTorch files were found for {model_id}, setting `export=True` to convert the model to the ExecuTorch IR. " | ||
# "Don't forget to save the resulting model with `.save_pretrained()`" | ||
) | ||
except Exception as exception: | ||
logger.warning( | ||
f"Could not infer whether the model was already converted or not to the ExecuTorch IR, keeping `export={export}`.\n{exception}" | ||
) | ||
|
||
temp_dir = None | ||
if _export: | ||
models_dict = cls._export( | ||
logging.info(f"Exporting {model_id} to ExecuTorch program...") | ||
models_dict, temp_dir = cls._export( | ||
model_id=model_id, | ||
config=config, | ||
revision=revision, | ||
|
@@ -351,6 +382,9 @@ def from_pretrained( | |
**kwargs, | ||
) | ||
else: | ||
logging.info( | ||
f"Pre-exported `.pte` artifact already exists in HuggingFace repo or provided file path for {model_id}, skipping export." | ||
) | ||
models_dict = {} | ||
for pte_file in pte_files: | ||
models_dict.update( | ||
|
@@ -368,7 +402,14 @@ def from_pretrained( | |
) | ||
) | ||
|
||
return cls(models_dict, config) | ||
model_instance = cls(models_dict, config) | ||
|
||
# Store the TemporaryDirectory reference to prevent GC | ||
if temp_dir is not None: | ||
model_instance._temp_dir = temp_dir | ||
logging.info(f"Stored temp directory reference in model: {temp_dir.name}") | ||
|
||
return model_instance | ||
|
||
|
||
class ExecuTorchModelForSeq2SeqLM(ExecuTorchModelBase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happened here? like config doesnt exist anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It still exists, feel like it's more idiomatic to iterate over the actual layers