|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
| 17 | +import glob |
17 | 18 | import os
|
18 | 19 | import os.path
|
| 20 | +import shutil |
19 | 21 | import sys
|
20 | 22 |
|
21 | 23 | import numpy as np
|
22 | 24 | import setuptools.command.build_ext as orig_build_ext
|
| 25 | +import setuptools.command.build_py as orig_build_py |
23 | 26 | import setuptools.command.develop as orig_develop
|
24 | 27 | import setuptools.command.install as orig_install
|
25 | 28 | from Cython.Build import cythonize
|
@@ -251,6 +254,34 @@ def run(self):
|
251 | 254 | return super().run()
|
252 | 255 |
|
253 | 256 |
|
| 257 | +class build_py(orig_build_py.build_py): |
| 258 | + def run(self): |
| 259 | + dpctl_src_dir = self.get_package_dir("dpctl") |
| 260 | + dpctl_build_dir = os.path.join(self.build_lib, "dpctl") |
| 261 | + os.makedirs(dpctl_build_dir, exist_ok=True) |
| 262 | + if IS_LIN: |
| 263 | + for fn in glob.glob(os.path.join(dpctl_src_dir, "*.so*")): |
| 264 | + # Check if the file already exists before copying. The check is |
| 265 | + # needed when dealing with symlinks. |
| 266 | + if not os.path.exists( |
| 267 | + os.path.join(dpctl_build_dir, os.path.basename(fn)) |
| 268 | + ): |
| 269 | + shutil.copy( |
| 270 | + src=fn, |
| 271 | + dst=dpctl_build_dir, |
| 272 | + follow_symlinks=False, |
| 273 | + ) |
| 274 | + elif IS_WIN: |
| 275 | + for fn in glob.glob(os.path.join(dpctl_src_dir, "*.lib")): |
| 276 | + shutil.copy(src=fn, dst=dpctl_build_dir) |
| 277 | + |
| 278 | + for fn in glob.glob(os.path.join(dpctl_src_dir, "*.dll")): |
| 279 | + shutil.copy(src=fn, dst=dpctl_build_dir) |
| 280 | + else: |
| 281 | + raise NotImplementedError("Unsupported platform") |
| 282 | + return super().run() |
| 283 | + |
| 284 | + |
254 | 285 | class install(orig_install.install):
|
255 | 286 | description = "Installs dpctl into Python prefix"
|
256 | 287 | user_options = orig_install.install.user_options + [
|
@@ -308,7 +339,22 @@ def run(self):
|
308 | 339 | else:
|
309 | 340 | self.define = ",".join((pre_d, "CYTHON_TRACE"))
|
310 | 341 | cythonize(self.distribution.ext_modules)
|
311 |
| - return super().run() |
| 342 | + ret = super().run() |
| 343 | + if IS_LIN: |
| 344 | + dpctl_build_dir = os.path.join( |
| 345 | + os.path.dirname(__file__), self.build_lib, "dpctl" |
| 346 | + ) |
| 347 | + dpctl_install_dir = os.path.join(self.install_libbase, "dpctl") |
| 348 | + for fn in glob.glob( |
| 349 | + os.path.join(dpctl_install_dir, "*DPCTLSyclInterface.so*") |
| 350 | + ): |
| 351 | + os.remove(fn) |
| 352 | + shutil.copy( |
| 353 | + src=os.path.join(dpctl_build_dir, os.path.basename(fn)), |
| 354 | + dst=dpctl_install_dir, |
| 355 | + follow_symlinks=False, |
| 356 | + ) |
| 357 | + return ret |
312 | 358 |
|
313 | 359 |
|
314 | 360 | class develop(orig_develop.develop):
|
@@ -393,6 +439,7 @@ def _get_cmdclass():
|
393 | 439 | cmdclass["install"] = install
|
394 | 440 | cmdclass["develop"] = develop
|
395 | 441 | cmdclass["build_ext"] = build_ext
|
| 442 | + cmdclass["build_py"] = build_py |
396 | 443 | return cmdclass
|
397 | 444 |
|
398 | 445 |
|
|
0 commit comments