Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 2dcd76f

Browse files
committed
Switch Unet mode
1 parent 4134ece commit 2dcd76f

File tree

13 files changed

+64
-48
lines changed

13 files changed

+64
-48
lines changed

OnnxStack.Console/Examples/ControlNetExample.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public async Task RunAsync()
3535
var controlImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\OpenPose.png");
3636

3737
// Create ControlNet
38-
var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\openpose.onnx", ControlNetType.OpenPose, DiffuserPipelineType.StableDiffusion);
38+
var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\openpose.onnx", ControlNetType.OpenPose);
3939

4040
// Create Pipeline
4141
var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Models\\stable-diffusion-v1-5-onnx");
@@ -49,7 +49,7 @@ public async Task RunAsync()
4949
};
5050

5151
// Preload (optional)
52-
await pipeline.LoadAsync(true);
52+
await pipeline.LoadAsync(UnetModeType.ControlNet);
5353

5454
// Run pipeline
5555
var result = await pipeline.RunAsync(promptOptions, controlNet: controlNet, progressCallback: OutputHelpers.ProgressCallback);

OnnxStack.Console/Examples/ControlNetFeatureExample.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public async Task RunAsync()
4343
await controlImage.SaveAsync(Path.Combine(_outputDirectory, $"Depth.png"));
4444

4545
// Create ControlNet
46-
var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\depth.onnx", ControlNetType.Depth, DiffuserPipelineType.StableDiffusion);
46+
var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\depth.onnx", ControlNetType.Depth);
4747

4848
// Create Pipeline
4949
var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Models\\stable-diffusion-v1-5-onnx");
@@ -57,7 +57,7 @@ public async Task RunAsync()
5757
};
5858

5959
// Preload (optional)
60-
await pipeline.LoadAsync(true);
60+
await pipeline.LoadAsync(UnetModeType.ControlNet);
6161

6262
// Run pipeline
6363
var result = await pipeline.RunAsync(promptOptions, controlNet: controlNet, progressCallback: OutputHelpers.ProgressCallback);
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
namespace OnnxStack.StableDiffusion.Enums
2+
{
3+
public enum UnetModeType
4+
{
5+
Default = 0,
6+
ControlNet = 1,
7+
Both = 2
8+
}
9+
}

OnnxStack.StableDiffusion/Models/ControlNetModel.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using Microsoft.ML.OnnxRuntime;
1+

2+
using Microsoft.ML.OnnxRuntime;
23
using OnnxStack.Core.Config;
34
using OnnxStack.Core.Model;
45
using OnnxStack.StableDiffusion.Enums;
@@ -15,14 +16,13 @@ public ControlNetModel(ControlNetModelConfig configuration)
1516
}
1617

1718
public ControlNetType Type => _configuration.Type;
18-
public DiffuserPipelineType PipelineType => _configuration.PipelineType;
1919

2020
public static ControlNetModel Create(ControlNetModelConfig configuration)
2121
{
2222
return new ControlNetModel(configuration);
2323
}
2424

25-
public static ControlNetModel Create(string modelFile, ControlNetType type, DiffuserPipelineType pipeline, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
25+
public static ControlNetModel Create(string modelFile, ControlNetType type, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
2626
{
2727
var configuration = new ControlNetModelConfig
2828
{
@@ -33,7 +33,6 @@ public static ControlNetModel Create(string modelFile, ControlNetType type, Diff
3333
IntraOpNumThreads = 0,
3434
OnnxModelPath = modelFile,
3535
Type = type,
36-
PipelineType = pipeline,
3736
};
3837
return new ControlNetModel(configuration);
3938
}
@@ -42,6 +41,5 @@ public static ControlNetModel Create(string modelFile, ControlNetType type, Diff
4241
public record ControlNetModelConfig : OnnxModelConfig
4342
{
4443
public ControlNetType Type { get; set; }
45-
public DiffuserPipelineType PipelineType { get; set; }
4644
}
4745
}

OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ public interface IPipeline
2020
/// </summary>
2121
string Name { get; }
2222

23+
/// <summary>
24+
/// Gets the type of the pipeline.
25+
/// </summary>
26+
DiffuserPipelineType PipelineType { get; }
2327

2428
/// <summary>
2529
/// Gets the pipelines supported diffusers.
@@ -39,11 +43,17 @@ public interface IPipeline
3943
SchedulerOptions DefaultSchedulerOptions { get; }
4044

4145

46+
/// <summary>
47+
/// Gets the current unet mode.
48+
/// </summary>
49+
UnetModeType CurrentUnetMode { get; }
50+
51+
4252
/// <summary>
4353
/// Loads the pipeline.
4454
/// </summary>
4555
/// <returns></returns>
46-
Task LoadAsync(bool controlNet = false);
56+
Task LoadAsync(UnetModeType unetMode = UnetModeType.Default);
4757

4858

4959
/// <summary>

OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,17 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
6262
public abstract SchedulerOptions DefaultSchedulerOptions { get; }
6363

6464

65+
/// <summary>
66+
/// Gets the current unet mode.
67+
/// </summary>
68+
public abstract UnetModeType CurrentUnetMode { get; }
69+
70+
6571
/// <summary>
6672
/// Loads the pipeline.
6773
/// </summary>
6874
/// <returns></returns>
69-
public abstract Task LoadAsync(bool controlNet = false);
75+
public abstract Task LoadAsync(UnetModeType unetMode = UnetModeType.Default);
7076

7177

7278
/// <summary>
@@ -96,8 +102,6 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
96102
public abstract Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions = default, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
97103

98104

99-
100-
101105
/// <summary>
102106
/// Runs the pipeline batch.
103107
/// </summary>

OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,16 @@ public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tok
7777
public UNetConditionModel DecoderUnet => _decoderUnet;
7878

7979

80-
public override Task LoadAsync(bool controlNet = false)
80+
public override Task LoadAsync(UnetModeType unetMode = UnetModeType.Default)
8181
{
8282
if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum)
83-
return base.LoadAsync(controlNet);
83+
return base.LoadAsync(unetMode);
8484

8585
// Preload all models into VRAM
8686
return Task.WhenAll
8787
(
8888
_decoderUnet.LoadAsync(),
89-
base.LoadAsync(controlNet)
89+
base.LoadAsync(unetMode)
9090
);
9191
}
9292

OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public class StableDiffusionPipeline : PipelineBase
3434
protected List<DiffuserType> _supportedDiffusers;
3535
protected IReadOnlyList<SchedulerType> _supportedSchedulers;
3636
protected SchedulerOptions _defaultSchedulerOptions;
37+
private UnetModeType _currentUnetMode;
3738

3839
protected sealed record BatchResultInternal(SchedulerOptions SchedulerOptions, List<DenseTensor<float>> Result);
3940

@@ -110,6 +111,11 @@ public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel t
110111
/// </summary>
111112
public override SchedulerOptions DefaultSchedulerOptions => _defaultSchedulerOptions;
112113

114+
/// <summary>
115+
/// Gets the current unet mode.
116+
/// </summary>
117+
public override UnetModeType CurrentUnetMode => _currentUnetMode;
118+
113119
/// <summary>
114120
/// Gets the unet.
115121
/// </summary>
@@ -144,22 +150,29 @@ public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel t
144150
/// <summary>
145151
/// Loads the pipeline.
146152
/// </summary>
147-
public override Task LoadAsync(bool controlNet = false)
153+
public override Task LoadAsync(UnetModeType unetMode = UnetModeType.Default)
148154
{
155+
_currentUnetMode = unetMode;
149156
if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum)
150157
return Task.CompletedTask;
151158

152-
// Preload all models into VRAM
153-
return Task.WhenAll
159+
var unetModels = Task.CompletedTask;
160+
if (_currentUnetMode == UnetModeType.Default)
161+
unetModels = Task.WhenAll(_unet.LoadAsync(), _controlNetUnet?.UnloadAsync() ?? Task.CompletedTask);
162+
if (_currentUnetMode == UnetModeType.ControlNet)
163+
unetModels = Task.WhenAll(_controlNetUnet.LoadAsync(), _unet.UnloadAsync());
164+
if (_currentUnetMode == UnetModeType.Both)
165+
unetModels = Task.WhenAll(_unet.LoadAsync(), _controlNetUnet?.LoadAsync() ?? Task.CompletedTask);
166+
167+
var subModels = Task.WhenAll
154168
(
155-
controlNet
156-
? _controlNetUnet.LoadAsync()
157-
: _unet.LoadAsync(),
158-
_tokenizer.LoadAsync(),
169+
_tokenizer.LoadAsync(),
159170
_textEncoder.LoadAsync(),
160171
_vaeDecoder.LoadAsync(),
161172
_vaeEncoder.LoadAsync()
162173
);
174+
175+
return Task.WhenAll(unetModels, subModels);
163176
}
164177

165178

@@ -695,4 +708,5 @@ public static StableDiffusionPipeline CreatePipeline(string modelFolder, ModelTy
695708
return CreatePipeline(ModelFactory.CreateModelSet(modelFolder, DiffuserPipelineType.StableDiffusion, modelType, deviceId, executionProvider, memoryMode), logger);
696709
}
697710
}
711+
698712
}

OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,17 @@ public StableDiffusionXLPipeline(PipelineOptions pipelineOptions, TokenizerModel
7878
/// <summary>
7979
/// Loads the pipeline
8080
/// </summary>
81-
public override Task LoadAsync(bool controlNet = false)
81+
public override Task LoadAsync(UnetModeType unetMode = UnetModeType.Default)
8282
{
8383
if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum)
84-
return base.LoadAsync(controlNet);
84+
return base.LoadAsync(unetMode);
8585

8686
// Preload all models into VRAM
8787
return Task.WhenAll
8888
(
8989
_tokenizer2.LoadAsync(),
9090
_textEncoder2.LoadAsync(),
91-
base.LoadAsync(controlNet)
91+
base.LoadAsync(unetMode)
9292
);
9393
}
9494

OnnxStack.UI/Dialogs/AddControlNetModelDialog.xaml.cs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,6 @@ public ControlNetType SelectedControlNetType
7676

7777
}
7878

79-
private DiffuserPipelineType _selectedPipelineType;
80-
81-
public DiffuserPipelineType SelectedPipelineType
82-
{
83-
get { return _selectedPipelineType; }
84-
set { _selectedPipelineType = value; NotifyPropertyChanged(); CreateModelSet(); }
85-
}
86-
87-
8879
public ControlNetModelSet ModelSetResult
8980
{
9081
get { return _modelSetResult; }
@@ -104,7 +95,7 @@ private void CreateModelSet()
10495
if (string.IsNullOrEmpty(_modelFile))
10596
return;
10697

107-
_modelSetResult = _modelFactory.CreateControlNetModelSet(ModelName.Trim(), _selectedControlNetType, _selectedPipelineType, _modelFile);
98+
_modelSetResult = _modelFactory.CreateControlNetModelSet(ModelName.Trim(), _selectedControlNetType, _modelFile);
10899

109100
// Validate
110101
ValidationResults.Add(new ValidationResult("Name", !_invalidOptions.Contains(_modelName, StringComparer.OrdinalIgnoreCase) && _modelName.Length > 2 && _modelName.Length < 50));

OnnxStack.UI/Models/UpdateControlNetModelSetViewModel.cs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ public class UpdateControlNetModelSetViewModel : INotifyPropertyChanged
2222
private ExecutionProvider _executionProvider;
2323
private string _modelFile;
2424
private ControlNetType _controlNetType;
25-
private DiffuserPipelineType _pipelineType;
2625

2726

2827
public string Name
@@ -37,12 +36,6 @@ public ControlNetType ControlNetType
3736
set { _controlNetType = value; NotifyPropertyChanged(); }
3837
}
3938

40-
public DiffuserPipelineType PipelineType
41-
{
42-
get { return _pipelineType; }
43-
set { _pipelineType = value; NotifyPropertyChanged(); }
44-
}
45-
4639
public int DeviceId
4740
{
4841
get { return _deviceId; }
@@ -86,7 +79,6 @@ public static UpdateControlNetModelSetViewModel FromModelSet(ControlNetModelSet
8679
{
8780
Name = modelset.Name,
8881
ControlNetType = modelset.ControlNetConfig.Type,
89-
PipelineType = modelset.ControlNetConfig.PipelineType,
9082
DeviceId = modelset.DeviceId,
9183
ExecutionMode = modelset.ExecutionMode,
9284
ExecutionProvider = modelset.ExecutionProvider,
@@ -111,7 +103,6 @@ public static ControlNetModelSet ToModelSet(UpdateControlNetModelSetViewModel mo
111103
ControlNetConfig = new ControlNetModelConfig
112104
{
113105
Type = modelset.ControlNetType,
114-
PipelineType = modelset.PipelineType,
115106
OnnxModelPath = modelset.ModelFile
116107
}
117108
};

OnnxStack.UI/Services/IModelFactory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public interface IModelFactory
1414

1515
UpscaleModelSet CreateUpscaleModelSet(string name, string filename, UpscaleModelTemplate modelTemplate);
1616
StableDiffusionModelSet CreateStableDiffusionModelSet(string name, string folder, StableDiffusionModelTemplate modelTemplate);
17-
ControlNetModelSet CreateControlNetModelSet(string name, ControlNetType controlNetType, DiffuserPipelineType pipelineType, string modelFilename);
17+
ControlNetModelSet CreateControlNetModelSet(string name, ControlNetType controlNetType, string modelFilename);
1818
FeatureExtractorModelSet CreateFeatureExtractorModelSet(string name, bool normalize, int sampleSize, int channels, string modelFilename);
1919
}
2020
}

OnnxStack.UI/Services/ModelFactory.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ public UpscaleModelSet CreateUpscaleModelSet(string name, string filename, Upsca
144144
}
145145

146146

147-
public ControlNetModelSet CreateControlNetModelSet(string name, ControlNetType controlNetType, DiffuserPipelineType pipelineType, string modelFilename)
147+
public ControlNetModelSet CreateControlNetModelSet(string name, ControlNetType controlNetType, string modelFilename)
148148
{
149149
return new ControlNetModelSet
150150
{
@@ -159,7 +159,6 @@ public ControlNetModelSet CreateControlNetModelSet(string name, ControlNetType c
159159
ControlNetConfig = new ControlNetModelConfig
160160
{
161161
Type = controlNetType,
162-
PipelineType = pipelineType,
163162
OnnxModelPath = modelFilename
164163
}
165164
};

0 commit comments

Comments
 (0)