-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat(LoRA): support AI Toolkit LoRA for FLUX #8071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
43e4b2e
to
0ca7a05
Compare
software = json.loads(metadata.get("software", "{}")) | ||
except json.JSONDecodeError: | ||
return False | ||
return software.get("name") == "ai-toolkit" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this might not be what we need, as I have another LoRA here that reports the same metadata 'software': '{"name": "ai-toolkit", "repo": "https://github.com/ostris/ai-toolkit", "version": "0.1.0"}',
but it uses transformer.single_transformer_blocks
keys instead of diffusion_model.single_blocks
keys.
to avoid disrupting already-working LoRA
This could use more inputs to test with; I've only tested this on a couple files from a single author. If anyone has a LoRA collection you can use something like this script to find which ones might be applicable: #!/usr/bin/env python
# coding: utf-8
import os
from dataclasses import dataclass
from pathlib import Path
import safetensors
INVOKEAI_ROOT = Path(os.environ.get("INVOKEAI_ROOT", "/PATH/TO/InvokeAI"))
lora_dirname = INVOKEAI_ROOT / "models" / "flux" / "lora"
lora_filenames = list(lora_dirname.glob("*.safetensors"))
@dataclass
class LoraField:
name: str
ai_toolkit_metadata: bool = False
diffusion_model_key: bool = False
def inspect_lora(lora_path):
with safetensors.safe_open(lora_path, "torch") as f:
return LoraField(
name=lora_path.stem,
ai_toolkit_metadata="ai-toolkit" in (f.metadata() or {}).get("software", ""),
diffusion_model_key=any(k.startswith("diffusion_model.") for k in f.keys())
)
records = [inspect_lora(lora_filename) for lora_filename in lora_filenames]
matching = [r for r in records if r.ai_toolkit_metadata or r.diffusion_model_key]
print("name\tmetadata\tdiffusion_model_key")
for record in matching:
print(f"{record.name}\t{record.ai_toolkit_metadata}\t{record.diffusion_model_key}") |
@jazzhaiku please review. |
Summary
https://github.com/ostris/ai-toolkit is "an all in one training suite for diffusion models."
and I guess LoRA saved by it don't look exactly like the ones we support already.
Related Issues / Discussions
QA Instructions
Merge Plan
N/A
Checklist
What's New
copy (if doing a release after this PR)