Skip to content

Commit

Permalink
refactor!: callback functions return StatusArgs instead of Status (#574)
Browse files Browse the repository at this point in the history
* fix!: NativePacketCallback returns StatusArgs instead of IntPtr

* remove pinned Status

* Marshal string as an IntPtr and release it after processing

* add Null check
  • Loading branch information
homuler authored May 17, 2022
1 parent ba48f78 commit d334e80
Show file tree
Hide file tree
Showing 14 changed files with 179 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace Mediapipe
{
public class CalculatorGraph : MpResourceHandle
{
public delegate IntPtr NativePacketCallback(IntPtr graphPtr, int streamId, IntPtr packetPtr);
public delegate Status PacketCallback<TPacket, TValue>(TPacket packet) where TPacket : Packet<TValue>;
public delegate Status.StatusArgs NativePacketCallback(IntPtr graphPtr, int streamId, IntPtr packetPtr);
public delegate void PacketCallback<TPacket, TValue>(TPacket packet) where TPacket : Packet<TValue>;

public CalculatorGraph() : base()
{
Expand Down Expand Up @@ -79,17 +79,16 @@ public Status ObserveOutputStream(string streamName, int streamId, NativePacketC
{
NativePacketCallback nativePacketCallback = (IntPtr graphPtr, int streamId, IntPtr packetPtr) =>
{
Status status = null;
try
{
var packet = Packet<TValue>.Create<TPacket>(packetPtr, false);
status = packetCallback(packet);
packetCallback(packet);
return Status.StatusArgs.Ok();
}
catch (Exception e)
{
status = Status.FailedPrecondition(e.ToString());
return Status.StatusArgs.Internal(e.ToString());
}
return status.mpPtr;
};
callbackHandle = GCHandle.Alloc(nativePacketCallback, GCHandleType.Pinned);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// https://opensource.org/licenses/MIT.

using System;
using System.Runtime.InteropServices;

namespace Mediapipe
{
Expand All @@ -31,6 +32,104 @@ public enum StatusCode : int
Unauthenticated = 16,
}

[StructLayout(LayoutKind.Sequential)]
public readonly struct StatusArgs
{
private readonly StatusCode _code;
private readonly IntPtr _message;

private StatusArgs(StatusCode code, string message = null)
{
_code = code;
_message = Marshal.StringToHGlobalAnsi(message);
}

public static StatusArgs Ok()
{
return new StatusArgs(StatusCode.Ok);
}

public static StatusArgs Cancelled(string message = null)
{
return new StatusArgs(StatusCode.Cancelled, message);
}

public static StatusArgs Unknown(string message = null)
{
return new StatusArgs(StatusCode.Unknown, message);
}

public static StatusArgs InvalidArgument(string message = null)
{
return new StatusArgs(StatusCode.InvalidArgument, message);
}

public static StatusArgs DeadlineExceeded(string message = null)
{
return new StatusArgs(StatusCode.DeadlineExceeded, message);
}

public static StatusArgs NotFound(string message = null)
{
return new StatusArgs(StatusCode.NotFound, message);
}

public static StatusArgs AlreadyExists(string message = null)
{
return new StatusArgs(StatusCode.AlreadyExists, message);
}

public static StatusArgs PermissionDenied(string message = null)
{
return new StatusArgs(StatusCode.PermissionDenied, message);
}

public static StatusArgs ResourceExhausted(string message = null)
{
return new StatusArgs(StatusCode.ResourceExhausted, message);
}

public static StatusArgs FailedPrecondition(string message = null)
{
return new StatusArgs(StatusCode.FailedPrecondition, message);
}

public static StatusArgs Aborted(string message = null)
{
return new StatusArgs(StatusCode.Aborted, message);
}

public static StatusArgs OutOfRange(string message = null)
{
return new StatusArgs(StatusCode.OutOfRange, message);
}

public static StatusArgs Unimplemented(string message = null)
{
return new StatusArgs(StatusCode.Unimplemented, message);
}

public static StatusArgs Internal(string message = null)
{
return new StatusArgs(StatusCode.Internal, message);
}

public static StatusArgs Unavailable(string message = null)
{
return new StatusArgs(StatusCode.Unavailable, message);
}

public static StatusArgs DataLoss(string message = null)
{
return new StatusArgs(StatusCode.DataLoss, message);
}

public static StatusArgs Unauthenticated(string message = null)
{
return new StatusArgs(StatusCode.Unauthenticated, message);
}
}

public Status(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

protected override void DeleteMpPtr()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace Mediapipe

public class GlCalculatorHelper : MpResourceHandle
{
public delegate IntPtr NativeGlStatusFunction();
public delegate Status GlStatusFunction();
public delegate Status.StatusArgs NativeGlStatusFunction();
public delegate void GlFunction();

public GlCalculatorHelper() : base()
{
Expand Down Expand Up @@ -45,26 +45,20 @@ public Status RunInGlContext(NativeGlStatusFunction nativeGlStatusFunction)
return new Status(statusPtr);
}

public Status RunInGlContext(GlStatusFunction glStatusFunc)
public Status RunInGlContext(GlFunction glStatusFunc)
{
Status tmpStatus = null;

var status = RunInGlContext(() =>
return RunInGlContext(() =>
{
try
{
tmpStatus = glStatusFunc();
glStatusFunc();
return Status.StatusArgs.Ok();
}
catch (Exception e)
{
tmpStatus = Status.FailedPrecondition(e.ToString());
return Status.StatusArgs.Internal(e.ToString());
}
return tmpStatus.mpPtr;
});

if (tmpStatus != null) { tmpStatus.Dispose(); }

return status;
}

public GlTexture CreateSourceTexture(ImageFrame imageFrame)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System;
using System.Security;
using System.Runtime.InteropServices;

namespace Mediapipe
{
Expand All @@ -21,5 +23,21 @@ internal static partial class UnsafeNativeMethods
#else
"mediapipe_c";
#endif

static UnsafeNativeMethods()
{
mp_api__SetFreeHGlobal(FreeHGlobal);
}

internal delegate void FreeHGlobalDelegate(IntPtr hglobal);

[AOT.MonoPInvokeCallback(typeof(FreeHGlobalDelegate))]
private static void FreeHGlobal(IntPtr hglobal)
{
Marshal.FreeHGlobal(hglobal);
}

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern void mp_api__SetFreeHGlobal([MarshalAs(UnmanagedType.FunctionPtr)] FreeHGlobalDelegate freeHGlobal);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
// https://opensource.org/licenses/MIT.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;

namespace Mediapipe.Unity
{
Expand All @@ -26,16 +24,6 @@ public OutputEventArgs(TValue value)
private static int _Counter = 0;
private static readonly GlobalInstanceTable<int, OutputStream<TPacket, TValue>> _InstanceTable = new GlobalInstanceTable<int, OutputStream<TPacket, TValue>>(20);

/// <summary>
/// Store the last <see cref="Status" /> to prevent it from being GCed while it is used by the Unmanaged Code.<br />
/// This member variable's key is a stream ID, but if the corresponding <see cref="OutputStream" /> instance exists,
/// the last <see cref="Status" /> will be stored in <see cref="_callbackStatus" /> instead.
/// </summary>
/// <remarks>
/// We may need to store multiple instances (e.g. using ring buffers) if the packet callback can be called from multiple threads at the same time.
/// </remarks>
private static readonly Dictionary<int, Status> _CallbackStatus = new Dictionary<int, Status>();

protected readonly CalculatorGraph calculatorGraph;

private readonly int _id;
Expand Down Expand Up @@ -72,14 +60,6 @@ protected TPacket referencePacket
}
}

/// <summary>
/// Store the last <see cref="Status" /> to prevent it from being GCed while it is used by the Unmanaged Code.
/// </summary>
/// <remarks>
/// We may need to store multiple instances (e.g. using ring buffers) if the packet callback can be called from multiple threads at the same time.
/// </remarks>
private Status _callbackStatus;

protected bool canTestPresence => presenceStreamName != null;

/// <summary>
Expand Down Expand Up @@ -109,7 +89,6 @@ public OutputStream(CalculatorGraph calculatorGraph, string streamName, bool obs
this.timeoutMicrosec = timeoutMicrosec;

_InstanceTable.Add(_id, this);
CompressCallbackStatus();
}

/// <summary>
Expand Down Expand Up @@ -181,7 +160,6 @@ public void RemoveAllListeners()
public void Close()
{
RemoveAllListeners();
RemoveCallbackStatus();

_poller?.Dispose();
_poller = null;
Expand Down Expand Up @@ -398,18 +376,18 @@ protected bool TryConsumePacketValue(Packet<TValue> packet, out TValue value, lo
}

[AOT.MonoPInvokeCallback(typeof(CalculatorGraph.NativePacketCallback))]
protected static IntPtr InvokeIfOutputStreamFound(IntPtr graphPtr, int streamId, IntPtr packetPtr)
protected static Status.StatusArgs InvokeIfOutputStreamFound(IntPtr graphPtr, int streamId, IntPtr packetPtr)
{
try
{
var isFound = _InstanceTable.TryGetValue(streamId, out var outputStream);
if (!isFound)
{
return GetPinnedStatusPtr(streamId, Status.FailedPrecondition($"OutputStream with id {streamId} is not found"));
return Status.StatusArgs.NotFound($"OutputStream with id {streamId} is not found");
}
if (outputStream.calculatorGraph.mpPtr != graphPtr)
{
return GetPinnedStatusPtr(streamId, Status.FailedPrecondition($"OutputStream is found, but is not linked to the specified CalclatorGraph"));
return Status.StatusArgs.InvalidArgument($"OutputStream is found, but is not linked to the specified CalclatorGraph");
}

outputStream.referencePacket.SwitchNativePtr(packetPtr);
Expand All @@ -419,63 +397,11 @@ protected static IntPtr InvokeIfOutputStreamFound(IntPtr graphPtr, int streamId,
}
outputStream.referencePacket.ReleaseMpResource();

return outputStream.GetPinnedStatusPtr(Status.Ok());
return Status.StatusArgs.Ok();
}
catch (Exception e)
{
return GetPinnedStatusPtr(streamId, Status.FailedPrecondition(e.ToString()));
}
}

/// <summary>
/// To prevent <paramref name="status" /> from being GCed, store it in <see cref="_CallbackStatus" />.
/// </summary>
/// <remarks>
/// Prefer the instance method with the same name.
/// </remarks>
protected static IntPtr GetPinnedStatusPtr(int streamId, Status status)
{
lock (((ICollection)_CallbackStatus).SyncRoot)
{
_CallbackStatus[streamId] = status;
return status.mpPtr;
}
}

/// <summary>
/// To prevent <paramref name="status" /> from being GCed, store it in <see cref="_callbackStatus" />.
/// </summary>
protected IntPtr GetPinnedStatusPtr(Status status)
{
_callbackStatus = status;
return _callbackStatus.mpPtr;
}

protected void RemoveCallbackStatus()
{
_callbackStatus?.Dispose();
_callbackStatus = null;

lock (((ICollection)_CallbackStatus).SyncRoot)
{
if (_CallbackStatus.TryGetValue(_id, out var status))
{
status.Dispose();
}
var _ = _CallbackStatus.Remove(_id);
}
}

protected static void CompressCallbackStatus()
{
lock (((ICollection)_CallbackStatus).SyncRoot)
{
var deadKeys = _CallbackStatus.Where(x => !_InstanceTable.ContainsKey(x.Key)).Select(x => x.Key).ToArray();

foreach (var key in deadKeys)
{
var _ = _CallbackStatus.Remove(key);
}
return Status.StatusArgs.Internal(e.ToString());
}
}
}
Expand Down
Loading

0 comments on commit d334e80

Please sign in to comment.