This repository has been archived by the owner on Mar 12, 2021. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 83
/
CUBLAS.jl
99 lines (79 loc) · 2.59 KB
/
CUBLAS.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
module CUBLAS
using CUDAapi
using CUDAdrv
using CUDAdrv: CUstream
using CUDAnative
using ..CuArrays
using ..CuArrays: libcublas, unsafe_free!, @retry_reclaim
using LinearAlgebra
using CEnum
# core library
include("libcublas_common.jl")
include("error.jl")
include("libcublas.jl")
# low-level wrappers
include("util.jl")
include("wrappers.jl")
# high-level integrations
include("linalg.jl")
# thread cache for task-local library handles
const thread_handles = Vector{Union{Nothing,cublasHandle_t}}()
const thread_xt_handles = Vector{Union{Nothing,cublasXtHandle_t}}()
function handle()
tid = Threads.threadid()
if @inbounds thread_handles[tid] === nothing
ctx = context()
thread_handles[tid] = get!(task_local_storage(), (:CUBLAS, ctx)) do
handle = cublasCreate_v2()
finalizer(current_task()) do task
CUDAdrv.isvalid(ctx) || return
context!(ctx) do
cublasDestroy_v2(handle)
end
end
# enable tensor math mode if our device supports it, and fast math is enabled
dev = CUDAdrv.device()
if Base.JLOptions().fast_math == 1 && CUDAdrv.capability(dev) >= v"7.0" && version() >= v"9"
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)
end
handle
end
end
@inbounds thread_handles[tid]
end
function xt_handle()
tid = Threads.threadid()
if @inbounds thread_xt_handles[tid] === nothing
ctx = context()
thread_xt_handles[tid] = get!(task_local_storage(), (:CUBLASxt, ctx)) do
handle = cublasXtCreate()
finalizer(current_task()) do task
CUDAdrv.isvalid(ctx) || return
context!(ctx) do
cublasXtDestroy(handle)
end
end
# select the devices
# TODO: this is weird, since we typically use a single device per thread/context
devs = convert.(Cint, CUDAdrv.devices())
cublasXtDeviceSelect(handle, length(devs), devs)
handle
end
end
@inbounds thread_xt_handles[tid]
end
function __init__()
resize!(thread_handles, Threads.nthreads())
fill!(thread_handles, nothing)
resize!(thread_xt_handles, Threads.nthreads())
fill!(thread_xt_handles, nothing)
CUDAnative.atcontextswitch() do tid, ctx
thread_handles[tid] = nothing
thread_xt_handles[tid] = nothing
end
CUDAnative.attaskswitch() do tid, task
thread_handles[tid] = nothing
thread_xt_handles[tid] = nothing
end
end
end