Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

name: ci
on:
- push
Expand All @@ -11,7 +10,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.4'
- '1.6'
- '1'
- 'nightly'
os:
- ubuntu-latest
Expand All @@ -37,4 +37,4 @@ jobs:
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
- uses: julia-actions/julia-runtest@latest
13 changes: 0 additions & 13 deletions .travis.yml

This file was deleted.

3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"

[compat]
julia = "1.3"
julia = "1.6"
MKL_jll = "2022.2.0"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
6 changes: 6 additions & 0 deletions gen/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[deps]
Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"

[compat]
julia = "1.6"
16 changes: 16 additions & 0 deletions gen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Wrapping headers

This directory contains scripts that can be used to automatically generate wrappers for C headers by Intel MKL libraries.
This is done using Clang.jl.

# Usage

Either run `julia wrapper.jl` directly, or include it and call the `main()` function.
Be sure to activate the project environment in this folder, which will install `Clang.jl` and `JuliaFormatter.jl`.
The `main` function supports the boolean keyword argument `optimized` to clear the generated wrappers.

# Remark

You should always review any changes to the headers!
Specifically, verify that pointer arguments are of the correct type, and if they aren't, modify the `rewriter.jl` file and regenerate the wrappers.
The `Ref` type should be considered as an alternative to plain `Ptr` if the pointer represents a scalar or single-value argument.
9 changes: 9 additions & 0 deletions gen/mkl.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[general]
use_julia_native_enum_type = false
print_using_CEnum = false
print_enum_as_integer = true

[codegen]
use_julia_bool = true
use_ccall_macro = true
always_NUL_terminated_string = true
36 changes: 36 additions & 0 deletions gen/rewriter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
type_modifications = Dict("Cint" => "BlasInt",
"Cfloat" => "Float32",
"Cdouble" => "Float64",
"MKL_Complex8" => "ComplexF32",
"MKL_Complex16" => "ComplexF64")

cstring_modifications = Dict("transa::Cstring" => "transa::Ref{UInt8}",
"uplo::Cstring" => "uplo::Ref{UInt8}",
"diag::Cstring" => "diag::Ref{UInt8}",
"matdescra::Cstring" => "matdescra::Ptr{UInt8}")

function rewrite!(path::String)
text = read(path, String)
for (keys, vals) in type_modifications
text = replace(text, keys => vals)
end
for (keys, vals) in cstring_modifications
text = replace(text, keys => vals)
end
# Note: `job` and `idiag` are vectors in some cases, we must be careful with these two arguments.
for argument in ("job", "m", "n", "k", "job", "nnz", "nnzmax", "lval",
"lb", "mblk", "idiag", "ldabsr", "ndiag", "ldAbsr", "sort",
"alpha", "beta", "lda", "ldb", "ldc", "ierr", "info")
for T in ("BlasInt", "Float32", "Float64", "ComplexF32", "ComplexF64")
text = replace(text, "$argument::Ptr{$T}" => "$argument::Ref{$T}")
end
end
# Remove comments in libmklsparse.jl
text = replace(text, "# typedef void ( * sgemm_jit_kernel_t ) ( void * , float * , float * , float * )\n" => "")
text = replace(text, "# typedef void ( * dgemm_jit_kernel_t ) ( void * , double * , double * , double * )\n" => "")
text = replace(text, "# typedef void ( * cgemm_jit_kernel_t ) ( void * , ComplexF32 * , ComplexF32 * , ComplexF32 * )\n" => "")
text = replace(text, "# typedef void ( * zgemm_jit_kernel_t ) ( void * , ComplexF64 * , ComplexF64 * , ComplexF64 * )\n" => "")
text = replace(text, "# Skipping MacroDefinition: MKL_LONG long int\n" => "")
text = replace(text, "# Skipping MacroDefinition: MKL_DEPRECATED __attribute__ ( ( deprecated ) )\n" => "")
write(path, text)
end
43 changes: 43 additions & 0 deletions gen/wrapper.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Script to parse MKL headers and generate Julia wrappers.
using Clang
using Clang.Generators
using JuliaFormatter

include("rewriter.jl")

function wrapper(name::String, headers::Vector{String}, optimized::Bool=false)

@info "Wrapping $name"
cd(@__DIR__)
include_dir = joinpath(@__DIR__, "mkl-include-2022.2.0-intel_8748", "include")

options = load_options(joinpath(@__DIR__, "mkl.toml"))
options["general"]["library_name"] = "libmkl_rt"
options["general"]["output_file_path"] = joinpath("..", "src", "$(name).jl")
optimized && (options["general"]["output_ignorelist"] = ["MKL_Complex8",
"MKL_Complex16",
"MKLVersion"])

args = get_default_args()
push!(args, "-I$include_dir")

ctx = create_context(headers, args, options)
build!(ctx)

path = options["general"]["output_file_path"]

format_file(path, YASStyle())
optimized && rewrite!(path)
return nothing
end

function main(; optimized::Bool=false)
# TODO: Add mkl_spblas.h in the artifact MKL_Headers_jll
mkl = joinpath(@__DIR__, "mkl-include-2022.2.0-intel_8748", "include")
wrapper("libmklsparse", ["$mkl/mkl_spblas.h"], optimized)
end

# If we want to use the file as a script with `julia wrapper.jl`
if abspath(PROGRAM_FILE) == @__FILE__
main()
end
18 changes: 0 additions & 18 deletions src/BLAS/BLAS.jl

This file was deleted.

110 changes: 0 additions & 110 deletions src/BLAS/level_2_3/generator.jl

This file was deleted.

17 changes: 15 additions & 2 deletions src/MKLSparse.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
module MKLSparse

using MKL_jll
using LinearAlgebra, SparseArrays
using LinearAlgebra: BlasInt, BlasFloat, checksquare
using MKL_jll: libmkl_rt

# For testing purposes:
global const __counter = Ref(0)

function __init__()
ccall((:MKL_Set_Interface_Layer, libmkl_rt), Cint, (Cint,), Base.USE_BLAS64 ? 1 : 0)
end

include(joinpath("BLAS", "BLAS.jl"))
# Wrappers generated by Clang.jl
include("libmklsparse.jl")

# TODO: BLAS1

# BLAS2 and BLAS3
include("matdescra.jl")
include("generator.jl")
include("matmul.jl")

end # module
Loading