Skip to content

Commit c4ded2c

Browse files
authored
reformat control params and pipeline utils (#128)
* reformat control params and pipeline utils * update
1 parent 12cc587 commit c4ded2c

File tree

10 files changed

+37
-38
lines changed

10 files changed

+37
-38
lines changed

diffsynth_engine/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
SDXLPipelineConfig,
44
FluxPipelineConfig,
55
WanPipelineConfig,
6+
ControlNetParams,
7+
ControlType,
68
)
79
from .pipelines import (
810
FluxImagePipeline,
911
SDXLImagePipeline,
1012
SDImagePipeline,
1113
WanVideoPipeline,
12-
ControlNetParams,
1314
)
1415
from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux
1516
from .models.sd import SDControlNet
@@ -44,6 +45,7 @@
4445
"FluxReplaceByControlTool",
4546
"FluxReduxRefTool",
4647
"ControlNetParams",
48+
"ControlType",
4749
"fetch_model",
4850
"fetch_modelscope_model",
4951
"fetch_civitai_model",

diffsynth_engine/configs/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
FluxPipelineConfig,
99
WanPipelineConfig,
1010
)
11-
from .controlnet import ControlType
11+
from .controlnet import ControlType, ControlNetParams
1212

1313
__all__ = [
1414
"BaseConfig",
@@ -20,4 +20,5 @@
2020
"FluxPipelineConfig",
2121
"WanPipelineConfig",
2222
"ControlType",
23+
"ControlNetParams",
2324
]

diffsynth_engine/configs/controlnet.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1+
from dataclasses import dataclass
12
from enum import Enum
23

4+
import torch
5+
import torch.nn as nn
6+
from typing import List, Union, Optional
7+
from PIL import Image
8+
9+
ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]]
10+
311

412
# FLUX ControlType
513
class ControlType(Enum):
@@ -15,3 +23,14 @@ def get_in_channel(self):
1523
return 128
1624
elif self == ControlType.bfl_fill:
1725
return 384
26+
27+
28+
@dataclass
29+
class ControlNetParams:
30+
image: ImageType
31+
scale: float = 1.0
32+
model: Optional[nn.Module] = None
33+
mask: Optional[ImageType] = None
34+
control_start: float = 0
35+
control_end: float = 1
36+
processor_name: Optional[str] = None # only used for sdxl controlnet union now

diffsynth_engine/pipelines/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .base import BasePipeline, LoRAStateDictConverter
2-
from .controlnet_helper import ControlNetParams
32
from .flux_image import FluxImagePipeline
43
from .sdxl_image import SDXLImagePipeline
54
from .sd_image import SDImagePipeline
@@ -13,5 +12,4 @@
1312
"SDXLImagePipeline",
1413
"SDImagePipeline",
1514
"WanVideoPipeline",
16-
"ControlNetParams",
1715
]

diffsynth_engine/pipelines/controlnet_helper.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

diffsynth_engine/pipelines/flux_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
flux_dit_config,
1818
flux_text_encoder_config,
1919
)
20-
from diffsynth_engine.configs import FluxPipelineConfig, ControlType
20+
from diffsynth_engine.configs import FluxPipelineConfig, ControlType, ControlNetParams
2121
from diffsynth_engine.models.basic.lora import LoRAContext
2222
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
23-
from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
23+
from diffsynth_engine.pipelines.utils import accumulate
2424
from diffsynth_engine.tokenizers import CLIPTokenizer, T5TokenizerFast
2525
from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
2626
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler

diffsynth_engine/pipelines/sd_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from tqdm import tqdm
77
from PIL import Image, ImageOps
88

9-
from diffsynth_engine.configs import SDPipelineConfig
9+
from diffsynth_engine.configs import SDPipelineConfig, ControlNetParams
1010
from diffsynth_engine.models.base import split_suffix
1111
from diffsynth_engine.models.basic.lora import LoRAContext
1212
from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config
1313
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
14-
from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
14+
from diffsynth_engine.pipelines.utils import accumulate
1515
from diffsynth_engine.tokenizers import CLIPTokenizer
1616
from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
1717
from diffsynth_engine.algorithm.sampler import EulerSampler

diffsynth_engine/pipelines/sdxl_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tqdm import tqdm
77
from PIL import Image, ImageOps
88

9-
from diffsynth_engine.configs import SDXLPipelineConfig
9+
from diffsynth_engine.configs import SDXLPipelineConfig, ControlNetParams
1010
from diffsynth_engine.models.base import split_suffix
1111
from diffsynth_engine.models.basic.lora import LoRAContext
1212
from diffsynth_engine.models.basic.timestep import TemporalTimesteps
@@ -19,7 +19,7 @@
1919
sdxl_unet_config,
2020
)
2121
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
22-
from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate
22+
from diffsynth_engine.pipelines.utils import accumulate
2323
from diffsynth_engine.tokenizers import CLIPTokenizer
2424
from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler
2525
from diffsynth_engine.algorithm.sampler import EulerSampler
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def accumulate(result, new_item):
2+
if result is None:
3+
return new_item
4+
for i, item in enumerate(new_item):
5+
result[i] += item
6+
return result

tests/test_pipelines/test_flux_bfl_image.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import unittest
22

33
from tests.common.test_case import ImageTestCase
4-
from diffsynth_engine.configs import FluxPipelineConfig
4+
from diffsynth_engine.configs import FluxPipelineConfig, ControlType, ControlNetParams
55
from diffsynth_engine.pipelines import FluxImagePipeline
6-
from diffsynth_engine.pipelines.flux_image import ControlType, ControlNetParams
76
from diffsynth_engine.processor.canny_processor import CannyProcessor
87
from diffsynth_engine.processor.depth_processor import DepthProcessor
98

0 commit comments

Comments
 (0)