Skip to content

Commit

Permalink
Support default and controlnet Unet in pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed May 3, 2024
1 parent 6938b70 commit 4134ece
Show file tree
Hide file tree
Showing 14 changed files with 182 additions and 80 deletions.
7 changes: 4 additions & 3 deletions OnnxStack.Console/Examples/ControlNetExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ public async Task RunAsync()
var controlImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\OpenPose.png");

// Create ControlNet
var controlNet = ControlNetModel.Create("D:\\Repositories\\controlnet_onnx\\controlnet\\openpose.onnx", ControlNetType.OpenPose, DiffuserPipelineType.StableDiffusion);
var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\openpose.onnx", ControlNetType.OpenPose, DiffuserPipelineType.StableDiffusion);

// Create Pipeline
var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Repositories\\stable_diffusion_onnx", ModelType.ControlNet);
var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Models\\stable-diffusion-v1-5-onnx");

// Prompt
var promptOptions = new PromptOptions
Expand All @@ -48,7 +48,8 @@ public async Task RunAsync()
InputContolImage = controlImage
};


// Preload (optional)
await pipeline.LoadAsync(true);

// Run pipeline
var result = await pipeline.RunAsync(promptOptions, controlNet: controlNet, progressCallback: OutputHelpers.ProgressCallback);
Expand Down
10 changes: 6 additions & 4 deletions OnnxStack.Console/Examples/ControlNetFeatureExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Models;
using OnnxStack.StableDiffusion.Pipelines;
using SixLabors.ImageSharp;

namespace OnnxStack.Console.Runner
{
Expand Down Expand Up @@ -35,7 +34,7 @@ public async Task RunAsync()
var inputImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp");

// Create Annotation pipeline
var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutput: true);
var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Models\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutput: true);

// Create Depth Image
var controlImage = await annotationPipeline.RunAsync(inputImage);
Expand All @@ -44,10 +43,10 @@ public async Task RunAsync()
await controlImage.SaveAsync(Path.Combine(_outputDirectory, $"Depth.png"));

// Create ControlNet
var controlNet = ControlNetModel.Create("D:\\Repositories\\controlnet_onnx\\controlnet\\depth.onnx", ControlNetType.Depth, DiffuserPipelineType.StableDiffusion);
var controlNet = ControlNetModel.Create("D:\\Models\\controlnet_onnx\\controlnet\\depth.onnx", ControlNetType.Depth, DiffuserPipelineType.StableDiffusion);

// Create Pipeline
var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Repositories\\stable_diffusion_onnx", ModelType.ControlNet);
var pipeline = StableDiffusionPipeline.CreatePipeline("D:\\Models\\stable-diffusion-v1-5-onnx");

// Prompt
var promptOptions = new PromptOptions
Expand All @@ -57,6 +56,9 @@ public async Task RunAsync()
InputContolImage = controlImage
};

// Preload (optional)
await pipeline.LoadAsync(true);

// Run pipeline
var result = await pipeline.RunAsync(promptOptions, controlNet: controlNet, progressCallback: OutputHelpers.ProgressCallback);

Expand Down
42 changes: 25 additions & 17 deletions OnnxStack.Console/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,26 @@
"BlankTokenId": 49407,
"TokenizerLimit": 77,
"TokenizerLength": 768,
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5-onnx\\cliptokenizer.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\cliptokenizer.onnx"
},
"TextEncoderConfig": {
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5-onnx\\text_encoder\\model.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\text_encoder\\model.onnx"
},
"UnetConfig": {
"ModelType": "Base",
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5-onnx\\unet\\model.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\unet\\model.onnx"
},
"VaeDecoderConfig": {
"ScaleFactor": 0.18215,
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5-onnx\\vae_decoder\\model.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\vae_decoder\\model.onnx"
},
"VaeEncoderConfig": {
"ScaleFactor": 0.18215,
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5-onnx\\vae_encoder\\model.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\vae_encoder\\model.onnx"
},
"ControlNetConfig": {
"ModelType": "ControlNet",
"OnnxModelPath": "D:\\Models\\stable-diffusion-v1-5-onnx\\controlnet\\model.onnx"
}
},
{
Expand All @@ -70,22 +74,26 @@
"BlankTokenId": 49407,
"TokenizerLimit": 77,
"TokenizerLength": 768,
"OnnxModelPath": "D:\\Repositories\\LCM_Dreamshaper_v7-onnx\\tokenizer\\model.onnx"
"OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\tokenizer\\model.onnx"
},
"TextEncoderConfig": {
"OnnxModelPath": "D:\\Repositories\\LCM_Dreamshaper_v7-onnx\\text_encoder\\model.onnx"
"OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\text_encoder\\model.onnx"
},
"UnetConfig": {
"ModelType": "Base",
"OnnxModelPath": "D:\\Repositories\\LCM_Dreamshaper_v7-onnx\\unet\\model.onnx"
"OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\unet\\model.onnx"
},
"VaeDecoderConfig": {
"ScaleFactor": 0.18215,
"OnnxModelPath": "D:\\Repositories\\LCM_Dreamshaper_v7-onnx\\vae_decoder\\model.onnx"
"OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\vae_decoder\\model.onnx"
},
"VaeEncoderConfig": {
"ScaleFactor": 0.18215,
"OnnxModelPath": "D:\\Repositories\\LCM_Dreamshaper_v7-onnx\\vae_encoder\\model.onnx"
"OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\vae_encoder\\model.onnx"
},
"ControlNetConfig": {
"ModelType": "ControlNet",
"OnnxModelPath": "D:\\Models\\LCM_Dreamshaper_v7-onnx\\controlnet\\model.onnx"
}
},
{
Expand All @@ -108,32 +116,32 @@
"BlankTokenId": 49407,
"TokenizerLimit": 77,
"TokenizerLength": 768,
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\tokenizer\\model.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\tokenizer\\model.onnx"
},
"Tokenizer2Config": {
"PadTokenId": 1,
"BlankTokenId": 49407,
"TokenizerLimit": 77,
"TokenizerLength": 1280,
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\tokenizer_2\\model.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\tokenizer_2\\model.onnx"
},
"TextEncoderConfig": {
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\text_encoder\\model.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\text_encoder\\model.onnx"
},
"TextEncoder2Config": {
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\text_encoder_2\\model.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\text_encoder_2\\model.onnx"
},
"UnetConfig": {
"ModelType": "Base",
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\unet\\model.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\unet\\model.onnx"
},
"VaeDecoderConfig": {
"ScaleFactor": 0.13025,
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\vae_decoder\\model.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\vae_decoder\\model.onnx"
},
"VaeEncoderConfig": {
"ScaleFactor": 0.13025,
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-xl-base-1.0-onnx\\vae_encoder\\model.onnx"
"OnnxModelPath": "D:\\Models\\stable-diffusion-xl-base-1.0-onnx\\vae_encoder\\model.onnx"
}
}
]
Expand Down
6 changes: 3 additions & 3 deletions OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public record StableDiffusionModelSet : IOnnxModelSetConfig

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public UNetConditionModelConfig UnetConfig { get; set; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public UNetConditionModelConfig Unet2Config { get; set; }

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public AutoEncoderModelConfig VaeDecoderConfig { get; set; }
Expand All @@ -44,11 +46,9 @@ public record StableDiffusionModelSet : IOnnxModelSetConfig
public AutoEncoderModelConfig VaeEncoderConfig { get; set; }

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public UNetConditionModelConfig DecoderUnetConfig { get; set; }
public UNetConditionModelConfig ControlNetUnetConfig { get; set; }

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public SchedulerOptions SchedulerOptions { get; set; }


}
}
1 change: 0 additions & 1 deletion OnnxStack.StableDiffusion/Enums/ModelType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ public enum ModelType
{
Base = 0,
Refiner = 1,
ControlNet = 2,
Turbo = 3,
Inpaint = 4
}
Expand Down
23 changes: 16 additions & 7 deletions OnnxStack.StableDiffusion/Helpers/ModelFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,9 @@ public static StableDiffusionModelSet CreateModelSet(string modelFolder, Diffuse
var vaeEncoderPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx");
var controlNetPath = Path.Combine(modelFolder, "controlNet", "model.onnx");

// Some repositories have the ControlNet in the unet folder, some in the controlnet folder
if (modelType == ModelType.ControlNet && File.Exists(controlNetPath))
unetPath = controlNetPath;

var diffusers = modelType switch
{
ModelType.Inpaint => new List<DiffuserType> { DiffuserType.ImageInpaint },
ModelType.ControlNet => new List<DiffuserType> { DiffuserType.ControlNet, DiffuserType.ControlNetImage },
_ => new List<DiffuserType> { DiffuserType.TextToImage, DiffuserType.ImageToImage, DiffuserType.ImageInpaintLegacy }
};

Expand Down Expand Up @@ -79,6 +74,19 @@ public static StableDiffusionModelSet CreateModelSet(string modelFolder, Diffuse
OnnxModelPath = vaeEncoderPath
};

var contronNetConfig = default(UNetConditionModelConfig);
if (File.Exists(controlNetPath))
{
diffusers.Add(DiffuserType.ControlNet);
diffusers.Add(DiffuserType.ControlNetImage);
contronNetConfig = new UNetConditionModelConfig
{
ModelType = modelType,
OnnxModelPath = controlNetPath
};
}


// SDXL Pipelines
if (pipeline == DiffuserPipelineType.StableDiffusionXL || pipeline == DiffuserPipelineType.LatentConsistencyXL)
{
Expand Down Expand Up @@ -129,7 +137,8 @@ public static StableDiffusionModelSet CreateModelSet(string modelFolder, Diffuse
TextEncoder2Config = textEncoder2Config,
UnetConfig = unetConfig,
VaeDecoderConfig = vaeDecoderConfig,
VaeEncoderConfig = vaeEncoderConfig
VaeEncoderConfig = vaeEncoderConfig,
ControlNetUnetConfig = contronNetConfig
};
return configuration;
}
Expand Down Expand Up @@ -201,7 +210,7 @@ public static StableDiffusionModelSet CreateStableCascadeModelSet(string modelFo
TextEncoderConfig = textEncoderConfig,
TextEncoder2Config = textEncoder2Config,
UnetConfig = priorUnetConfig,
DecoderUnetConfig = decoderUnetConfig,
Unet2Config = decoderUnetConfig,
VaeDecoderConfig = vqganConfig,
VaeEncoderConfig = imageEncoderConfig
};
Expand Down
2 changes: 1 addition & 1 deletion OnnxStack.StableDiffusion/Pipelines/Base/IPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public interface IPipeline
/// Loads the pipeline.
/// </summary>
/// <returns></returns>
Task LoadAsync();
Task LoadAsync(bool controlNet = false);


/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ protected PipelineBase(PipelineOptions pipelineOptions, ILogger logger)
/// Loads the pipeline.
/// </summary>
/// <returns></returns>
public abstract Task LoadAsync();
public abstract Task LoadAsync(bool controlNet = false);


/// <summary>
Expand Down
12 changes: 8 additions & 4 deletions OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ public sealed class InstaFlowPipeline : StableDiffusionPipeline
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public InstaFlowPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List<DiffuserType> diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
: base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger)
public InstaFlowPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List<DiffuserType> diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
: base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger)
{
_supportedDiffusers = diffusers ?? new List<DiffuserType>
{
Expand Down Expand Up @@ -62,7 +62,7 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
return diffuserType switch
{
DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
_ => throw new NotImplementedException()
};
}
Expand All @@ -81,8 +81,12 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
var controlnet = default(UNetConditionModel);
if (modelSet.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));

var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
}


Expand Down
14 changes: 9 additions & 5 deletions OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public sealed class LatentConsistencyPipeline : StableDiffusionPipeline
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public LatentConsistencyPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, List<DiffuserType> diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
: base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, diffusers, defaultSchedulerOptions, logger)
public LatentConsistencyPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List<DiffuserType> diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
: base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger)
{
_supportedSchedulers = new List<SchedulerType>
{
Expand Down Expand Up @@ -112,8 +112,8 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
DiffuserType.TextToImage => new TextDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
DiffuserType.ImageToImage => new ImageDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
DiffuserType.ImageInpaintLegacy => new InpaintLegacyDiffuser(_unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _unet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
DiffuserType.ControlNet => new ControlNetDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
DiffuserType.ControlNetImage => new ControlNetImageDiffuser(controlNetModel, _controlNetUnet, _vaeDecoder, _vaeEncoder, _pipelineOptions.MemoryMode, _logger),
_ => throw new NotImplementedException()
};
}
Expand All @@ -132,8 +132,12 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
var textEncoder = new TextEncoderModel(modelSet.TextEncoderConfig.ApplyDefaults(modelSet));
var vaeDecoder = new AutoEncoderModel(modelSet.VaeDecoderConfig.ApplyDefaults(modelSet));
var vaeEncoder = new AutoEncoderModel(modelSet.VaeEncoderConfig.ApplyDefaults(modelSet));
var controlnet = default(UNetConditionModel);
if (modelSet.ControlNetUnetConfig is not null)
controlnet = new UNetConditionModel(modelSet.ControlNetUnetConfig.ApplyDefaults(modelSet));

var pipelineOptions = new PipelineOptions(modelSet.Name, modelSet.MemoryMode);
return new LatentConsistencyPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
return new LatentConsistencyPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, modelSet.Diffusers, modelSet.SchedulerOptions, logger);
}


Expand Down
Loading

0 comments on commit 4134ece

Please sign in to comment.