diff --git a/OnnxStack.Core/Extensions/Extensions.cs b/OnnxStack.Core/Extensions/Extensions.cs index ac24faf5..59de7672 100644 --- a/OnnxStack.Core/Extensions/Extensions.cs +++ b/OnnxStack.Core/Extensions/Extensions.cs @@ -19,6 +19,7 @@ public static SessionOptions GetSessionOptions(this OnnxModelConfig configuratio InterOpNumThreads = configuration.InterOpNumThreads.Value, IntraOpNumThreads = configuration.IntraOpNumThreads.Value }; + switch (configuration.ExecutionProvider) { case ExecutionProvider.DirectML: @@ -87,72 +88,6 @@ public static bool IsNullOrEmpty(this IEnumerable source) } - /// - /// Batches the source sequence into sized buckets. - /// - /// Type of elements in sequence. - /// The source sequence. - /// Size of buckets. - /// A sequence of equally sized buckets containing elements of the source collection. - /// - /// This operator uses deferred execution and streams its results (buckets and bucket content). - /// - public static IEnumerable> Batch(this IEnumerable source, int size) - { - return Batch(source, size, x => x); - } - - /// - /// Batches the source sequence into sized buckets and applies a projection to each bucket. - /// - /// Type of elements in sequence. - /// Type of result returned by . - /// The source sequence. - /// Size of buckets. - /// The projection to apply to each bucket. - /// A sequence of projections on equally sized buckets containing elements of the source collection. - /// - /// This operator uses deferred execution and streams its results (buckets and bucket content). - /// - public static IEnumerable Batch(this IEnumerable source, int size, Func, TResult> resultSelector) - { - if (source == null) - throw new ArgumentNullException(nameof(source)); - if (size <= 0) - throw new ArgumentOutOfRangeException(nameof(size)); - if (resultSelector == null) - throw new ArgumentNullException(nameof(resultSelector)); - return BatchImpl(source, size, resultSelector); - } - - - private static IEnumerable BatchImpl(this IEnumerable source, int size, Func, TResult> resultSelector) - { - TSource[] bucket = null; - var count = 0; - foreach (var item in source) - { - if (bucket == null) - bucket = new TSource[size]; - - bucket[count++] = item; - - // The bucket is fully buffered before it's yielded - if (count != size) - continue; - - // Select is necessary so bucket contents are streamed too - yield return resultSelector(bucket.Select(x => x)); - bucket = null; - count = 0; - } - - // Return the last bucket with all remaining elements - if (bucket != null && count > 0) - yield return resultSelector(bucket.Take(count)); - } - - /// /// Get the index of the specified item /// diff --git a/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs b/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs index ac99836d..99138337 100644 --- a/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs +++ b/OnnxStack.StableDiffusion/Config/SchedulerOptions.cs @@ -1,6 +1,7 @@ using OnnxStack.StableDiffusion.Enums; using System.Collections.Generic; using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; namespace OnnxStack.StableDiffusion.Config { @@ -36,6 +37,7 @@ public record SchedulerOptions /// If value is set to 0 a random seed is used. /// [Range(0, int.MaxValue)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public int Seed { get; set; } /// @@ -45,6 +47,7 @@ public record SchedulerOptions /// The number of steps to run inference for. The more steps the longer it will take to run the inference loop but the image quality should improve. /// [Range(5, 200)] + public int InferenceSteps { get; set; } = 30; /// @@ -62,34 +65,76 @@ public record SchedulerOptions public float Strength { get; set; } = 0.6f; [Range(0, int.MaxValue)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public int TrainTimesteps { get; set; } = 1000; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public float BetaStart { get; set; } = 0.00085f; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public float BetaEnd { get; set; } = 0.012f; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public IEnumerable TrainedBetas { get; set; } + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public TimestepSpacingType TimestepSpacing { get; set; } = TimestepSpacingType.Linspace; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public BetaScheduleType BetaSchedule { get; set; } = BetaScheduleType.ScaledLinear; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public int StepsOffset { get; set; } = 0; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public bool UseKarrasSigmas { get; set; } = false; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public VarianceType VarianceType { get; set; } = VarianceType.FixedSmall; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public float SampleMaxValue { get; set; } = 1.0f; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public bool Thresholding { get; set; } = false; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public bool ClipSample { get; set; } = false; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public float ClipSampleRange { get; set; } = 1f; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public PredictionType PredictionType { get; set; } = PredictionType.Epsilon; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public AlphaTransformType AlphaTransformType { get; set; } = AlphaTransformType.Cosine; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public float MaximumBeta { get; set; } = 0.999f; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public List Timesteps { get; set; } + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public int OriginalInferenceSteps { get; set; } = 50; + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public float AestheticScore { get; set; } = 6f; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public float AestheticNegativeScore { get; set; } = 2.5f; + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public float ConditioningScale { get; set; } = 0.7f; + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public int InferenceSteps2 { get; set; } = 10; + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public float GuidanceScale2 { get; set; } = 0; + [JsonIgnore] public bool IsKarrasScheduler { get diff --git a/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs b/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs index 95c22e98..a6bf6642 100644 --- a/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs +++ b/OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs @@ -11,16 +11,19 @@ public record StableDiffusionModelSet : IOnnxModelSetConfig { public string Name { get; set; } public bool IsEnabled { get; set; } - public int SampleSize { get; set; } = 512; - public DiffuserPipelineType PipelineType { get; set; } - public List Diffusers { get; set; } = new List(); - public MemoryModeType MemoryMode { get; set; } public int DeviceId { get; set; } public int InterOpNumThreads { get; set; } public int IntraOpNumThreads { get; set; } public ExecutionMode ExecutionMode { get; set; } public ExecutionProvider ExecutionProvider { get; set; } public OnnxModelPrecision Precision { get; set; } + public MemoryModeType MemoryMode { get; set; } + public int SampleSize { get; set; } = 512; + public DiffuserPipelineType PipelineType { get; set; } + public List Diffusers { get; set; } + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List Schedulers { get; set; } [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public TokenizerModelConfig TokenizerConfig { get; set; } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs index d2ac4e61..c2e3f157 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs @@ -252,7 +252,7 @@ protected async Task> DiffuseDecodeAsync(PromptOptions prompt // Unload if required if (_memoryMode == MemoryModeType.Minimum) - await _unet.UnloadAsync(); + await _decoderUnet.UnloadAsync(); return latents; } diff --git a/OnnxStack.StableDiffusion/Models/AutoEncoderModel.cs b/OnnxStack.StableDiffusion/Models/AutoEncoderModel.cs index 397e011a..e4941f0f 100644 --- a/OnnxStack.StableDiffusion/Models/AutoEncoderModel.cs +++ b/OnnxStack.StableDiffusion/Models/AutoEncoderModel.cs @@ -1,4 +1,5 @@ -using OnnxStack.Core.Config; +using Microsoft.ML.OnnxRuntime; +using OnnxStack.Core.Config; using OnnxStack.Core.Model; namespace OnnxStack.StableDiffusion.Models @@ -13,6 +14,27 @@ public AutoEncoderModel(AutoEncoderModelConfig configuration) : base(configurati } public float ScaleFactor => _configuration.ScaleFactor; + + + public static AutoEncoderModel Create(AutoEncoderModelConfig configuration) + { + return new AutoEncoderModel(configuration); + } + + public static AutoEncoderModel Create(string modelFile, float scaleFactor, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) + { + var configuration = new AutoEncoderModelConfig + { + DeviceId = deviceId, + ExecutionProvider = executionProvider, + ExecutionMode = ExecutionMode.ORT_SEQUENTIAL, + InterOpNumThreads = 0, + IntraOpNumThreads = 0, + OnnxModelPath = modelFile, + ScaleFactor = scaleFactor + }; + return new AutoEncoderModel(configuration); + } } public record AutoEncoderModelConfig : OnnxModelConfig diff --git a/OnnxStack.StableDiffusion/Models/ControlNetModel.cs b/OnnxStack.StableDiffusion/Models/ControlNetModel.cs index 9888b0d9..4802af7f 100644 --- a/OnnxStack.StableDiffusion/Models/ControlNetModel.cs +++ b/OnnxStack.StableDiffusion/Models/ControlNetModel.cs @@ -32,7 +32,7 @@ public static ControlNetModel Create(string modelFile, ControlNetType type, int InterOpNumThreads = 0, IntraOpNumThreads = 0, OnnxModelPath = modelFile, - Type = type, + Type = type }; return new ControlNetModel(configuration); } diff --git a/OnnxStack.StableDiffusion/Models/TextEncoderModel.cs b/OnnxStack.StableDiffusion/Models/TextEncoderModel.cs index 72863c9d..d788aae2 100644 --- a/OnnxStack.StableDiffusion/Models/TextEncoderModel.cs +++ b/OnnxStack.StableDiffusion/Models/TextEncoderModel.cs @@ -1,4 +1,5 @@ -using OnnxStack.Core.Config; +using Microsoft.ML.OnnxRuntime; +using OnnxStack.Core.Config; using OnnxStack.Core.Model; namespace OnnxStack.StableDiffusion.Models @@ -11,6 +12,25 @@ public TextEncoderModel(TextEncoderModelConfig configuration) : base(configurati { _configuration = configuration; } + + public static TextEncoderModel Create(TextEncoderModelConfig configuration) + { + return new TextEncoderModel(configuration); + } + + public static TextEncoderModel Create(string modelFile, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) + { + var configuration = new TextEncoderModelConfig + { + DeviceId = deviceId, + ExecutionProvider = executionProvider, + ExecutionMode = ExecutionMode.ORT_SEQUENTIAL, + InterOpNumThreads = 0, + IntraOpNumThreads = 0, + OnnxModelPath = modelFile + }; + return new TextEncoderModel(configuration); + } } public record TextEncoderModelConfig : OnnxModelConfig diff --git a/OnnxStack.StableDiffusion/Models/TokenizerModel.cs b/OnnxStack.StableDiffusion/Models/TokenizerModel.cs index f597ac15..55ecf61e 100644 --- a/OnnxStack.StableDiffusion/Models/TokenizerModel.cs +++ b/OnnxStack.StableDiffusion/Models/TokenizerModel.cs @@ -1,5 +1,7 @@ -using OnnxStack.Core.Config; +using Microsoft.ML.OnnxRuntime; +using OnnxStack.Core.Config; using OnnxStack.Core.Model; +using OnnxStack.StableDiffusion.Enums; namespace OnnxStack.StableDiffusion.Models { @@ -16,6 +18,29 @@ public TokenizerModel(TokenizerModelConfig configuration) : base(configuration) public int TokenizerLength => _configuration.TokenizerLength; public int PadTokenId => _configuration.PadTokenId; public int BlankTokenId => _configuration.BlankTokenId; + + public static TokenizerModel Create(TokenizerModelConfig configuration) + { + return new TokenizerModel(configuration); + } + + public static TokenizerModel Create(string modelFile, int tokenizerLength = 768, int tokenizerLimit = 77, int padTokenId = 49407, int blankTokenId = 49407, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) + { + var configuration = new TokenizerModelConfig + { + DeviceId = deviceId, + ExecutionProvider = executionProvider, + ExecutionMode = ExecutionMode.ORT_SEQUENTIAL, + InterOpNumThreads = 0, + IntraOpNumThreads = 0, + OnnxModelPath = modelFile, + PadTokenId = padTokenId, + BlankTokenId = blankTokenId, + TokenizerLength = tokenizerLength, + TokenizerLimit = tokenizerLimit + }; + return new TokenizerModel(configuration); + } } public record TokenizerModelConfig : OnnxModelConfig diff --git a/OnnxStack.StableDiffusion/Models/UNetConditionModel.cs b/OnnxStack.StableDiffusion/Models/UNetConditionModel.cs index 3f32dd85..af3e43ce 100644 --- a/OnnxStack.StableDiffusion/Models/UNetConditionModel.cs +++ b/OnnxStack.StableDiffusion/Models/UNetConditionModel.cs @@ -1,4 +1,5 @@ -using OnnxStack.Core.Config; +using Microsoft.ML.OnnxRuntime; +using OnnxStack.Core.Config; using OnnxStack.Core.Model; using OnnxStack.StableDiffusion.Enums; @@ -14,6 +15,26 @@ public UNetConditionModel(UNetConditionModelConfig configuration) : base(configu } public ModelType ModelType => _configuration.ModelType; + + public static UNetConditionModel Create(UNetConditionModelConfig configuration) + { + return new UNetConditionModel(configuration); + } + + public static UNetConditionModel Create(string modelFile, ModelType modelType, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML) + { + var configuration = new UNetConditionModelConfig + { + DeviceId = deviceId, + ExecutionProvider = executionProvider, + ExecutionMode = ExecutionMode.ORT_SEQUENTIAL, + InterOpNumThreads = 0, + IntraOpNumThreads = 0, + OnnxModelPath = modelFile, + ModelType = modelType + }; + return new UNetConditionModel(configuration); + } } diff --git a/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs index 0794cff9..4060d436 100644 --- a/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.Logging; using OnnxStack.Core; using OnnxStack.Core.Config; +using OnnxStack.StableDiffusion.Common; using OnnxStack.StableDiffusion.Config; using OnnxStack.StableDiffusion.Diffusers; using OnnxStack.StableDiffusion.Diffusers.InstaFlow; @@ -25,14 +26,14 @@ public sealed class InstaFlowPipeline : StableDiffusionPipeline /// The vae decoder. /// The vae encoder. /// The logger. - public InstaFlowPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) - : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger) + public InstaFlowPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, List schedulers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) + : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, schedulers, defaultSchedulerOptions, logger) { _supportedDiffusers = diffusers ?? new List { DiffuserType.TextToImage }; - _supportedSchedulers = new List + _supportedSchedulers = schedulers ?? new List { SchedulerType.InstaFlow }; @@ -87,7 +88,7 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config)); var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode); - return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger); + return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.Schedulers, config.SchedulerOptions, logger); } diff --git a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs index da050895..5917f507 100644 --- a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyPipeline.cs @@ -30,10 +30,10 @@ public sealed class LatentConsistencyPipeline : StableDiffusionPipeline /// The vae decoder. /// The vae encoder. /// The logger. - public LatentConsistencyPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) - : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger) + public LatentConsistencyPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, List schedulers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) + : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, schedulers, defaultSchedulerOptions, logger) { - _supportedSchedulers = new List + _supportedSchedulers = schedulers ?? new List { SchedulerType.LCM }; @@ -138,7 +138,7 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config)); var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode); - return new LatentConsistencyPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger); + return new LatentConsistencyPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.Schedulers, config.SchedulerOptions, logger); } diff --git a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs index 04026878..9fc95d3d 100644 --- a/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/LatentConsistencyXLPipeline.cs @@ -32,10 +32,10 @@ public sealed class LatentConsistencyXLPipeline : StableDiffusionXLPipeline /// The vae decoder. /// The vae encoder. /// The logger. - public LatentConsistencyXLPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) - : base(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger) + public LatentConsistencyXLPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, List schedulers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) + : base(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, schedulers, defaultSchedulerOptions, logger) { - _supportedSchedulers = new List + _supportedSchedulers = schedulers ?? new List { SchedulerType.LCM }; @@ -131,7 +131,7 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config)); var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode); - return new LatentConsistencyXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger); + return new LatentConsistencyXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.Schedulers, config.SchedulerOptions, logger); } diff --git a/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs index 53d1e695..2912599f 100644 --- a/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/StableCascadePipeline.cs @@ -34,8 +34,8 @@ public sealed class StableCascadePipeline : StableDiffusionPipeline /// The diffusers. /// The default scheduler options. /// The logger. - public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel priorUnet, UNetConditionModel decoderUnet, AutoEncoderModel imageDecoder, AutoEncoderModel imageEncoder, UNetConditionModel controlNet, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) - : base(pipelineOptions, tokenizer, textEncoder, priorUnet, imageDecoder, imageEncoder, controlNet, diffusers, defaultSchedulerOptions, logger) + public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel priorUnet, UNetConditionModel decoderUnet, AutoEncoderModel imageDecoder, AutoEncoderModel imageEncoder, UNetConditionModel controlNet, List diffusers, List schedulers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) + : base(pipelineOptions, tokenizer, textEncoder, priorUnet, imageDecoder, imageEncoder, controlNet, diffusers, schedulers,defaultSchedulerOptions, logger) { _decoderUnet = decoderUnet; _supportedDiffusers = diffusers ?? new List @@ -43,7 +43,7 @@ public StableCascadePipeline(PipelineOptions pipelineOptions, TokenizerModel tok DiffuserType.TextToImage, DiffuserType.ImageToImage }; - _supportedSchedulers = new List + _supportedSchedulers = schedulers ?? new List { SchedulerType.DDPM, SchedulerType.DDPMWuerstchen @@ -233,9 +233,9 @@ private async Task GenerateEmbedsAsync(TokenizerResult i // The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate var tokenBatches = new List(); var attentionBatches = new List(); - foreach (var tokenBatch in inputTokens.InputIds.Batch(_tokenizer.TokenizerLimit)) + foreach (var tokenBatch in inputTokens.InputIds.Chunk(_tokenizer.TokenizerLimit)) tokenBatches.Add(PadWithBlankTokens(tokenBatch, _tokenizer.TokenizerLimit, _tokenizer.PadTokenId).ToArray()); - foreach (var attentionBatch in inputTokens.AttentionMask.Batch(_tokenizer.TokenizerLimit)) + foreach (var attentionBatch in inputTokens.AttentionMask.Chunk(_tokenizer.TokenizerLimit)) attentionBatches.Add(PadWithBlankTokens(attentionBatch, _tokenizer.TokenizerLimit, 1).ToArray()); var promptEmbeddings = new List(); @@ -273,7 +273,7 @@ private async Task GenerateEmbedsAsync(TokenizerResult i controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config)); var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode); - return new StableCascadePipeline(pipelineOptions, tokenizer, textEncoder, priorUnet, decoderUnet, imageDecoder, imageEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger); + return new StableCascadePipeline(pipelineOptions, tokenizer, textEncoder, priorUnet, decoderUnet, imageDecoder, imageEncoder, controlnet, config.Diffusers, config.Schedulers, config.SchedulerOptions, logger); } diff --git a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs index ec55e6a2..af217baa 100644 --- a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs @@ -51,7 +51,7 @@ protected sealed record BatchResultInternal(SchedulerOptions SchedulerOptions, L /// The diffusers. /// The default scheduler options. /// The logger. - public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNetUnet, List diffusers = default, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) : base(pipelineOptions, logger) + public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNetUnet, List diffusers = default, List schedulers = default, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) : base(pipelineOptions, logger) { _unet = unet; _tokenizer = tokenizer; @@ -70,7 +70,7 @@ public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel t if (_controlNetUnet is null) _supportedDiffusers.RemoveRange(new[] { DiffuserType.ControlNet, DiffuserType.ControlNetImage }); - _supportedSchedulers = new List + _supportedSchedulers = schedulers ?? new List { SchedulerType.LMS, SchedulerType.Euler, @@ -639,9 +639,9 @@ protected async Task GeneratePromptEmbedsAsync(Tokenizer // The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate1 var tokenBatches = new List(); var attentionBatches = new List(); - foreach (var tokenBatch in inputTokens.InputIds.Batch(_tokenizer.TokenizerLimit)) + foreach (var tokenBatch in inputTokens.InputIds.Chunk(_tokenizer.TokenizerLimit)) tokenBatches.Add(PadWithBlankTokens(tokenBatch, _tokenizer.TokenizerLimit, _tokenizer.PadTokenId).ToArray()); - foreach (var attentionBatch in inputTokens.AttentionMask.Batch(_tokenizer.TokenizerLimit)) + foreach (var attentionBatch in inputTokens.AttentionMask.Chunk(_tokenizer.TokenizerLimit)) attentionBatches.Add(PadWithBlankTokens(attentionBatch, _tokenizer.TokenizerLimit, 1).ToArray()); var promptEmbeddings = new List(); @@ -693,7 +693,7 @@ protected IEnumerable PadWithBlankTokens(IEnumerable inputs, int req controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config)); var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode); - return new StableDiffusionPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger); + return new StableDiffusionPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.Schedulers, config.SchedulerOptions, logger); } diff --git a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs index e38e4517..746e33ee 100644 --- a/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs +++ b/OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs @@ -34,17 +34,18 @@ public class StableDiffusionXLPipeline : StableDiffusionPipeline /// The vae decoder. /// The vae encoder. /// The logger. - public StableDiffusionXLPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) - : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger) + public StableDiffusionXLPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TokenizerModel tokenizer2, TextEncoderModel textEncoder, TextEncoderModel textEncoder2, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List diffusers, List schedulers = default, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default) + : base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, schedulers, defaultSchedulerOptions, logger) { _tokenizer2 = tokenizer2; _textEncoder2 = textEncoder2; - _supportedSchedulers = new List + _supportedSchedulers = schedulers ?? new List { SchedulerType.Euler, SchedulerType.EulerAncestral, SchedulerType.DDPM, - SchedulerType.KDPM2 + SchedulerType.KDPM2, + SchedulerType.DDIM }; _defaultSchedulerOptions = defaultSchedulerOptions ?? new SchedulerOptions { @@ -290,9 +291,9 @@ private async Task GenerateEmbedsAsync(TokenizerResult i // The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate1 var tokenBatches = new List(); var attentionBatches = new List(); - foreach (var tokenBatch in inputTokens.InputIds.Batch(_tokenizer.TokenizerLimit)) + foreach (var tokenBatch in inputTokens.InputIds.Chunk(_tokenizer.TokenizerLimit)) tokenBatches.Add(PadWithBlankTokens(tokenBatch, _tokenizer.TokenizerLimit, _tokenizer.PadTokenId).ToArray()); - foreach (var attentionBatch in inputTokens.AttentionMask.Batch(_tokenizer.TokenizerLimit)) + foreach (var attentionBatch in inputTokens.AttentionMask.Chunk(_tokenizer.TokenizerLimit)) attentionBatches.Add(PadWithBlankTokens(attentionBatch, _tokenizer.TokenizerLimit, 1).ToArray()); var promptEmbeddings = new List(); @@ -331,7 +332,7 @@ private async Task GenerateEmbedsAsync(TokenizerResult i controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config)); var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode); - return new StableDiffusionXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger); + return new StableDiffusionXLPipeline(pipelineOptions, tokenizer, tokenizer2, textEncoder, textEncoder2, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.Schedulers, config.SchedulerOptions, logger); } diff --git a/OnnxStack.UI/Behaviors/SliderMouseWheelBehavior.cs b/OnnxStack.UI/Behaviors/SliderMouseWheelBehavior.cs index 80ce3b39..9f084380 100644 --- a/OnnxStack.UI/Behaviors/SliderMouseWheelBehavior.cs +++ b/OnnxStack.UI/Behaviors/SliderMouseWheelBehavior.cs @@ -1,5 +1,4 @@ using Microsoft.Xaml.Behaviors; -using Newtonsoft.Json.Linq; using System.Windows.Controls; using System.Windows.Input;