Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,151 @@ func ExampleNewCallback_cdecl() {

// Output: 83
}

func TestNewCallbackInt32Packing(t *testing.T) {
var result int32
cb := purego.NewCallback(func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12 int32) int32 {
result = a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12
return result
})

var fn func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12 int32) int32
purego.RegisterFunc(&fn, cb)

got := fn(2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37)
want := int32(197)

if got != want {
t.Errorf("callback returned %d, want %d", got, want)
}
}

func TestNewCallbackMixedPacking(t *testing.T) {
var gotI32_1, gotI32_2 int32
var gotI64 int64
cb := purego.NewCallback(func(r1, r2, r3, r4, r5, r6, r7, r8 int64, s1 int32, s2 int64, s3 int32) {
gotI32_1 = s1
gotI64 = s2
gotI32_2 = s3
})

var fn func(r1, r2, r3, r4, r5, r6, r7, r8 int64, s1 int32, s2 int64, s3 int32)
purego.RegisterFunc(&fn, cb)

fn(1, 2, 3, 4, 5, 6, 7, 8, 100, 200, 300)

if gotI32_1 != 100 || gotI64 != 200 || gotI32_2 != 300 {
t.Errorf("got (%d, %d, %d), want (100, 200, 300)", gotI32_1, gotI64, gotI32_2)
}
}

func TestNewCallbackSmallTypes(t *testing.T) {
var gotBool bool
var gotI8 int8
var gotU8 uint8
var gotI16 int16
var gotU16 uint16
var gotI32 int32
cb := purego.NewCallback(func(r1, r2, r3, r4, r5, r6, r7, r8 int64, b bool, i8 int8, u8 uint8, i16 int16, u16 uint16, i32 int32) {
gotBool = b
gotI8 = i8
gotU8 = u8
gotI16 = i16
gotU16 = u16
gotI32 = i32
})

var fn func(r1, r2, r3, r4, r5, r6, r7, r8 int64, b bool, i8 int8, u8 uint8, i16 int16, u16 uint16, i32 int32)
purego.RegisterFunc(&fn, cb)

fn(1, 2, 3, 4, 5, 6, 7, 8, true, -42, 200, -1000, 50000, 123456)

if !gotBool || gotI8 != -42 || gotU8 != 200 || gotI16 != -1000 || gotU16 != 50000 || gotI32 != 123456 {
t.Errorf("got (bool=%v, i8=%d, u8=%d, i16=%d, u16=%d, i32=%d), want (true, -42, 200, -1000, 50000, 123456)",
gotBool, gotI8, gotU8, gotI16, gotU16, gotI32)
}
}

func TestCallbackFromC(t *testing.T) {
libFileName := filepath.Join(t.TempDir(), "libcbpackingtest.so")

if err := buildSharedLib("CC", libFileName, filepath.Join("testdata", "libcbtest", "callback_packing_test.c")); err != nil {
t.Fatal(err)
}
defer os.Remove(libFileName)

lib, err := purego.Dlopen(libFileName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
t.Fatalf("Dlopen(%q) failed: %v", libFileName, err)
}

t.Run("int32_packing", func(t *testing.T) {
var result int32
goCallback := func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12 int32) int32 {
result = a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12
return result
}

var callCallbackInt32Packing func(uintptr) int32
purego.RegisterLibFunc(&callCallbackInt32Packing, lib, "callCallbackInt32Packing")

cb := purego.NewCallback(goCallback)
got := callCallbackInt32Packing(cb)
want := int32(197) // sum of primes: 2+3+5+7+11+13+17+19+23+29+31+37

if got != want {
t.Errorf("C called callback returned %d, want %d", got, want)
}
if result != want {
t.Errorf("callback received wrong args, sum=%d, want %d", result, want)
}
})

t.Run("mixed_packing", func(t *testing.T) {
var gotI32_1, gotI32_2 int32
var gotI64 int64
goCallback := func(r1, r2, r3, r4, r5, r6, r7, r8 int64, s1 int32, s2 int64, s3 int32) {
gotI32_1 = s1
gotI64 = s2
gotI32_2 = s3
}

var callCallbackMixedPacking func(uintptr)
purego.RegisterLibFunc(&callCallbackMixedPacking, lib, "callCallbackMixedPacking")

cb := purego.NewCallback(goCallback)
callCallbackMixedPacking(cb)

if gotI32_1 != 100 || gotI64 != 200 || gotI32_2 != 300 {
t.Errorf("callback received (%d, %d, %d), want (100, 200, 300)", gotI32_1, gotI64, gotI32_2)
}
})

t.Run("small_types", func(t *testing.T) {
var gotBool bool
var gotI8 int8
var gotU8 uint8
var gotI16 int16
var gotU16 uint16
var gotI32 int32
goCallback := func(r1, r2, r3, r4, r5, r6, r7, r8 int64, b bool, i8 int8, u8 uint8, i16 int16, u16 uint16, i32 int32) {
gotBool = b
gotI8 = i8
gotU8 = u8
gotI16 = i16
gotU16 = u16
gotI32 = i32
}

var callCallbackSmallTypes func(uintptr)
purego.RegisterLibFunc(&callCallbackSmallTypes, lib, "callCallbackSmallTypes")

cb := purego.NewCallback(goCallback)
callCallbackSmallTypes(cb)

if !gotBool || gotI8 != -42 || gotU8 != 200 || gotI16 != -1000 || gotU16 != 50000 || gotI32 != 123456 {
t.Errorf("callback received (bool=%v, i8=%d, u8=%d, i16=%d, u16=%d, i32=%d), want (true, -42, 200, -1000, 50000, 123456)",
gotBool, gotI8, gotU8, gotI16, gotU16, gotI32)
}
})
}
76 changes: 34 additions & 42 deletions func.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"math"
"reflect"
"runtime"
"strconv"
"sync"
"unsafe"

Expand Down Expand Up @@ -131,9 +130,8 @@ func RegisterFunc(fptr any, cfn uintptr) {
{
// this code checks how many registers and stack this function will use
// to avoid crashing with too many arguments
var ints int
var floats int
var stack int
var numInts, numFloats int
var stackBytes int
for i := 0; i < ty.NumIn(); i++ {
arg := ty.In(i)
switch arg.Kind() {
Expand All @@ -153,20 +151,20 @@ func RegisterFunc(fptr any, cfn uintptr) {
case reflect.String, reflect.Uintptr, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Ptr, reflect.UnsafePointer,
reflect.Slice, reflect.Bool:
if ints < numOfIntegerRegisters() {
ints++
if numInts < numOfIntegerRegisters() {
numInts++
} else {
stack++
stackBytes++
}
case reflect.Float32, reflect.Float64:
const is32bit = unsafe.Sizeof(uintptr(0)) == 4
if is32bit {
panic("purego: floats only supported on 64bit platforms")
}
if floats < numOfFloatRegisters {
floats++
if numFloats < numOfFloatRegisters {
numFloats++
} else {
stack++
stackBytes++
}
case reflect.Struct:
if runtime.GOOS != "darwin" || (runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64") {
Expand All @@ -176,15 +174,15 @@ func RegisterFunc(fptr any, cfn uintptr) {
continue
}
addInt := func(u uintptr) {
ints++
numInts++
}
addFloat := func(u uintptr) {
floats++
numFloats++
}
addStack := func(u uintptr) {
stack++
stackBytes++
}
_ = addStruct(reflect.New(arg).Elem(), &ints, &floats, &stack, addInt, addFloat, addStack, nil)
_ = addStruct(reflect.New(arg).Elem(), &numInts, &numFloats, &stackBytes, addInt, addFloat, addStack, nil)
default:
panic("purego: unsupported kind " + arg.Kind().String())
}
Expand All @@ -198,12 +196,23 @@ func RegisterFunc(fptr any, cfn uintptr) {
if runtime.GOARCH == "amd64" && outType.Size() > maxRegAllocStructSize {
// on amd64 if struct is bigger than 16 bytes allocate the return struct
// and pass it in as a hidden first argument.
ints++
numInts++
}
}

sizeOfStack := maxArgs - numOfIntegerRegisters()
if stack > sizeOfStack {
panic("purego: too many arguments")
// On Darwin ARM64, use byte-based validation since arguments pack efficiently.
// See: https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
actualStackBytes := estimateStackBytes(ty)
maxStackBytes := sizeOfStack * 8
if actualStackBytes > maxStackBytes {
panic("purego: too many stack arguments")
}
} else {
if stackBytes > sizeOfStack {
panic("purego: too many stack arguments")
}
}
}
v := reflect.MakeFunc(ty, func(args []reflect.Value) (results []reflect.Value) {
Expand Down Expand Up @@ -281,28 +290,15 @@ func RegisterFunc(fptr any, cfn uintptr) {
}
continue
}
if runtime.GOARCH == "arm64" && runtime.GOOS == "darwin" &&
(numInts >= numOfIntegerRegisters() || numFloats >= numOfFloatRegisters) && v.Kind() != reflect.Struct { // hit the stack
fields := make([]reflect.StructField, len(args[i:]))
// Check if we need to start Darwin ARM64 C-style stack packing
if runtime.GOARCH == "arm64" && runtime.GOOS == "darwin" && shouldBundleStackArgs(v, numInts, numFloats) {
// Collect and separate remaining args into register vs stack
stackArgs, newKeepAlive := collectStackArgs(args, i, numInts, numFloats,
keepAlive, addInt, addFloat, addStack, &numInts, &numFloats, &numStack)
keepAlive = newKeepAlive

for j, val := range args[i:] {
if val.Kind() == reflect.String {
ptr := strings.CString(val.String())
keepAlive = append(keepAlive, ptr)
val = reflect.ValueOf(ptr)
args[i+j] = val
}
fields[j] = reflect.StructField{
Name: "X" + strconv.Itoa(j),
Type: val.Type(),
}
}
structType := reflect.StructOf(fields)
structInstance := reflect.New(structType).Elem()
for j, val := range args[i:] {
structInstance.Field(j).Set(val)
}
placeRegisters(structInstance, addFloat, addInt)
// Bundle stack arguments with C-style packing
bundleStackArgs(stackArgs, addStack)
break
}
keepAlive = addValue(v, keepAlive, addInt, addFloat, addStack, &numInts, &numFloats, &numStack)
Expand Down Expand Up @@ -472,10 +468,6 @@ func checkStructFieldsSupported(ty reflect.Type) {
}
}

func roundUpTo8(val uintptr) uintptr {
return (val + 7) &^ 7
}

func numOfIntegerRegisters() int {
switch runtime.GOARCH {
case "arm64", "loong64":
Expand Down
Loading