Skip to content

Commit

Permalink
完成异步操作
Browse files Browse the repository at this point in the history
  • Loading branch information
xljiulang committed Oct 9, 2022
1 parent d5ad9da commit 8c4963c
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 147 deletions.
6 changes: 3 additions & 3 deletions App/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ internal class Program
{
static async Task Main(string[] args)
{
var filter = Filter.True.And(f => f.IsUdp);
var filter = Filter.True.And(f => f.IsTcp);
using var divert = new WinDivert(filter, WinDivertLayer.Network);
using var packet = new WinDivertPacket();
var addr = new WinDivertAddress();

while (true)
{
var recvLength = await divert.RecvAsync(packet, ref addr);
var result = packet.GetParseResult();

var checkState = packet.CalcChecksums(ref addr);
// var sendLength = divert.Send(packet, ref addr);
var sendLength = await divert.SendAsync(packet, ref addr);

Console.WriteLine(result.Protocol);
}
Expand Down
202 changes: 71 additions & 131 deletions WindivertDotnet/WinDivert.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System;
using System.ComponentModel;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -13,8 +12,13 @@ namespace WindivertDotnet
[DebuggerDisplay("Filter = {Filter}, Layer = {Layer}")]
public partial class WinDivert : IDisposable
{
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private readonly WinDivertHandle handle;
private unsafe readonly static IOCompletionCallback completionCallback = new(IOCompletionCallback);

[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private readonly Lazy<ThreadPoolBoundHandle> boundHandle;

private unsafe readonly static IOCompletionCallback ioCompletionCallback = new(IOCompletionCallback);

/// <summary>
/// 获取过滤器
Expand All @@ -29,21 +33,14 @@ public partial class WinDivert : IDisposable
/// <summary>
/// 获取软件版本
/// </summary>
/// <exception cref="Win32Exception"></exception>
public Version Version
{
get
{
var major = (int)this.GetParam(WinDivertParam.VersionMajor);
var minor = (int)this.GetParam(WinDivertParam.VersionMinor);
return new Version(major, minor);
}
}
public Version Version { get; }


/// <summary>
/// 获取或设置列队的容量大小
/// </summary>
/// <exception cref="Win32Exception"></exception>
/// <exception cref="InvalidOperationException"></exception>
public long QueueLength
{
get => this.GetParam(WinDivertParam.QueueLength);
Expand All @@ -54,6 +51,7 @@ public long QueueLength
/// 获取或设自动丢弃数据包之前可以排队的最短时长
/// </summary>
/// <exception cref="Win32Exception"></exception>
/// <exception cref="InvalidOperationException"></exception>
public TimeSpan QueueTime
{
get => TimeSpan.FromMilliseconds(this.GetParam(WinDivertParam.QueueTime));
Expand All @@ -64,6 +62,7 @@ public TimeSpan QueueTime
/// 获取或设置存储在数据包队列中的最大字节数
/// </summary>
/// <exception cref="Win32Exception"></exception>
/// <exception cref="InvalidOperationException"></exception>
public long QueueSize
{
get => this.GetParam(WinDivertParam.QueueSize);
Expand Down Expand Up @@ -98,9 +97,14 @@ public WinDivert(string filter, WinDivertLayer layer, short priority = 0, WinDiv
{
throw new Win32Exception();
}
this.boundHandle = new Lazy<ThreadPoolBoundHandle>(() => ThreadPoolBoundHandle.BindHandle(this.handle));

this.Filter = filter;
this.Layer = layer;

var major = this.GetParam(WinDivertParam.VersionMajor);
var minor = this.GetParam(WinDivertParam.VersionMinor);
this.Version = new Version((int)major, (int)minor);
}

/// <summary>
Expand All @@ -112,38 +116,33 @@ public WinDivert(string filter, WinDivertLayer layer, short priority = 0, WinDiv
/// <exception cref="Win32Exception"></exception>
public int Recv(WinDivertPacket packet, ref WinDivertAddress addr)
{
var length = 0;
var result = WinDivertNative.WinDivertRecv(this.handle, packet, packet.Capacity, ref length, ref addr);
if (result == false)
{
throw new Win32Exception();
}
packet.Length = length;
return length;
return this.RecvAsync(packet, ref addr).Result;
}


/// <summary>
/// 读取数据包
/// </summary>
/// <param name="packet">数据包</param>
/// <param name="addr">地址信息</param>
/// <returns></returns>
/// <exception cref="Win32Exception"></exception>
public Task<int> RecvAsync(WinDivertPacket packet, ref WinDivertAddress addr)
{
var operation = new RecvOperation(this.handle, packet, completionCallback);
operation.RecvEx(ref addr);
return operation.Task;
var controller = new WindivertRecvController(this.handle, packet, this.boundHandle.Value, ioCompletionCallback);
controller.IoControl(ref addr);
return controller.Task;
}


private unsafe static void IOCompletionCallback(uint errorCode, uint numBytes, NativeOverlapped* pOVERLAP)
/// <summary>
/// 发送数据包
/// </summary>
/// <param name="packet">数据包</param>
/// <param name="addr">地址信息</param>
/// <returns></returns>
/// <exception cref="Win32Exception"></exception>
public int Send(WinDivertPacket packet, ref WinDivertAddress addr)
{
var operation = (RecvOperation)ThreadPoolBoundHandle.GetNativeOverlappedState(pOVERLAP);
operation.Dispose();

if (errorCode > 0)
{
operation.SetException(errorCode);
}
else
{
operation.SetResult(numBytes);
}
return this.SendAsync(packet, ref addr).Result;
}

/// <summary>
Expand All @@ -153,26 +152,47 @@ private unsafe static void IOCompletionCallback(uint errorCode, uint numBytes, N
/// <param name="addr">地址信息</param>
/// <returns></returns>
/// <exception cref="Win32Exception"></exception>
public int Send(WinDivertPacket packet, ref WinDivertAddress addr)
public Task<int> SendAsync(WinDivertPacket packet, ref WinDivertAddress addr)
{
var length = 0;
var result = WinDivertNative.WinDivertSend(this.handle, packet, packet.Length, ref length, ref addr);
if (result == false)
var controller = new WindivertSendController(this.handle, packet, this.boundHandle.Value, ioCompletionCallback);
controller.IoControl(ref addr);
return controller.Task;
}

/// <summary>
/// io完成回调
/// </summary>
/// <param name="errorCode"></param>
/// <param name="numBytes"></param>
/// <param name="pOVERLAP"></param>
private unsafe static void IOCompletionCallback(uint errorCode, uint numBytes, NativeOverlapped* pOVERLAP)
{
var controller = (WindivertController)ThreadPoolBoundHandle.GetNativeOverlappedState(pOVERLAP);
if (errorCode > 0)
{
throw new Win32Exception();
controller.SetException((int)errorCode);
}
else
{
controller.SetResult((int)numBytes);
}
packet.Length = length;
return length;
}


/// <summary>
/// 获取指定的参数值
/// </summary>
/// <param name="param"></param>
/// <returns></returns>
/// <exception cref="Win32Exception"></exception>
/// <exception cref="InvalidOperationException"></exception>
private long GetParam(WinDivertParam param)
{
if (this.boundHandle.IsValueCreated)
{
throw new InvalidOperationException();
}

var value = 0L;
var result = WinDivertNative.WinDivertGetParam(this.handle, param, ref value);
return result ? value : throw new Win32Exception();
Expand All @@ -184,8 +204,14 @@ private long GetParam(WinDivertParam param)
/// <param name="param"></param>
/// <param name="value"></param>
/// <exception cref="Win32Exception"></exception>
/// <exception cref="InvalidOperationException"></exception>
private void SetParam(WinDivertParam param, long value)
{
if (this.boundHandle.IsValueCreated)
{
throw new InvalidOperationException();
}

if (WinDivertNative.WinDivertSetParam(this.handle, param, value) == false)
{
throw new Win32Exception();
Expand All @@ -207,94 +233,8 @@ public bool Shutdown(WinDivertShutdown how)
/// </summary>
public void Dispose()
{
this.Shutdown(WinDivertShutdown.Both);
this.Shutdown(WinDivertShutdown.Both);
this.handle.Dispose();
}

private class RecvOperation : IDisposable
{
private readonly WinDivertHandle handle;
private readonly WinDivertPacket packet;

private readonly ThreadPoolBoundHandle threadPoolBoundHandle;
private readonly PreAllocatedOverlapped preAllocatedOverlapped;
private unsafe readonly NativeOverlapped* nativeOverlapped;

private readonly TaskCompletionSource<int> taskCompletionSource = new();

public Task<int> Task => this.taskCompletionSource.Task;

public unsafe RecvOperation(
WinDivertHandle handle,
WinDivertPacket packet,
IOCompletionCallback completionCallback)
{
this.handle = handle;
this.packet = packet;

this.threadPoolBoundHandle = ThreadPoolBoundHandle.BindHandle(handle);
this.preAllocatedOverlapped = new PreAllocatedOverlapped(completionCallback, this, null);
this.nativeOverlapped = this.threadPoolBoundHandle.AllocateNativeOverlapped(this.preAllocatedOverlapped);
}

public unsafe void RecvEx(ref WinDivertAddress addr)
{
var length = 0;
var addrLength = sizeof(WinDivertAddress);
var flag = WinDivertNative.WinDivertRecvEx(this.threadPoolBoundHandle.Handle, this.packet, this.packet.Capacity, ref length, 0, ref addr, &addrLength, nativeOverlapped);

if (flag == true)
{
this.SetResult(length);
this.Dispose();
return;
}

var errorCode = Marshal.GetLastWin32Error();
if (errorCode == 997)
{
return;
}

this.SetException(errorCode);
this.Dispose();
}


public void SetResult(uint numBytes)
{
this.SetResult(length: (int)numBytes);
}

public void SetResult(int length)
{
Console.WriteLine(length);
this.packet.Length = length;
this.taskCompletionSource.SetResult(length);
}

public void SetException(uint errorCode)
{
this.SetException((int)errorCode);
}

public void SetException(int errorCode)
{
var exception = new Win32Exception(errorCode);
this.taskCompletionSource.SetException(exception);
}

public unsafe void FreeNativeOverlapped(NativeOverlapped* pOVERLAP)
{
this.threadPoolBoundHandle.FreeNativeOverlapped(pOVERLAP);
}

public unsafe void Dispose()
{
this.threadPoolBoundHandle.FreeNativeOverlapped(this.nativeOverlapped);
this.threadPoolBoundHandle.Dispose();
this.preAllocatedOverlapped.Dispose();
}
}
}
}
22 changes: 9 additions & 13 deletions WindivertDotnet/WinDivertNative.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,32 @@ public static extern WinDivertHandle WinDivertOpen(
[MarshalAs(UnmanagedType.LPStr)] string filter,
WinDivertLayer layer,
short priority,
WinDivertFlag flags);

[DllImport(library, CallingConvention = CallingConvention.Cdecl, SetLastError = true)]
public static extern bool WinDivertRecv(
WinDivertHandle handle,
SafeHandle pPacket,
int packetLen,
ref int pRecvLen,
ref WinDivertAddress pAddr);
WinDivertFlag flags);


[DllImport(library, CallingConvention = CallingConvention.Cdecl, SetLastError = true)]
public static extern bool WinDivertRecvEx(
SafeHandle handle,
WinDivertHandle handle,
SafeHandle pPacket,
int packetLen,
ref int pRecvLen,
ulong flags,
ref WinDivertAddress pAddr,
int* pAddrLen,
NativeOverlapped* lpOverlapped);
NativeOverlapped* lpOverlapped);



[DllImport(library, CallingConvention = CallingConvention.Cdecl, SetLastError = true)]
public static extern bool WinDivertSend(
public static extern bool WinDivertSendEx(
WinDivertHandle handle,
SafeHandle pPacket,
int packetLen,
ref int pSendLen,
ref WinDivertAddress pAddr);
ulong flags,
ref WinDivertAddress pAddr,
int addrLen,
NativeOverlapped* lpOverlapped);


[DllImport(library, CallingConvention = CallingConvention.Cdecl, SetLastError = true)]
Expand Down
Loading

0 comments on commit 8c4963c

Please sign in to comment.