diff --git a/Driver.cpp b/Driver.cpp index b89efea..abfa076 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -329,7 +329,9 @@ OvpnStopVPN(_In_ POVPN_DEVICE device) { LOG_ENTER(); - OvpnFlushPeers(device); + OvpnCleanupPeerTable(device, &device->PeersByVpn6); + OvpnCleanupPeerTable(device, &device->PeersByVpn4); + OvpnCleanupPeerTable(device, &device->Peers); KIRQL kirql = ExAcquireSpinLockExclusive(&device->SpinLock); PWSK_SOCKET socket = device->Socket.Socket; @@ -762,210 +764,3 @@ OvpnEvtDeviceAdd(WDFDRIVER wdfDriver, PWDFDEVICE_INIT deviceInit) { return status; } - -_Use_decl_annotations_ -NTSTATUS -OvpnAddPeerToTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* table, OvpnPeerContext* peer) -{ - NTSTATUS status; - BOOLEAN newElem; - - auto irql = ExAcquireSpinLockExclusive(&device->SpinLock); - - RtlInsertElementGenericTable(table, (PVOID)&peer, sizeof(OvpnPeerContext*), &newElem); - - if (newElem) { - status = STATUS_SUCCESS; - InterlockedIncrement(&peer->RefCounter); - } - else { - LOG_ERROR("Unable to add new peer"); - status = STATUS_NO_MEMORY; - } - - ExReleaseSpinLockExclusive(&device->SpinLock, irql); - - return status; -} - -_Use_decl_annotations_ -VOID -OvpnFlushPeers(POVPN_DEVICE device) { - OvpnCleanupPeerTable(device, &device->PeersByVpn6); - OvpnCleanupPeerTable(device, &device->PeersByVpn4); - OvpnCleanupPeerTable(device, &device->Peers); -} - -_Use_decl_annotations_ -VOID -OvpnCleanupPeerTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* peers) -{ - auto irql = ExAcquireSpinLockExclusive(&device->SpinLock); - - while (!RtlIsGenericTableEmpty(peers)) { - PVOID ptr = RtlGetElementGenericTable(peers, 0); - OvpnPeerContext* peer = *(OvpnPeerContext**)ptr; - RtlDeleteElementGenericTable(peers, ptr); - - OvpnPeerCtxRelease(peer); - } - - ExReleaseSpinLockExclusive(&device->SpinLock, irql); -} - -_Use_decl_annotations_ -OvpnPeerContext* -OvpnGetFirstPeer(POVPN_DEVICE device) -{ - auto irql = ExAcquireSpinLockShared(&device->SpinLock); - - OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); - OvpnPeerContext* peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; - - if (peer != nullptr) { - InterlockedIncrement(&peer->RefCounter); - } - - ExReleaseSpinLockShared(&device->SpinLock, irql); - - return peer; -} - -_Use_decl_annotations_ -OvpnPeerContext* -OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId) -{ - OvpnPeerContext* peer = nullptr; - OvpnPeerContext** ptr = nullptr; - - auto kirql = ExAcquireSpinLockShared(&device->SpinLock); - - if (device->Mode == OVPN_MODE_P2P) { - ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); - } - else { - OvpnPeerContext p{}; - p.PeerId = PeerId; - - auto* pp = &p; - ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->Peers, &pp); - } - - peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; - - if (peer) { - InterlockedIncrement(&peer->RefCounter); - } - - ExReleaseSpinLockShared(&device->SpinLock, kirql); - - return peer; -} - -_Use_decl_annotations_ -OvpnPeerContext* -OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr) -{ - OvpnPeerContext* peer = nullptr; - OvpnPeerContext** ptr = nullptr; - - auto kirql = ExAcquireSpinLockShared(&device->SpinLock); - - if (device->Mode == OVPN_MODE_P2P) { - ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); - } - else { - OvpnPeerContext p{}; - p.VpnAddrs.IPv4 = addr; - - auto* pp = &p; - ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn4, &pp); - } - - peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; - if (peer) { - InterlockedIncrement(&peer->RefCounter); - } - - ExReleaseSpinLockShared(&device->SpinLock, kirql); - - return peer; -} - -_Use_decl_annotations_ -OvpnPeerContext* -OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr) -{ - OvpnPeerContext* peer = nullptr; - OvpnPeerContext** ptr = nullptr; - - auto kirql = ExAcquireSpinLockShared(&device->SpinLock); - - if (device->Mode == OVPN_MODE_P2P) { - ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); - } - else { - OvpnPeerContext p{}; - p.VpnAddrs.IPv6 = addr; - - auto* pp = &p; - ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp); - } - - peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; - if (peer) { - InterlockedIncrement(&peer->RefCounter); - } - - ExReleaseSpinLockShared(&device->SpinLock, kirql); - - return peer; -} - -VOID -OvpnDeletePeerFromTable(POVPN_DEVICE device, RTL_GENERIC_TABLE *table, OvpnPeerContext *peer, char* tableName) -{ - auto peerId = peer->PeerId; - auto pp = &peer; - - auto kirql = ExAcquireSpinLockExclusive(&device->SpinLock); - - if (RtlDeleteElementGenericTable(table, pp)) { - LOG_INFO("Peer deleted", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); - - if (InterlockedDecrement(&peer->RefCounter) == 0) { - OvpnPeerCtxFree(peer); - LOG_INFO("Peer freed", TraceLoggingValue(peerId, "peer-id")); - } - } - else { - LOG_INFO("Peer not found", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); - } - - ExReleaseSpinLockExclusive(&device->SpinLock, kirql); -} - -_Use_decl_annotations_ -NTSTATUS -OvpnDeletePeer(POVPN_DEVICE device, INT32 peerId) -{ - NTSTATUS status = STATUS_SUCCESS; - - LOG_INFO("Deleting peer", TraceLoggingValue(peerId, "peer-id")); - - // get peer from main table - OvpnPeerContext* peer = OvpnFindPeer(device, peerId); - if (peer == NULL) { - status = STATUS_NOT_FOUND; - LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peer-id")); - } - else { - OvpnDeletePeerFromTable(device, &device->PeersByVpn4, peer, "vpn4"); - OvpnDeletePeerFromTable(device, &device->PeersByVpn6, peer, "vpn6"); - OvpnDeletePeerFromTable(device, &device->Peers, peer, "peers"); - - OvpnPeerCtxRelease(peer); - } - - return status; -} diff --git a/Driver.h b/Driver.h index 881cb7e..054eebf 100644 --- a/Driver.h +++ b/Driver.h @@ -107,35 +107,3 @@ struct OVPN_DEVICE { typedef OVPN_DEVICE * POVPN_DEVICE; WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(OVPN_DEVICE, OvpnGetDeviceContext) - -struct OvpnPeerContext; - -_Must_inspect_result_ -NTSTATUS -OvpnAddPeerToTable(POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer); - -VOID -OvpnFlushPeers(_In_ POVPN_DEVICE device); - -VOID -OvpnCleanupPeerTable(_In_ POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE*); - -_Must_inspect_result_ -OvpnPeerContext* -OvpnGetFirstPeer(_In_ POVPN_DEVICE device); - -_Must_inspect_result_ -OvpnPeerContext* -OvpnFindPeer(_In_ POVPN_DEVICE device, INT32 PeerId); - -_Must_inspect_result_ -OvpnPeerContext* -OvpnFindPeerVPN4(_In_ POVPN_DEVICE device, _In_ IN_ADDR addr); - -_Must_inspect_result_ -OvpnPeerContext* -OvpnFindPeerVPN6(_In_ POVPN_DEVICE device, _In_ IN6_ADDR addr); - -_Must_inspect_result_ -NTSTATUS -OvpnDeletePeer(_In_ POVPN_DEVICE device, INT32 peerId); diff --git a/peer.cpp b/peer.cpp index fddcb7c..b6c278a 100644 --- a/peer.cpp +++ b/peer.cpp @@ -132,6 +132,181 @@ OvpnPeerCompareByVPN6Routine(RTL_GENERIC_TABLE* table, PVOID first, PVOID second return GenericGreaterThan; } +_Use_decl_annotations_ +NTSTATUS +OvpnAddPeerToTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* table, OvpnPeerContext* peer) +{ + NTSTATUS status; + BOOLEAN newElem; + + auto irql = ExAcquireSpinLockExclusive(&device->SpinLock); + + RtlInsertElementGenericTable(table, (PVOID)&peer, sizeof(OvpnPeerContext*), &newElem); + + if (newElem) { + status = STATUS_SUCCESS; + InterlockedIncrement(&peer->RefCounter); + } + else { + LOG_ERROR("Unable to add new peer"); + status = STATUS_NO_MEMORY; + } + + ExReleaseSpinLockExclusive(&device->SpinLock, irql); + + return status; +} + +_Use_decl_annotations_ +VOID +OvpnCleanupPeerTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* peers) +{ + auto irql = ExAcquireSpinLockExclusive(&device->SpinLock); + + while (!RtlIsGenericTableEmpty(peers)) { + PVOID ptr = RtlGetElementGenericTable(peers, 0); + OvpnPeerContext* peer = *(OvpnPeerContext**)ptr; + RtlDeleteElementGenericTable(peers, ptr); + + OvpnPeerCtxRelease(peer); + } + + ExReleaseSpinLockExclusive(&device->SpinLock, irql); +} + +_Use_decl_annotations_ +OvpnPeerContext* +OvpnGetFirstPeer(POVPN_DEVICE device) +{ + auto irql = ExAcquireSpinLockShared(&device->SpinLock); + + OvpnPeerContext** ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + OvpnPeerContext* peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + + if (peer != nullptr) { + InterlockedIncrement(&peer->RefCounter); + } + + ExReleaseSpinLockShared(&device->SpinLock, irql); + + return peer; +} + +_Use_decl_annotations_ +OvpnPeerContext* +OvpnFindPeer(POVPN_DEVICE device, INT32 PeerId) +{ + OvpnPeerContext* peer = nullptr; + OvpnPeerContext** ptr = nullptr; + + auto kirql = ExAcquireSpinLockShared(&device->SpinLock); + + if (device->Mode == OVPN_MODE_P2P) { + ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + } + else { + OvpnPeerContext p{}; + p.PeerId = PeerId; + + auto* pp = &p; + ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->Peers, &pp); + } + + peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + + if (peer) { + InterlockedIncrement(&peer->RefCounter); + } + + ExReleaseSpinLockShared(&device->SpinLock, kirql); + + return peer; +} + +_Use_decl_annotations_ +OvpnPeerContext* +OvpnFindPeerVPN4(POVPN_DEVICE device, IN_ADDR addr) +{ + OvpnPeerContext* peer = nullptr; + OvpnPeerContext** ptr = nullptr; + + auto kirql = ExAcquireSpinLockShared(&device->SpinLock); + + if (device->Mode == OVPN_MODE_P2P) { + ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + } + else { + OvpnPeerContext p{}; + p.VpnAddrs.IPv4 = addr; + + auto* pp = &p; + ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn4, &pp); + } + + peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + if (peer) { + InterlockedIncrement(&peer->RefCounter); + } + + ExReleaseSpinLockShared(&device->SpinLock, kirql); + + return peer; +} + +_Use_decl_annotations_ +OvpnPeerContext* +OvpnFindPeerVPN6(POVPN_DEVICE device, IN6_ADDR addr) +{ + OvpnPeerContext* peer = nullptr; + OvpnPeerContext** ptr = nullptr; + + auto kirql = ExAcquireSpinLockShared(&device->SpinLock); + + if (device->Mode == OVPN_MODE_P2P) { + ptr = (OvpnPeerContext**)RtlGetElementGenericTable(&device->Peers, 0); + } + else { + OvpnPeerContext p{}; + p.VpnAddrs.IPv6 = addr; + + auto* pp = &p; + ptr = (OvpnPeerContext**)RtlLookupElementGenericTable(&device->PeersByVpn6, &pp); + } + + peer = ptr ? (OvpnPeerContext*)*ptr : nullptr; + if (peer) { + InterlockedIncrement(&peer->RefCounter); + } + + ExReleaseSpinLockShared(&device->SpinLock, kirql); + + return peer; +} + +VOID +OvpnDeletePeerFromTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* table, OvpnPeerContext* peer, char* tableName) +{ + auto peerId = peer->PeerId; + auto pp = &peer; + + auto kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + + if (RtlDeleteElementGenericTable(table, pp)) { + LOG_INFO("Peer deleted", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); + + if (InterlockedDecrement(&peer->RefCounter) == 0) { + OvpnPeerCtxFree(peer); + LOG_INFO("Peer freed", TraceLoggingValue(peerId, "peer-id")); + } + } + else { + LOG_INFO("Peer not found", TraceLoggingValue(tableName, "table"), TraceLoggingValue(peerId, "peer-id")); + } + + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); +} + + static VOID OvpnPeerZeroStats(POVPN_STATS stats) @@ -587,3 +762,27 @@ OvpnPeerSwapKeys(POVPN_DEVICE device) return status; } +_Use_decl_annotations_ +NTSTATUS +OvpnPeerDelete(POVPN_DEVICE device, INT32 peerId) +{ + NTSTATUS status = STATUS_SUCCESS; + + LOG_INFO("Deleting peer", TraceLoggingValue(peerId, "peer-id")); + + // get peer from main table + OvpnPeerContext* peer = OvpnFindPeer(device, peerId); + if (peer == NULL) { + status = STATUS_NOT_FOUND; + LOG_WARN("Peer not found", TraceLoggingValue(peerId, "peer-id")); + } + else { + OvpnDeletePeerFromTable(device, &device->PeersByVpn4, peer, "vpn4"); + OvpnDeletePeerFromTable(device, &device->PeersByVpn6, peer, "vpn6"); + OvpnDeletePeerFromTable(device, &device->Peers, peer, "peers"); + + OvpnPeerCtxRelease(peer); + } + + return status; +} \ No newline at end of file diff --git a/peer.h b/peer.h index 800649a..37cb26f 100644 --- a/peer.h +++ b/peer.h @@ -83,6 +83,32 @@ RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByPeerIdRoutine; RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByVPN4Routine; RTL_GENERIC_COMPARE_ROUTINE OvpnPeerCompareByVPN6Routine; +_Must_inspect_result_ +NTSTATUS +OvpnAddPeerToTable(POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE* table, _In_ OvpnPeerContext* peer); + +VOID +OvpnCleanupPeerTable(_In_ POVPN_DEVICE device, _In_ RTL_GENERIC_TABLE*); + +_Must_inspect_result_ +OvpnPeerContext* +OvpnGetFirstPeer(_In_ POVPN_DEVICE device); + +_Must_inspect_result_ +OvpnPeerContext* +OvpnFindPeer(_In_ POVPN_DEVICE device, INT32 PeerId); + +_Must_inspect_result_ +OvpnPeerContext* +OvpnFindPeerVPN4(_In_ POVPN_DEVICE device, _In_ IN_ADDR addr); + +_Must_inspect_result_ +OvpnPeerContext* +OvpnFindPeerVPN6(_In_ POVPN_DEVICE device, _In_ IN6_ADDR addr); + +VOID +OvpnDeletePeerFromTable(POVPN_DEVICE device, RTL_GENERIC_TABLE* table, OvpnPeerContext* peer, char* tableName); + _Must_inspect_result_ _IRQL_requires_(PASSIVE_LEVEL) NTSTATUS @@ -128,3 +154,7 @@ _Must_inspect_result_ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS OvpnPeerSwapKeys(_In_ POVPN_DEVICE device); + +_Must_inspect_result_ +NTSTATUS +OvpnPeerDelete(_In_ POVPN_DEVICE device, INT32 peerId); \ No newline at end of file diff --git a/timer.cpp b/timer.cpp index e1839e5..f98d8ad 100644 --- a/timer.cpp +++ b/timer.cpp @@ -130,7 +130,7 @@ static BOOLEAN OvpnTimerRecv(WDFTIMER timer) WdfRequestCompleteWithInformation(request, STATUS_CONNECTION_DISCONNECTED, bytesSent); } else { - (VOID)OvpnDeletePeer(device, peerId); + (VOID)OvpnPeerDelete(device, peerId); status = WdfIoQueueRetrieveNextRequest(device->PendingNotificationRequestsQueue, &request); if (!NT_SUCCESS(status)) { diff --git a/timer.h b/timer.h index 88ee019..4fe8745 100644 --- a/timer.h +++ b/timer.h @@ -25,6 +25,8 @@ #include #include +#include "peer.h" + VOID OvpnTimerResetXmit(WDFTIMER timer);