Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ jobs:
name: Test
strategy:
matrix:
go_version: [1.11, 1.12, 1.13, 1.14]
os: [ubuntu-latest, windows-latest]
go_version: [1.7, 1.8, 1.9, "1.10", 1.11, 1.12, 1.13, 1.14, 1.15]
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
steps:
- name: Set up Go ${{ matrix.go_version }}
Expand Down
40 changes: 16 additions & 24 deletions patcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@ var (
func PatchMethod(target, redirection interface{}) (*Patch, error) {
tValue := getValueFrom(target)
rValue := getValueFrom(redirection)
err := isPatchable(&tValue, &rValue)
if err != nil {
if err := isPatchable(&tValue, &rValue); err != nil {
return nil, err
}
patch := &Patch{target: &tValue, redirection: &rValue}
err = applyPatch(patch)
if err != nil {
if err := applyPatch(patch); err != nil {
return nil, err
}
return patch, nil
Expand All @@ -54,37 +52,32 @@ func PatchInstanceMethodByName(target reflect.Type, methodName string, redirecti
if !ok {
return nil, errors.New(fmt.Sprintf("Method '%v' not found", methodName))
}
return PatchMethodByReflect(method, redirection)
return PatchMethodByReflect(method.Func, redirection)
}
func PatchMethodByReflect(target reflect.Method, redirection interface{}) (*Patch, error) {
tValue := &target.Func
func PatchMethodByReflect(target reflect.Value, redirection interface{}) (*Patch, error) {
tValue := &target
rValue := getValueFrom(redirection)
err := isPatchable(tValue, &rValue)
if err != nil {
if err := isPatchable(tValue, &rValue); err != nil {
return nil, err
}
patch := &Patch{target: tValue, redirection: &rValue}
err = applyPatch(patch)
if err != nil {
if err := applyPatch(patch); err != nil {
return nil, err
}
return patch, nil
}
func PatchMethodWithMakeFunc(target reflect.Method, fn func(args []reflect.Value) (results []reflect.Value)) (*Patch, error) {
rValue := reflect.MakeFunc(target.Type, fn)
return PatchMethodByReflect(target, rValue)
func PatchMethodWithMakeFunc(target reflect.Value, fn func(args []reflect.Value) (results []reflect.Value)) (*Patch, error) {
return PatchMethodByReflect(target, reflect.MakeFunc(target.Type(), fn))
}

func (p *Patch) Patch() error {
if p == nil {
return errors.New("patch is nil")
}
err := isPatchable(p.target, p.redirection)
if err != nil {
if err := isPatchable(p.target, p.redirection); err != nil {
return err
}
err = applyPatch(p)
if err != nil {
if err := applyPatch(p); err != nil {
return err
}
return nil
Expand All @@ -103,7 +96,7 @@ func isPatchable(target, redirection *reflect.Value) error {
if target.Type() != redirection.Type() {
return errors.New(fmt.Sprintf("the target and/or redirection doesn't have the same type: %s != %s", target.Type(), redirection.Type()))
}
if _, ok := patches[getSafePointer(target)]; ok {
if _, ok := patches[getSafeCodePointer(target)]; ok {
return errors.New("the target is already patched")
}
return nil
Expand All @@ -112,7 +105,7 @@ func isPatchable(target, redirection *reflect.Value) error {
func applyPatch(patch *Patch) error {
patchLock.Lock()
defer patchLock.Unlock()
tPointer := getSafePointer(patch.target)
tPointer := getSafeCodePointer(patch.target)
rPointer := getInternalPtrFromValue(patch.redirection)
rPointerJumpBytes, err := getJumpFuncBytes(rPointer)
if err != nil {
Expand All @@ -121,8 +114,7 @@ func applyPatch(patch *Patch) error {
tPointerBytes := getMemorySliceFromPointer(tPointer, len(rPointerJumpBytes))
targetBytes := make([]byte, len(tPointerBytes))
copy(targetBytes, tPointerBytes)
err = copyDataToPtr(tPointer, rPointerJumpBytes)
if err != nil {
if err := copyDataToPtr(tPointer, rPointerJumpBytes); err != nil {
return err
}
patch.targetBytes = targetBytes
Expand All @@ -136,7 +128,7 @@ func applyUnpatch(patch *Patch) error {
if patch.targetBytes == nil || len(patch.targetBytes) == 0 {
return errors.New("the target is not patched")
}
tPointer := getSafePointer(patch.target)
tPointer := getSafeCodePointer(patch.target)
if _, ok := patches[tPointer]; !ok {
return errors.New("the target is not patched")
}
Expand Down Expand Up @@ -164,7 +156,7 @@ func getMemorySliceFromPointer(p unsafe.Pointer, length int) []byte {
}))
}

func getSafePointer(value *reflect.Value) unsafe.Pointer {
func getSafeCodePointer(value *reflect.Value) unsafe.Pointer {
p := getInternalPtrFromValue(value)
if p != nil {
p = *(*unsafe.Pointer)(p)
Expand Down
38 changes: 38 additions & 0 deletions patcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,44 @@ func TestPatcher(t *testing.T) {
}
}

func TestPatcherUsingReflect(t *testing.T) {
patch, err := PatchMethodByReflect(reflect.ValueOf(methodA), methodB)
if err != nil {
t.Fatal(err)
}
if methodA() != 2 {
t.Fatal("The patch did not work")
}

err = patch.Unpatch()
if err != nil {
t.Fatal(err)
}
if methodA() != 1 {
t.Fatal("The unpatch did not work")
}
}

func TestPatcherUsingMakeFunc(t *testing.T) {
patch, err := PatchMethodWithMakeFunc(reflect.ValueOf(methodA), func(args []reflect.Value) (results []reflect.Value) {
return []reflect.Value{reflect.ValueOf(42)}
})
if err != nil {
t.Fatal(err)
}
if methodA() != 42 {
t.Fatal("The patch did not work")
}

err = patch.Unpatch()
if err != nil {
t.Fatal(err)
}
if methodA() != 1 {
t.Fatal("The unpatch did not work")
}
}

func TestInstancePatcher(t *testing.T) {
mStruct := myStruct{}

Expand Down
9 changes: 3 additions & 6 deletions patcher_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ func getMemorySliceFromUintptr(p uintptr, length int) []byte {
func callMProtect(addr unsafe.Pointer, length int, prot int) error {
for p := uintptr(addr) & ^(uintptr(pageSize - 1)); p < uintptr(addr)+uintptr(length); p += uintptr(pageSize) {
page := getMemorySliceFromUintptr(p, pageSize)
err := syscall.Mprotect(page, prot)
if err != nil {
if err := syscall.Mprotect(page, prot); err != nil {
return err
}
}
Expand All @@ -35,13 +34,11 @@ func callMProtect(addr unsafe.Pointer, length int, prot int) error {
func copyDataToPtr(ptr unsafe.Pointer, data []byte) error {
dataLength := len(data)
ptrByteSlice := getMemorySliceFromPointer(ptr, len(data))
err := callMProtect(ptr, dataLength, writeAccess)
if err != nil {
if err := callMProtect(ptr, dataLength, writeAccess); err != nil {
return err
}
copy(ptrByteSlice, data[:])
err = callMProtect(ptr, dataLength, readAccess)
if err != nil {
if err := callMProtect(ptr, dataLength, readAccess); err != nil {
return err
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion patcher_unsupported.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ import (
// Gets the jump function rewrite bytes
//go:nosplit
func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) {
return nil, errors.New(fmt.Sprintf("Unsupported architecture: %s", runtime.GOARCH))
return nil, errors.New(fmt.Sprintf("unsupported architecture: %s", runtime.GOARCH))
}
6 changes: 2 additions & 4 deletions patcher_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ func copyDataToPtr(ptr unsafe.Pointer, data []byte) error {
var oldPerms, tmp uint32
dataLength := len(data)
ptrByteSlice := getMemorySliceFromPointer(ptr, len(data))
err := callVirtualProtect(ptr, dataLength, pageExecuteReadAndWrite, unsafe.Pointer(&oldPerms))
if err != nil {
if err := callVirtualProtect(ptr, dataLength, pageExecuteReadAndWrite, unsafe.Pointer(&oldPerms)); err != nil {
return err
}
copy(ptrByteSlice, data[:])
err = callVirtualProtect(ptr, dataLength, oldPerms, unsafe.Pointer(&tmp))
if err != nil {
if err := callVirtualProtect(ptr, dataLength, oldPerms, unsafe.Pointer(&tmp)); err != nil {
return err
}
return nil
Expand Down
2 changes: 2 additions & 0 deletions patcher_x32.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package mpatch

import "unsafe"

const jumpLength = 7

// Gets the jump function rewrite bytes
//go:nosplit
func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) {
Expand Down
2 changes: 2 additions & 0 deletions patcher_x64.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package mpatch

import "unsafe"

const jumpLength = 12

// Gets the jump function rewrite bytes
//go:nosplit
func getJumpFuncBytes(to unsafe.Pointer) ([]byte, error) {
Expand Down