diff --git a/OnnxStack.Console/Examples/StableCascadeExample.cs b/OnnxStack.Console/Examples/StableCascadeExample.cs index a9807f5..da5c91e 100644 --- a/OnnxStack.Console/Examples/StableCascadeExample.cs +++ b/OnnxStack.Console/Examples/StableCascadeExample.cs @@ -49,7 +49,7 @@ public async Task RunAsync() { SchedulerType = StableDiffusion.Enums.SchedulerType.DDPM, GuidanceScale =4f, - InferenceSteps = 60, + InferenceSteps = 20, Width = 1024, Height = 1024 }; @@ -60,7 +60,7 @@ public async Task RunAsync() // Run pipeline var result = await pipeline.RunAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback); - var image = new OnnxImage(result, ImageNormalizeType.ZeroToOne); + var image = new OnnxImage(result); // Save Image File await image.SaveAsync(Path.Combine(_outputDirectory, $"output.png")); diff --git a/OnnxStack.Core/Extensions/Extensions.cs b/OnnxStack.Core/Extensions/Extensions.cs index 3ea62c6..d405cfa 100644 --- a/OnnxStack.Core/Extensions/Extensions.cs +++ b/OnnxStack.Core/Extensions/Extensions.cs @@ -251,7 +251,7 @@ public static long[] ToLong(this int[] array) /// Normalize the data using Min-Max scaling to ensure all values are in the range [0, 1]. /// /// The values. - public static void NormalizeMinMax(this Span values) + public static Span NormalizeZeroToOne(this Span values) { float min = float.PositiveInfinity, max = float.NegativeInfinity; foreach (var val in values) @@ -265,6 +265,23 @@ public static void NormalizeMinMax(this Span values) { values[i] = (values[i] - min) / range; } + return values; + } + + + public static Span NormalizeOneToOne(this Span values) + { + float max = values[0]; + foreach (var val in values) + { + if (max < val) max = val; + } + + for (var i = 0; i < values.Length; i++) + { + values[i] = (values[i] * 2) - 1; + } + return values; } } } diff --git a/OnnxStack.Core/Extensions/OrtValueExtensions.cs b/OnnxStack.Core/Extensions/OrtValueExtensions.cs index 4c02a80..95c4fc9 100644 --- a/OnnxStack.Core/Extensions/OrtValueExtensions.cs +++ b/OnnxStack.Core/Extensions/OrtValueExtensions.cs @@ -127,6 +127,18 @@ public static DenseTensor ToDenseTensor(this OrtValue ortValue, ReadOnlyS } + /// + /// Converts Span to DenseTensor. + /// + /// The ort span value. + /// The dimensions. + /// + public static DenseTensor ToDenseTensor(this Span ortSpanValue, ReadOnlySpan dimensions) + { + return new DenseTensor(ortSpanValue.ToArray(), dimensions); + } + + /// /// Converts to array. /// TODO: Optimization diff --git a/OnnxStack.Core/Extensions/TensorExtension.cs b/OnnxStack.Core/Extensions/TensorExtension.cs index b7ae64a..fbd5904 100644 --- a/OnnxStack.Core/Extensions/TensorExtension.cs +++ b/OnnxStack.Core/Extensions/TensorExtension.cs @@ -287,7 +287,7 @@ public static DenseTensor Repeat(this DenseTensor tensor1, int cou /// The tensor. public static void NormalizeMinMax(this DenseTensor tensor) { - tensor.Buffer.Span.NormalizeMinMax(); + tensor.Buffer.Span.NormalizeZeroToOne(); } diff --git a/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs index 1dfa97e..f523397 100644 --- a/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs +++ b/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs @@ -196,21 +196,24 @@ protected override async Task> DecodeLatentsAsync(PromptOptio { latents = latents.MultiplyBy(_vaeDecoder.ScaleFactor); - var outputDim = new[] { 1, 4, 256, 256 }; + var outputDim = new[] { 1, 3, options.Height, options.Width }; var metadata = await _vaeDecoder.GetMetadataAsync(); using (var inferenceParameters = new OnnxInferenceParameters(metadata)) { inferenceParameters.AddInputTensor(latents); - inferenceParameters.AddOutputBuffer(); + inferenceParameters.AddOutputBuffer(outputDim); - var results = _vaeDecoder.RunInference(inferenceParameters); + var results = await _vaeDecoder.RunInferenceAsync(inferenceParameters); using (var imageResult = results.First()) { // Unload if required if (_memoryMode == MemoryModeType.Minimum) await _vaeDecoder.UnloadAsync(); - return imageResult.ToDenseTensor(); + return imageResult + .GetTensorMutableDataAsSpan() + .NormalizeOneToOne() + .ToDenseTensor(outputDim); } } }