Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass alloc/free functions to MemoryLoadLibraryEx #33

Merged
merged 3 commits into from
Apr 24, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions MemoryModule.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ typedef struct {
BOOL initialized;
BOOL isDLL;
BOOL isRelocated;
CustomAllocFunc alloc;
CustomFreeFunc free;
CustomLoadLibraryFunc loadLibrary;
CustomGetProcAddressFunc getProcAddress;
CustomFreeLibraryFunc freeLibrary;
Expand Down Expand Up @@ -115,10 +117,11 @@ CopySections(const unsigned char *data, size_t size, PIMAGE_NT_HEADERS old_heade
// uninitialized data
section_size = old_headers->OptionalHeader.SectionAlignment;
if (section_size > 0) {
dest = (unsigned char *)VirtualAlloc(codeBase + section->VirtualAddress,
dest = (unsigned char *)module->alloc(codeBase + section->VirtualAddress,
section_size,
MEM_COMMIT,
PAGE_READWRITE);
PAGE_READWRITE,
module->userdata);
if (dest == NULL) {
return FALSE;
}
Expand All @@ -139,10 +142,11 @@ CopySections(const unsigned char *data, size_t size, PIMAGE_NT_HEADERS old_heade
}

// commit memory block and copy data from dll
dest = (unsigned char *)VirtualAlloc(codeBase + section->VirtualAddress,
dest = (unsigned char *)module->alloc(codeBase + section->VirtualAddress,
section->SizeOfRawData,
MEM_COMMIT,
PAGE_READWRITE);
PAGE_READWRITE,
module->userdata);
if (dest == NULL) {
return FALSE;
}
Expand Down Expand Up @@ -202,7 +206,7 @@ FinalizeSection(PMEMORYMODULE module, PSECTIONFINALIZEDATA sectionData) {
(sectionData->size % module->pageSize) == 0)
) {
// Only allowed to decommit whole pages
VirtualFree(sectionData->address, sectionData->size, MEM_DECOMMIT);
module->free(sectionData->address, sectionData->size, MEM_DECOMMIT, module->userdata);
}
return TRUE;
}
Expand Down Expand Up @@ -429,6 +433,18 @@ BuildImportTable(PMEMORYMODULE module)
return result;
}

LPVOID MemoryDefaultAlloc(LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void* userdata)
{
UNREFERENCED_PARAMETER(userdata);
return VirtualAlloc(address, size, allocationType, protect);
}

BOOL MemoryDefaultFree(LPVOID lpAddress, SIZE_T dwSize, DWORD dwFreeType, void* userdata)
{
UNREFERENCED_PARAMETER(userdata);
return VirtualFree(lpAddress, dwSize, dwFreeType);
}

HCUSTOMMODULE MemoryDefaultLoadLibrary(LPCSTR filename, void *userdata)
{
HMODULE result;
Expand All @@ -455,10 +471,12 @@ void MemoryDefaultFreeLibrary(HCUSTOMMODULE module, void *userdata)

HMEMORYMODULE MemoryLoadLibrary(const void *data, size_t size)
{
return MemoryLoadLibraryEx(data, size, MemoryDefaultLoadLibrary, MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, NULL);
return MemoryLoadLibraryEx(data, size, MemoryDefaultAlloc, MemoryDefaultFree, MemoryDefaultLoadLibrary, MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, NULL);
}

HMEMORYMODULE MemoryLoadLibraryEx(const void *data, size_t size,
CustomAllocFunc allocMemory,
CustomFreeFunc freeMemory,
CustomLoadLibraryFunc loadLibrary,
CustomGetProcAddressFunc getProcAddress,
CustomFreeLibraryFunc freeLibrary,
Expand Down Expand Up @@ -535,17 +553,19 @@ HMEMORYMODULE MemoryLoadLibraryEx(const void *data, size_t size,
// reserve memory for image of library
// XXX: is it correct to commit the complete memory region at once?
// calling DllEntry raises an exception if we don't...
code = (unsigned char *)VirtualAlloc((LPVOID)(old_header->OptionalHeader.ImageBase),
code = (unsigned char *)allocMemory((LPVOID)(old_header->OptionalHeader.ImageBase),
alignedImageSize,
MEM_RESERVE | MEM_COMMIT,
PAGE_READWRITE);
PAGE_READWRITE,
userdata);

if (code == NULL) {
// try to allocate memory at arbitrary position
code = (unsigned char *)VirtualAlloc(NULL,
code = (unsigned char *)allocMemory(NULL,
alignedImageSize,
MEM_RESERVE | MEM_COMMIT,
PAGE_READWRITE);
PAGE_READWRITE,
userdata);
if (code == NULL) {
SetLastError(ERROR_OUTOFMEMORY);
return NULL;
Expand All @@ -554,13 +574,15 @@ HMEMORYMODULE MemoryLoadLibraryEx(const void *data, size_t size,

result = (PMEMORYMODULE)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(MEMORYMODULE));
if (result == NULL) {
VirtualFree(code, 0, MEM_RELEASE);
freeMemory(code, 0, MEM_RELEASE, userdata);
SetLastError(ERROR_OUTOFMEMORY);
return NULL;
}

result->codeBase = code;
result->isDLL = (old_header->FileHeader.Characteristics & IMAGE_FILE_DLL) != 0;
result->alloc = allocMemory;
result->free = freeMemory;
result->loadLibrary = loadLibrary;
result->getProcAddress = getProcAddress;
result->freeLibrary = freeLibrary;
Expand All @@ -572,10 +594,11 @@ HMEMORYMODULE MemoryLoadLibraryEx(const void *data, size_t size,
}

// commit memory for headers
headers = (unsigned char *)VirtualAlloc(code,
headers = (unsigned char *)allocMemory(code,
old_header->OptionalHeader.SizeOfHeaders,
MEM_COMMIT,
PAGE_READWRITE);
PAGE_READWRITE,
userdata);

// copy PE header to code
memcpy(headers, dos_header, old_header->OptionalHeader.SizeOfHeaders);
Expand Down Expand Up @@ -724,7 +747,7 @@ void MemoryFreeLibrary(HMEMORYMODULE mod)

if (module->codeBase != NULL) {
// release memory of library
VirtualFree(module->codeBase, 0, MEM_RELEASE);
module->free(module->codeBase, 0, MEM_RELEASE, module->userdata);
}

HeapFree(GetProcessHeap(), 0, module);
Expand Down
20 changes: 20 additions & 0 deletions MemoryModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ typedef void *HCUSTOMMODULE;
extern "C" {
#endif

typedef LPVOID (*CustomAllocFunc)(LPVOID, SIZE_T, DWORD, DWORD, void*);
typedef BOOL (*CustomFreeFunc)(LPVOID, SIZE_T, DWORD, void*);
typedef HCUSTOMMODULE (*CustomLoadLibraryFunc)(LPCSTR, void *);
typedef FARPROC (*CustomGetProcAddressFunc)(HCUSTOMMODULE, LPCSTR, void *);
typedef void (*CustomFreeLibraryFunc)(HCUSTOMMODULE, void *);
Expand All @@ -58,6 +60,8 @@ HMEMORYMODULE MemoryLoadLibrary(const void *, size_t);
* Dependencies will be resolved using passed callback methods.
*/
HMEMORYMODULE MemoryLoadLibraryEx(const void *, size_t,
CustomAllocFunc,
CustomFreeFunc,
CustomLoadLibraryFunc,
CustomGetProcAddressFunc,
CustomFreeLibraryFunc,
Expand Down Expand Up @@ -117,6 +121,22 @@ int MemoryLoadString(HMEMORYMODULE, UINT, LPTSTR, int);
*/
int MemoryLoadStringEx(HMEMORYMODULE, UINT, LPTSTR, int, WORD);

/**
* Default implementation of CustomAllocFunc that calls VirtualAlloc
* internally to allocate memory for a library
*
* This is the default as used by MemoryLoadLibrary.
*/
LPVOID MemoryDefaultAlloc(LPVOID, SIZE_T, DWORD, DWORD, void *);

/**
* Default implementation of CustomFreeFunc that calls VirtualFree
* internally to free the memory used by a library
*
* This is the default as used by MemoryLoadLibrary.
*/
BOOL MemoryDefaultFree(LPVOID, SIZE_T, DWORD, void *);

/**
* Default implementation of CustomLoadLibraryFunc that calls LoadLibraryA
* internally to load an additional libary.
Expand Down
Loading