Skip to content

Commit

Permalink
Fix incorrect upscale tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed Jun 12, 2024
1 parent de788da commit 6ae4eee
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public async IAsyncEnumerable<OnnxImage> RunAsync(IAsyncEnumerable<OnnxImage> im
private async Task<OnnxImage> UpscaleImageAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
{
var inputTensor = inputImage.GetImageTensor(_upscaleModel.NormalizeType, _upscaleModel.Channels);
var outputTensor = await RunInternalAsync(inputTensor, inputImage.Height, inputImage.Width, cancellationToken);
var outputTensor = await RunInternalAsync(inputTensor, cancellationToken);
return new OnnxImage(outputTensor, _upscaleModel.NormalizeType);
}

Expand All @@ -164,10 +164,7 @@ public async Task<DenseTensor<float>> UpscaleTensorAsync(DenseTensor<float> inpu
if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne)
inputTensor.NormalizeOneOneToZeroOne();

var height = inputTensor.Dimensions[2];
var width = inputTensor.Dimensions[3];
var result = await RunInternalAsync(inputTensor, height, width, cancellationToken);

var result = await RunInternalAsync(inputTensor, cancellationToken);
if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne)
result.NormalizeZeroOneToOneOne();

Expand All @@ -181,9 +178,9 @@ public async Task<DenseTensor<float>> UpscaleTensorAsync(DenseTensor<float> inpu
/// <param name="inputTensor">The input tensor.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
private async Task<DenseTensor<float>> RunInternalAsync(DenseTensor<float> inputTensor, int height, int width, CancellationToken cancellationToken = default)
private async Task<DenseTensor<float>> RunInternalAsync(DenseTensor<float> inputTensor, CancellationToken cancellationToken = default)
{
if (height <= _upscaleModel.TileSize && width <= _upscaleModel.TileSize)
if (inputTensor.Dimensions[2] <= _upscaleModel.SampleSize && inputTensor.Dimensions[3] <= _upscaleModel.SampleSize)
{
return await RunInferenceAsync(inputTensor, cancellationToken);
}
Expand All @@ -194,10 +191,10 @@ private async Task<DenseTensor<float>> RunInternalAsync(DenseTensor<float> input
inputTiles.Width * _upscaleModel.ScaleFactor,
inputTiles.Height * _upscaleModel.ScaleFactor,
inputTiles.Overlap * _upscaleModel.ScaleFactor,
await RunInternalAsync(inputTiles.Tile1, inputTiles.Height, inputTiles.Width, cancellationToken),
await RunInternalAsync(inputTiles.Tile2, inputTiles.Height, inputTiles.Width, cancellationToken),
await RunInternalAsync(inputTiles.Tile3, inputTiles.Height, inputTiles.Width, cancellationToken),
await RunInternalAsync(inputTiles.Tile4, inputTiles.Height, inputTiles.Width, cancellationToken)
await RunInternalAsync(inputTiles.Tile1, cancellationToken),
await RunInternalAsync(inputTiles.Tile2, cancellationToken),
await RunInternalAsync(inputTiles.Tile3, cancellationToken),
await RunInternalAsync(inputTiles.Tile4, cancellationToken)
);
return outputTiles.JoinImageTiles();
}
Expand Down

0 comments on commit 6ae4eee

Please sign in to comment.