Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 7cc7527

Browse files
committed
Remove legacy batch tensor duplication code
1 parent ca747ff commit 7cc7527

File tree

11 files changed

+21
-63
lines changed

11 files changed

+21
-63
lines changed

OnnxStack.Console/appsettings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"InterOpNumThreads": 0,
2727
"IntraOpNumThreads": 0,
2828
"ExecutionMode": "ORT_SEQUENTIAL",
29-
"ExecutionProvider": "Cuda",
29+
"ExecutionProvider": "DirectML",
3030
"ModelConfigurations": [
3131
{
3232
"Type": "Tokenizer",

OnnxStack.StableDiffusion/Config/PromptOptions.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ public class PromptOptions
1515
[StringLength(512)]
1616
public string NegativePrompt { get; set; }
1717

18-
public int BatchCount { get; set; } = 1; // Delete Me
19-
2018
public InputImage InputImage { get; set; }
2119

2220
public InputImage InputImageMask { get; set; }

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
119119

120120
return schedulerResult;
121121
}
122-
122+
123123

124124
/// <summary>
125125
/// Runs the stable diffusion batch loop
@@ -210,35 +210,23 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
210210
// Scale and decode the image latents with vae.
211211
latents = latents.MultiplyBy(1.0f / model.ScaleFactor);
212212

213-
var images = prompt.BatchCount > 1
214-
? latents.Split(prompt.BatchCount)
215-
: new[] { latents };
216-
var imageTensors = new List<DenseTensor<float>>();
217-
foreach (var image in images)
218-
{
219-
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
220-
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeDecoder);
213+
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
214+
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeDecoder);
221215

222-
var outputDim = new[] { 1, 3, options.Height, options.Width };
223-
var outputBuffer = new DenseTensor<float>(outputDim);
224-
using (var inputTensorValue = image.ToOrtValue())
225-
using (var outputTensorValue = outputBuffer.ToOrtValue())
216+
var outputDim = new[] { 1, 3, options.Height, options.Width };
217+
var outputBuffer = new DenseTensor<float>(outputDim);
218+
using (var inputTensorValue = latents.ToOrtValue())
219+
using (var outputTensorValue = outputBuffer.ToOrtValue())
220+
{
221+
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
222+
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
223+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputs, outputs);
224+
using (var imageResult = results.First())
226225
{
227-
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
228-
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
229-
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputs, outputs);
230-
using (var imageResult = results.First())
231-
{
232-
imageTensors.Add(outputBuffer);
233-
}
226+
_logger?.LogEnd("End", timestamp);
227+
return outputBuffer;
234228
}
235229
}
236-
237-
var result = prompt.BatchCount > 1
238-
? imageTensors.Join()
239-
: imageTensors.FirstOrDefault();
240-
_logger?.LogEnd("End", timestamp);
241-
return result;
242230
}
243231

244232

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,7 @@ protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions m
7676
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
7777
.MultiplyBy(model.ScaleFactor);
7878

79-
var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
80-
if (prompt.BatchCount > 1)
81-
return noisySample.Repeat(prompt.BatchCount);
82-
83-
return noisySample;
79+
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
8480
}
8581
}
8682
}

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/TextDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
4848
/// <returns></returns>
4949
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
5050
{
51-
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(prompt.BatchCount), scheduler.InitNoiseSigma));
51+
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
5252
}
5353
}
5454
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/AnimateDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
5151
/// <returns></returns>
5252
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
5353
{
54-
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(prompt.BatchCount), scheduler.InitNoiseSigma));
54+
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
5555
}
5656
}
5757
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,7 @@ protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions m
7878
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
7979
.MultiplyBy(model.ScaleFactor);
8080

81-
var noisySample = scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
82-
if (prompt.BatchCount > 1)
83-
return noisySample.Repeat(prompt.BatchCount);
84-
85-
return noisySample;
81+
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
8682
}
8783
}
8884
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,6 @@ private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions
161161
});
162162

163163
imageTensor = imageTensor.MultiplyBy(modelOptions.ScaleFactor);
164-
if (promptOptions.BatchCount > 1)
165-
imageTensor = imageTensor.Repeat(promptOptions.BatchCount);
166-
167164
if (schedulerOptions.GuidanceScale > 1f)
168165
imageTensor = imageTensor.Repeat(2);
169166

@@ -232,9 +229,6 @@ private DenseTensor<float> PrepareImageMask(IModelOptions modelOptions, PromptOp
232229
{
233230
var sample = inferResult.FirstElementAs<DenseTensor<float>>();
234231
var scaledSample = sample.MultiplyBy(modelOptions.ScaleFactor);
235-
if (promptOptions.BatchCount > 1)
236-
scaledSample = scaledSample.Repeat(promptOptions.BatchCount);
237-
238232
if (schedulerOptions.GuidanceScale > 1f)
239233
scaledSample = scaledSample.Repeat(2);
240234

@@ -267,7 +261,7 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
267261
/// <returns></returns>
268262
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
269263
{
270-
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(prompt.BatchCount), scheduler.InitNoiseSigma));
264+
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
271265
}
272266

273267

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,6 @@ protected override async Task<DenseTensor<float>> PrepareLatents(IModelOptions m
173173
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
174174
.MultiplyBy(model.ScaleFactor);
175175

176-
if (prompt.BatchCount > 1)
177-
return scaledSample.Repeat(prompt.BatchCount);
178-
179176
return scaledSample;
180177
}
181178
}
@@ -214,9 +211,6 @@ private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions
214211
}
215212
});
216213

217-
if (promptOptions.BatchCount > 1)
218-
return maskTensor.Repeat(promptOptions.BatchCount);
219-
220214
return maskTensor;
221215
}
222216
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/TextDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
5050
/// <returns></returns>
5151
protected override Task<DenseTensor<float>> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
5252
{
53-
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(prompt.BatchCount), scheduler.InitNoiseSigma));
53+
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
5454
}
5555
}
5656
}

OnnxStack.StableDiffusion/Services/PromptService.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,6 @@ public async Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, Pro
4646
var promptEmbeddings = await GenerateEmbedsAsync(model, promptTokens, maxPromptTokenCount);
4747
var negativePromptEmbeddings = await GenerateEmbedsAsync(model, negativePromptTokens, maxPromptTokenCount);
4848

49-
// If we have a batch, repeat the prompt embeddings
50-
if (promptOptions.BatchCount > 1)
51-
{
52-
promptEmbeddings = promptEmbeddings.Repeat(promptOptions.BatchCount);
53-
negativePromptEmbeddings = negativePromptEmbeddings.Repeat(promptOptions.BatchCount);
54-
}
55-
5649
// If we are doing guided diffusion, concatenate the negative prompt embeddings
5750
// If not we ingore the negative prompt embeddings
5851
if (isGuidanceEnabled)
@@ -166,6 +159,5 @@ private static IReadOnlyCollection<NamedOnnxValue> CreateInputParameters(params
166159
{
167160
return parameters.ToList().AsReadOnly();
168161
}
169-
170162
}
171163
}

0 commit comments

Comments
 (0)