diff --git a/test/inductor/test_cpp_wrapper.py b/test/inductor/test_cpp_wrapper.py index e2a80276eeeb13..5c08108f7118d2 100644 --- a/test/inductor/test_cpp_wrapper.py +++ b/test/inductor/test_cpp_wrapper.py @@ -192,6 +192,7 @@ class BaseTest(NamedTuple): ), BaseTest("test_mm_views"), BaseTest("test_multihead_attention", "cpu", test_cpu_repro.CPUReproTests()), + BaseTest("test_multi_threading"), BaseTest("test_profiler_mark_wrapper_call"), BaseTest("test_randint"), BaseTest("test_randn_with_dtype_and_device"), @@ -267,6 +268,7 @@ class BaseTest(NamedTuple): BaseTest("test_linear2"), BaseTest("test_mm_views"), BaseTest("test_multi_device"), + BaseTest("test_multi_threading"), BaseTest("test_profiler_mark_wrapper_call"), BaseTest("test_reduction1"), # Reduction BaseTest("test_relu"), # multiple inputs diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index d64f5c746511ea..b3dae260c06a3e 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -12,6 +12,7 @@ import re import subprocess import sys +import threading import time import typing import unittest @@ -2600,6 +2601,29 @@ def fn(x): ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) + def test_multi_threading(self): + model = torch.nn.Linear(2, 3).eval() + inp = torch.randn(4, 2) + + num_run = 3 + + def run_weights_sharing_model(m, inp): + with torch.no_grad(): + for i in range(num_run): + y = m(inp) + + numb_instance = 2 + threads = [] + compiled_m = torch.compile(model) + for i in range(1, numb_instance + 1): + thread = threading.Thread( + target=run_weights_sharing_model, args=(compiled_m, inp) + ) + threads.append(thread) + thread.start() + for thread in threads: + thread.join() + @unittest.skipIf(config.is_fbcode(), "fbcode triton error, needs debugging") def test_adaptive_avg_pool2d_low_prec(self): class Model(torch.nn.Module): diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 8b6f4bea17e7d4..c6a9148f89eceb 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1311,6 +1311,12 @@ def write_wrapper_decl(self): auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, num_inputs()); """ ) + else: + self.prefix.splice( + """ + py::gil_scoped_release release; + """ + ) if inputs_len != 0: for idx, input_key in enumerate(V.graph.graph_inputs.keys()):