From 6ae4eee4b9bf0db8397f6f5423947fb3a92cc24b Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Thu, 13 Jun 2024 06:57:53 +1200 Subject: [PATCH] Fix incorrect upscale tiling --- .../Pipelines/ImageUpscalePipeline.cs | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs index d4112fb..9fce1fd 100644 --- a/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs +++ b/OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs @@ -148,7 +148,7 @@ public async IAsyncEnumerable RunAsync(IAsyncEnumerable im private async Task UpscaleImageAsync(OnnxImage inputImage, CancellationToken cancellationToken = default) { var inputTensor = inputImage.GetImageTensor(_upscaleModel.NormalizeType, _upscaleModel.Channels); - var outputTensor = await RunInternalAsync(inputTensor, inputImage.Height, inputImage.Width, cancellationToken); + var outputTensor = await RunInternalAsync(inputTensor, cancellationToken); return new OnnxImage(outputTensor, _upscaleModel.NormalizeType); } @@ -164,10 +164,7 @@ public async Task> UpscaleTensorAsync(DenseTensor inpu if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne) inputTensor.NormalizeOneOneToZeroOne(); - var height = inputTensor.Dimensions[2]; - var width = inputTensor.Dimensions[3]; - var result = await RunInternalAsync(inputTensor, height, width, cancellationToken); - + var result = await RunInternalAsync(inputTensor, cancellationToken); if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne) result.NormalizeZeroOneToOneOne(); @@ -181,9 +178,9 @@ public async Task> UpscaleTensorAsync(DenseTensor inpu /// The input tensor. /// The cancellation token. /// - private async Task> RunInternalAsync(DenseTensor inputTensor, int height, int width, CancellationToken cancellationToken = default) + private async Task> RunInternalAsync(DenseTensor inputTensor, CancellationToken cancellationToken = default) { - if (height <= _upscaleModel.TileSize && width <= _upscaleModel.TileSize) + if (inputTensor.Dimensions[2] <= _upscaleModel.SampleSize && inputTensor.Dimensions[3] <= _upscaleModel.SampleSize) { return await RunInferenceAsync(inputTensor, cancellationToken); } @@ -194,10 +191,10 @@ private async Task> RunInternalAsync(DenseTensor input inputTiles.Width * _upscaleModel.ScaleFactor, inputTiles.Height * _upscaleModel.ScaleFactor, inputTiles.Overlap * _upscaleModel.ScaleFactor, - await RunInternalAsync(inputTiles.Tile1, inputTiles.Height, inputTiles.Width, cancellationToken), - await RunInternalAsync(inputTiles.Tile2, inputTiles.Height, inputTiles.Width, cancellationToken), - await RunInternalAsync(inputTiles.Tile3, inputTiles.Height, inputTiles.Width, cancellationToken), - await RunInternalAsync(inputTiles.Tile4, inputTiles.Height, inputTiles.Width, cancellationToken) + await RunInternalAsync(inputTiles.Tile1, cancellationToken), + await RunInternalAsync(inputTiles.Tile2, cancellationToken), + await RunInternalAsync(inputTiles.Tile3, cancellationToken), + await RunInternalAsync(inputTiles.Tile4, cancellationToken) ); return outputTiles.JoinImageTiles(); }