Skip to content

Commit 61a00e8

Browse files
authored
Drop cmake build system, use PyTorch C++ extensions (#239)
1 parent 5833b2f commit 61a00e8

File tree

3 files changed

+54
-174
lines changed

3 files changed

+54
-174
lines changed

CMakeLists.txt

Lines changed: 0 additions & 109 deletions
This file was deleted.

setup.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
from build_tools import setup_helpers
1414
from setuptools import setup, find_packages
1515

16+
import glob
17+
from torch.utils.cpp_extension import (
18+
CppExtension,
19+
BuildExtension,
20+
)
21+
22+
1623

1724
def _get_pytorch_version():
1825
if "PYTORCH_VERSION" in os.environ:
@@ -60,6 +67,50 @@ def _run_cmd(cmd):
6067
return None
6168

6269

70+
def get_extensions():
71+
extension = CppExtension
72+
73+
extra_link_args = []
74+
extra_compile_args = {"cxx": [
75+
"-O3",
76+
"-std=c++14",
77+
"-fdiagnostics-color=always",
78+
]}
79+
debug_mode = os.getenv('DEBUG', '0') == '1'
80+
if debug_mode:
81+
print("Compiling in debug mode")
82+
extra_compile_args = {
83+
"cxx": [
84+
"-O0",
85+
"-fno-inline",
86+
"-g",
87+
"-std=c++14",
88+
"-fdiagnostics-color=always",
89+
]}
90+
extra_link_args = ["-O0", "-g"]
91+
92+
this_dir = os.path.dirname(os.path.abspath(__file__))
93+
extensions_dir = os.path.join(this_dir, "torchrl", "csrc")
94+
95+
extension_sources = set(
96+
os.path.join(extensions_dir, p)
97+
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
98+
)
99+
sources = list(extension_sources)
100+
101+
ext_modules = [
102+
extension(
103+
"torchrl._torchrl",
104+
sources,
105+
include_dirs=[this_dir],
106+
extra_compile_args=extra_compile_args,
107+
extra_link_args=extra_link_args,
108+
)
109+
]
110+
111+
return ext_modules
112+
113+
63114
def _main():
64115
pytorch_package_dep = _get_pytorch_version()
65116
print("-- PyTorch dependency:", pytorch_package_dep)
@@ -71,10 +122,10 @@ def _main():
71122
version="0.1",
72123
author="torchrl contributors",
73124
author_email="vmoens@fb.com",
74-
packages=_get_packages(),
75-
ext_modules=setup_helpers.get_ext_modules(),
125+
packages=find_packages(),
126+
ext_modules=get_extensions(),
76127
cmdclass={
77-
"build_ext": setup_helpers.CMakeBuild,
128+
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
78129
"clean": clean,
79130
},
80131
install_requires=[pytorch_package_dep, "numpy", "tensorboard", "packaging"],

torchrl/csrc/CMakeLists.txt

Lines changed: 0 additions & 62 deletions
This file was deleted.

0 commit comments

Comments
 (0)