Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 106a4c4

Browse files
committed
Clone ModelSetConfig on pipeline creation
1 parent dd6c1ec commit 106a4c4

File tree

7 files changed

+90
-70
lines changed

7 files changed

+90
-70
lines changed

OnnxStack.Core/Extensions/Extensions.cs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,15 @@ public static SessionOptions GetSessionOptions(this OnnxModelConfig configuratio
6161

6262
public static T ApplyDefaults<T>(this T config, IOnnxModelSetConfig defaults) where T : OnnxModelConfig
6363
{
64-
config.DeviceId ??= defaults.DeviceId;
65-
config.ExecutionMode ??= defaults.ExecutionMode;
66-
config.ExecutionProvider ??= defaults.ExecutionProvider;
67-
config.InterOpNumThreads ??= defaults.InterOpNumThreads;
68-
config.IntraOpNumThreads ??= defaults.IntraOpNumThreads;
69-
config.Precision ??= defaults.Precision;
70-
return config;
64+
return config with
65+
{
66+
DeviceId = config.DeviceId ?? defaults.DeviceId,
67+
ExecutionMode = config.ExecutionMode ?? defaults.ExecutionMode,
68+
ExecutionProvider = config.ExecutionProvider ?? defaults.ExecutionProvider,
69+
InterOpNumThreads = config.InterOpNumThreads ?? defaults.InterOpNumThreads,
70+
IntraOpNumThreads = config.IntraOpNumThreads ?? defaults.IntraOpNumThreads,
71+
Precision = config.Precision ?? defaults.Precision
72+
};
7173
}
7274

7375

@@ -283,5 +285,15 @@ public static Span<float> NormalizeOneToOne(this Span<float> values)
283285
}
284286
return values;
285287
}
288+
289+
290+
public static void RemoveRange<TSource>(this List<TSource> source, IEnumerable<TSource> toRemove)
291+
{
292+
if (toRemove.IsNullOrEmpty())
293+
return;
294+
295+
foreach (var item in toRemove)
296+
source.Remove(item);
297+
}
286298
}
287299
}

OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,18 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
7676
/// <returns></returns>
7777
public static new InstaFlowPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
7878
{
79-
var unet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
80-
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
81-
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
82-
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
83-
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
79+
var config = modelSet with { };
80+
var unet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
81+
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
82+
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
83+
var vaeDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
84+
var vaeEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
8485
var controlnet = default(UNetConditionModel);
85-
if (modelSet.ControlNetUnetConfig is not null)
86-
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
86+
if (config.ControlNetUnetConfig is not null)
87+
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));
8788

88-
var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
89-
return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
89+
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
90+
return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
9091
}
9192

9293

OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,18 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
127127
/// <returns></returns>
128128
public static new LatentConsistencyPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
129129
{
130-
var unet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
131-
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
132-
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
133-
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
134-
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
130+
var config = modelSet with { };
131+
var unet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
132+
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
133+
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
134+
var vaeDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
135+
var vaeEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
135136
var controlnet = default(UNetConditionModel);
136-
if (modelSet.ControlNetUnetConfig is not null)
137-
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
137+
if (config.ControlNetUnetConfig is not null)
138+
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));
138139

139-
var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
140-
return new LatentConsistencyPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
140+
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
141+
return new LatentConsistencyPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
141142
}
142143

143144

OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public LatentConsistencyXLPipeline(PipelineOptions pipelineOptions, TokenizerMod
3737
{
3838
_supportedSchedulers = new List<SchedulerType>
3939
{
40-
SchedulerType.LCM
40+
SchedulerType.LCM
4141
};
4242
_defaultSchedulerOptions = defaultSchedulerOptions ?? new SchedulerOptions
4343
{
@@ -118,19 +118,20 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
118118
/// <returns></returns>
119119
public static new LatentConsistencyXLPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
120120
{
121-
var unet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
122-
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
123-
var tokenizer2 = new TokenizerModel(modelSet.Tokenizer2Config.ApplyDefaults(modelSet));
124-
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
125-
var textEncoder2 = new TextEncoderModel(modelSet.TextEncoder2Config.ApplyDefaults(modelSet));
126-
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
127-
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
121+
var config = modelSet with { };
122+
var unet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
123+
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
124+
var tokenizer2 = new TokenizerModel(config.Tokenizer2Config.ApplyDefaults(config));
125+
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
126+
var textEncoder2 = new TextEncoderModel(config.TextEncoder2Config.ApplyDefaults(config));
127+
var vaeDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
128+
var vaeEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
128129
var controlnet = default(UNetConditionModel);
129-
if (modelSet.ControlNetUnetConfig is not null)
130-
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
130+
if (config.ControlNetUnetConfig is not null)
131+
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));
131132

132-
var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
133-
return new LatentConsistencyXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
133+
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
134+
return new LatentConsistencyXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
134135
}
135136

136137

OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,18 +261,19 @@ private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(TokenizerResult i
261261
/// <returns></returns>
262262
public static new StableCascadePipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
263263
{
264-
var priorUnet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
265-
var decoderUnet = new UNetConditionModel(modelSet.Unet2Config.ApplyDefaults(modelSet));
266-
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
267-
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
268-
var imageDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
269-
var imageEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
264+
var config = modelSet with { };
265+
var priorUnet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
266+
var decoderUnet = new UNetConditionModel(config.Unet2Config.ApplyDefaults(config));
267+
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
268+
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
269+
var imageDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
270+
var imageEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
270271
var controlnet = default(UNetConditionModel);
271-
if (modelSet.ControlNetUnetConfig is not null)
272-
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
272+
if (config.ControlNetUnetConfig is not null)
273+
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));
273274

274-
var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
275-
return new StableCascadePipeline(pipelineOptions, tokenizer, textEncoder, priorUnet, decoderUnet, imageDecoder, imageEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
275+
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
276+
return new StableCascadePipeline(pipelineOptions, tokenizer, textEncoder, priorUnet, decoderUnet, imageDecoder, imageEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
276277
}
277278

278279

OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel t
6363
{
6464
DiffuserType.TextToImage,
6565
DiffuserType.ImageToImage,
66-
DiffuserType.ImageInpaintLegacy
66+
DiffuserType.ImageInpaintLegacy,
67+
DiffuserType.ControlNet,
68+
DiffuserType.ControlNetImage
6769
};
68-
if (_controlNetUnet is not null)
69-
_supportedDiffusers.AddRange(new[] { DiffuserType.ControlNet, DiffuserType.ControlNetImage });
70+
if (_controlNetUnet is null)
71+
_supportedDiffusers.RemoveRange(new[] { DiffuserType.ControlNet, DiffuserType.ControlNetImage });
7072

7173
_supportedSchedulers = new List<SchedulerType>
7274
{
@@ -680,17 +682,18 @@ protected IEnumerable<long> PadWithBlankTokens(IEnumerable<long> inputs, int req
680682
/// <returns></returns>
681683
public static new StableDiffusionPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
682684
{
683-
var unet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
684-
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
685-
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
686-
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
687-
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
685+
var config = modelSet with { };
686+
var unet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
687+
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
688+
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
689+
var vaeDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
690+
var vaeEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
688691
var controlnet = default(UNetConditionModel);
689-
if (modelSet.ControlNetUnetConfig is not null)
690-
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
692+
if (config.ControlNetUnetConfig is not null)
693+
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));
691694

692-
var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
693-
return new StableDiffusionPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
695+
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
696+
return new StableDiffusionPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
694697
}
695698

696699

OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -318,19 +318,20 @@ private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(TokenizerResult i
318318
/// <returns></returns>
319319
public static new StableDiffusionXLPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
320320
{
321-
var unet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
322-
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
323-
var tokenizer2 = new TokenizerModel(modelSet.Tokenizer2Config.ApplyDefaults(modelSet));
324-
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
325-
var textEncoder2 = new TextEncoderModel(modelSet.TextEncoder2Config.ApplyDefaults(modelSet));
326-
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
327-
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
321+
var config = modelSet with { };
322+
var unet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
323+
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
324+
var tokenizer2 = new TokenizerModel(config.Tokenizer2Config.ApplyDefaults(config));
325+
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
326+
var textEncoder2 = new TextEncoderModel(config.TextEncoder2Config.ApplyDefaults(config));
327+
var vaeDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
328+
var vaeEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
328329
var controlnet = default(UNetConditionModel);
329-
if (modelSet.ControlNetUnetConfig is not null)
330-
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
330+
if (config.ControlNetUnetConfig is not null)
331+
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));
331332

332-
var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
333-
return new StableDiffusionXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
333+
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
334+
return new StableDiffusionXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
334335
}
335336

336337

0 commit comments

Comments
 (0)