From 4134ecefea3c818c8a69c5f287943cee71946b5a Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Fri, 3 May 2024 13:10:55 +1200 Subject: [PATCH] Support default and controlnet Unet in pipeline --- .../Examples/ControlNetExample.cs | 7 +- .../Examples/ControlNetFeatureExample.cs | 10 +-- OnnxStack.Console/appsettings.json | 42 +++++++----- .../Config/StableDiffusionModelSet.cs | 6 +- OnnxStack.StableDiffusion/Enums/ModelType.cs | 1 - .../Helpers/ModelFactory.cs | 23 +++++-- .../Pipelines/Base/IPipeline.cs | 2 +- .../Pipelines/Base/PipelineBase.cs | 2 +- .../Pipelines/InstaFlowPipeline.cs | 12 ++-- .../Pipelines/LatentConsistencyPipeline.cs | 14 ++-- .../Pipelines/LatentConsistencyXLPipeline.cs | 14 ++-- .../Pipelines/StableCascadePipeline.cs | 28 ++++++-- .../Pipelines/StableDiffusionPipeline.cs | 67 +++++++++++++++---- .../Pipelines/StableDiffusionXLPipeline.cs | 34 +++++++--- 14 files changed, 182 insertions(+), 80 deletions(-) diff --git a/OnnxStack.Console/Examples/ControlNetExample.cs b/OnnxStack.Console/Examples/ControlNetExample.cs index ad8d99b0..8e481c80 100644 --- a/OnnxStack.Console/Examples/ControlNetExample.cs +++ b/OnnxStack.Console/Examples/ControlNetExample.cs @@ -35,10 +35,10 @@ public async Task RunAsync() var controlImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\OpenPose.png"); // Create ControlNet - var controlNet = ControlNetModel.Create("D:\\Repositories\\controlnet_onnx\\controlnet\\openpose.onnx", ControlNetType.OpenPose, DiffuserPipelineType.StableDiffusion); + var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\openpose.onnx", ControlNetType.OpenPose, DiffuserPipelineType.StableDiffusion); // Create Pipeline - var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Repositories\\stable_diffusion_onnx", ModelType.ControlNet); + var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Models\\stable-diffusion-v1-5-onnx"); // Prompt var promptOptions = new PromptOptions @@ -48,7 +48,8 @@ public async Task RunAsync() InputContolImage = controlImage }; - + // Preload (optional) + await pipeline.LoadAsync(true); // Run pipeline var result = await pipeline.RunAsync(promptOptions, controlNet: controlNet, progressCallback: OutputHelpers.ProgressCallback); diff --git a/OnnxStack.Console/Examples/ControlNetFeatureExample.cs b/OnnxStack.Console/Examples/ControlNetFeatureExample.cs index 3c580c12..e67dd0e9 100644 --- a/OnnxStack.Console/Examples/ControlNetFeatureExample.cs +++ b/OnnxStack.Console/Examples/ControlNetFeatureExample.cs @@ -4,7 +4,6 @@ using OnnxStack.StableDiffusion.Enums; using OnnxStack.StableDiffusion.Models; using OnnxStack.StableDiffusion.Pipelines; -using SixLabors.ImageSharp; namespace OnnxStack.Console.Runner { @@ -35,7 +34,7 @@ public async Task RunAsync() var inputImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp"); // Create Annotation pipeline - var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutput: true); + var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Models\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutput: true); // Create Depth Image var controlImage = await annotationPipeline.RunAsync(inputImage); @@ -44,10 +43,10 @@ public async Task RunAsync() await controlImage.SaveAsync(Path.Combine(_outputDirectory, $"Depth.png")); // Create ControlNet - var controlNet = ControlNetModel.Create("D:\\Repositories\\controlnet_onnx\\controlnet\\depth.onnx", ControlNetType.Depth, DiffuserPipelineType.StableDiffusion); + var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\depth.onnx", ControlNetType.Depth, DiffuserPipelineType.StableDiffusion); // Create Pipeline - var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Repositories\\stable_diffusion_onnx", ModelType.ControlNet); + var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Models\\stable-diffusion-v1-5-onnx"); // Prompt var promptOptions = new PromptOptions @@ -57,6 +56,9 @@ public async Task RunAsync() InputContolImage = controlImage }; + // Preload (optional) + await pipeline.LoadAsync(true); + // Run pipeline var result = await pipeline.RunAsync(promptOptions, controlNet: controlNet, progressCallback: OutputHelpers.ProgressCallback); diff --git a/OnnxStack.Console/appsettings.json b/OnnxStack.Console/appsettings.json index c30ea88c..43ce3a09 100644 --- a/OnnxStack.Console/appsettings.json +++ b/OnnxStack.Console/appsettings.json @@ -32,22 +32,26 @@ "BlankTokenId": 49407, "TokenizerLimit": 77, "TokenizerLength": 768, - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5-onnx\\cliptokenizer.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\cliptokenizer.onnx" }, "TextEncoderConfig": { - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5-onnx\\text_encoder\\model.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\text_encoder\\model.onnx" }, "UnetConfig": { "ModelType": "Base", - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5-onnx\\unet\\model.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\unet\\model.onnx" }, "VaeDecoderConfig": { "ScaleFactor": 0.18215, - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5-onnx\\vae_decoder\\model.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\vae_decoder\\model.onnx" }, "VaeEncoderConfig": { "ScaleFactor": 0.18215, - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5-onnx\\vae_encoder\\model.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\vae_encoder\\model.onnx" + }, + "ControlNetConfig": { + "ModelType": "ControlNet", + "OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\controlnet\\model.onnx" } }, { @@ -70,22 +74,26 @@ "BlankTokenId": 49407, "TokenizerLimit": 77, "TokenizerLength": 768, - "OnnxModelPath": "D:\\Repositories\\LCM_Dreamshaper_v7-onnx\\tokenizer\\model.onnx" + "OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\tokenizer\\model.onnx" }, "TextEncoderConfig": { - "OnnxModelPath": "D:\\Repositories\\LCM_Dreamshaper_v7-onnx\\text_encoder\\model.onnx" + "OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\text_encoder\\model.onnx" }, "UnetConfig": { "ModelType": "Base", - "OnnxModelPath": "D:\\Repositories\\LCM_Dreamshaper_v7-onnx\\unet\\model.onnx" + "OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\unet\\model.onnx" }, "VaeDecoderConfig": { "ScaleFactor": 0.18215, - "OnnxModelPath": "D:\\Repositories\\LCM_Dreamshaper_v7-onnx\\vae_decoder\\model.onnx" + "OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\vae_decoder\\model.onnx" }, "VaeEncoderConfig": { "ScaleFactor": 0.18215, - "OnnxModelPath": "D:\\Repositories\\LCM_Dreamshaper_v7-onnx\\vae_encoder\\model.onnx" + "OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\vae_encoder\\model.onnx" + }, + "ControlNetConfig": { + "ModelType": "ControlNet", + "OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\controlnet\\model.onnx" } }, { @@ -108,32 +116,32 @@ "BlankTokenId": 49407, "TokenizerLimit": 77, "TokenizerLength": 768, - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\tokenizer\\model.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\tokenizer\\model.onnx" }, "Tokenizer2Config": { "PadTokenId": 1, "BlankTokenId": 49407, "TokenizerLimit": 77, "TokenizerLength": 1280, - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\tokenizer_2\\model.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\tokenizer_2\\model.onnx" }, "TextEncoderConfig": { - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\text_encoder\\model.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\text_encoder\\model.onnx" }, "TextEncoder2Config": { - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\text_encoder_2\\model.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\text_encoder_2\\model.onnx" }, "UnetConfig": { "ModelType": "Base", - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\unet\\model.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\unet\\model.onnx" }, "VaeDecoderConfig": { "ScaleFactor": 0.13025, - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\vae_decoder\\model.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\vae_decoder\\model.onnx" }, "VaeEncoderConfig": { "ScaleFactor": 0.13025, - "OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\vae_encoder\\model.onnx" + "OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\vae_encoder\\model.onnx" } } ] diff --git a/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs b/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs index 2442c56e..95c22e98 100644 --- a/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs +++ b/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs @@ -36,6 +36,8 @@ public record StableDiffusionModelSet : IOnnxModelSetConfig [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public UNetConditionModelConfig UnetConfig { get; set; } + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public UNetConditionModelConfig Unet2Config { get; set; } [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public AutoEncoderModelConfig VaeDecoderConfig { get; set; } @@ -44,11 +46,9 @@ public record StableDiffusionModelSet : IOnnxModelSetConfig public AutoEncoderModelConfig VaeEncoderConfig { get; set; } [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] - public UNetConditionModelConfig DecoderUnetConfig { get; set; } + public UNetConditionModelConfig ControlNetUnetConfig { get; set; } [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public SchedulerOptions SchedulerOptions { get; set; } - - } } diff --git a/OnnxStack.StableDiffusion/Enums/ModelType.cs b/OnnxStack.StableDiffusion/Enums/ModelType.cs index afd1b2fe..e502a7ec 100644 --- a/OnnxStack.StableDiffusion/Enums/ModelType.cs +++ b/OnnxStack.StableDiffusion/Enums/ModelType.cs @@ -4,7 +4,6 @@ public enum ModelType { Base = 0, Refiner = 1, - ControlNet = 2, Turbo = 3, Inpaint = 4 } diff --git a/OnnxStack.StableDiffusion/Helpers/ModelFactory.cs b/OnnxStack.StableDiffusion/Helpers/ModelFactory.cs index 5148c471..9ac5b8f5 100644 --- a/OnnxStack.StableDiffusion/Helpers/ModelFactory.cs +++ b/OnnxStack.StableDiffusion/Helpers/ModelFactory.cs @@ -37,14 +37,9 @@ public static StableDiffusionModelSet CreateModelSet(string modelFolder, Diffuse var vaeEncoderPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx"); var controlNetPath = Path.Combine(modelFolder, "controlNet", "model.onnx"); - // Some repositories have the ControlNet in the unet folder, some in the controlnet folder - if (modelType == ModelType.ControlNet && File.Exists(controlNetPath)) - unetPath = controlNetPath; - var diffusers = modelType switch { ModelType.Inpaint => new List { DiffuserType.ImageInpaint }, - ModelType.ControlNet => new List { DiffuserType.ControlNet, DiffuserType.ControlNetImage }, _ => new List { DiffuserType.TextToImage, DiffuserType.ImageToImage, DiffuserType.ImageInpaintLegacy } }; @@ -79,6 +74,19 @@ public static StableDiffusionModelSet CreateModelSet(string modelFolder, Diffuse OnnxModelPath = vaeEncoderPath }; + var contronNetConfig = default(UNetConditionModelConfig); + if (File.Exists(controlNetPath)) + { + diffusers.Add(DiffuserType.ControlNet); + diffusers.Add(DiffuserType.ControlNetImage); + contronNetConfig = new UNetConditionModelConfig + { + ModelType = modelType, + OnnxModelPath = controlNetPath + }; + } + + // SDXL Pipelines if (pipeline == DiffuserPipelineType.StableDiffusionXL || pipeline == DiffuserPipelineType.LatentConsistencyXL) { @@ -129,7 +137,8 @@ public static StableDiffusionModelSet CreateModelSet(string modelFolder, Diffuse TextEncoder2Config = textEncoder2Config, UnetConfig = unetConfig, VaeDecoderConfig = vaeDecoderConfig, - VaeEncoderConfig = vaeEncoderConfig + VaeEncoderConfig = vaeEncoderConfig, + ControlNetUnetConfig = contronNetConfig }; return configuration; } @@ -201,7 +210,7 @@ public static StableDiffusionModelSet CreateStableCascadeModelSet(string modelFo TextEncoderConfig = textEncoderConfig, TextEncoder2Config = textEncoder2Config, UnetConfig = priorUnetConfig, - DecoderUnetConfig = decoderUnetConfig, + Unet2Config = decoderUnetConfig, VaeDecoderConfig = vqganConfig, VaeEncoderConfig = imageEncoderConfig }; diff --git a/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs index 30dad861..805608cf 100644 --- a/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs @@ -43,7 +43,7 @@ public interface IPipeline /// Loads the pipeline. /// /// - Task LoadAsync(); + Task LoadAsync(bool controlNet = false); /// diff --git a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs index 648cdafe..b2cb8898 100644 --- a/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs +++ b/OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs @@ -66,7 +66,7 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger) /// Loads the pipeline. /// /// - public abstract Task LoadAsync(); + public abstract Task LoadAsync(bool controlNet = false); /// diff --git a/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs index 2f16f39a..527022c2 100644 --- a/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs @@ -25,8 +25,8 @@ public sealed class InstaFlowPipeline : StableDiffusionPipeline /// The vae decoder. /// The vae encoder. /// The logger. - public InstaFlowPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) - : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger) + public InstaFlowPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) + : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger) { _supportedDiffusers = diffusers ?? new List { @@ -62,7 +62,7 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe return diffuserType switch { DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), - DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), + DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), _ => throw new NotImplementedException() }; } @@ -81,8 +81,12 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet)); var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet)); var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet)); + var controlnet = default(UNetConditionModel); + if (modelSet.ControlNetUnetConfig is not null) + controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet)); + var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode); - return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger); + return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger); } diff --git a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs index 62966e31..ea17fd52 100644 --- a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs @@ -30,8 +30,8 @@ public sealed class LatentConsistencyPipeline : StableDiffusionPipeline /// The vae decoder. /// The vae encoder. /// The logger. - public LatentConsistencyPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) - : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger) + public LatentConsistencyPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) + : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger) { _supportedSchedulers = new List { @@ -112,8 +112,8 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), - DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), - DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), + DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), + DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), _ => throw new NotImplementedException() }; } @@ -132,8 +132,12 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet)); var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet)); var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet)); + var controlnet = default(UNetConditionModel); + if (modelSet.ControlNetUnetConfig is not null) + controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet)); + var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode); - return new LatentConsistencyPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger); + return new LatentConsistencyPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger); } diff --git a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs index 3e2494fe..d49d5c3b 100644 --- a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs @@ -32,8 +32,8 @@ public sealed class LatentConsistencyXLPipeline : StableDiffusionXLPipeline /// The vae decoder. /// The vae encoder. /// The logger. - public LatentConsistencyXLPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) - : base(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger) + public LatentConsistencyXLPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) + : base(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger) { _supportedSchedulers = new List { @@ -103,8 +103,8 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), - DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), - DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), + DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), + DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), _ => throw new NotImplementedException() }; } @@ -125,8 +125,12 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe var textEncoder2 = new TextEncoderModel(modelSet.TextEncoder2Config.ApplyDefaults(modelSet)); var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet)); var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet)); + var controlnet = default(UNetConditionModel); + if (modelSet.ControlNetUnetConfig is not null) + controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet)); + var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode); - return new LatentConsistencyXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger); + return new LatentConsistencyXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger); } diff --git a/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs index c6376e01..7f7a121d 100644 --- a/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs @@ -34,8 +34,8 @@ public sealed class StableCascadePipeline : StableDiffusionPipeline /// The diffusers. /// The default scheduler options. /// The logger. - public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel priorUnet, UNetConditionModel decoderUnet, AutoEncoderModel imageDecoder, AutoEncoderModel imageEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) - : base(pipelineOptions, tokenizer, textEncoder, priorUnet, imageDecoder, imageEncoder, diffusers, defaultSchedulerOptions, logger) + public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel priorUnet, UNetConditionModel decoderUnet, AutoEncoderModel imageDecoder, AutoEncoderModel imageEncoder, UNetConditionModel controlNet, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) + : base(pipelineOptions, tokenizer, textEncoder, priorUnet, imageDecoder, imageEncoder, controlNet, diffusers, defaultSchedulerOptions, logger) { _decoderUnet = decoderUnet; _supportedDiffusers = diffusers ?? new List @@ -66,17 +66,27 @@ public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tok /// public override DiffuserPipelineType PipelineType => DiffuserPipelineType.StableCascade; + /// + /// Gets the unet. + /// + public UNetConditionModel PriorUnet => _unet; - public override Task LoadAsync() + /// + /// Gets the unet. + /// + public UNetConditionModel DecoderUnet => _decoderUnet; + + + public override Task LoadAsync(bool controlNet = false) { if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum) - return base.LoadAsync(); + return base.LoadAsync(controlNet); // Preload all models into VRAM return Task.WhenAll ( _decoderUnet.LoadAsync(), - base.LoadAsync() + base.LoadAsync(controlNet) ); } @@ -252,13 +262,17 @@ private async Task GenerateEmbedsAsync(TokenizerResult i public static new StableCascadePipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default) { var priorUnet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet)); - var decoderUnet = new UNetConditionModel(modelSet.DecoderUnetConfig.ApplyDefaults(modelSet)); + var decoderUnet = new UNetConditionModel(modelSet.Unet2Config.ApplyDefaults(modelSet)); var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet)); var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet)); var imageDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet)); var imageEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet)); + var controlnet = default(UNetConditionModel); + if (modelSet.ControlNetUnetConfig is not null) + controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet)); + var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode); - return new StableCascadePipeline(pipelineOptions, tokenizer, textEncoder, priorUnet, decoderUnet, imageDecoder, imageEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger); + return new StableCascadePipeline(pipelineOptions, tokenizer, textEncoder, priorUnet, decoderUnet, imageDecoder, imageEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger); } diff --git a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs index 6f6327af..3b388d3e 100644 --- a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs @@ -24,12 +24,13 @@ namespace OnnxStack.StableDiffusion.Pipelines public class StableDiffusionPipeline : PipelineBase { protected readonly UNetConditionModel _unet; + protected readonly UNetConditionModel _controlNetUnet; protected readonly TokenizerModel _tokenizer; protected readonly TextEncoderModel _textEncoder; protected AutoEncoderModel _vaeDecoder; protected AutoEncoderModel _vaeEncoder; - protected OnnxModelSession _controlNet; + protected List _supportedDiffusers; protected IReadOnlyList _supportedSchedulers; protected SchedulerOptions _defaultSchedulerOptions; @@ -37,7 +38,7 @@ public class StableDiffusionPipeline : PipelineBase protected sealed record BatchResultInternal(SchedulerOptions SchedulerOptions, List> Result); /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The pipeline options /// The tokenizer. @@ -45,20 +46,27 @@ protected sealed record BatchResultInternal(SchedulerOptions SchedulerOptions, L /// The unet. /// The vae decoder. /// The vae encoder. + /// The control net unet. + /// The diffusers. + /// The default scheduler options. /// The logger. - public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers = default, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) : base(pipelineOptions, logger) + public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNetUnet, List diffusers = default, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) : base(pipelineOptions, logger) { _unet = unet; _tokenizer = tokenizer; _textEncoder = textEncoder; _vaeDecoder = vaeDecoder; _vaeEncoder = vaeEncoder; + _controlNetUnet = controlNetUnet; _supportedDiffusers = diffusers ?? new List { DiffuserType.TextToImage, DiffuserType.ImageToImage, DiffuserType.ImageInpaintLegacy }; + if (_controlNetUnet is not null) + _supportedDiffusers.AddRange(new[] { DiffuserType.ControlNet, DiffuserType.ControlNetImage }); + _supportedSchedulers = new List { SchedulerType.LMS, @@ -82,35 +90,61 @@ public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel t /// public override string Name => _pipelineOptions.Name; - /// /// Gets the supported diffusers. /// public override IReadOnlyList SupportedDiffusers => _supportedDiffusers; - /// /// Gets the supported schedulers. /// public override IReadOnlyList SupportedSchedulers => _supportedSchedulers; - /// /// Gets the type of the pipeline. /// public override DiffuserPipelineType PipelineType => DiffuserPipelineType.StableDiffusion; - /// /// Gets the default scheduler options. /// public override SchedulerOptions DefaultSchedulerOptions => _defaultSchedulerOptions; + /// + /// Gets the unet. + /// + public UNetConditionModel Unet => _unet; + + /// + /// Gets the control net unet. + /// + public UNetConditionModel ControlNetUnet => _controlNetUnet; + + /// + /// Gets the tokenizer. + /// + public TokenizerModel Tokenizer => _tokenizer; + + /// + /// Gets the text encoder. + /// + public TextEncoderModel TextEncoder => _textEncoder; + + /// + /// Gets the vae decoder. + /// + public AutoEncoderModel VaeDecoder => _vaeDecoder; + + /// + /// Gets the vae encoder. + /// + public AutoEncoderModel VaeEncoder => _vaeEncoder; + /// /// Loads the pipeline. /// - public override Task LoadAsync() + public override Task LoadAsync(bool controlNet = false) { if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum) return Task.CompletedTask; @@ -118,7 +152,9 @@ public override Task LoadAsync() // Preload all models into VRAM return Task.WhenAll ( - _unet.LoadAsync(), + controlNet + ? _controlNetUnet.LoadAsync() + : _unet.LoadAsync(), _tokenizer.LoadAsync(), _textEncoder.LoadAsync(), _vaeDecoder.LoadAsync(), @@ -138,6 +174,7 @@ public override async Task UnloadAsync() await Task.Yield(); _unet?.Dispose(); + _controlNetUnet?.Dispose(); _tokenizer?.Dispose(); _textEncoder?.Dispose(); _vaeDecoder?.Dispose(); @@ -479,8 +516,8 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), DiffuserType.ImageInpaint => new InpaintDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), - DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), - DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), + DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), + DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), _ => throw new NotImplementedException() }; } @@ -602,7 +639,7 @@ protected async Task GeneratePromptEmbedsAsync(Tokenizer } var promptTensor = new DenseTensor(promptEmbeddings.ToArray(), new[] { 1, promptEmbeddings.Count / _tokenizer.TokenizerLength, _tokenizer.TokenizerLength }); - var pooledTensor = new DenseTensor(pooledPromptEmbeddings.ToArray(), new[] { 1, tokenBatches.Count, _tokenizer.TokenizerLength }); + var pooledTensor = new DenseTensor(pooledPromptEmbeddings.ToArray(), new[] { 1, tokenBatches.Count, _tokenizer.TokenizerLength }); return new PromptEmbeddingsResult(promptTensor, pooledTensor); } @@ -635,8 +672,12 @@ protected IEnumerable PadWithBlankTokens(IEnumerable inputs, int req var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet)); var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet)); var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet)); + var controlnet = default(UNetConditionModel); + if (modelSet.ControlNetUnetConfig is not null) + controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet)); + var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode); - return new StableDiffusionPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger); + return new StableDiffusionPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger); } diff --git a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs index e1f86184..6a8af636 100644 --- a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs @@ -20,7 +20,7 @@ namespace OnnxStack.StableDiffusion.Pipelines public class StableDiffusionXLPipeline : StableDiffusionPipeline { protected TokenizerModel _tokenizer2; - protected OnnxModelSession _textEncoder2; + protected TextEncoderModel _textEncoder2; /// /// Initializes a new instance of the class. @@ -34,8 +34,8 @@ public class StableDiffusionXLPipeline : StableDiffusionPipeline /// The vae decoder. /// The vae encoder. /// The logger. - public StableDiffusionXLPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) - : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger) + public StableDiffusionXLPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) + : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger) { _tokenizer2 = tokenizer2; _textEncoder2 = textEncoder2; @@ -63,20 +63,32 @@ public StableDiffusionXLPipeline(PipelineOptions pipelineOptions, TokenizerModel public override DiffuserPipelineType PipelineType => DiffuserPipelineType.StableDiffusionXL; + /// + /// Gets the tokenizer2. + /// + public TokenizerModel Tokenizer2 => _tokenizer2; + + + /// + /// Gets the text encoder2. + /// + public TextEncoderModel TextEncoder2 => _textEncoder2; + + /// /// Loads the pipeline /// - public override Task LoadAsync() + public override Task LoadAsync(bool controlNet = false) { if (_pipelineOptions.MemoryMode == MemoryModeType.Minimum) - return base.LoadAsync(); + return base.LoadAsync(controlNet); // Preload all models into VRAM return Task.WhenAll ( _tokenizer2.LoadAsync(), _textEncoder2.LoadAsync(), - base.LoadAsync() + base.LoadAsync(controlNet) ); } @@ -106,8 +118,8 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), - DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), - DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), + DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), + DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger), _ => throw new NotImplementedException() }; } @@ -313,8 +325,12 @@ private async Task GenerateEmbedsAsync(TokenizerResult i var textEncoder2 = new TextEncoderModel(modelSet.TextEncoder2Config.ApplyDefaults(modelSet)); var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet)); var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet)); + var controlnet = default(UNetConditionModel); + if (modelSet.ControlNetUnetConfig is not null) + controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet)); + var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode); - return new StableDiffusionXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger); + return new StableDiffusionXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger); }