@@ -41,6 +41,7 @@ def __init__(
4141 self .offload_mode = None
4242 self .model_names = []
4343 self ._offload_param_dict = {}
44+ self .offload_to_disk = False
4445
4546 @classmethod
4647 def from_pretrained (cls , model_path_or_config : str | BaseConfig ) -> "BasePipeline" :
@@ -228,19 +229,23 @@ def eval(self):
228229 model .eval ()
229230 return self
230231
231- def enable_cpu_offload (self , offload_mode : str ):
232- valid_offload_mode = ("cpu_offload" , "sequential_cpu_offload" )
232+ def enable_cpu_offload (self , offload_mode : str | None , offload_to_disk : bool = False ):
233+ valid_offload_mode = ("cpu_offload" , "sequential_cpu_offload" , "disable" , None )
233234 if offload_mode not in valid_offload_mode :
234235 raise ValueError (f"offload_mode must be one of { valid_offload_mode } , but got { offload_mode } " )
235236 if self .device == "cpu" or self .device == "mps" :
236237 logger .warning ("must set an non cpu device for pipeline before calling enable_cpu_offload" )
237238 return
238- if offload_mode == "cpu_offload" :
239+ if offload_mode is None or offload_mode == "disable" :
240+ self ._disable_offload ()
241+ elif offload_mode == "cpu_offload" :
239242 self ._enable_model_cpu_offload ()
240243 elif offload_mode == "sequential_cpu_offload" :
241244 self ._enable_sequential_cpu_offload ()
245+ self .offload_to_disk = offload_to_disk
242246
243- def _enable_model_cpu_offload (self ):
247+
248+ def _enable_model_cpu_offload (self ):
244249 for model_name in self .model_names :
245250 model = getattr (self , model_name )
246251 if model is not None :
@@ -253,6 +258,15 @@ def _enable_sequential_cpu_offload(self):
253258 if model is not None :
254259 enable_sequential_cpu_offload (model , self .device )
255260 self .offload_mode = "sequential_cpu_offload"
261+
262+ def _disable_offload (self ):
263+ self .offload_mode = None
264+ self ._offload_param_dict = {}
265+ for model_name in self .model_names :
266+ model = getattr (self , model_name )
267+ if model is not None :
268+ model .to (self .device )
269+
256270
257271 def enable_fp8_autocast (
258272 self , model_names : List [str ], compute_dtype : torch .dtype = torch .bfloat16 , use_fp8_linear : bool = False
@@ -282,10 +296,26 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
282296 # load the needed models to device
283297 for model_name in load_model_names :
284298 model = getattr (self , model_name )
299+ if model is None :
300+ raise ValueError (f"model { model_name } is not loaded, maybe this model has been destroyed by model_lifecycle_finish function with offload_to_disk=True" )
285301 if model is not None and (p := next (model .parameters (), None )) is not None and p .device .type != self .device :
286302 model .to (self .device )
287303 # fresh the cuda cache
288304 empty_cache ()
289305
306+ def model_lifecycle_finish (self , model_names : List [str ] | None = None ):
307+ if not self .offload_to_disk or self .offload_mode is None :
308+ return
309+ for model_name in model_names :
310+ model = getattr (self , model_name )
311+ del model
312+ if model_name in self ._offload_param_dict :
313+ del self ._offload_param_dict [model_name ]
314+ setattr (self , model_name , None )
315+ print (f"model { model_name } has been deleted from memory" )
316+ logger .info (f"model { model_name } has been deleted from memory" )
317+ empty_cache ()
318+
319+
290320 def compile (self ):
291321 raise NotImplementedError (f"{ self .__class__ .__name__ } does not support compile" )
0 commit comments