Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/cgen.nim
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import
lowerings, semparallel

from modulegraphs import ModuleGraph
from dynlib import libCandidates

import strutils except `%` # collides with ropes.`%`

Expand Down
2 changes: 2 additions & 0 deletions compiler/commands.nim
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ proc testCompileOptionArg*(switch, arg: string, info: TLineInfo): bool =
of "staticlib": result = contains(gGlobalOptions, optGenStaticLib) and
not contains(gGlobalOptions, optGenGuiApp)
else: localError(info, errGuiConsoleOrLibExpectedButXFound, arg)
of "dynliboverride":
result = isDynlibOverride(arg)
else: invalidCmdLineOption(passCmd1, switch, info)

proc testCompileOption*(switch: string, info: TLineInfo): bool =
Expand Down
11 changes: 0 additions & 11 deletions compiler/options.nim
Original file line number Diff line number Diff line change
Expand Up @@ -372,17 +372,6 @@ proc findModule*(modulename, currentModule: string): string =
result = findFile(m)
patchModule()

proc libCandidates*(s: string, dest: var seq[string]) =
var le = strutils.find(s, '(')
var ri = strutils.find(s, ')', le+1)
if le >= 0 and ri > le:
var prefix = substr(s, 0, le - 1)
var suffix = substr(s, ri + 1)
for middle in split(substr(s, le + 1, ri - 1), '|'):
libCandidates(prefix & middle & suffix, dest)
else:
add(dest, s)

proc canonDynlibName(s: string): string =
let start = if s.startsWith("lib"): 3 else: 0
let ende = strutils.find(s, {'(', ')', '.'})
Expand Down
32 changes: 28 additions & 4 deletions lib/pure/dynlib.nim
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@
## libraries. On POSIX this uses the ``dlsym`` mechanism, on
## Windows ``LoadLibrary``.

import strutils

type
LibHandle* = pointer ## a handle to a dynamically loaded library

{.deprecated: [TLibHandle: LibHandle].}

proc loadLib*(path: string, global_symbols=false): LibHandle
proc loadLib*(path: string, global_symbols=false): LibHandle {.gcsafe.}
## loads a library from `path`. Returns nil if the library could not
## be loaded.

proc loadLib*(): LibHandle
proc loadLib*(): LibHandle {.gcsafe.}
## gets the handle from the current executable. Returns nil if the
## library could not be loaded.

proc unloadLib*(lib: LibHandle)
proc unloadLib*(lib: LibHandle) {.gcsafe.}
## unloads the library `lib`

proc raiseInvalidLibrary*(name: cstring) {.noinline, noreturn.} =
Expand All @@ -34,7 +36,7 @@ proc raiseInvalidLibrary*(name: cstring) {.noinline, noreturn.} =
e.msg = "could not find symbol: " & $name
raise e

proc symAddr*(lib: LibHandle, name: cstring): pointer
proc symAddr*(lib: LibHandle, name: cstring): pointer {.gcsafe.}
## retrieves the address of a procedure/variable from `lib`. Returns nil
## if the symbol could not be found.

Expand All @@ -44,6 +46,28 @@ proc checkedSymAddr*(lib: LibHandle, name: cstring): pointer =
result = symAddr(lib, name)
if result == nil: raiseInvalidLibrary(name)

proc libCandidates*(s: string, dest: var seq[string]) =
## given a library name pattern `s` write possible library names to `dest`.
var le = strutils.find(s, '(')
var ri = strutils.find(s, ')', le+1)
if le >= 0 and ri > le:
var prefix = substr(s, 0, le - 1)
var suffix = substr(s, ri + 1)
for middle in split(substr(s, le + 1, ri - 1), '|'):
libCandidates(prefix & middle & suffix, dest)
else:
add(dest, s)

proc loadLibPattern*(pattern: string, global_symbols=false): LibHandle =
## loads a library with name matching `pattern`, similar to what `dlimport`
## pragma does. Returns nil if the library could not be loaded.
## Warning: this proc uses the GC and so cannot be used to load the GC.
var candidates = newSeq[string]()
libCandidates(pattern, candidates)
for c in candidates:
result = loadLib(c, global_symbols)
if not result.isNil: break

when defined(posix):
#
# =========================================================================
Expand Down
33 changes: 10 additions & 23 deletions lib/pure/net.nim
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ when defineSsl:

SslContext* = ref object
context*: SslCtx
extraInternalIndex: int
referencedData: HashSet[int]
extraInternal: SslContextExtraInternal

SslAcceptResult* = enum
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
Expand All @@ -103,6 +103,10 @@ when defineSsl:

SslServerGetPskFunc* = proc(identity: string): string

SslContextExtraInternal = ref object of RootRef
serverGetPskFunc: SslServerGetPskFunc
clientGetPskFunc: SslClientGetPskFunc

{.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
TSSLAcceptResult: SSLAcceptResult].}
Expand Down Expand Up @@ -240,11 +244,6 @@ when defineSsl:
ErrLoadBioStrings()
OpenSSL_add_all_algorithms()

type
SslContextExtraInternal = ref object of RootRef
serverGetPskFunc: SslServerGetPskFunc
clientGetPskFunc: SslClientGetPskFunc

proc raiseSSLError*(s = "") =
## Raises a new SSL error.
if s != "":
Expand All @@ -257,12 +256,6 @@ when defineSsl:
var errStr = ErrErrorString(err, nil)
raise newException(SSLError, $errStr)

proc getExtraDataIndex*(ctx: SSLContext): int =
## Retrieves unique index for storing extra data in SSLContext.
result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int
if result < 0:
raiseSSLError()

proc getExtraData*(ctx: SSLContext, index: int): RootRef =
## Retrieves arbitrary data stored inside SSLContext.
if index notin ctx.referencedData:
Expand Down Expand Up @@ -347,15 +340,11 @@ when defineSsl:
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
newCTX.loadCertificates(certFile, keyFile)

result = SSLContext(context: newCTX, extraInternalIndex: 0,
referencedData: initSet[int]())
result.extraInternalIndex = getExtraDataIndex(result)

let extraInternal = new(SslContextExtraInternal)
result.setExtraData(result.extraInternalIndex, extraInternal)
result = SSLContext(context: newCTX, referencedData: initSet[int](),
extraInternal: new(SslContextExtraInternal))

proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex))
return ctx.extraInternal

proc destroyContext*(ctx: SSLContext) =
## Free memory referenced by SSLContext.
Expand All @@ -379,7 +368,7 @@ when defineSsl:

proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
max_psk_len: cuint): cuint {.cdecl.} =
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX)
let hintString = if hint == nil: nil else: $hint
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
if psk.len.cuint > max_psk_len:
Expand All @@ -398,16 +387,14 @@ when defineSsl:
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().clientGetPskFunc = fun
assert ctx.extraInternalIndex == 0,
"The pskClientCallback assumes the extraInternalIndex is 0"
ctx.context.SSL_CTX_set_psk_client_callback(
if fun == nil: nil else: pskClientCallback)

proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc =
return ctx.getExtraInternal().serverGetPskFunc

proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX)
let pskString = (ctx.serverGetPskFunc)($identity)
if psk.len.cint > max_psk_len:
return 0
Expand Down
81 changes: 68 additions & 13 deletions lib/wrappers/openssl.nim
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ else:
DLLUtilName = "libcrypto.so" & versions
from posix import SocketHandle

import dynlib

type
SslStruct {.final, pure.} = object
SslPtr* = ptr SslStruct
Expand Down Expand Up @@ -185,16 +187,74 @@ const
BIO_C_DO_STATE_MACHINE = 101
BIO_C_GET_SSL = 110

proc SSL_library_init*(): cInt{.cdecl, dynlib: DLLSSLName, importc, discardable.}
proc SSL_load_error_strings*(){.cdecl, dynlib: DLLSSLName, importc.}
proc ERR_load_BIO_strings*(){.cdecl, dynlib: DLLUtilName, importc.}

proc SSLv23_client_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv23_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv2_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv3_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
proc TLSv1_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}

when compileOption("dynlibOverride", "ssl"):
proc SSL_library_init*(): cint {.cdecl, dynlib: DLLSSLName, importc, discardable.}
proc SSL_load_error_strings*() {.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv23_client_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}

proc SSLv23_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv2_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv3_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}

template OpenSSL_add_all_algorithms*() = discard
else:
# Here we're trying to stay compatible with openssl 1.0.* and 1.1.*. Some
# symbols are loaded dynamically and we don't use them if not found.
proc thisModule(): LibHandle {.inline.} =
var thisMod {.global.}: LibHandle
if thisMod.isNil: thisMod = loadLib()
result = thisMod

proc sslModule(): LibHandle {.inline.} =
var sslMod {.global.}: LibHandle
if sslMod.isNil: sslMod = loadLibPattern(DLLSSLName)
result = sslMod

proc sslSym(name: string): pointer =
var dl = thisModule()
if not dl.isNil:
result = symAddr(dl, name)
if result.isNil:
dl = sslModule()
if not dl.isNil:
result = symAddr(dl, name)

proc SSL_library_init*(): cint {.discardable.} =
let theProc = cast[proc(): cint {.cdecl.}](sslSym("SSL_library_init"))
if not theProc.isNil: result = theProc()

proc SSL_load_error_strings*() =
let theProc = cast[proc() {.cdecl.}](sslSym("SSL_load_error_strings"))
if not theProc.isNil: theProc()

proc SSLv23_client_method*(): PSSL_METHOD =
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv23_client_method"))
if not theProc.isNil: result = theProc()
else: result = TLSv1_method()

proc SSLv23_method*(): PSSL_METHOD =
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv23_method"))
if not theProc.isNil: result = theProc()
else: result = TLSv1_method()

proc SSLv2_method*(): PSSL_METHOD =
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv2_method"))
if not theProc.isNil: result = theProc()
else: result = TLSv1_method()

proc SSLv3_method*(): PSSL_METHOD =
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv3_method"))
if not theProc.isNil: result = theProc()
else: result = TLSv1_method()

proc OpenSSL_add_all_algorithms*() =
let theProc = cast[proc() {.cdecl.}](sslSym("OPENSSL_add_all_algorithms_conf"))
if not theProc.isNil: theProc()

proc ERR_load_BIO_strings*(){.cdecl, dynlib: DLLUtilName, importc.}

proc SSL_new*(context: SslCtx): SslPtr{.cdecl, dynlib: DLLSSLName, importc.}
proc SSL_free*(ssl: SslPtr){.cdecl, dynlib: DLLSSLName, importc.}
proc SSL_get_SSL_CTX*(ssl: SslPtr): SslCtx {.cdecl, dynlib: DLLSSLName, importc.}
Expand Down Expand Up @@ -261,11 +321,6 @@ proc ERR_error_string*(e: cInt, buf: cstring): cstring{.cdecl,
proc ERR_get_error*(): cInt{.cdecl, dynlib: DLLUtilName, importc.}
proc ERR_peek_last_error*(): cInt{.cdecl, dynlib: DLLUtilName, importc.}

when defined(android):
template OpenSSL_add_all_algorithms*() = discard
else:
proc OpenSSL_add_all_algorithms*(){.cdecl, dynlib: DLLUtilName, importc: "OPENSSL_add_all_algorithms_conf".}

proc OPENSSL_config*(configName: cstring){.cdecl, dynlib: DLLSSLName, importc.}

when not useWinVersion and not defined(macosx) and not defined(android):
Expand Down