Skip to content

Commit

Permalink
Upscale/FeatureExtractor progress callback
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed May 13, 2024
1 parent 2dcd76f commit 5eb45d8
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 7 deletions.
2 changes: 1 addition & 1 deletion OnnxStack.Core/Model/OnnxModelSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class OnnxModelSession : IDisposable
public OnnxModelSession(OnnxModelConfig configuration)
{
if (!File.Exists(configuration.OnnxModelPath))
throw new FileNotFoundException("Onnx model file not found", configuration.OnnxModelPath);
throw new FileNotFoundException($"Onnx model file not found, Path: {configuration.OnnxModelPath}", configuration.OnnxModelPath);

_configuration = configuration;
_options = configuration.GetSessionOptions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using OnnxStack.Core.Model;
using OnnxStack.Core.Video;
using OnnxStack.FeatureExtractor.Common;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
Expand Down Expand Up @@ -95,13 +96,15 @@ public async Task<OnnxImage> RunAsync(OnnxImage inputImage, CancellationToken ca
/// </summary>
/// <param name="videoFrames">The input video.</param>
/// <returns></returns>
public async Task<OnnxVideo> RunAsync(OnnxVideo video, CancellationToken cancellationToken = default)
public async Task<OnnxVideo> RunAsync(OnnxVideo video, Action<OnnxImage, OnnxImage> progressCallback = default, CancellationToken cancellationToken = default)
{
var timestamp = _logger?.LogBegin("Extracting OnnxVideo features...");
var featureFrames = new List<OnnxImage>();
foreach (var videoFrame in video.Frames)
{
featureFrames.Add(await RunAsync(videoFrame, cancellationToken));
var result = await RunAsync(videoFrame, cancellationToken);
featureFrames.Add(result);
progressCallback?.Invoke(videoFrame, result);
}
_logger?.LogEnd("Extracting OnnxVideo features complete.", timestamp);
return new OnnxVideo(video.Info, featureFrames);
Expand Down
6 changes: 4 additions & 2 deletions OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,15 @@ public async Task<OnnxImage> RunAsync(OnnxImage inputImage, CancellationToken ca
/// <param name="inputVideo">The input video.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
public async Task<OnnxVideo> RunAsync(OnnxVideo inputVideo, CancellationToken cancellationToken = default)
public async Task<OnnxVideo> RunAsync(OnnxVideo inputVideo, Action<OnnxImage, OnnxImage> progressCallback = default, CancellationToken cancellationToken = default)
{
var timestamp = _logger?.LogBegin("Upscale OnnxVideo..");
var upscaledFrames = new List<OnnxImage>();
foreach (var videoFrame in inputVideo.Frames)
{
upscaledFrames.Add(await UpscaleImageAsync(videoFrame, cancellationToken));
var result = await UpscaleImageAsync(videoFrame, cancellationToken);
upscaledFrames.Add(result);
progressCallback?.Invoke(videoFrame, result);
}

var firstFrame = upscaledFrames.First();
Expand Down
2 changes: 1 addition & 1 deletion OnnxStack.UI/Services/FeatureExtractorService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public async Task<OnnxVideo> GenerateAsync(FeatureExtractorModelSet model, OnnxV
if (!_pipelines.TryGetValue(model, out var pipeline))
throw new Exception("Pipeline not found or is unsupported");

return await pipeline.RunAsync(inputVideo, cancellationToken);
return await pipeline.RunAsync(inputVideo, cancellationToken: cancellationToken);
}


Expand Down
2 changes: 1 addition & 1 deletion OnnxStack.UI/Services/UpscaleService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public async Task<OnnxVideo> GenerateAsync(UpscaleModelSet model, OnnxVideo inpu
if (!_pipelines.TryGetValue(model, out var pipeline))
throw new Exception("Pipeline not found or is unsupported");

return await pipeline.RunAsync(inputVideo, cancellationToken);
return await pipeline.RunAsync(inputVideo, cancellationToken: cancellationToken);
}


Expand Down

0 comments on commit 5eb45d8

Please sign in to comment.