Skip to content

Move error handling to C code #265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 148 additions & 46 deletions cmd/mkcgo/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,13 @@ func generateGo(src *mkcgo.Source, w io.Writer) {
// This block outputs C header includes and forward declarations for loader functions.
fmt.Fprintf(w, "/*\n")
fmt.Fprintf(w, "#cgo CFLAGS: -Wno-attributes\n\n")
if *includeHeader != "" {
fmt.Fprintf(w, "#include \"%s\"\n", *includeHeader)
}
for _, file := range src.Files {
fmt.Fprintf(w, "#include %q\n", file)
}
fmt.Fprintf(w, "\n")
for _, tag := range src.Tags() {
fmt.Fprintf(w, "void __mkcgoLoad_%s(void* handle);\n", tag)
fmt.Fprintf(w, "void __mkcgoUnload_%s();\n", tag)
}
fmt.Fprintf(w, "\n")
for _, fn := range src.Funcs {
if fn.Optional {
fmt.Fprintf(w, "int %s_Available();\n", fn.ImportName)
}
if *includeHeader != "" {
fmt.Fprintf(w, "#include \"%s\"\n", *includeHeader)
}
fmt.Fprintf(w, "#include \"%s\"\n", autogeneratedFileName(".h"))
fmt.Fprintf(w, "*/\n")
fmt.Fprintf(w, "import \"C\"\n")
fmt.Fprintf(w, "import \"unsafe\"\n\n")
Expand All @@ -56,10 +46,14 @@ func generateGo(src *mkcgo.Source, w io.Writer) {
fmt.Fprintf(w, "}\n\n")
}

typedefs := make(map[string]string, len(src.TypeDefs))
for _, def := range src.TypeDefs {
typedefs[def.Name] = def.Type
}
// Generate error wrapper noescape function, which hides the
// error state pointer from the Go garbage collector.
// An instance of https://github.com/golang/go/blob/d704ef76068eb7da15520b08dc7df98f45f85ffa/src/runtime/stubs.go#L194-L201
fmt.Fprintf(w, "//go:nosplit\n")
fmt.Fprintf(w, "func %s(p *C.%s) *C.%s {\n", mkcgoNoEscape, mkcgoErrState, mkcgoErrState)
fmt.Fprintf(w, "\tx := uintptr(unsafe.Pointer(p))\n")
fmt.Fprintf(w, "\treturn (*C.%s)(unsafe.Pointer(x ^ 0))\n", mkcgoErrState)
fmt.Fprintf(w, "}\n\n")

// Generate function wrappers.
for _, fn := range src.Funcs {
Expand All @@ -73,7 +67,7 @@ func generateGo(src *mkcgo.Source, w io.Writer) {
fmt.Fprintf(w, "\treturn C.%s_Available() != 0\n", fn.ImportName)
fmt.Fprintf(w, "}\n\n")
}
generateGoFn(typedefs, fn, w)
generateGoFn(fn, w)
}
}

Expand All @@ -87,11 +81,15 @@ func generateGo124(src *mkcgo.Source, w io.Writer) {
// This block outputs C header includes and forward declarations for loader functions.
fmt.Fprintf(w, "/*\n")
for _, fn := range src.Funcs {
name := fn.CName
if fnNeedErrWrapper(fn) {
name = fnCErrWrapperName(fn)
}
if fn.NoEscape {
fmt.Fprintf(w, "#cgo noescape %s\n", fn.CName)
fmt.Fprintf(w, "#cgo noescape %s\n", name)
}
if fn.NoCallback {
fmt.Fprintf(w, "#cgo nocallback %s\n", fn.CName)
fmt.Fprintf(w, "#cgo nocallback %s\n", name)
}
}
fmt.Fprintf(w, "*/\n")
Expand Down Expand Up @@ -148,20 +146,71 @@ func generateGoAliases(funcs []*mkcgo.Func, w io.Writer) {
}
}

// generateC creates the C source file content.
func generateC(src *mkcgo.Source, w io.Writer) {
// generateCHeader generates C header file content with
// the C functions defined in the autogenerated C source file.
func generateCHeader(src *mkcgo.Source, w io.Writer) {
// Header and includes.
fmt.Fprintf(w, "// Code generated by mkcgo. DO NOT EDIT.\n\n")

fmt.Fprintf(w, "#ifndef MKCGO_H // only include this header once\n")
fmt.Fprintf(w, "#define MKCGO_H\n\n")

for _, file := range src.Files {
fmt.Fprintf(w, "#include %q\n", file)
}
if *includeHeader != "" {
fmt.Fprintf(w, "#include \"%s\"\n", *includeHeader)
}
for _, file := range src.Files {
fmt.Fprintf(w, "#include %q\n", file)
fmt.Fprintf(w, "\n")

// Custom types
fmt.Fprintf(w, "typedef void* %s;\n", mkcgoErrState)
fmt.Fprintf(w, "%s mkcgo_err_retrieve();\n", mkcgoErrState)
fmt.Fprintf(w, "void mkcgo_err_free(%s);\n", mkcgoErrState)
fmt.Fprintf(w, "void mkcgo_err_clear();\n\n")

// Add forward declarations for loader functions.
for _, tag := range src.Tags() {
fmt.Fprintf(w, "void __mkcgoLoad_%s(void* handle);\n", tag)
fmt.Fprintf(w, "void __mkcgoUnload_%s();\n", tag)
}
fmt.Fprintf(w, "\n")

// Add forward declarations for optional functions.
for _, fn := range src.Funcs {
if fn.Optional {
fmt.Fprintf(w, "int %s_Available();\n", fn.ImportName)
}
}
fmt.Fprintf(w, "\n")

// Add forward declarations for function wrappers returning errors.
for _, fn := range src.Funcs {
if !fnNeedErrWrapper(fn) {
continue
}
fmt.Fprintf(w, "%s %s(%s);\n", fn.Ret.Type, fnCErrWrapperName(fn), fnCErrWrapperParams(fn, false))
}
fmt.Fprintf(w, "\n")
fmt.Fprintf(w, "#endif // MKCGO_H\n")
}

// generateC creates the C source file content.
func generateC(src *mkcgo.Source, w io.Writer) {
// Header and includes.
fmt.Fprintf(w, "// Code generated by mkcgo. DO NOT EDIT.\n\n")

fmt.Fprintf(w, "#include <stddef.h>\n")
fmt.Fprintf(w, "#include <stdlib.h>\n")
fmt.Fprintf(w, "#include <stdint.h>\n")
fmt.Fprintf(w, "#include <stdio.h>\n")
for _, file := range src.Files {
fmt.Fprintf(w, "#include %q\n", file)
}
if *includeHeader != "" {
fmt.Fprintf(w, "#include \"%s\"\n", *includeHeader)
}
fmt.Fprintf(w, "#include \"%s\"\n", autogeneratedFileName(".h"))
fmt.Fprintf(w, "\n")

// Platform-specific includes.
Expand Down Expand Up @@ -238,6 +287,10 @@ func generateC(src *mkcgo.Source, w io.Writer) {
}

// Generate C function wrappers.
typedefs := make(map[string]string, len(src.TypeDefs))
for _, def := range src.TypeDefs {
typedefs[def.Name] = def.Type
}
for _, fn := range src.Funcs {
if fn.Variadic() {
// cgo doesn't support variadic functions
Expand All @@ -250,11 +303,12 @@ func generateC(src *mkcgo.Source, w io.Writer) {
fmt.Fprintf(w, "}\n\n")
}
generateCFn(fn, w)
generateCFnErrorWrapper(typedefs, fn, w)
}
}

// generateGoFn generates Go function f.
func generateGoFn(typedefs map[string]string, fn *mkcgo.Func, w io.Writer) {
func generateGoFn(fn *mkcgo.Func, w io.Writer) {
fnCall := fmt.Sprintf("C.%s(%s)", fn.CName, fnToGoArgs(fn))
// Function definition
fmt.Fprintf(w, "func %s(%s)", fn.GoName, fnToGoParams(fn))
Expand Down Expand Up @@ -296,21 +350,13 @@ func generateGoFn(typedefs map[string]string, fn *mkcgo.Func, w io.Writer) {
fmt.Fprintf(w, "}\n\n")
return
}
fmt.Fprintf(w, "\t_ret := C.%s(%s)\n", fn.CName, fnToGoArgs(fn))

// Error handling
errCond := "<= 0"
if fn.ErrCond != "" {
errCond = fn.ErrCond
} else if strings.Contains(goType, "unsafe.Pointer") {
errCond = "== nil"
} else if typ, ok := typedefs[goType]; ok && typ == "void*" {
errCond = "== nil"
fmt.Fprintf(w, "\tvar _err C.%s\n", mkcgoErrState)
fmt.Fprintf(w, "\t_ret := C.%s(", fnCErrWrapperName(fn))
args := fnToGoArgs(fn)
if len(args) > 0 {
args += ", "
}
fmt.Fprintf(w, "\tvar _err error\n")
fmt.Fprintf(w, "\tif _ret %s {\n", errCond)
fmt.Fprintf(w, "\t\t_err = newOpenSSLError(\"%s\")\n", fn.CName)
fmt.Fprintf(w, "\t}\n")
fmt.Fprintf(w, "%s%s(&_err))\n", args, mkcgoNoEscape)

// Return the value
fmt.Fprintf(w, "\treturn ")
Expand All @@ -322,16 +368,38 @@ func generateGoFn(typedefs map[string]string, fn *mkcgo.Func, w io.Writer) {
} else {
fmt.Fprintf(w, "_ret")
}
fmt.Fprintf(w, ", _err\n")
fmt.Fprintf(w, ", newMkcgoErr(%q, _err)\n", fn.CName)
fmt.Fprintf(w, "}\n\n")
}

func generateCFn(fn *mkcgo.Func, w io.Writer) {
fmt.Fprintf(w, "%s %s(%s) {\n\t", fn.Ret.Type, fn.CName, fnToCArgs(fn, true))
fmt.Fprintf(w, "%s %s(%s) {\n\t", fn.Ret.Type, fn.CName, fnToCArgs(fn, true, true))
if !retIsVoid(fn.Ret) {
fmt.Fprintf(w, "return ")
}
fmt.Fprintf(w, "_g_%s(%s);\n", fn.ImportName, fnToCArgs(fn, false))
fmt.Fprintf(w, "_g_%s(%s);\n", fn.ImportName, fnToCArgs(fn, false, true))
fmt.Fprintf(w, "}\n\n")
}

// generateCFnErrorWrapper generates C function wrapper for function f
// that returns an error state.
func generateCFnErrorWrapper(typedefs map[string]string, fn *mkcgo.Func, w io.Writer) {
if !fnNeedErrWrapper(fn) {
return
}
fmt.Fprintf(w, "%s %s(%s) {\n", fn.Ret.Type, fnCErrWrapperName(fn), fnCErrWrapperParams(fn, true))
fmt.Fprintf(w, "\tmkcgo_err_clear();\n") // clear any previous error
fmt.Fprintf(w, "\t%s _ret = _g_%s(%s);\n", fn.Ret.Type, fn.ImportName, fnToCArgs(fn, false, true))
errCond := "<= 0"
if fn.ErrCond != "" {
errCond = fn.ErrCond
} else if strings.Contains(fn.Ret.Type, "*") {
errCond = "== NULL"
} else if typ, ok := typedefs[fn.Ret.Type]; ok && typ == "void*" {
errCond = "== NULL"
}
fmt.Fprintf(w, "\tif (_ret %s) *_err_state = mkcgo_err_retrieve();\n", errCond)
fmt.Fprintf(w, "\treturn _ret;\n")
fmt.Fprintf(w, "}\n\n")
}

Expand Down Expand Up @@ -436,12 +504,15 @@ func cTypeToGo(t string, cgo bool) (string, bool) {
}

// paramToC returns C source code of parameter p.
func paramToC(i int, p *mkcgo.Param, addType bool) string {
func paramToC(i int, p *mkcgo.Param, addType, addName bool) string {
if p.Type == "..." {
return ""
}
var s string
if addType {
s += p.Type
}
if p.Type != "void" && p.Type != "..." {
if addName && p.Type != "void" {
if len(s) > 0 {
s += " "
}
Expand Down Expand Up @@ -470,9 +541,9 @@ func fnToGoArgs(fn *mkcgo.Func) string {
}

// fnToCArgs returns source code for C parameters for function f.
func fnToCArgs(fn *mkcgo.Func, addType bool) string {
func fnToCArgs(fn *mkcgo.Func, addType, addName bool) string {
return join(fn.Params, func(i int, p *mkcgo.Param) string {
return paramToC(i, p, addType)
return paramToC(i, p, addType, addName)
}, ", ")
}

Expand All @@ -492,3 +563,34 @@ func join(ps []*mkcgo.Param, fn func(int, *mkcgo.Param) string, sep string) stri
}
return strings.Join(params, sep)
}

const mkcgoNoEscape = "mkcgoNoEscape"
const mkcgoErrState = "mkcgo_err_state"

// fnCErrWrapperParams returns source code for C parameters for function f
// with the error state added as the last parameter.
func fnCErrWrapperParams(fn *mkcgo.Func, addName bool) string {
errArg := mkcgoErrState + " *"
if addName {
errArg += "_err_state"
}
args := fnToCArgs(fn, true, addName)
if len(args) == 0 {
args = errArg
} else if args == "void" {
args = errArg
} else {
args += ", " + errArg
}
return args
}

// fnCErrWrapperName returns the name of the error wrapper function for function f.
func fnCErrWrapperName(fn *mkcgo.Func) string {
return "_mkcgo_err_" + fn.CName
}

// fnNeedErrWrapper reports whether function fn needs an error wrapper.
func fnNeedErrWrapper(fn *mkcgo.Func) bool {
return !fn.NoError && !retIsVoid(fn.Ret)
}
38 changes: 23 additions & 15 deletions cmd/mkcgo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,45 +40,53 @@ func main() {
log.Fatal(err)
}

var gobuf, go124buf, cbuf bytes.Buffer
var gobuf, go124buf, hbuf, cbuf bytes.Buffer
generateGo(src, &gobuf)
generateGo124(src, &go124buf)
generateCHeader(src, &hbuf)
generateC(src, &cbuf)

// Format the generated Go source code.
godata := goformat(gobuf.Bytes())
go124data := goformat(go124buf.Bytes())

var baseName string
if *fileName == "" {
baseName = "mkcgo"
} else {
baseName = strings.TrimSuffix(*fileName, ".go")
}

for _, d := range []struct {
name string
data []byte
suffix string
data []byte
}{
{baseName + ".go", godata},
{baseName + "_go124.go", go124data},
{baseName + ".c", cbuf.Bytes()},
{".go", godata},
{"_go124.go", go124data},
{".h", hbuf.Bytes()},
{".c", cbuf.Bytes()},
} {
name := autogeneratedFileName(d.suffix)
var err error
if *fileName == "" {
// Write output. If no explicit output file is specified,
// // write both Go and C output to stdout.
os.Stdout.WriteString("// === " + d.name + " ===\n\n")
os.Stdout.WriteString("// === " + name + " ===\n\n")
_, err = os.Stdout.Write(d.data)
} else {
err = os.WriteFile(d.name, d.data, 0o644)
err = os.WriteFile(name, d.data, 0o644)
}
if err != nil {
log.Fatal(err)
}
}
}

// autogeneratedFileName returns the name of the autogenerated file
// using the provided suffix.
func autogeneratedFileName(suffix string) string {
var baseName string
if *fileName == "" {
baseName = "mkcgo"
} else {
baseName = strings.TrimSuffix(*fileName, ".go")
}
return baseName + suffix
}

func writeTempSourceFile(data []byte) (string, error) {
f, err := os.CreateTemp("", "mkcgo-generated-*.go")
if err != nil {
Expand Down
Loading
Loading