Skip to content

flux does not work on MPS devices #9047

Closed
@bghira

Description

@bghira

Describe the bug

    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Reproduction

import torch
from diffusers import  FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision='refs/pr/1')
#pipe.enable_model_cpu_offload()
pipe.to(device='mps')

prompt = "A cat holding a sign that says hello world"
out = pipe(
    prompt=prompt, 
    guidance_scale=0., 
    height=768, 
    width=1360, 
    num_inference_steps=4, 
    max_sequence_length=256,
).images[0]
out.save("image.png")

it also doesn't work with cpu offload.

Logs

scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

System Info

Git master

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions