Skip to content

Commit

Permalink
fix: for spatial unet loading
Browse files Browse the repository at this point in the history
  • Loading branch information
johnmullan committed Oct 4, 2023
1 parent e5e5abe commit bf0e057
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions hotshot_xl/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,11 +952,20 @@ def from_pretrained_spatial(cls, pretrained_model_path, subfolder=None):

config["mid_block_type"] = "UNetMidBlock3DCrossAttn"

from diffusers.utils import WEIGHTS_NAME
model = cls.from_config(config)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)

if not os.path.isfile(model_file):
model_files = [
os.path.join(pretrained_model_path, 'diffusion_pytorch_model.bin'),
os.path.join(pretrained_model_path, 'diffusion_pytorch_model.safetensors')
]

model_file = None

for fp in model_files:
if os.path.exists(fp):
model_file = fp

if not model_file:
raise RuntimeError(f"{model_file} does not exist")

state_dict = torch.load(model_file, map_location="cpu")
Expand Down

0 comments on commit bf0e057

Please sign in to comment.