13
13
from build_tools import setup_helpers
14
14
from setuptools import setup , find_packages
15
15
16
+ import glob
17
+ from torch .utils .cpp_extension import (
18
+ CppExtension ,
19
+ BuildExtension ,
20
+ )
21
+
22
+
16
23
17
24
def _get_pytorch_version ():
18
25
if "PYTORCH_VERSION" in os .environ :
@@ -60,6 +67,50 @@ def _run_cmd(cmd):
60
67
return None
61
68
62
69
70
+ def get_extensions ():
71
+ extension = CppExtension
72
+
73
+ extra_link_args = []
74
+ extra_compile_args = {"cxx" : [
75
+ "-O3" ,
76
+ "-std=c++14" ,
77
+ "-fdiagnostics-color=always" ,
78
+ ]}
79
+ debug_mode = os .getenv ('DEBUG' , '0' ) == '1'
80
+ if debug_mode :
81
+ print ("Compiling in debug mode" )
82
+ extra_compile_args = {
83
+ "cxx" : [
84
+ "-O0" ,
85
+ "-fno-inline" ,
86
+ "-g" ,
87
+ "-std=c++14" ,
88
+ "-fdiagnostics-color=always" ,
89
+ ]}
90
+ extra_link_args = ["-O0" , "-g" ]
91
+
92
+ this_dir = os .path .dirname (os .path .abspath (__file__ ))
93
+ extensions_dir = os .path .join (this_dir , "torchrl" , "csrc" )
94
+
95
+ extension_sources = set (
96
+ os .path .join (extensions_dir , p )
97
+ for p in glob .glob (os .path .join (extensions_dir , "*.cpp" ))
98
+ )
99
+ sources = list (extension_sources )
100
+
101
+ ext_modules = [
102
+ extension (
103
+ "torchrl._torchrl" ,
104
+ sources ,
105
+ include_dirs = [this_dir ],
106
+ extra_compile_args = extra_compile_args ,
107
+ extra_link_args = extra_link_args ,
108
+ )
109
+ ]
110
+
111
+ return ext_modules
112
+
113
+
63
114
def _main ():
64
115
pytorch_package_dep = _get_pytorch_version ()
65
116
print ("-- PyTorch dependency:" , pytorch_package_dep )
@@ -71,10 +122,10 @@ def _main():
71
122
version = "0.1" ,
72
123
author = "torchrl contributors" ,
73
124
author_email = "vmoens@fb.com" ,
74
- packages = _get_packages (),
75
- ext_modules = setup_helpers . get_ext_modules (),
125
+ packages = find_packages (),
126
+ ext_modules = get_extensions (),
76
127
cmdclass = {
77
- "build_ext" : setup_helpers . CMakeBuild ,
128
+ "build_ext" : BuildExtension . with_options ( no_python_abi_suffix = True ) ,
78
129
"clean" : clean ,
79
130
},
80
131
install_requires = [pytorch_package_dep , "numpy" , "tensorboard" , "packaging" ],
0 commit comments