Skip to content

Commit

Permalink
Configurable scheduler sets
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed Jun 12, 2024
1 parent 106a4c4 commit e4f9bca
Show file tree
Hide file tree
Showing 16 changed files with 179 additions and 107 deletions.
67 changes: 1 addition & 66 deletions OnnxStack.Core/Extensions/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public static SessionOptions GetSessionOptions(this OnnxModelConfig configuratio
InterOpNumThreads = configuration.InterOpNumThreads.Value,
IntraOpNumThreads = configuration.IntraOpNumThreads.Value
};

switch (configuration.ExecutionProvider)
{
case ExecutionProvider.DirectML:
Expand Down Expand Up @@ -87,72 +88,6 @@ public static bool IsNullOrEmpty<TSource>(this IEnumerable<TSource> source)
}


/// <summary>
/// Batches the source sequence into sized buckets.
/// </summary>
/// <typeparam name="TSource">Type of elements in <paramref name="source" /> sequence.</typeparam>
/// <param name="source">The source sequence.</param>
/// <param name="size">Size of buckets.</param>
/// <returns>A sequence of equally sized buckets containing elements of the source collection.</returns>
/// <remarks>
/// This operator uses deferred execution and streams its results (buckets and bucket content).
/// </remarks>
public static IEnumerable<IEnumerable<TSource>> Batch<TSource>(this IEnumerable<TSource> source, int size)
{
return Batch(source, size, x => x);
}

/// <summary>
/// Batches the source sequence into sized buckets and applies a projection to each bucket.
/// </summary>
/// <typeparam name="TSource">Type of elements in <paramref name="source" /> sequence.</typeparam>
/// <typeparam name="TResult">Type of result returned by <paramref name="resultSelector" />.</typeparam>
/// <param name="source">The source sequence.</param>
/// <param name="size">Size of buckets.</param>
/// <param name="resultSelector">The projection to apply to each bucket.</param>
/// <returns>A sequence of projections on equally sized buckets containing elements of the source collection.</returns>
/// <remarks>
/// This operator uses deferred execution and streams its results (buckets and bucket content).
/// </remarks>
public static IEnumerable<TResult> Batch<TSource, TResult>(this IEnumerable<TSource> source, int size, Func<IEnumerable<TSource>, TResult> resultSelector)
{
if (source == null)
throw new ArgumentNullException(nameof(source));
if (size <= 0)
throw new ArgumentOutOfRangeException(nameof(size));
if (resultSelector == null)
throw new ArgumentNullException(nameof(resultSelector));
return BatchImpl(source, size, resultSelector);
}


private static IEnumerable<TResult> BatchImpl<TSource, TResult>(this IEnumerable<TSource> source, int size, Func<IEnumerable<TSource>, TResult> resultSelector)
{
TSource[] bucket = null;
var count = 0;
foreach (var item in source)
{
if (bucket == null)
bucket = new TSource[size];

bucket[count++] = item;

// The bucket is fully buffered before it's yielded
if (count != size)
continue;

// Select is necessary so bucket contents are streamed too
yield return resultSelector(bucket.Select(x => x));
bucket = null;
count = 0;
}

// Return the last bucket with all remaining elements
if (bucket != null && count > 0)
yield return resultSelector(bucket.Take(count));
}


/// <summary>
/// Get the index of the specified item
/// </summary>
Expand Down
45 changes: 45 additions & 0 deletions OnnxStack.StableDiffusion/Config/SchedulerOptions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using OnnxStack.StableDiffusion.Enums;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;

namespace OnnxStack.StableDiffusion.Config
{
Expand Down Expand Up @@ -36,6 +37,7 @@ public record SchedulerOptions
/// If value is set to 0 a random seed is used.
/// </value>
[Range(0, int.MaxValue)]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public int Seed { get; set; }

/// <summary>
Expand All @@ -45,6 +47,7 @@ public record SchedulerOptions
/// The number of steps to run inference for. The more steps the longer it will take to run the inference loop but the image quality should improve.
/// </value>
[Range(5, 200)]

public int InferenceSteps { get; set; } = 30;

/// <summary>
Expand All @@ -62,34 +65,76 @@ public record SchedulerOptions
public float Strength { get; set; } = 0.6f;

[Range(0, int.MaxValue)]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public int TrainTimesteps { get; set; } = 1000;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public float BetaStart { get; set; } = 0.00085f;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public float BetaEnd { get; set; } = 0.012f;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public IEnumerable<float> TrainedBetas { get; set; }

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public TimestepSpacingType TimestepSpacing { get; set; } = TimestepSpacingType.Linspace;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public BetaScheduleType BetaSchedule { get; set; } = BetaScheduleType.ScaledLinear;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public int StepsOffset { get; set; } = 0;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public bool UseKarrasSigmas { get; set; } = false;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public VarianceType VarianceType { get; set; } = VarianceType.FixedSmall;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public float SampleMaxValue { get; set; } = 1.0f;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public bool Thresholding { get; set; } = false;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public bool ClipSample { get; set; } = false;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public float ClipSampleRange { get; set; } = 1f;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public PredictionType PredictionType { get; set; } = PredictionType.Epsilon;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public AlphaTransformType AlphaTransformType { get; set; } = AlphaTransformType.Cosine;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public float MaximumBeta { get; set; } = 0.999f;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public List<int> Timesteps { get; set; }

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public int OriginalInferenceSteps { get; set; } = 50;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public float AestheticScore { get; set; } = 6f;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public float AestheticNegativeScore { get; set; } = 2.5f;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public float ConditioningScale { get; set; } = 0.7f;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public int InferenceSteps2 { get; set; } = 10;

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public float GuidanceScale2 { get; set; } = 0;

[JsonIgnore]
public bool IsKarrasScheduler
{
get
Expand Down
11 changes: 7 additions & 4 deletions OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@ public record StableDiffusionModelSet : IOnnxModelSetConfig
{
public string Name { get; set; }
public bool IsEnabled { get; set; }
public int SampleSize { get; set; } = 512;
public DiffuserPipelineType PipelineType { get; set; }
public List<DiffuserType> Diffusers { get; set; } = new List<DiffuserType>();
public MemoryModeType MemoryMode { get; set; }
public int DeviceId { get; set; }
public int InterOpNumThreads { get; set; }
public int IntraOpNumThreads { get; set; }
public ExecutionMode ExecutionMode { get; set; }
public ExecutionProvider ExecutionProvider { get; set; }
public OnnxModelPrecision Precision { get; set; }
public MemoryModeType MemoryMode { get; set; }
public int SampleSize { get; set; } = 512;
public DiffuserPipelineType PipelineType { get; set; }
public List<DiffuserType> Diffusers { get; set; }

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public List<SchedulerType> Schedulers { get; set; }

[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public TokenizerModelConfig TokenizerConfig { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ protected async Task<DenseTensor<float>> DiffuseDecodeAsync(PromptOptions prompt

// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _unet.UnloadAsync();
await _decoderUnet.UnloadAsync();

return latents;
}
Expand Down
24 changes: 23 additions & 1 deletion OnnxStack.StableDiffusion/Models/AutoEncoderModel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using OnnxStack.Core.Config;
using Microsoft.ML.OnnxRuntime;
using OnnxStack.Core.Config;
using OnnxStack.Core.Model;

namespace OnnxStack.StableDiffusion.Models
Expand All @@ -13,6 +14,27 @@ public AutoEncoderModel(AutoEncoderModelConfig configuration) : base(configurati
}

public float ScaleFactor => _configuration.ScaleFactor;


public static AutoEncoderModel Create(AutoEncoderModelConfig configuration)
{
return new AutoEncoderModel(configuration);
}

public static AutoEncoderModel Create(string modelFile, float scaleFactor, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
{
var configuration = new AutoEncoderModelConfig
{
DeviceId = deviceId,
ExecutionProvider = executionProvider,
ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,
InterOpNumThreads = 0,
IntraOpNumThreads = 0,
OnnxModelPath = modelFile,
ScaleFactor = scaleFactor
};
return new AutoEncoderModel(configuration);
}
}

public record AutoEncoderModelConfig : OnnxModelConfig
Expand Down
2 changes: 1 addition & 1 deletion OnnxStack.StableDiffusion/Models/ControlNetModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public static ControlNetModel Create(string modelFile, ControlNetType type, int
InterOpNumThreads = 0,
IntraOpNumThreads = 0,
OnnxModelPath = modelFile,
Type = type,
Type = type
};
return new ControlNetModel(configuration);
}
Expand Down
22 changes: 21 additions & 1 deletion OnnxStack.StableDiffusion/Models/TextEncoderModel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using OnnxStack.Core.Config;
using Microsoft.ML.OnnxRuntime;
using OnnxStack.Core.Config;
using OnnxStack.Core.Model;

namespace OnnxStack.StableDiffusion.Models
Expand All @@ -11,6 +12,25 @@ public TextEncoderModel(TextEncoderModelConfig configuration) : base(configurati
{
_configuration = configuration;
}

public static TextEncoderModel Create(TextEncoderModelConfig configuration)
{
return new TextEncoderModel(configuration);
}

public static TextEncoderModel Create(string modelFile, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
{
var configuration = new TextEncoderModelConfig
{
DeviceId = deviceId,
ExecutionProvider = executionProvider,
ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,
InterOpNumThreads = 0,
IntraOpNumThreads = 0,
OnnxModelPath = modelFile
};
return new TextEncoderModel(configuration);
}
}

public record TextEncoderModelConfig : OnnxModelConfig
Expand Down
27 changes: 26 additions & 1 deletion OnnxStack.StableDiffusion/Models/TokenizerModel.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using OnnxStack.Core.Config;
using Microsoft.ML.OnnxRuntime;
using OnnxStack.Core.Config;
using OnnxStack.Core.Model;
using OnnxStack.StableDiffusion.Enums;

namespace OnnxStack.StableDiffusion.Models
{
Expand All @@ -16,6 +18,29 @@ public TokenizerModel(TokenizerModelConfig configuration) : base(configuration)
public int TokenizerLength => _configuration.TokenizerLength;
public int PadTokenId => _configuration.PadTokenId;
public int BlankTokenId => _configuration.BlankTokenId;

public static TokenizerModel Create(TokenizerModelConfig configuration)
{
return new TokenizerModel(configuration);
}

public static TokenizerModel Create(string modelFile, int tokenizerLength = 768, int tokenizerLimit = 77, int padTokenId = 49407, int blankTokenId = 49407, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
{
var configuration = new TokenizerModelConfig
{
DeviceId = deviceId,
ExecutionProvider = executionProvider,
ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,
InterOpNumThreads = 0,
IntraOpNumThreads = 0,
OnnxModelPath = modelFile,
PadTokenId = padTokenId,
BlankTokenId = blankTokenId,
TokenizerLength = tokenizerLength,
TokenizerLimit = tokenizerLimit
};
return new TokenizerModel(configuration);
}
}

public record TokenizerModelConfig : OnnxModelConfig
Expand Down
23 changes: 22 additions & 1 deletion OnnxStack.StableDiffusion/Models/UNetConditionModel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using OnnxStack.Core.Config;
using Microsoft.ML.OnnxRuntime;
using OnnxStack.Core.Config;
using OnnxStack.Core.Model;
using OnnxStack.StableDiffusion.Enums;

Expand All @@ -14,6 +15,26 @@ public UNetConditionModel(UNetConditionModelConfig configuration) : base(configu
}

public ModelType ModelType => _configuration.ModelType;

public static UNetConditionModel Create(UNetConditionModelConfig configuration)
{
return new UNetConditionModel(configuration);
}

public static UNetConditionModel Create(string modelFile, ModelType modelType, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
{
var configuration = new UNetConditionModelConfig
{
DeviceId = deviceId,
ExecutionProvider = executionProvider,
ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,
InterOpNumThreads = 0,
IntraOpNumThreads = 0,
OnnxModelPath = modelFile,
ModelType = modelType
};
return new UNetConditionModel(configuration);
}
}


Expand Down
9 changes: 5 additions & 4 deletions OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.Extensions.Logging;
using OnnxStack.Core;
using OnnxStack.Core.Config;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Diffusers;
using OnnxStack.StableDiffusion.Diffusers.InstaFlow;
Expand All @@ -25,14 +26,14 @@ public sealed class InstaFlowPipeline : StableDiffusionPipeline
/// <param name="vaeDecoder">The vae decoder.</param>
/// <param name="vaeEncoder">The vae encoder.</param>
/// <param name="logger">The logger.</param>
public InstaFlowPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List<DiffuserType> diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
: base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger)
public InstaFlowPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List<DiffuserType> diffusers, List<SchedulerType> schedulers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
: base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, schedulers, defaultSchedulerOptions, logger)
{
_supportedDiffusers = diffusers ?? new List<DiffuserType>
{
DiffuserType.TextToImage
};
_supportedSchedulers = new List<SchedulerType>
_supportedSchedulers = schedulers ?? new List<SchedulerType>
{
SchedulerType.InstaFlow
};
Expand Down Expand Up @@ -87,7 +88,7 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));

var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.Schedulers, config.SchedulerOptions, logger);
}


Expand Down
Loading

0 comments on commit e4f9bca

Please sign in to comment.