Description
Describe the bug
I tried running train_dreambooth_lora_flux.py
again with the merged source code, but I am still encountering an issue similar to #9237 during the log_validation
stage.
I have resolved this issue with the following modification:
autocast_ctx = nullcontext()
to
autocast_ctx = torch.autocast(accelerator.device.type, dtype=torch_dtype)
I am currently in the process of verifying that this fix correctly uploads the experiment to wandb before submitting a PR with the change.
If you have any suggestions for a better solution, I would greatly appreciate your feedback!
Reproduction
CUDA_VISIBLE_DEVICES=0 accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path="/FLUX.1-dev "\
--instance_data_dir="/dataset/dog "\
--output_dir="trained-flux-dog-0928" \
--mixed_precision=bf16 \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--lr_scheduler=constant \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--checkpointing_steps=50 \
--seed=0 \
--rank=32 \
--report_to="wandb" \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25
Logs
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
[WARNING] using untested triton version (3.0.0), only 1.0.0 is known to be compatible
stderr: Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
09/28/2024 14:20:08 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Mixed precision type: bf16
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type t5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards: 50%|█████ | 1/2 [00:03<00:03, 3.92s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00, 3.60s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00, 3.65s/it]
{'axes_dims_rope'} was not found in config. Values will be initialized to default values.
wandb: Currently logged in as: timdalee (timdalee-ai). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.17.8
wandb: Run data is saved locally in /diffusers/examples/dreambooth/wandb/run-20240928_142109-n4e0rrva
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run brisk-cherry-9
wandb: ⭐️ View project at https://wandb.ai/timdalee-ai/dreambooth-flux-dev-lora
wandb: 🚀 View run at https://wandb.ai/timdalee-ai/dreambooth-flux-dev-lora/runs/n4e0rrva
09/28/2024 14:21:13 - INFO - __main__ - ***** Running training *****
09/28/2024 14:21:13 - INFO - __main__ - Num examples = 5
09/28/2024 14:21:13 - INFO - __main__ - Num batches each epoch = 5
09/28/2024 14:21:13 - INFO - __main__ - Num Epochs = 250
09/28/2024 14:21:13 - INFO - __main__ - Instantaneous batch size per device = 1
09/28/2024 14:21:13 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 4
09/28/2024 14:21:13 - INFO - __main__ - Gradient Accumulation steps = 4
09/28/2024 14:21:13 - INFO - __main__ - Total optimization steps = 500
Steps: 0%| | 0/500 [00:00<?, ?it/s]Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
Steps: 0%| | 0/500 [00:01<?, ?it/s, loss=0.559, lr=0.0001]Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
Steps: 0%| | 0/500 [00:01<?, ?it/s, loss=0.574, lr=0.0001]Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
Steps: 0%| | 0/500 [00:02<?, ?it/s, loss=0.529, lr=0.0001]Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
Steps: 0%| | 1/500 [00:02<24:27, 2.94s/it, loss=0.529, lr=0.0001]
Steps: 0%| | 1/500 [00:02<24:27, 2.94s/it, loss=0.691, lr=0.0001]Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor
Steps: 0%| | 2/500 [00:03<12:46, 1.54s/it, loss=0.691, lr=0.0001]
Steps: 0%| | 2/500 [00:03<12:46, 1.54s/it, loss=0.762, lr=0.0001]
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards: 50%|█████ | 1/2 [00:04<00:04, 4.13s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00, 3.78s/it]
/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:49: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
def forward(ctx, input, weight, bias=None):
/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:67: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
def backward(ctx, grad_output):
Loaded scheduler as FlowMatchEulerDiscreteScheduler from `scheduler` subfolder of /FLUX.1-dev.]
Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of /FLUX.1-dev.
Loaded tokenizer_2 as T5TokenizerFast from `tokenizer_2` subfolder of /FLUX.1-dev.
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 22.50it/s]
09/28/2024 14:21:27 - INFO - __main__ - Running validation...
Generating 4 images with prompt: A photo of sks dog in a bucket.
Traceback (most recent call last):
File "/diffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 1893, in <module>
main(args)
File "/diffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 1813, in main
images = log_validation(
File "/diffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 191, in log_validation
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
File "/diffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 191, in <listcomp>
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 763, in __call__
image = self.vae.decode(latents, return_dict=False)[0]
File "/diffusers/src/diffusers/utils/accelerate_utils.py", line 46, in wrapper
return method(self, *args, **kwargs)
File "/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 326, in decode
decoded = self._decode(z).sample
File "/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 297, in _decode
dec = self.decoder(z)
File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/diffusers/src/diffusers/models/autoencoders/vae.py", line 291, in forward
sample = self.conv_in(sample)
File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 458, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same
Steps: 0%| | 2/500 [00:39<2:44:07, 19.77s/it, loss=0.755, lr=0.0001]
Traceback (most recent call last):
File "/root/miniconda3/envs/flux_diffusers/bin/accelerate", line 8, in <module>
sys.exit(main())
File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
args.func(args)
File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1106, in launch_command
simple_launcher(args)
File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/accelerate/commands/launch.py", line 704, in simple_launcher
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/root/miniconda3/envs/flux_diffusers/bin/python', 'train_dreambooth_lora_flux.py', '--pretrained_model_name_or_path=/FLUX.1-dev', '--instance_data_dir=/dataset/dog', '--output_dir=trained-flux-dog-0928', '--mixed_precision=bf16', '--instance_prompt=a photo of sks dog', '--resolution=512', '--train_batch_size=1', '--gradient_accumulation_steps=4', '--learning_rate=1e-4', '--lr_scheduler=constant', '--lr_warmup_steps=0', '--max_train_steps=500', '--checkpointing_steps=50', '--seed=0', '--rank=32', '--validation_prompt=A photo of sks dog in a bucket', '--validation_epochs=25']' returned non-zero exit status 1.
System Info
🤗 Diffusers version: 0.31.0.dev0
Platform: Linux-4.18.0-513.11.1.el8_9.x86_64-x86_64-with-glibc2.31
Running on Google Colab?: No
Python version: 3.10.14
PyTorch version (GPU?): 2.4.0+cu121 (True)
Flax version (CPU?/GPU?/TPU?): not installed (NA)
Jax version: not installed
JaxLib version: not installed
Huggingface_hub version: 0.24.6
Transformers version: 4.44.2
Accelerate version: 0.33.0
PEFT version: 0.12.0
Bitsandbytes version: not installed
Safetensors version: 0.4.4
xFormers version: not installed
Accelerator: NVIDIA A100 80GB PCIe, 81920 MiB
Using GPU in script?:
Using distributed or parallel set-up in script?: