Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add preference optimization (Diffusion-DPO, MaPO) #1427

Draft
wants to merge 5 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class BaseSubsetParams:
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
preference: bool = False
preference_caption_prefix: Optional[str] = None
preference_caption_suffix: Optional[str] = None
non_preference_caption_prefix: Optional[str] = None
non_preference_caption_suffix: Optional[str] = None


@dataclass
Expand All @@ -100,6 +105,7 @@ class ControlNetSubsetParams(BaseSubsetParams):
conditioning_data_dir: str = None
caption_extension: str = ".caption"
cache_info: bool = False
preference: bool = False


@dataclass
Expand Down Expand Up @@ -199,6 +205,11 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"token_warmup_step": Any(float, int),
"caption_prefix": str,
"caption_suffix": str,
"preference": bool,
"preference_caption_prefix": str,
"preference_caption_suffix": str,
"non_preference_caption_prefix": str,
"non_preference_caption_suffix": str
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
Expand Down Expand Up @@ -540,14 +551,29 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask},
token_warmup_min: {subset.token_warmup_min}
token_warmup_step: {subset.token_warmup_step}
alpha_mask: {subset.alpha_mask}
preference: {subset.preference}
"""
),
" ",
)

if subset.preference:
info += indent(
dedent(

f"""\
preference_caption_prefix: {subset.preference_caption_prefix}
preference_caption_suffix: {subset.preference_caption_suffix}
non_preference_caption_prefix: {subset.non_preference_caption_prefix}
non_preference_caption_suffix: {subset.non_preference_caption_suffix}
\n"""
),
" ",
)

if is_dreambooth:
info += indent(
dedent(
Expand Down
Loading