Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion diffsynth_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
SDXLPipelineConfig,
FluxPipelineConfig,
WanPipelineConfig,
ControlNetParams,
ControlType,
)
from .pipelines import (
FluxImagePipeline,
SDXLImagePipeline,
SDImagePipeline,
WanVideoPipeline,
ControlNetParams,
)
from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux
from .models.sd import SDControlNet
Expand Down Expand Up @@ -44,6 +45,7 @@
"FluxReplaceByControlTool",
"FluxReduxRefTool",
"ControlNetParams",
"ControlType",
"fetch_model",
"fetch_modelscope_model",
"fetch_civitai_model",
Expand Down
3 changes: 2 additions & 1 deletion diffsynth_engine/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
FluxPipelineConfig,
WanPipelineConfig,
)
from .controlnet import ControlType
from .controlnet import ControlType, ControlNetParams

__all__ = [
"BaseConfig",
Expand All @@ -20,4 +20,5 @@
"FluxPipelineConfig",
"WanPipelineConfig",
"ControlType",
"ControlNetParams",
]
19 changes: 19 additions & 0 deletions diffsynth_engine/configs/controlnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from dataclasses import dataclass
from enum import Enum

import torch
import torch.nn as nn
from typing import List, Union, Optional
from PIL import Image

ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]]


# FLUX ControlType
class ControlType(Enum):
Expand All @@ -15,3 +23,14 @@ def get_in_channel(self):
return 128
elif self == ControlType.bfl_fill:
return 384


@dataclass
class ControlNetParams:
image: ImageType
scale: float = 1.0
model: Optional[nn.Module] = None
mask: Optional[ImageType] = None
control_start: float = 0
control_end: float = 1
processor_name: Optional[str] = None # only used for sdxl controlnet union now
2 changes: 0 additions & 2 deletions diffsynth_engine/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .base import BasePipeline, LoRAStateDictConverter
from .controlnet_helper import ControlNetParams
from .flux_image import FluxImagePipeline
from .sdxl_image import SDXLImagePipeline
from .sd_image import SDImagePipeline
Expand All @@ -13,5 +12,4 @@
"SDXLImagePipeline",
"SDImagePipeline",
"WanVideoPipeline",
"ControlNetParams",
]
26 changes: 0 additions & 26 deletions diffsynth_engine/pipelines/controlnet_helper.py

This file was deleted.

4 changes: 2 additions & 2 deletions diffsynth_engine/pipelines/flux_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
flux_dit_config,
flux_text_encoder_config,
)
from diffsynth_engine.configs import FluxPipelineConfig, ControlType
from diffsynth_engine.configs import FluxPipelineConfig, ControlType, ControlNetParams
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and adhere to PEP 8 guidelines for line length, it's recommended to wrap this long import statement using parentheses.

Suggested change
from diffsynth_engine.configs import FluxPipelineConfig, ControlType, ControlNetParams
from diffsynth_engine.configs import (
FluxPipelineConfig, ControlType, ControlNetParams
)

from diffsynth_engine.models.basic.lora import LoRAContext
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
from diffsynth_engine.pipelines.utils import accumulate
from diffsynth_engine.tokenizers import CLIPTokenizer, T5TokenizerFast
from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
Expand Down
4 changes: 2 additions & 2 deletions diffsynth_engine/pipelines/sd_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from tqdm import tqdm
from PIL import Image, ImageOps

from diffsynth_engine.configs import SDPipelineConfig
from diffsynth_engine.configs import SDPipelineConfig, ControlNetParams
from diffsynth_engine.models.base import split_suffix
from diffsynth_engine.models.basic.lora import LoRAContext
from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
from diffsynth_engine.pipelines.utils import accumulate
from diffsynth_engine.tokenizers import CLIPTokenizer
from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
from diffsynth_engine.algorithm.sampler import EulerSampler
Expand Down
4 changes: 2 additions & 2 deletions diffsynth_engine/pipelines/sdxl_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm
from PIL import Image, ImageOps

from diffsynth_engine.configs import SDXLPipelineConfig
from diffsynth_engine.configs import SDXLPipelineConfig, ControlNetParams
from diffsynth_engine.models.base import split_suffix
from diffsynth_engine.models.basic.lora import LoRAContext
from diffsynth_engine.models.basic.timestep import TemporalTimesteps
Expand All @@ -19,7 +19,7 @@
sdxl_unet_config,
)
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
from diffsynth_engine.pipelines.utils import accumulate
from diffsynth_engine.tokenizers import CLIPTokenizer
from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
from diffsynth_engine.algorithm.sampler import EulerSampler
Expand Down
6 changes: 6 additions & 0 deletions diffsynth_engine/pipelines/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def accumulate(result, new_item):
if result is None:
return new_item
for i, item in enumerate(new_item):
result[i] += item
return result
Comment on lines +1 to +6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve code clarity and maintainability, please add type hints to the accumulate function. This will make it easier to understand the expected types for result and new_item and what the function returns. Based on its usage, result is an optional list of tensors, and new_item is a list of tensors.

Suggested change
def accumulate(result, new_item):
if result is None:
return new_item
for i, item in enumerate(new_item):
result[i] += item
return result
from typing import List, Optional
import torch
def accumulate(result: Optional[List[torch.Tensor]], new_item: List[torch.Tensor]) -> List[torch.Tensor]:
if result is None:
return new_item
for i, item in enumerate(new_item):
result[i] += item
return result

3 changes: 1 addition & 2 deletions tests/test_pipelines/test_flux_bfl_image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import unittest

from tests.common.test_case import ImageTestCase
from diffsynth_engine.configs import FluxPipelineConfig
from diffsynth_engine.configs import FluxPipelineConfig, ControlType, ControlNetParams
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and adhere to PEP 8 style for line length, please wrap this long import statement using parentheses.

Suggested change
from diffsynth_engine.configs import FluxPipelineConfig, ControlType, ControlNetParams
from diffsynth_engine.configs import (
FluxPipelineConfig, ControlType, ControlNetParams
)

from diffsynth_engine.pipelines import FluxImagePipeline
from diffsynth_engine.pipelines.flux_image import ControlType, ControlNetParams
from diffsynth_engine.processor.canny_processor import CannyProcessor
from diffsynth_engine.processor.depth_processor import DepthProcessor

Expand Down