Skip to content

Commit

Permalink
Clone ModelSetConfig on pipeline creation
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed May 23, 2024
1 parent dd6c1ec commit 106a4c4
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 70 deletions.
26 changes: 19 additions & 7 deletions OnnxStack.Core/Extensions/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ public static SessionOptions GetSessionOptions(this OnnxModelConfig configuratio

public static T ApplyDefaults<T>(this T config, IOnnxModelSetConfig defaults) where T : OnnxModelConfig
{
config.DeviceId ??= defaults.DeviceId;
config.ExecutionMode ??= defaults.ExecutionMode;
config.ExecutionProvider ??= defaults.ExecutionProvider;
config.InterOpNumThreads ??= defaults.InterOpNumThreads;
config.IntraOpNumThreads ??= defaults.IntraOpNumThreads;
config.Precision ??= defaults.Precision;
return config;
return config with
{
DeviceId = config.DeviceId ?? defaults.DeviceId,
ExecutionMode = config.ExecutionMode ?? defaults.ExecutionMode,
ExecutionProvider = config.ExecutionProvider ?? defaults.ExecutionProvider,
InterOpNumThreads = config.InterOpNumThreads ?? defaults.InterOpNumThreads,
IntraOpNumThreads = config.IntraOpNumThreads ?? defaults.IntraOpNumThreads,
Precision = config.Precision ?? defaults.Precision
};
}


Expand Down Expand Up @@ -283,5 +285,15 @@ public static Span<float> NormalizeOneToOne(this Span<float> values)
}
return values;
}


public static void RemoveRange<TSource>(this List<TSource> source, IEnumerable<TSource> toRemove)
{
if (toRemove.IsNullOrEmpty())
return;

foreach (var item in toRemove)
source.Remove(item);
}
}
}
19 changes: 10 additions & 9 deletions OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,18 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
/// <returns></returns>
public static new InstaFlowPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
{
var unet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
var config = modelSet with { };
var unet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
var vaeDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
var vaeEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
var controlnet = default(UNetConditionModel);
if (modelSet.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
if (config.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));

var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
}


Expand Down
19 changes: 10 additions & 9 deletions OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,18 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
/// <returns></returns>
public static new LatentConsistencyPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
{
var unet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
var config = modelSet with { };
var unet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
var vaeDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
var vaeEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
var controlnet = default(UNetConditionModel);
if (modelSet.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
if (config.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));

var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
return new LatentConsistencyPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
return new LatentConsistencyPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
}


Expand Down
25 changes: 13 additions & 12 deletions OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public LatentConsistencyXLPipeline(PipelineOptions pipelineOptions, TokenizerMod
{
_supportedSchedulers = new List<SchedulerType>
{
SchedulerType.LCM
SchedulerType.LCM
};
_defaultSchedulerOptions = defaultSchedulerOptions ?? new SchedulerOptions
{
Expand Down Expand Up @@ -118,19 +118,20 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
/// <returns></returns>
public static new LatentConsistencyXLPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
{
var unet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
var tokenizer2 = new TokenizerModel(modelSet.Tokenizer2Config.ApplyDefaults(modelSet));
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
var textEncoder2 = new TextEncoderModel(modelSet.TextEncoder2Config.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
var config = modelSet with { };
var unet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
var tokenizer2 = new TokenizerModel(config.Tokenizer2Config.ApplyDefaults(config));
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
var textEncoder2 = new TextEncoderModel(config.TextEncoder2Config.ApplyDefaults(config));
var vaeDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
var vaeEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
var controlnet = default(UNetConditionModel);
if (modelSet.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
if (config.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));

var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
return new LatentConsistencyXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
return new LatentConsistencyXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
}


Expand Down
21 changes: 11 additions & 10 deletions OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,18 +261,19 @@ private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(TokenizerResult i
/// <returns></returns>
public static new StableCascadePipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
{
var priorUnet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
var decoderUnet = new UNetConditionModel(modelSet.Unet2Config.ApplyDefaults(modelSet));
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
var imageDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var imageEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
var config = modelSet with { };
var priorUnet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
var decoderUnet = new UNetConditionModel(config.Unet2Config.ApplyDefaults(config));
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
var imageDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
var imageEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
var controlnet = default(UNetConditionModel);
if (modelSet.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
if (config.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));

var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
return new StableCascadePipeline(pipelineOptions, tokenizer, textEncoder, priorUnet, decoderUnet, imageDecoder, imageEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
return new StableCascadePipeline(pipelineOptions, tokenizer, textEncoder, priorUnet, decoderUnet, imageDecoder, imageEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
}


Expand Down
27 changes: 15 additions & 12 deletions OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@ public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel t
{
DiffuserType.TextToImage,
DiffuserType.ImageToImage,
DiffuserType.ImageInpaintLegacy
DiffuserType.ImageInpaintLegacy,
DiffuserType.ControlNet,
DiffuserType.ControlNetImage
};
if (_controlNetUnet is not null)
_supportedDiffusers.AddRange(new[] { DiffuserType.ControlNet, DiffuserType.ControlNetImage });
if (_controlNetUnet is null)
_supportedDiffusers.RemoveRange(new[] { DiffuserType.ControlNet, DiffuserType.ControlNetImage });

_supportedSchedulers = new List<SchedulerType>
{
Expand Down Expand Up @@ -680,17 +682,18 @@ protected IEnumerable<long> PadWithBlankTokens(IEnumerable<long> inputs, int req
/// <returns></returns>
public static new StableDiffusionPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
{
var unet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
var config = modelSet with { };
var unet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
var vaeDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
var vaeEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
var controlnet = default(UNetConditionModel);
if (modelSet.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
if (config.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));

var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
return new StableDiffusionPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
return new StableDiffusionPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
}


Expand Down
23 changes: 12 additions & 11 deletions OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -318,19 +318,20 @@ private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(TokenizerResult i
/// <returns></returns>
public static new StableDiffusionXLPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
{
var unet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));
var tokenizer2 = new TokenizerModel(modelSet.Tokenizer2Config.ApplyDefaults(modelSet));
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
var textEncoder2 = new TextEncoderModel(modelSet.TextEncoder2Config.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
var config = modelSet with { };
var unet = new UNetConditionModel(config.UnetConfig.ApplyDefaults(config));
var tokenizer = new TokenizerModel(config.TokenizerConfig.ApplyDefaults(config));
var tokenizer2 = new TokenizerModel(config.Tokenizer2Config.ApplyDefaults(config));
var textEncoder = new TextEncoderModel(config.TextEncoderConfig.ApplyDefaults(config));
var textEncoder2 = new TextEncoderModel(config.TextEncoder2Config.ApplyDefaults(config));
var vaeDecoder = new AutoEncoderModel(config.VaeDecoderConfig.ApplyDefaults(config));
var vaeEncoder = new AutoEncoderModel(config.VaeEncoderConfig.ApplyDefaults(config));
var controlnet = default(UNetConditionModel);
if (modelSet.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));
if (config.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));

var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
return new StableDiffusionXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
return new StableDiffusionXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
}


Expand Down

0 comments on commit 106a4c4

Please sign in to comment.