Skip to content

Commit

Permalink
Add some checks to prevent overruns on broken input.
Browse files Browse the repository at this point in the history
  • Loading branch information
fancycode committed Dec 20, 2015
1 parent 9df6e7d commit bc04853
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 26 deletions.
89 changes: 71 additions & 18 deletions MemoryModule.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ typedef struct {

#define GET_HEADER_DICTIONARY(module, idx) &(module)->headers->OptionalHeader.DataDirectory[idx]
#define ALIGN_DOWN(address, alignment) (LPVOID)((uintptr_t)(address) & ~((alignment) - 1))
#define ALIGN_VALUE_UP(value, alignment) (((value) + (alignment) - 1) & ~((alignment) - 1))

#ifdef DEBUG_OUTPUT
static void
Expand All @@ -86,20 +87,30 @@ OutputLastError(const char *msg)
#endif

static BOOL
CopySections(const unsigned char *data, PIMAGE_NT_HEADERS old_headers, PMEMORYMODULE module)
CheckSize(size_t size, size_t expected) {
if (size < expected) {
SetLastError(ERROR_INVALID_DATA);
return FALSE;
}

return TRUE;
}

static BOOL
CopySections(const unsigned char *data, size_t size, PIMAGE_NT_HEADERS old_headers, PMEMORYMODULE module)
{
int i, size;
int i, section_size;
unsigned char *codeBase = module->codeBase;
unsigned char *dest;
PIMAGE_SECTION_HEADER section = IMAGE_FIRST_SECTION(module->headers);
for (i=0; i<module->headers->FileHeader.NumberOfSections; i++, section++) {
if (section->SizeOfRawData == 0) {
// section doesn't contain data in the dll itself, but may define
// uninitialized data
size = old_headers->OptionalHeader.SectionAlignment;
if (size > 0) {
section_size = old_headers->OptionalHeader.SectionAlignment;
if (section_size > 0) {
dest = (unsigned char *)VirtualAlloc(codeBase + section->VirtualAddress,
size,
section_size,
MEM_COMMIT,
PAGE_READWRITE);
if (dest == NULL) {
Expand All @@ -110,13 +121,17 @@ CopySections(const unsigned char *data, PIMAGE_NT_HEADERS old_headers, PMEMORYMO
// than page size.
dest = codeBase + section->VirtualAddress;
section->Misc.PhysicalAddress = (DWORD) (uintptr_t) dest;
memset(dest, 0, size);
memset(dest, 0, section_size);
}

// section is empty
continue;
}

if (!CheckSize(size, section->PointerToRawData + section->SizeOfRawData)) {
return FALSE;
}

// commit memory block and copy data from dll
dest = (unsigned char *)VirtualAlloc(codeBase + section->VirtualAddress,
section->SizeOfRawData,
Expand Down Expand Up @@ -285,7 +300,7 @@ ExecuteTLS(PMEMORYMODULE module)
}

static BOOL
PerformBaseRelocation(PMEMORYMODULE module, SIZE_T delta)
PerformBaseRelocation(PMEMORYMODULE module, ptrdiff_t delta)
{
unsigned char *codeBase = module->codeBase;
PIMAGE_BASE_RELOCATION relocation;
Expand Down Expand Up @@ -428,30 +443,43 @@ static void _FreeLibrary(HCUSTOMMODULE module, void *userdata)
FreeLibrary((HMODULE) module);
}

HMEMORYMODULE MemoryLoadLibrary(const void *data)
HMEMORYMODULE MemoryLoadLibrary(const void *data, size_t size)
{
return MemoryLoadLibraryEx(data, _LoadLibrary, _GetProcAddress, _FreeLibrary, NULL);
return MemoryLoadLibraryEx(data, size, _LoadLibrary, _GetProcAddress, _FreeLibrary, NULL);
}

HMEMORYMODULE MemoryLoadLibraryEx(const void *data,
#include <stdio.h>

HMEMORYMODULE MemoryLoadLibraryEx(const void *data, size_t size,
CustomLoadLibraryFunc loadLibrary,
CustomGetProcAddressFunc getProcAddress,
CustomFreeLibraryFunc freeLibrary,
void *userdata)
{
PMEMORYMODULE result;
PMEMORYMODULE result = NULL;
PIMAGE_DOS_HEADER dos_header;
PIMAGE_NT_HEADERS old_header;
unsigned char *code, *headers;
SIZE_T locationDelta;
ptrdiff_t locationDelta;
SYSTEM_INFO sysInfo;
PIMAGE_SECTION_HEADER section;
DWORD i;
size_t optionalSectionSize;
size_t lastSectionEnd = 0;
size_t alignedImageSize;

if (!CheckSize(size, sizeof(IMAGE_DOS_HEADER))) {
return NULL;
}
dos_header = (PIMAGE_DOS_HEADER)data;
if (dos_header->e_magic != IMAGE_DOS_SIGNATURE) {
SetLastError(ERROR_BAD_EXE_FORMAT);
return NULL;
}

if (!CheckSize(size, dos_header->e_lfanew + sizeof(IMAGE_NT_HEADERS))) {
return NULL;
}
old_header = (PIMAGE_NT_HEADERS)&((const unsigned char *)(data))[dos_header->e_lfanew];
if (old_header->Signature != IMAGE_NT_SIGNATURE) {
SetLastError(ERROR_BAD_EXE_FORMAT);
Expand All @@ -473,18 +501,41 @@ HMEMORYMODULE MemoryLoadLibraryEx(const void *data,
return NULL;
}

section = IMAGE_FIRST_SECTION(old_header);
optionalSectionSize = old_header->OptionalHeader.SectionAlignment;
for (i=0; i<old_header->FileHeader.NumberOfSections; i++, section++) {
size_t endOfSection;
if (section->SizeOfRawData == 0) {
// Section without data in the DLL
endOfSection = section->VirtualAddress + optionalSectionSize;
} else {
endOfSection = section->VirtualAddress + section->SizeOfRawData;
}

if (endOfSection > lastSectionEnd) {
lastSectionEnd = endOfSection;
}
}

GetNativeSystemInfo(&sysInfo);
alignedImageSize = ALIGN_VALUE_UP(old_header->OptionalHeader.SizeOfImage, sysInfo.dwPageSize);
if (alignedImageSize != ALIGN_VALUE_UP(lastSectionEnd, sysInfo.dwPageSize)) {
SetLastError(ERROR_BAD_EXE_FORMAT);
return NULL;
}

// reserve memory for image of library
// XXX: is it correct to commit the complete memory region at once?
// calling DllEntry raises an exception if we don't...
code = (unsigned char *)VirtualAlloc((LPVOID)(old_header->OptionalHeader.ImageBase),
old_header->OptionalHeader.SizeOfImage,
alignedImageSize,
MEM_RESERVE | MEM_COMMIT,
PAGE_READWRITE);

if (code == NULL) {
// try to allocate memory at arbitrary position
code = (unsigned char *)VirtualAlloc(NULL,
old_header->OptionalHeader.SizeOfImage,
alignedImageSize,
MEM_RESERVE | MEM_COMMIT,
PAGE_READWRITE);
if (code == NULL) {
Expand All @@ -506,10 +557,12 @@ HMEMORYMODULE MemoryLoadLibraryEx(const void *data,
result->getProcAddress = getProcAddress;
result->freeLibrary = freeLibrary;
result->userdata = userdata;

GetNativeSystemInfo(&sysInfo);
result->pageSize = sysInfo.dwPageSize;

if (!CheckSize(size, old_header->OptionalHeader.SizeOfHeaders)) {
goto error;
}

// commit memory for headers
headers = (unsigned char *)VirtualAlloc(code,
old_header->OptionalHeader.SizeOfHeaders,
Expand All @@ -524,12 +577,12 @@ HMEMORYMODULE MemoryLoadLibraryEx(const void *data,
result->headers->OptionalHeader.ImageBase = (uintptr_t)code;

// copy sections from DLL file block to new memory location
if (!CopySections((const unsigned char *) data, old_header, result)) {
if (!CopySections((const unsigned char *) data, size, old_header, result)) {
goto error;
}

// adjust base address of imported data
locationDelta = (SIZE_T)(code - old_header->OptionalHeader.ImageBase);
locationDelta = (ptrdiff_t)(result->headers->OptionalHeader.ImageBase - old_header->OptionalHeader.ImageBase);
if (locationDelta != 0) {
result->isRelocated = PerformBaseRelocation(result, locationDelta);
} else {
Expand Down
9 changes: 5 additions & 4 deletions MemoryModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,20 @@ typedef FARPROC (*CustomGetProcAddressFunc)(HCUSTOMMODULE, LPCSTR, void *);
typedef void (*CustomFreeLibraryFunc)(HCUSTOMMODULE, void *);

/**
* Load EXE/DLL from memory location.
* Load EXE/DLL from memory location with the given size.
*
* All dependencies are resolved using default LoadLibrary/GetProcAddress
* calls through the Windows API.
*/
HMEMORYMODULE MemoryLoadLibrary(const void *);
HMEMORYMODULE MemoryLoadLibrary(const void *, size_t);

/**
* Load EXE/DLL from memory location using custom dependency resolvers.
* Load EXE/DLL from memory location with the given size using custom dependency
* resolvers.
*
* Dependencies will be resolved using passed callback methods.
*/
HMEMORYMODULE MemoryLoadLibraryEx(const void *,
HMEMORYMODULE MemoryLoadLibraryEx(const void *, size_t,
CustomLoadLibraryFunc,
CustomGetProcAddressFunc,
CustomFreeLibraryFunc,
Expand Down
2 changes: 1 addition & 1 deletion doc/readme.rst
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ The interface is very similar to the standard methods for loading of libraries::

typedef void *HMEMORYMODULE;

HMEMORYMODULE MemoryLoadLibrary(const void *);
HMEMORYMODULE MemoryLoadLibrary(const void *, size_t);
FARPROC MemoryGetProcAddress(HMEMORYMODULE, const char *);
void MemoryFreeLibrary(HMEMORYMODULE);

Expand Down
2 changes: 1 addition & 1 deletion example/DllLoader/DllLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void LoadFromMemory(void)
assert(read == static_cast<size_t>(size));
fclose(fp);

handle = MemoryLoadLibrary(data);
handle = MemoryLoadLibrary(data, size);
if (handle == NULL)
{
_tprintf(_T("Can't load library from memory.\n"));
Expand Down
2 changes: 1 addition & 1 deletion example/DllLoader/DllLoaderLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int RunFromMemory(void)
assert(read == static_cast<size_t>(size));
fclose(fp);

handle = MemoryLoadLibrary(data);
handle = MemoryLoadLibrary(data, size);
if (handle == NULL)
{
_tprintf(_T("Can't load library from memory.\n"));
Expand Down
2 changes: 1 addition & 1 deletion tests/LoadDll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ BOOL LoadFromMemory(char *filename)
assert(read == static_cast<size_t>(size));
fclose(fp);

handle = MemoryLoadLibrary(data);
handle = MemoryLoadLibrary(data, size);
if (handle == NULL)
{
_tprintf(_T("Can't load library from memory.\n"));
Expand Down

0 comments on commit bc04853

Please sign in to comment.