Skip to content

Commit

Permalink
inductor cpp wrapper: add GIL release and acquire (pytorch#111888)
Browse files Browse the repository at this point in the history
Support multiple instances inference (in different threads of the same process) as in pytorch#93524 (comment).

Pull Request resolved: pytorch#111888
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
  • Loading branch information
chunyuan-w authored and pytorchmergebot committed Oct 31, 2023
1 parent bb97ce4 commit f50ec34
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
2 changes: 2 additions & 0 deletions test/inductor/test_cpp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class BaseTest(NamedTuple):
),
BaseTest("test_mm_views"),
BaseTest("test_multihead_attention", "cpu", test_cpu_repro.CPUReproTests()),
BaseTest("test_multi_threading"),
BaseTest("test_profiler_mark_wrapper_call"),
BaseTest("test_randint"),
BaseTest("test_randn_with_dtype_and_device"),
Expand Down Expand Up @@ -267,6 +268,7 @@ class BaseTest(NamedTuple):
BaseTest("test_linear2"),
BaseTest("test_mm_views"),
BaseTest("test_multi_device"),
BaseTest("test_multi_threading"),
BaseTest("test_profiler_mark_wrapper_call"),
BaseTest("test_reduction1"), # Reduction
BaseTest("test_relu"), # multiple inputs
Expand Down
24 changes: 24 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import re
import subprocess
import sys
import threading
import time
import typing
import unittest
Expand Down Expand Up @@ -2600,6 +2601,29 @@ def fn(x):
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)

def test_multi_threading(self):
model = torch.nn.Linear(2, 3).eval()
inp = torch.randn(4, 2)

num_run = 3

def run_weights_sharing_model(m, inp):
with torch.no_grad():
for i in range(num_run):
y = m(inp)

numb_instance = 2
threads = []
compiled_m = torch.compile(model)
for i in range(1, numb_instance + 1):
thread = threading.Thread(
target=run_weights_sharing_model, args=(compiled_m, inp)
)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()

@unittest.skipIf(config.is_fbcode(), "fbcode triton error, needs debugging")
def test_adaptive_avg_pool2d_low_prec(self):
class Model(torch.nn.Module):
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,12 @@ def write_wrapper_decl(self):
auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, num_inputs());
"""
)
else:
self.prefix.splice(
"""
py::gil_scoped_release release;
"""
)

if inputs_len != 0:
for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
Expand Down

0 comments on commit f50ec34

Please sign in to comment.