Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
MMqd committed Jul 23, 2023
1 parent 9e8cf35 commit 39586c3
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 48 deletions.
22 changes: 12 additions & 10 deletions scripts/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,41 +110,43 @@ class AbstractModel():

def __init__(self, cache_dir="", version="0"):
self.stages = [1]
self.cache_dir = os.path.join(os.path.join(script_path, 'models'), cache_dir)
self.models_path = os.path.join(script_path, 'models')
self.cache_dir = os.path.join(self.models_path, cache_dir)
self.version = version
self.sd_checkpoint_info = KandinskyCheckpointInfo(version=self.version)
self.sd_model_hash = self.sd_checkpoint_info.shorthash

def load_pipeline(self, pipe_name: str, pipeline: DiffusionPipeline, pretrained_model_name_or_path, move_to_cuda = True, kwargs = {}):
def load_pipeline(self, pipe_name: str, pipeline: DiffusionPipeline, pretrained_model_name_or_path, move_to_cuda = True, kwargs = {}, enable_sequential_cpu_offload = True):
pipe = getattr(self, pipe_name, None)

if not isinstance(pipe, pipeline) or pipe is None:
if pipe is not None:
pipe = None
gc.collect()
devices.torch_gc()
kwargs.update({
new_kwargs = {
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"variant": "fp16",
"torch_dtype": torch.float16,
"cache_dir": self.cache_dir,
"resume_download": True,
#"local_files_only": True,
"low_cpu_mem_usage": True
})
}
new_kwargs.update(kwargs)
kwargs = new_kwargs

pipe = pipeline.from_pretrained(**kwargs)#, scheduler=dpm)
gc.collect()
devices.torch_gc()

if move_to_cuda:
pipe.to("cuda")
else:
elif enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
#pipe.enable_sequential_cpu_offload()
pipe.enable_attention_slicing(self.attention_type)
#pipe.unet.to(memory_format=torch.channels_last)
setattr(self, pipe_name, pipe)
elif move_to_cuda:
pipe.to("cuda")
else:
elif enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()

return pipe
Expand Down
93 changes: 67 additions & 26 deletions scripts/kandinsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

class KandinskyModel(AbstractModel):
def __init__(self, cache_dir="", version="2.1"):
AbstractModel.__init__(self, cache_dir="Kandinsky", version=version)
AbstractModel.__init__(self, cache_dir="kandinsky22", version=version)
self.image_encoder = None
self.pipe_prior = None
self.pipe = None
Expand Down Expand Up @@ -60,61 +60,98 @@ def mix_images(self, p, generation_parameters, b, result_images):
self.pipe.to("cpu")
return result_images



def load_encoder(self):
if self.version == "2.1":
if self.pipe_prior is None:
self.pipe_prior = self.load_pipeline("pipe_prior", KandinskyPriorPipeline, f"kandinsky-community/kandinsky-{self.version}-prior".replace(".", "-"))
elif self.version == "2.2":
if self.image_encoder is None:
if self.low_vram:
encoder_torch_type = torch.float32
else:
encoder_torch_type = torch.float16

# self.image_encoder = self.load_pipeline("image_encoder", CLIPVisionModelWithProjection, "kandinsky-community/kandinsky-2-2-prior",
# move_to_cuda=False, kwargs={"subfolder": 'image_encoder', "torch_dtype": encoder_torch_type}, enable_sequential_cpu_offload=False)

self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
'kandinsky-community/kandinsky-2-2-prior',
subfolder='image_encoder',
cache_dir=os.path.join(self.models_path, "kandinsky22"),
low_cpu_mem_usage=True
# local_files_only=True
torch_dtype=encoder_torch_type,
low_cpu_mem_usage=True,
resume_download=True,
# local_files_only=True
)

self.image_encoder.to("cpu" if self.low_vram else "cuda")

if self.low_vram:
self.image_encoder.to("cpu")
else:
self.image_encoder = self.image_encoder.half().to("cuda")

# self.pipe_prior = self.load_pipeline("pipe_prior", KandinskyV22PriorPipeline, "kandinsky-community/kandinsky-2-2-prior",
# move_to_cuda=False, kwargs={"image_encoder": self.image_encoder, "torch_dtype": encoder_torch_type}, enable_sequential_cpu_offload=False)

self.pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
'kandinsky-community/kandinsky-2-2-prior',
image_encoder=self.image_encoder,
torch_dtype=torch.float32,
torch_dtype=encoder_torch_type,
cache_dir=os.path.join(self.models_path, "kandinsky22"),
low_cpu_mem_usage=True
# local_files_only=True
low_cpu_mem_usage=True,
resume_download=True,
# local_files_only=True
)

self.image_encoder.to("cpu" if self.low_vram else "cuda")
if self.low_vram:
self.pipe_prior.to("cpu")
else:
self.pipe_prior.to("cuda")

# self.unet = self.load_pipeline("unet", UNet2DConditionModel, "kandinsky-community/kandinsky-2-2-decoder",
# move_to_cuda=False, kwargs={"subfolder": 'unet'}, enable_sequential_cpu_offload=False).half().to("cuda")

# self.pipe = self.load_pipeline("pipe", KandinskyV22Pipeline, "kandinsky-community/kandinsky-2-2-decoder",
# move_to_cuda=False, kwargs={"unet": self.unet}, enable_sequential_cpu_offload=False).to("cuda")

self.unet = UNet2DConditionModel.from_pretrained(
'kandinsky-community/kandinsky-2-2-decoder',
subfolder='unet',
cache_dir=os.path.join(self.models_path, "kandinsky22"),
torch_dtype=torch.float16,
low_cpu_mem_usage=True
# local_files_only=True
low_cpu_mem_usage=True,
resume_download=True,
# local_files_only=True
).half().to("cuda")

self.pipe = KandinskyV22Pipeline.from_pretrained(
'kandinsky-community/kandinsky-2-2-decoder',
unet=self.unet,
torch_dtype=torch.float16,
cache_dir=os.path.join(self.models_path, "kandinsky22"),
low_cpu_mem_usage=True
# local_files_only=True
low_cpu_mem_usage=True,
resume_download=True,
# local_files_only=True
).to("cuda")

def run_encoder(self, prior_settings_dict):
self.main_model_to_cpu()
return self.pipe_prior(**prior_settings_dict).to_tuple()

def encoder_to_cpu(self):
pass
#self.image_encoder.to("cpu")
#self.pipe_prior.to("cpu")
#self.pipe.to("cuda")
#self.unet.to("cuda")
if self.low_vram:
if self.pipe is not None:
self.pipe.to("cpu")

if self.unet is not None:
self.unet.to("cpu")

if self.image_encoder is not None:
self.image_encoder.to("cuda")

if self.pipe_prior is not None:
self.pipe_prior.to("cuda")

def unload(self):
if self.image_encoder is not None:
Expand All @@ -137,11 +174,18 @@ def unload(self):
torch.cuda.empty_cache()

def main_model_to_cpu(self):
pass
#self.pipe.to("cpu")
#self.unet.to("cpu")
#self.image_encoder.to("cuda")
#self.pipe_prior.to("cuda")
if self.low_vram:
if self.pipe is not None:
self.pipe.to("cuda")

if self.unet is not None:
self.unet.to("cuda")

if self.image_encoder is not None:
self.image_encoder.to("cpu")

if self.pipe_prior is not None:
self.pipe_prior.to("cpu")

def sd_processing_to_dict_encoder(self, p: StableDiffusionProcessing):
torch.manual_seed(0)
Expand Down Expand Up @@ -174,9 +218,6 @@ def cleanup_on_error(self):
def txt2img(self, p, generation_parameters, b):
if self.version == "2.1":
self.pipe = self.load_pipeline("pipe", KandinskyPipeline, f"kandinsky-community/kandinsky-{self.version}".replace(".", "-"), move_to_cuda=move_to_cuda)
#else:
# self.unet.to("cuda")
# self.pipe.to("cuda")

result_images = self.pipe(**generation_parameters, num_images_per_prompt=p.batch_size).images
return self.mix_images(p, generation_parameters, b, result_images)
Expand Down
13 changes: 1 addition & 12 deletions scripts/kandinsky_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,7 @@ def reload_model():

def unload_kandinsky_model():
if getattr(shared, "kandinsky_model", None) is not None:
getattr(shared, "kandinsky_model", None).unload()
#if getattr(shared.kandinsky_model, "pipe_prior", None) is not None:
# del shared.kandinsky_model.pipe_prior
# devices.torch_gc()
# gc.collect()
# torch.cuda.empty_cache()

#if getattr(shared.kandinsky_model, "pipe", None) is not None:
# del shared.kandinsky_model.pipe
# devices.torch_gc()
# gc.collect()
# torch.cuda.empty_cache()
shared.kandinsky_model.unload()

del shared.kandinsky_model
print("Unloaded Kandinsky model")
Expand Down

0 comments on commit 39586c3

Please sign in to comment.