Skip to content

Commit 7e2eaba

Browse files
Support --tensor-includes for dpctl module, add Dpctl_TENSOR_INCLUDE_DIR in cmake integration
1 parent e80d3a1 commit 7e2eaba

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

cmake/FindDpctl.cmake

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ if(NOT Dpctl_FOUND)
4040
OUTPUT_STRIP_TRAILING_WHITESPACE
4141
ERROR_QUIET
4242
)
43-
4443
endif()
4544
endif()
4645

@@ -49,6 +48,12 @@ find_path(Dpctl_INCLUDE_DIR
4948
PATHS "${_dpctl_include_dir}" "${PYTHON_INCLUDE_DIR}"
5049
PATH_SUFFIXES dpctl/include
5150
)
51+
get_filename_component(_dpctl_dir ${_dpctl_include_dir} DIRECTORY)
52+
53+
find_path(Dpctl_TENSOR_INCLUDE_DIR
54+
kernels utils
55+
PATHS "${_dpctl_dir}/tensor/libtensor/include"
56+
)
5257

5358
set(Dpctl_INCLUDE_DIRS ${Dpctl_INCLUDE_DIR})
5459

@@ -57,8 +62,9 @@ set(Dpctl_INCLUDE_DIRS ${Dpctl_INCLUDE_DIR})
5762
include(FindPackageHandleStandardArgs)
5863
find_package_handle_standard_args(Dpctl
5964
REQUIRED_VARS
60-
Dpctl_INCLUDE_DIR
65+
Dpctl_INCLUDE_DIR Dpctl_TENSOR_INCLUDE_DIR
6166
VERSION_VAR Dpctl_VERSION
6267
)
6368

6469
mark_as_advanced(Dpctl_INCLUDE_DIR)
70+
mark_as_advanced(Dpctl_TENSOR_INCLUDE_DIR)

dpctl/__main__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ def print_includes() -> None:
3535
print("-I " + dpctl.get_include())
3636

3737

38+
def print_tensor_includes() -> None:
39+
"Prints include flags for dpctl and SyclInterface library"
40+
dpctl_dir = _dpctl_dir()
41+
libtensor_dir = os.path.join(dpctl_dir, "tensor", "libtensor", "include")
42+
print("-I " + libtensor_dir)
43+
44+
3845
def print_cmake_dir() -> None:
3946
"Prints directory with FindDpctl.cmake"
4047
dpctl_dir = _dpctl_dir()
@@ -75,7 +82,12 @@ def main() -> None:
7582
parser.add_argument(
7683
"--includes",
7784
action="store_true",
78-
help="Include flags dpctl headers.",
85+
help="Include flags for dpctl headers.",
86+
)
87+
parser.add_argument(
88+
"--tensor-includes",
89+
action="store_true",
90+
help="Include flags for dpctl libtensor headers.",
7991
)
8092
parser.add_argument(
8193
"--cmakedir",
@@ -128,6 +140,8 @@ def main() -> None:
128140
return
129141
if args.includes:
130142
print_includes()
143+
if args.tensor_includes:
144+
print_tensor_includes()
131145
if args.cmakedir:
132146
print_cmake_dir()
133147
if args.library:

0 commit comments

Comments
 (0)