Skip to content

Commit

Permalink
Normalize VQGAN output
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed Apr 25, 2024
1 parent 995c9eb commit 8d5575a
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 8 deletions.
4 changes: 2 additions & 2 deletions OnnxStack.Console/Examples/StableCascadeExample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public async Task RunAsync()
{
SchedulerType = StableDiffusion.Enums.SchedulerType.DDPM,
GuidanceScale =4f,
InferenceSteps = 60,
InferenceSteps = 20,
Width = 1024,
Height = 1024
};
Expand All @@ -60,7 +60,7 @@ public async Task RunAsync()
// Run pipeline
var result = await pipeline.RunAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback);

var image = new OnnxImage(result, ImageNormalizeType.ZeroToOne);
var image = new OnnxImage(result);

// Save Image File
await image.SaveAsync(Path.Combine(_outputDirectory, $"output.png"));
Expand Down
19 changes: 18 additions & 1 deletion OnnxStack.Core/Extensions/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ public static long[] ToLong(this int[] array)
/// Normalize the data using Min-Max scaling to ensure all values are in the range [0, 1].
/// </summary>
/// <param name="values">The values.</param>
public static void NormalizeMinMax(this Span<float> values)
public static Span<float> NormalizeZeroToOne(this Span<float> values)
{
float min = float.PositiveInfinity, max = float.NegativeInfinity;
foreach (var val in values)
Expand All @@ -265,6 +265,23 @@ public static void NormalizeMinMax(this Span<float> values)
{
values[i] = (values[i] - min) / range;
}
return values;
}


public static Span<float> NormalizeOneToOne(this Span<float> values)
{
float max = values[0];
foreach (var val in values)
{
if (max < val) max = val;
}

for (var i = 0; i < values.Length; i++)
{
values[i] = (values[i] * 2) - 1;
}
return values;
}
}
}
12 changes: 12 additions & 0 deletions OnnxStack.Core/Extensions/OrtValueExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ public static DenseTensor<float> ToDenseTensor(this OrtValue ortValue, ReadOnlyS
}


/// <summary>
/// Converts Span<float> to DenseTensor<float>.
/// </summary>
/// <param name="ortSpanValue">The ort span value.</param>
/// <param name="dimensions">The dimensions.</param>
/// <returns></returns>
public static DenseTensor<float> ToDenseTensor(this Span<float> ortSpanValue, ReadOnlySpan<int> dimensions)
{
return new DenseTensor<float>(ortSpanValue.ToArray(), dimensions);
}


/// <summary>
/// Converts to array.
/// TODO: Optimization
Expand Down
2 changes: 1 addition & 1 deletion OnnxStack.Core/Extensions/TensorExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ public static DenseTensor<float> Repeat(this DenseTensor<float> tensor1, int cou
/// <param name="tensor">The tensor.</param>
public static void NormalizeMinMax(this DenseTensor<float> tensor)
{
tensor.Buffer.Span.NormalizeMinMax();
tensor.Buffer.Span.NormalizeZeroToOne();
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,21 +196,24 @@ protected override async Task<DenseTensor<float>> DecodeLatentsAsync(PromptOptio
{
latents = latents.MultiplyBy(_vaeDecoder.ScaleFactor);

var outputDim = new[] { 1, 4, 256, 256 };
var outputDim = new[] { 1, 3, options.Height, options.Width };
var metadata = await _vaeDecoder.GetMetadataAsync();
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(latents);
inferenceParameters.AddOutputBuffer();
inferenceParameters.AddOutputBuffer(outputDim);

var results = _vaeDecoder.RunInference(inferenceParameters);
var results = await _vaeDecoder.RunInferenceAsync(inferenceParameters);
using (var imageResult = results.First())
{
// Unload if required
if (_memoryMode == MemoryModeType.Minimum)
await _vaeDecoder.UnloadAsync();

return imageResult.ToDenseTensor();
return imageResult
.GetTensorMutableDataAsSpan<float>()
.NormalizeOneToOne()
.ToDenseTensor(outputDim);
}
}
}
Expand Down

0 comments on commit 8d5575a

Please sign in to comment.