diff --git a/OnnxStack.Console/Examples/StableCascadeExample.cs b/OnnxStack.Console/Examples/StableCascadeExample.cs
index a9807f5..da5c91e 100644
--- a/OnnxStack.Console/Examples/StableCascadeExample.cs
+++ b/OnnxStack.Console/Examples/StableCascadeExample.cs
@@ -49,7 +49,7 @@ public async Task RunAsync()
{
SchedulerType = StableDiffusion.Enums.SchedulerType.DDPM,
GuidanceScale =4f,
- InferenceSteps = 60,
+ InferenceSteps = 20,
Width = 1024,
Height = 1024
};
@@ -60,7 +60,7 @@ public async Task RunAsync()
// Run pipeline
var result = await pipeline.RunAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback);
- var image = new OnnxImage(result, ImageNormalizeType.ZeroToOne);
+ var image = new OnnxImage(result);
// Save Image File
await image.SaveAsync(Path.Combine(_outputDirectory, $"output.png"));
diff --git a/OnnxStack.Core/Extensions/Extensions.cs b/OnnxStack.Core/Extensions/Extensions.cs
index 3ea62c6..d405cfa 100644
--- a/OnnxStack.Core/Extensions/Extensions.cs
+++ b/OnnxStack.Core/Extensions/Extensions.cs
@@ -251,7 +251,7 @@ public static long[] ToLong(this int[] array)
/// Normalize the data using Min-Max scaling to ensure all values are in the range [0, 1].
///
/// The values.
- public static void NormalizeMinMax(this Span values)
+ public static Span NormalizeZeroToOne(this Span values)
{
float min = float.PositiveInfinity, max = float.NegativeInfinity;
foreach (var val in values)
@@ -265,6 +265,23 @@ public static void NormalizeMinMax(this Span values)
{
values[i] = (values[i] - min) / range;
}
+ return values;
+ }
+
+
+ public static Span NormalizeOneToOne(this Span values)
+ {
+ float max = values[0];
+ foreach (var val in values)
+ {
+ if (max < val) max = val;
+ }
+
+ for (var i = 0; i < values.Length; i++)
+ {
+ values[i] = (values[i] * 2) - 1;
+ }
+ return values;
}
}
}
diff --git a/OnnxStack.Core/Extensions/OrtValueExtensions.cs b/OnnxStack.Core/Extensions/OrtValueExtensions.cs
index 4c02a80..95c4fc9 100644
--- a/OnnxStack.Core/Extensions/OrtValueExtensions.cs
+++ b/OnnxStack.Core/Extensions/OrtValueExtensions.cs
@@ -127,6 +127,18 @@ public static DenseTensor ToDenseTensor(this OrtValue ortValue, ReadOnlyS
}
+ ///
+ /// Converts Span to DenseTensor.
+ ///
+ /// The ort span value.
+ /// The dimensions.
+ ///
+ public static DenseTensor ToDenseTensor(this Span ortSpanValue, ReadOnlySpan dimensions)
+ {
+ return new DenseTensor(ortSpanValue.ToArray(), dimensions);
+ }
+
+
///
/// Converts to array.
/// TODO: Optimization
diff --git a/OnnxStack.Core/Extensions/TensorExtension.cs b/OnnxStack.Core/Extensions/TensorExtension.cs
index b7ae64a..fbd5904 100644
--- a/OnnxStack.Core/Extensions/TensorExtension.cs
+++ b/OnnxStack.Core/Extensions/TensorExtension.cs
@@ -287,7 +287,7 @@ public static DenseTensor Repeat(this DenseTensor tensor1, int cou
/// The tensor.
public static void NormalizeMinMax(this DenseTensor tensor)
{
- tensor.Buffer.Span.NormalizeMinMax();
+ tensor.Buffer.Span.NormalizeZeroToOne();
}
diff --git a/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs b/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs
index 1dfa97e..f523397 100644
--- a/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs
+++ b/OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs
@@ -196,21 +196,24 @@ protected override async Task> DecodeLatentsAsync(PromptOptio
{
latents = latents.MultiplyBy(_vaeDecoder.ScaleFactor);
- var outputDim = new[] { 1, 4, 256, 256 };
+ var outputDim = new[] { 1, 3, options.Height, options.Width };
var metadata = await _vaeDecoder.GetMetadataAsync();
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(latents);
- inferenceParameters.AddOutputBuffer();
+ inferenceParameters.AddOutputBuffer(outputDim);
- var results = _vaeDecoder.RunInference(inferenceParameters);
+ var results = await _vaeDecoder.RunInferenceAsync(inferenceParameters);
using (var imageResult = results.First())
{
// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _vaeDecoder.UnloadAsync();
- return imageResult.ToDenseTensor();
+ return imageResult
+ .GetTensorMutableDataAsSpan()
+ .NormalizeOneToOne()
+ .ToDenseTensor(outputDim);
}
}
}