Skip to content

Commit

Permalink
consistently use safe_load
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Dec 18, 2024
1 parent fb34e91 commit ec0fdf1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
13 changes: 4 additions & 9 deletions hyvideo/modules/fp8_optimization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os

import torch
import torch.nn as nn
from torch.nn import functional as F
Expand Down Expand Up @@ -85,22 +84,18 @@ def convert_fp8_linear(module, original_dtype):
script_directory = os.path.dirname(os.path.abspath(__file__))

# loading fp8 mapping file
#fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
fp8_map_path = os.path.join(script_directory,"fp8_map.safetensors")
if os.path.exists(fp8_map_path):
#fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
fp8_map = load_torch_file(fp8_map_path)
fp8_map = load_torch_file(fp8_map_path, safe_load=True)
else:
raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")

fp8_layers = []
#fp8_layers = []
for key, layer in module.named_modules():
if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
fp8_layers.append(key)
#fp8_layers.append(key)
original_forward = layer.forward
layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
#layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
setattr(layer, "original_forward", original_forward)
setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))


6 changes: 3 additions & 3 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def loadmodel(self, model, base_precision, load_device, quantization,
base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[base_precision]

model_path = folder_paths.get_full_path_or_raise("diffusion_models", model)
sd = load_torch_file(model_path, device=transformer_load_device)
sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True)

in_channels = out_channels = 16
factor_kwargs = {"device": transformer_load_device, "dtype": base_dtype}
Expand Down Expand Up @@ -503,7 +503,7 @@ def loadmodel(self, model_name, precision, compile_args=None):
with open(os.path.join(script_directory, 'configs', 'hy_vae_config.json')) as f:
vae_config = json.load(f)
model_path = folder_paths.get_full_path("vae", model_name)
vae_sd = load_torch_file(model_path)
vae_sd = load_torch_file(model_path, safe_load=True)

vae = AutoencoderKLCausal3D.from_config(vae_config)
vae.load_state_dict(vae_sd)
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def INPUT_TYPES(s):

def load(self, embeds):
embed_path = folder_paths.get_full_path_or_raise("hyvid_embeds", embeds)
loaded_tensors = load_torch_file(embed_path)
loaded_tensors = load_torch_file(embed_path, safe_load=True)
# Reconstruct original dictionary with None for missing keys
prompt_embeds_dict = {
"prompt_embeds": loaded_tensors.get("prompt_embeds", None),
Expand Down

0 comments on commit ec0fdf1

Please sign in to comment.