Skip to content

Move error handling to the autogenerated Go code #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 41 additions & 37 deletions cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func loadCipher(k cipherKind, mode cipherMode) (cipher _EVP_CIPHER_PTR) {
// not created by EVP_CIPHER has negative performance
// implications, as cipher operations will have
// to fetch it on every call. Better to just fetch it once here.
cipher = go_openssl_EVP_CIPHER_fetch(nil, go_openssl_EVP_CIPHER_get0_name(cipher), nil)
cipher, _ = go_openssl_EVP_CIPHER_fetch(nil, go_openssl_EVP_CIPHER_get0_name(cipher), nil)
}
cacheCipher.Store(cacheCipherKey{k, mode}, cipher)
}()
Expand Down Expand Up @@ -178,8 +178,8 @@ func (c *evpCipher) encrypt(dst, src []byte) error {
defer go_openssl_EVP_CIPHER_CTX_free(enc_ctx)

var outl int32
if go_openssl_EVP_EncryptUpdate(enc_ctx, base(dst), &outl, base(src), int32(c.blockSize)) != 1 {
return errors.New("EncryptUpdate failed")
if _, err := go_openssl_EVP_EncryptUpdate(enc_ctx, base(dst), &outl, base(src), int32(c.blockSize)); err != nil {
return err
}
runtime.KeepAlive(c)
return nil
Expand All @@ -203,8 +203,8 @@ func (c *evpCipher) decrypt(dst, src []byte) error {
}
defer go_openssl_EVP_CIPHER_CTX_free(dec_ctx)

if go_openssl_EVP_CIPHER_CTX_set_padding(dec_ctx, 0) != 1 {
return errors.New("could not disable cipher padding")
if _, err := go_openssl_EVP_CIPHER_CTX_set_padding(dec_ctx, 0); err != nil {
return err
}

var outl int32
Expand Down Expand Up @@ -236,19 +236,19 @@ func (x *cipherCBC) CryptBlocks(dst, src []byte) {
}
if len(src) > 0 {
var outl int32
if go_openssl_EVP_CipherUpdate(x.ctx, base(dst), &outl, base(src), int32(len(src))) != 1 {
panic("crypto/cipher: CipherUpdate failed")
if _, err := go_openssl_EVP_CipherUpdate(x.ctx, base(dst), &outl, base(src), int32(len(src))); err != nil {
panic("crypto/cipher: " + err.Error())
}
runtime.KeepAlive(x)
}
}

func (x *cipherCBC) SetIV(iv []byte) {
if len(iv) != x.blockSize {
panic("cipher: incorrect length IV")
panic("crypto/cipher: incorrect length IV")
}
if go_openssl_EVP_CipherInit_ex(x.ctx, nil, nil, nil, base(iv), int32(cipherOpNone)) != 1 {
panic("cipher: unable to initialize EVP cipher ctx")
if _, err := go_openssl_EVP_CipherInit_ex(x.ctx, nil, nil, nil, base(iv), int32(cipherOpNone)); err != nil {
panic("crypto/cipher: " + err.Error())
}
}

Expand All @@ -259,8 +259,8 @@ func (c *evpCipher) newCBC(iv []byte, op cipherOp) cipher.BlockMode {
}
x := &cipherCBC{ctx: ctx, blockSize: c.blockSize}
runtime.SetFinalizer(x, (*cipherCBC).finalize)
if go_openssl_EVP_CIPHER_CTX_set_padding(x.ctx, 0) != 1 {
panic("cipher: unable to set padding")
if _, err := go_openssl_EVP_CIPHER_CTX_set_padding(x.ctx, 0); err != nil {
panic("crypto/cipher: " + err.Error())
}
return x
}
Expand All @@ -280,8 +280,8 @@ func (x *cipherCTR) XORKeyStream(dst, src []byte) {
return
}
var outl int32
if go_openssl_EVP_EncryptUpdate(x.ctx, base(dst), &outl, base(src), int32(len(src))) != 1 {
panic("crypto/cipher: EncryptUpdate failed")
if _, err := go_openssl_EVP_EncryptUpdate(x.ctx, base(dst), &outl, base(src), int32(len(src))); err != nil {
panic("crypto/cipher: " + err.Error())
}
runtime.KeepAlive(x)
}
Expand Down Expand Up @@ -455,22 +455,24 @@ func (g *cipherGCM) Seal(dst, nonce, plaintext, aad []byte) []byte {
// relying in the explicit nonce being securely set externally,
// and it also gives some interesting speed gains.
// Unfortunately we can't use it because Go expects AEAD.Seal to honor the provided nonce.
if go_openssl_EVP_EncryptInit_ex(ctx, nil, nil, nil, base(nonce)) != 1 {
panic(newOpenSSLError("EVP_EncryptInit_ex"))
if _, err := go_openssl_EVP_EncryptInit_ex(ctx, nil, nil, nil, base(nonce)); err != nil {
panic(err)
}
var outl, discard int32
if go_openssl_EVP_EncryptUpdate(ctx, nil, &discard, baseNeverEmpty(aad), int32(len(aad))) != 1 ||
go_openssl_EVP_EncryptUpdate(ctx, base(out), &outl, baseNeverEmpty(plaintext), int32(len(plaintext))) != 1 {
panic(newOpenSSLError("EVP_EncryptUpdate"))
if _, err := go_openssl_EVP_EncryptUpdate(ctx, nil, &discard, baseNeverEmpty(aad), int32(len(aad))); err != nil {
panic(err)
}
if _, err := go_openssl_EVP_EncryptUpdate(ctx, base(out), &outl, baseNeverEmpty(plaintext), int32(len(plaintext))); err != nil {
panic(err)
}
if len(plaintext) != int(outl) {
panic("cipher: incorrect length returned from GCM EncryptUpdate")
}
if go_openssl_EVP_EncryptFinal_ex(ctx, base(out[outl:]), &discard) != 1 {
panic(newOpenSSLError("EVP_EncryptFinal_ex"))
if _, err := go_openssl_EVP_EncryptFinal_ex(ctx, base(out[outl:]), &discard); err != nil {
panic(err)
}
if go_openssl_EVP_CIPHER_CTX_ctrl(ctx, _EVP_CTRL_GCM_GET_TAG, 16, unsafe.Pointer(base(out[outl:]))) != 1 {
panic(newOpenSSLError("EVP_CIPHER_CTX_ctrl"))
if _, err := go_openssl_EVP_CIPHER_CTX_ctrl(ctx, _EVP_CTRL_GCM_GET_TAG, 16, unsafe.Pointer(base(out[outl:]))); err != nil {
panic(err)
}
runtime.KeepAlive(g)
return ret
Expand Down Expand Up @@ -515,21 +517,23 @@ func (g *cipherGCM) Open(dst, nonce, ciphertext, aad []byte) (_ []byte, err erro
}
}
}()
if go_openssl_EVP_DecryptInit_ex(ctx, nil, nil, nil, base(nonce)) != 1 {
if _, err := go_openssl_EVP_DecryptInit_ex(ctx, nil, nil, nil, base(nonce)); err != nil {
return nil, errOpen
}
if go_openssl_EVP_CIPHER_CTX_ctrl(ctx, _EVP_CTRL_GCM_SET_TAG, 16, unsafe.Pointer(base(tag))) != 1 {
if _, err := go_openssl_EVP_CIPHER_CTX_ctrl(ctx, _EVP_CTRL_GCM_SET_TAG, 16, unsafe.Pointer(base(tag))); err != nil {
return nil, errOpen
}
var outl, discard int32
if go_openssl_EVP_DecryptUpdate(ctx, nil, &discard, baseNeverEmpty(aad), int32(len(aad))) != 1 ||
go_openssl_EVP_DecryptUpdate(ctx, base(out), &outl, baseNeverEmpty(ciphertext), int32(len(ciphertext))) != 1 {
if _, err := go_openssl_EVP_DecryptUpdate(ctx, nil, &discard, baseNeverEmpty(aad), int32(len(aad))); err != nil {
return nil, errOpen
}
if _, err := go_openssl_EVP_DecryptUpdate(ctx, base(out), &outl, baseNeverEmpty(ciphertext), int32(len(ciphertext))); err != nil {
return nil, errOpen
}
if len(ciphertext) != int(outl) {
return nil, errOpen
}
if go_openssl_EVP_DecryptFinal_ex(ctx, base(out[outl:]), &discard) != 1 {
if _, err := go_openssl_EVP_DecryptFinal_ex(ctx, base(out[outl:]), &discard); err != nil {
return nil, errOpen
}
runtime.KeepAlive(g)
Expand All @@ -553,9 +557,9 @@ func newCipherCtx(kind cipherKind, mode cipherMode, encrypt cipherOp, key, iv []
if cipher == nil {
panic("crypto/cipher: unsupported cipher: " + kind.String())
}
ctx := go_openssl_EVP_CIPHER_CTX_new()
if ctx == nil {
return nil, fail("unable to create EVP cipher ctx")
ctx, err := go_openssl_EVP_CIPHER_CTX_new()
if err != nil {
return nil, err
}
defer func() {
if err != nil {
Expand All @@ -566,17 +570,17 @@ func newCipherCtx(kind cipherKind, mode cipherMode, encrypt cipherOp, key, iv []
// RC4 cipher supports a variable key length.
// We need to set the key length before setting the key,
// and to do so we need to have an initialized cipher ctx.
if go_openssl_EVP_CipherInit_ex(ctx, cipher, nil, nil, nil, int32(encrypt)) != 1 {
return nil, newOpenSSLError("EVP_CipherInit_ex")
if _, err := go_openssl_EVP_CipherInit_ex(ctx, cipher, nil, nil, nil, int32(encrypt)); err != nil {
return nil, err
}
if go_openssl_EVP_CIPHER_CTX_set_key_length(ctx, int32(len(key))) != 1 {
return nil, newOpenSSLError("EVP_CIPHER_CTX_set_key_length")
if _, err := go_openssl_EVP_CIPHER_CTX_set_key_length(ctx, int32(len(key))); err != nil {
return nil, err
}
// Pass nil to the next call to EVP_CipherInit_ex to avoid resetting ctx's cipher.
cipher = nil
}
if go_openssl_EVP_CipherInit_ex(ctx, cipher, nil, base(key), base(iv), int32(encrypt)) != 1 {
return nil, newOpenSSLError("unable to initialize EVP cipher ctx")
if _, err := go_openssl_EVP_CipherInit_ex(ctx, cipher, nil, base(key), base(iv), int32(encrypt)); err != nil {
return nil, err
}
return ctx, nil
}
Expand Down
121 changes: 87 additions & 34 deletions cmd/mkcgo/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ func generateGo(src *mkcgo.Source, w io.Writer) {
fmt.Fprintf(w, "}\n\n")
}

typedefs := make(map[string]string, len(src.TypeDefs))
for _, def := range src.TypeDefs {
typedefs[def.Name] = def.Type
}

// Generate function wrappers.
for _, fn := range src.Funcs {
if fn.Variadic() {
Expand All @@ -66,7 +71,7 @@ func generateGo(src *mkcgo.Source, w io.Writer) {
fmt.Fprintf(w, "\treturn C.%s_Available() != 0\n", fn.ImportName)
fmt.Fprintf(w, "}\n\n")
}
generateGoFn(fn, w)
generateGoFn(typedefs, fn, w)
}
}

Expand Down Expand Up @@ -226,31 +231,75 @@ func generateC(src *mkcgo.Source, w io.Writer) {
}

// generateGoFn generates Go function f.
func generateGoFn(fn *mkcgo.Func, w io.Writer) {
func generateGoFn(typedefs map[string]string, fn *mkcgo.Func, w io.Writer) {
fnCall := fmt.Sprintf("C.%s(%s)", fn.CName, fnToGoArgs(fn))
// Function definition
fmt.Fprintf(w, "func %s(%s)", fn.GoName, fnToGoParams(fn))
if !retIsVoid(fn.Ret) {
fmt.Fprintf(w, " %s ", cTypeToGo(fn.Ret.Type, false))
if retIsVoid(fn.Ret) {
// Easy path, just call the C function. No need to write the return types,
// nor do error handling, nor cast the return value.
fmt.Fprintf(w, "{\n")
fmt.Fprintf(w, "\t%s\n", fnCall)
fmt.Fprintf(w, "}\n\n")
return
}
typ, _ := cTypeToGo(fn.Ret.Type, false)
if fn.NoError {
fmt.Fprintf(w, " %s ", typ)
} else {
fmt.Fprintf(w, " (%s, error) ", typ)
}
fmt.Fprintf(w, "{\n")
fmt.Fprintf(w, "\t")
var closePar int
if !retIsVoid(fn.Ret) {
fmt.Fprintf(w, "return ")
goType := cTypeToGo(fn.Ret.Type, false)
if goType != "" && goType != fn.Ret.Type {
closePar++
if goType[0] == '*' {
goType = fmt.Sprintf("(%s)(unsafe.Pointer", goType)
closePar++

// Function call
var needUnsafeCast bool
goType, needCast := cTypeToGo(fn.Ret.Type, false)
if needCast && goType[0] == '*' {
goType = fmt.Sprintf("(%s)(unsafe.Pointer", goType)
needUnsafeCast = true
}
if fn.NoError {
// No error handling, just cast the return value if necessary.
fmt.Fprintf(w, "\treturn ")
if needCast {
fmt.Fprintf(w, "%s(%s)", goType, fnCall)
if needUnsafeCast {
fmt.Fprintf(w, ")")
}
fmt.Fprintf(w, "%s(", goType)
} else {
fmt.Fprintf(w, "%s", fnCall)
}
fmt.Fprintf(w, "\n")
fmt.Fprintf(w, "}\n\n")
return
}
fmt.Fprintf(w, "C.%s(%s)", fn.CName, fnToGoArgs(fn))
if closePar > 0 {
fmt.Fprint(w, strings.Repeat(")", closePar))
fmt.Fprintf(w, "\t_ret := C.%s(%s)\n", fn.CName, fnToGoArgs(fn))

// Error handling
errCond := "<= 0"
if fn.ErrCond != "" {
errCond = fn.ErrCond
} else if strings.Contains(goType, "unsafe.Pointer") {
errCond = "== nil"
} else if typ, ok := typedefs[goType]; ok && typ == "void*" {
errCond = "== nil"
}
fmt.Fprintf(w, "\n")
fmt.Fprintf(w, "\tvar _err error\n")
fmt.Fprintf(w, "\tif _ret %s {\n", errCond)
fmt.Fprintf(w, "\t\t_err = newOpenSSLError(\"%s\")\n", fn.CName)
fmt.Fprintf(w, "\t}\n")

// Return the value
fmt.Fprintf(w, "\treturn ")
if needCast {
fmt.Fprintf(w, "%s(_ret)", goType)
if needUnsafeCast {
fmt.Fprintf(w, ")")
}
} else {
fmt.Fprintf(w, "_ret")
}
fmt.Fprintf(w, ", _err\n")
fmt.Fprintf(w, "}\n\n")
}

Expand All @@ -265,7 +314,10 @@ func generateCFn(fn *mkcgo.Func, w io.Writer) {

// paramToGo converts C parameter p to Go parameter.
func paramToGo(p *mkcgo.Param) string {
goType := cTypeToGo(p.Type, true)
goType, needCast := cTypeToGo(p.Type, true)
if !needCast {
return p.Name
}
switch {
case goType == "unsafe.Pointer" || goType == "":
return p.Name
Expand Down Expand Up @@ -318,13 +370,17 @@ var cstdTypesToCgo = map[string]string{
}

// cTypeToGo converts C type t to a Go type.
func cTypeToGo(t string, cgo bool) string {
// If cgo is true, it returns the type that can be
// passed to a cgo function call.
// It returns the Go type and a boolean that reports whether
// the type needs to be casted to goType or not.
func cTypeToGo(t string, cgo bool) (string, bool) {
t, _ = strings.CutPrefix(t, "const ")
if t == "void" {
return ""
return "", true
}
if strings.HasPrefix(t, "void*") {
return "unsafe.Pointer"
return "unsafe.Pointer", false
}
if strings.HasSuffix(t, "*") {
// Remove all trailing '*' characters.
Expand All @@ -335,29 +391,25 @@ func cTypeToGo(t string, cgo bool) string {
}
n++
}
s := cTypeToGo(t[:i+1], cgo)
s, std := cTypeToGo(t[:i+1], cgo)
if s != "" {
s = strings.Repeat("*", n) + s
}
return s
return s, std
}
if !isStdType(t) {
if cgo {
// Non-standard C types are aliased so C.<type> so they don't need to be converted.
return ""
}
return t
return t, false
}
if cgo {
if s, ok := cstdTypesToCgo[t]; ok {
t = s
}
return "C." + t
return "C." + t, true
}
if t, ok := cstdTypesToGo[t]; ok {
return t
return t, true
}
return t
return t, true
}

// paramToC returns C source code of parameter p.
Expand All @@ -382,7 +434,8 @@ func retIsVoid(r *mkcgo.Return) bool {
// fnToGoParams returns source code for function f parameters.
func fnToGoParams(fn *mkcgo.Func) string {
return join(fn.Params, func(_ int, p *mkcgo.Param) string {
return p.Name + " " + cTypeToGo(p.Type, false)
typ, _ := cTypeToGo(p.Type, false)
return p.Name + " " + typ
}, ", ")
}

Expand Down
Loading
Loading