Skip to content

Commit

Permalink
Quick fix
Browse files Browse the repository at this point in the history
  • Loading branch information
MMqd committed Jul 23, 2023
1 parent 32abc5e commit f8d45cf
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions scripts/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, cache_dir=""):
self.stages = [1]
self.cache_dir = os.path.join(os.path.join(script_path, 'models'), cache_dir)

def load_pipeline(self, pipe_name: str, pipeline, pretrained_model_name_or_path, move_to_cuda = True, kwargs = {}):
def load_pipeline(self, pipe_name: str, pipeline: DiffusionPipeline, pretrained_model_name_or_path, move_to_cuda = True, kwargs = {}):
pipe = getattr(self, pipe_name, None)

if not isinstance(pipe, pipeline) or pipe is None:
Expand All @@ -126,10 +126,7 @@ def load_pipeline(self, pipe_name: str, pipeline, pretrained_model_name_or_path,
"torch_dtype": torch.float16,
"cache_dir": self.cache_dir
})
if callable(pipeline):
pipeline(**kwargs)
else:
pipe = pipeline.from_pretrained(**kwargs)#, scheduler=dpm)
pipe = pipeline.from_pretrained(**kwargs)#, scheduler=dpm)
if move_to_cuda:
pipe.to("cuda")
else:
Expand Down

0 comments on commit f8d45cf

Please sign in to comment.