Skip to content

Commit 86a4e2a

Browse files
authored
Merge branch 'main' into compile_utils
2 parents 83ba712 + 6760300 commit 86a4e2a

File tree

5 files changed

+92
-31
lines changed

5 files changed

+92
-31
lines changed

examples/dreambooth/test_dreambooth_lora_sana.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import logging
1718
import os
1819
import sys
1920
import tempfile
2021

2122
import safetensors
2223

24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
25+
2326

2427
sys.path.append("..")
2528
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -204,3 +207,42 @@ def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_mult
204207
run_command(self._launch_args + resume_run_args)
205208

206209
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
210+
211+
def test_dreambooth_lora_sana_with_metadata(self):
212+
lora_alpha = 8
213+
rank = 4
214+
with tempfile.TemporaryDirectory() as tmpdir:
215+
test_args = f"""
216+
{self.script_path}
217+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
218+
--instance_data_dir={self.instance_data_dir}
219+
--output_dir={tmpdir}
220+
--resolution=32
221+
--train_batch_size=1
222+
--gradient_accumulation_steps=1
223+
--max_train_steps=4
224+
--lora_alpha={lora_alpha}
225+
--rank={rank}
226+
--checkpointing_steps=2
227+
--max_sequence_length 166
228+
""".split()
229+
230+
test_args.extend(["--instance_prompt", ""])
231+
run_command(self._launch_args + test_args)
232+
233+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
234+
self.assertTrue(os.path.isfile(state_dict_file))
235+
236+
# Check if the metadata was properly serialized.
237+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
238+
metadata = f.metadata() or {}
239+
240+
metadata.pop("format", None)
241+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
242+
if raw:
243+
raw = json.loads(raw)
244+
245+
loaded_lora_alpha = raw["transformer.lora_alpha"]
246+
self.assertTrue(loaded_lora_alpha == lora_alpha)
247+
loaded_lora_rank = raw["transformer.r"]
248+
self.assertTrue(loaded_lora_rank == rank)

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
)
5353
from diffusers.optimization import get_scheduler
5454
from diffusers.training_utils import (
55+
_collate_lora_metadata,
5556
cast_training_params,
5657
compute_density_for_timestep_sampling,
5758
compute_loss_weighting_for_sd3,
@@ -323,9 +324,13 @@ def parse_args(input_args=None):
323324
default=4,
324325
help=("The dimension of the LoRA update matrices."),
325326
)
326-
327+
parser.add_argument(
328+
"--lora_alpha",
329+
type=int,
330+
default=4,
331+
help="LoRA alpha to be used for additional scaling.",
332+
)
327333
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
328-
329334
parser.add_argument(
330335
"--with_prior_preservation",
331336
default=False,
@@ -1023,7 +1028,7 @@ def main(args):
10231028
# now we will add new LoRA weights the transformer layers
10241029
transformer_lora_config = LoraConfig(
10251030
r=args.rank,
1026-
lora_alpha=args.rank,
1031+
lora_alpha=args.lora_alpha,
10271032
lora_dropout=args.lora_dropout,
10281033
init_lora_weights="gaussian",
10291034
target_modules=target_modules,
@@ -1039,10 +1044,11 @@ def unwrap_model(model):
10391044
def save_model_hook(models, weights, output_dir):
10401045
if accelerator.is_main_process:
10411046
transformer_lora_layers_to_save = None
1042-
1047+
modules_to_save = {}
10431048
for model in models:
10441049
if isinstance(model, type(unwrap_model(transformer))):
10451050
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1051+
modules_to_save["transformer"] = model
10461052
else:
10471053
raise ValueError(f"unexpected save model: {model.__class__}")
10481054

@@ -1052,6 +1058,7 @@ def save_model_hook(models, weights, output_dir):
10521058
SanaPipeline.save_lora_weights(
10531059
output_dir,
10541060
transformer_lora_layers=transformer_lora_layers_to_save,
1061+
**_collate_lora_metadata(modules_to_save),
10551062
)
10561063

10571064
def load_model_hook(models, input_dir):
@@ -1507,15 +1514,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15071514
accelerator.wait_for_everyone()
15081515
if accelerator.is_main_process:
15091516
transformer = unwrap_model(transformer)
1517+
modules_to_save = {}
15101518
if args.upcast_before_saving:
15111519
transformer.to(torch.float32)
15121520
else:
15131521
transformer = transformer.to(weight_dtype)
15141522
transformer_lora_layers = get_peft_model_state_dict(transformer)
1523+
modules_to_save["transformer"] = transformer
15151524

15161525
SanaPipeline.save_lora_weights(
15171526
save_directory=args.output_dir,
15181527
transformer_lora_layers=transformer_lora_layers,
1528+
**_collate_lora_metadata(modules_to_save),
15191529
)
15201530

15211531
# Final inference

src/diffusers/pipelines/wan/pipeline_wan_vace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def prepare_masks(
593593
num_ref_images = len(reference_images_batch)
594594
if num_ref_images > 0:
595595
mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :])
596-
mask_ = torch.cat([mask_, mask_padding], dim=1)
596+
mask_ = torch.cat([mask_padding, mask_], dim=1)
597597
mask_list.append(mask_)
598598
return torch.stack(mask_list)
599599

tests/models/test_modeling_common.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from typing import Dict, List, Optional, Tuple, Union
3131

3232
import numpy as np
33+
import pytest
3334
import requests_mock
3435
import safetensors.torch
3536
import torch
@@ -938,8 +939,9 @@ def recursive_check(tuple_object, dict_object):
938939

939940
@require_torch_accelerator_with_training
940941
def test_enable_disable_gradient_checkpointing(self):
942+
# Skip test if model does not support gradient checkpointing
941943
if not self.model_class._supports_gradient_checkpointing:
942-
return # Skip test if model does not support gradient checkpointing
944+
pytest.skip("Gradient checkpointing is not supported.")
943945

944946
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
945947

@@ -957,8 +959,9 @@ def test_enable_disable_gradient_checkpointing(self):
957959

958960
@require_torch_accelerator_with_training
959961
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
962+
# Skip test if model does not support gradient checkpointing
960963
if not self.model_class._supports_gradient_checkpointing:
961-
return # Skip test if model does not support gradient checkpointing
964+
pytest.skip("Gradient checkpointing is not supported.")
962965

963966
# enable deterministic behavior for gradient checkpointing
964967
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1015,8 +1018,9 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
10151018
def test_gradient_checkpointing_is_applied(
10161019
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
10171020
):
1021+
# Skip test if model does not support gradient checkpointing
10181022
if not self.model_class._supports_gradient_checkpointing:
1019-
return # Skip test if model does not support gradient checkpointing
1023+
pytest.skip("Gradient checkpointing is not supported.")
10201024

10211025
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10221026

@@ -1073,7 +1077,7 @@ def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
10731077
model = self.model_class(**init_dict).to(torch_device)
10741078

10751079
if not issubclass(model.__class__, PeftAdapterMixin):
1076-
return
1080+
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
10771081

10781082
torch.manual_seed(0)
10791083
output_no_lora = model(**inputs_dict, return_dict=False)[0]
@@ -1128,7 +1132,7 @@ def test_lora_wrong_adapter_name_raises_error(self):
11281132
model = self.model_class(**init_dict).to(torch_device)
11291133

11301134
if not issubclass(model.__class__, PeftAdapterMixin):
1131-
return
1135+
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
11321136

11331137
denoiser_lora_config = LoraConfig(
11341138
r=4,
@@ -1159,7 +1163,7 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_d
11591163
model = self.model_class(**init_dict).to(torch_device)
11601164

11611165
if not issubclass(model.__class__, PeftAdapterMixin):
1162-
return
1166+
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
11631167

11641168
denoiser_lora_config = LoraConfig(
11651169
r=rank,
@@ -1196,7 +1200,7 @@ def test_lora_adapter_wrong_metadata_raises_error(self):
11961200
model = self.model_class(**init_dict).to(torch_device)
11971201

11981202
if not issubclass(model.__class__, PeftAdapterMixin):
1199-
return
1203+
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
12001204

12011205
denoiser_lora_config = LoraConfig(
12021206
r=4,
@@ -1233,10 +1237,10 @@ def test_lora_adapter_wrong_metadata_raises_error(self):
12331237

12341238
@require_torch_accelerator
12351239
def test_cpu_offload(self):
1240+
if self.model_class._no_split_modules is None:
1241+
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
12361242
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
12371243
model = self.model_class(**config).eval()
1238-
if model._no_split_modules is None:
1239-
return
12401244

12411245
model = model.to(torch_device)
12421246

@@ -1263,10 +1267,10 @@ def test_cpu_offload(self):
12631267

12641268
@require_torch_accelerator
12651269
def test_disk_offload_without_safetensors(self):
1270+
if self.model_class._no_split_modules is None:
1271+
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
12661272
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
12671273
model = self.model_class(**config).eval()
1268-
if model._no_split_modules is None:
1269-
return
12701274

12711275
model = model.to(torch_device)
12721276

@@ -1296,10 +1300,10 @@ def test_disk_offload_without_safetensors(self):
12961300

12971301
@require_torch_accelerator
12981302
def test_disk_offload_with_safetensors(self):
1303+
if self.model_class._no_split_modules is None:
1304+
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
12991305
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
13001306
model = self.model_class(**config).eval()
1301-
if model._no_split_modules is None:
1302-
return
13031307

13041308
model = model.to(torch_device)
13051309

@@ -1324,10 +1328,10 @@ def test_disk_offload_with_safetensors(self):
13241328

13251329
@require_torch_multi_accelerator
13261330
def test_model_parallelism(self):
1331+
if self.model_class._no_split_modules is None:
1332+
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
13271333
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
13281334
model = self.model_class(**config).eval()
1329-
if model._no_split_modules is None:
1330-
return
13311335

13321336
model = model.to(torch_device)
13331337

@@ -1426,10 +1430,10 @@ def test_sharded_checkpoints_with_variant(self):
14261430

14271431
@require_torch_accelerator
14281432
def test_sharded_checkpoints_device_map(self):
1433+
if self.model_class._no_split_modules is None:
1434+
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
14291435
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
14301436
model = self.model_class(**config).eval()
1431-
if model._no_split_modules is None:
1432-
return
14331437
model = model.to(torch_device)
14341438

14351439
torch.manual_seed(0)
@@ -1497,7 +1501,7 @@ def test_variant_sharded_ckpt_right_format(self):
14971501
def test_layerwise_casting_training(self):
14981502
def test_fn(storage_dtype, compute_dtype):
14991503
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
1500-
return
1504+
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
15011505
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
15021506

15031507
model = self.model_class(**init_dict)
@@ -1617,6 +1621,9 @@ def get_memory_usage(storage_dtype, compute_dtype):
16171621
@parameterized.expand([False, True])
16181622
@require_torch_accelerator
16191623
def test_group_offloading(self, record_stream):
1624+
if not self.model_class._supports_group_offloading:
1625+
pytest.skip("Model does not support group offloading.")
1626+
16201627
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
16211628
torch.manual_seed(0)
16221629

@@ -1633,8 +1640,6 @@ def run_forward(model):
16331640
return model(**inputs_dict)[0]
16341641

16351642
model = self.model_class(**init_dict)
1636-
if not getattr(model, "_supports_group_offloading", True):
1637-
return
16381643

16391644
model.to(torch_device)
16401645
output_without_group_offloading = run_forward(model)
@@ -1670,13 +1675,13 @@ def run_forward(model):
16701675
@require_torch_accelerator
16711676
@torch.no_grad()
16721677
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
1678+
if not self.model_class._supports_group_offloading:
1679+
pytest.skip("Model does not support group offloading.")
1680+
16731681
torch.manual_seed(0)
16741682
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
16751683
model = self.model_class(**init_dict)
16761684

1677-
if not getattr(model, "_supports_group_offloading", True):
1678-
return
1679-
16801685
model.to(torch_device)
16811686
model.eval()
16821687
_ = model(**inputs_dict)[0]
@@ -1698,13 +1703,13 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty
16981703
@require_torch_accelerator
16991704
@torch.no_grad()
17001705
def test_group_offloading_with_disk(self, record_stream, offload_type):
1706+
if not self.model_class._supports_group_offloading:
1707+
pytest.skip("Model does not support group offloading.")
1708+
17011709
torch.manual_seed(0)
17021710
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17031711
model = self.model_class(**init_dict)
17041712

1705-
if not getattr(model, "_supports_group_offloading", True):
1706-
return
1707-
17081713
torch.manual_seed(0)
17091714
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17101715
model = self.model_class(**init_dict)

tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_upscale.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
import onnxruntime as ort
4343

4444

45+
# TODO: (Dhruv) Update hub_checkpoint repo_id
46+
@unittest.skip(
47+
"There is a potential backdoor vulnerability in the hub_checkpoint. Skip running this test until resolved"
48+
)
4549
class OnnxStableDiffusionUpscalePipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase):
4650
# TODO: is there an appropriate internal test set?
4751
hub_checkpoint = "ssube/stable-diffusion-x4-upscaler-onnx"

0 commit comments

Comments
 (0)