Skip to content

Commit

Permalink
unset torch arch list for JIT mode (#1765)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Feb 11, 2022
1 parent 4f96ffd commit 674c758
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,16 @@ def jit_load(self, verbose=True):
extra_include_paths = [
self.deepspeed_src_path(path) for path in self.include_paths()
]

# Torch will try and apply whatever CCs are in the arch list at compile time,
# we have already set the intended targets ourselves we know that will be
# needed at runtime. This prevents CC collisions such as multiple __half
# implementations. Stash arch list to reset after build.
torch_arch_list = None
if "TORCH_CUDA_ARCH_LIST" in os.environ:
torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
os.environ["TORCH_CUDA_ARCH_LIST"] = ""

op_module = load(
name=self.name,
sources=self.strip_empty_entries(sources),
Expand All @@ -443,6 +453,11 @@ def jit_load(self, verbose=True):
build_duration = time.time() - start_build
if verbose:
print(f"Time to load {self.name} op: {build_duration} seconds")

# Reset arch list so we are not silently removing it for other possible use cases
if torch_arch_list:
os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list

return op_module


Expand Down

0 comments on commit 674c758

Please sign in to comment.