-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[Core] better support offloading when side loading is enabled. #4855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
c810d48
c14fc20
46b0874
6c842c7
2a27542
b3fb9a7
773ff91
cd2d963
b8b5422
3d06c51
340887e
7bcf71d
6b88f4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ | |
|
||
if is_accelerate_available(): | ||
from accelerate import init_empty_weights | ||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module | ||
from accelerate.utils import set_module_tensor_to_device | ||
|
||
logger = logging.get_logger(__name__) | ||
|
@@ -763,6 +764,19 @@ def load_textual_inversion( | |
f" `{self.load_textual_inversion.__name__}`" | ||
) | ||
|
||
# Remove any existing hooks. | ||
is_model_cpu_offload = False | ||
is_sequential_cpu_offload = False | ||
for _, component in self.components.items(): | ||
if isinstance(component, nn.Module): | ||
if hasattr(component, "_hf_hook"): | ||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) | ||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||
logger.info( | ||
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." | ||
) | ||
remove_hook_from_module(component) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't one of these two hooks styles hook into every sub module as well, so shouldn't one of the checks be recursive? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @muellerzr to help here a bit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @muellerzr to help here a bit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should be able to do CC @SunMarc too for a second glance :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But is it required here? Sorry for not making my comment clear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess trying to understand just what we're aiming to achieve (solid guess based on context, let me know if I'm accurate):
Is this accurate? Otherwise may need a bit more info/context I'm missing somehow There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to be able to detect if a Let me know if that helps? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From what I understood from the codebase, if we have |
||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) | ||
force_download = kwargs.pop("force_download", False) | ||
resume_download = kwargs.pop("resume_download", False) | ||
|
@@ -916,6 +930,12 @@ def load_textual_inversion( | |
for token_id, embedding in token_ids_and_embeddings: | ||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding | ||
|
||
# offload back | ||
if is_model_cpu_offload: | ||
self.enable_model_cpu_offload() | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif is_sequential_cpu_offload: | ||
self.enable_sequential_cpu_offload() | ||
|
||
|
||
class LoraLoaderMixin: | ||
r""" | ||
|
@@ -946,6 +966,19 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
kwargs (`dict`, *optional*): | ||
See [`~loaders.LoraLoaderMixin.lora_state_dict`]. | ||
""" | ||
# Remove any existing hooks. | ||
is_model_cpu_offload = False | ||
is_sequential_cpu_offload = False | ||
for _, component in self.components.items(): | ||
if isinstance(component, nn.Module): | ||
if hasattr(component, "_hf_hook"): | ||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) | ||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) | ||
logger.info( | ||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." | ||
) | ||
remove_hook_from_module(component) | ||
|
||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | ||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) | ||
self.load_lora_into_text_encoder( | ||
|
@@ -955,6 +988,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
lora_scale=self.lora_scale, | ||
) | ||
|
||
# Offload back. | ||
if is_model_cpu_offload: | ||
self.enable_model_cpu_offload() | ||
elif is_sequential_cpu_offload: | ||
self.enable_sequential_cpu_offload() | ||
|
||
@classmethod | ||
def lora_state_dict( | ||
cls, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1207,6 +1207,23 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
# We could have accessed the unet config from `lora_state_dict()` too. We pass | ||
# it here explicitly to be able to tell that it's coming from an SDXL | ||
# pipeline. | ||
|
||
# Remove any existing hooks. | ||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): | ||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module | ||
else: | ||
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") | ||
is_model_cpu_offload = False | ||
is_sequential_cpu_offload = False | ||
for _, component in self.components.items(): | ||
if isinstance(component, torch.nn.Module): | ||
if hasattr(component, "_hf_hook"): | ||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) | ||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) | ||
logger.info( | ||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." | ||
) | ||
Comment on lines
+1230
to
+1232
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the log statement might be a bit noisy. It'd be nice if we expected the user to do additional things with the placed accelerate hooks and should be aware if they expected some state to be maintained or something but we definitely don't want the user to touch the hooks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's relatively simple given the context the message is being raised from. If you have a better suggestion, let me know. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I think my main point is the log is a bit noisy given that it leaks what is supposed to be an internal implementation detail, I think it's not really something that should be exposed to an end user |
||
remove_hook_from_module(component) | ||
state_dict, network_alphas = self.lora_state_dict( | ||
pretrained_model_name_or_path_or_dict, | ||
unet_config=self.unet.config, | ||
|
@@ -1234,6 +1251,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
lora_scale=self.lora_scale, | ||
) | ||
|
||
# Offload back. | ||
if is_model_cpu_offload: | ||
self.enable_model_cpu_offload() | ||
elif is_sequential_cpu_offload: | ||
self.enable_sequential_cpu_offload() | ||
|
||
@classmethod | ||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights | ||
def save_lora_weights( | ||
|
Uh oh!
There was an error while loading. Please reload this page.