Skip to content

Commit

Permalink
Revert "[inductor] Parallelize Max Autotune step 1: Use Popen (pytorc…
Browse files Browse the repository at this point in the history
…h#107982)"

This reverts commit d685668.

Reverted pytorch#107982 on behalf of https://github.com/masnesral due to fbcode failures ([comment](pytorch#107982 (comment)))
  • Loading branch information
pytorchmergebot committed Sep 12, 2023
1 parent c36c2bf commit 2039f30
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 131 deletions.
219 changes: 93 additions & 126 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import dataclasses
import logging
import pickle
import subprocess
import sys
import queue
import time
import warnings

from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

import torch
from torch import multiprocessing
from torch._dynamo.testing import rand_strided

from torch._inductor import ir
Expand All @@ -20,11 +19,9 @@
from .utils import do_bench
from .virtualized import V


DEBUG = False
EXIT_HANDLER_REGISTERED = False

log = logging.getLogger(__name__)


# Used to synchronize between parent and child processes
class Ping:
Expand All @@ -37,62 +34,58 @@ class Pong:

@dataclasses.dataclass
class TuningProcess:
"""
Abstraction for launching a helper process to benchmark kernels. Rather
than spawning the parent process, the approach Popens a new process with
an entry point that we control. Avoiding the spawn means we do not re-enter
the toplevel script. The subprocess communicates with the parent process
via pickling requests/responses over stdin/stdout pipes.
"""

process: Optional["subprocess.Popen[bytes]"] = None
process: Optional[BaseProcess] = None
request_queue: Optional["Queue[Any]"] = None
response_queue: Optional["Queue[Any]"] = None

@staticmethod
def process_main() -> None:
"""
Entry point for the child process.
"""
log.debug("Entering TuningProcess child main")
try:
TuningProcess.workloop()
except Exception:
log.exception("Exception in TuningProcess")

@staticmethod
def workloop() -> None:
"""
Work loop for the benchmarking subprocess.
"""

def reply(obj):
# Note this is subtly different than the put() method below.
pickle.dump(obj, sys.stdout.buffer)
sys.stdout.flush()

def process_main(
request_queue: "Queue[Any]",
response_queue: "Queue[Any]",
) -> None:
print("enter child process main")
while True:
obj = pickle.load(sys.stdin.buffer)
obj = request_queue.get()

if obj is None:
# None is a sentinel for the child to terminate
break
break # None is a sentinel for the child to terminate
elif isinstance(obj, Ping):
reply(Pong())
response_queue.put(Pong())
elif isinstance(obj, BenchmarkRequest):
reply(obj.benchmark())
response_queue.put(obj.benchmark())
else:
raise RuntimeError(f"Invalid request type {type(obj)}")

def valid(self) -> bool:
return (
self.process is not None
and self.request_queue is not None
and self.response_queue is not None
)

def clear(self) -> None:
self.process = self.request_queue = self.response_queue = None

def initialize(self) -> None:
"""
Create child process and do the warm up.
Create child process, request/response queues and do the warm up.
"""
if self.process is not None:
if self.valid():
return

self.process = subprocess.Popen(
[sys.executable, "-m", "torch._inductor.autotune_process_entry"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
# cuda runtime does not work with "fork", use "spawn" to start processes.
ctx = multiprocessing.get_context("spawn")
request_queue = self.request_queue = ctx.Queue()
response_queue = self.response_queue = ctx.Queue()

process = self.process = ctx.Process(
target=self.process_main,
args=(
self.request_queue,
self.response_queue,
),
)
process.start()

# register the exit handler for the parent process so it will terminate
# the child processes
Expand All @@ -104,58 +97,18 @@ def initialize(self) -> None:
atexit.register(lambda: self.terminate())

# wait for the initialization to be done
self.put(Ping())
resp = self.get()
request_queue.put(Ping())
resp = response_queue.get()
assert isinstance(resp, Pong)

def put(self, obj: Any) -> None:
"""
Push a work item to the child process.
"""
# In case of a prior crash, ensure the subprocess is running
self.initialize()
assert self.process is not None
assert self.process.stdin is not None
pickle.dump(obj, self.process.stdin)
self.process.stdin.flush()

def get(self) -> Any:
"""
Get a response from the child process.
"""
assert self.process is not None
assert self.process.stdout is not None
try:
return pickle.load(self.process.stdout)
except EOFError:
# Child crashed; clean up
self.close()
raise
except pickle.UnpicklingError as ex:
raise RuntimeError(
"Error deserializing response from the benchmarking subprocess. "
"Is the benchmark code path writing to stdout?"
) from ex

def close(self) -> None:
"""
Close the communication pipes from the child process.
"""
if self.process is not None:
assert self.process.stdin is not None
assert self.process.stdout is not None
self.process.stdin.close()
self.process.stdout.close()
self.process = None

def terminate(self) -> None:
"""
Signal the child process to terminate and wait for it to exit.
"""
if self.process is not None:
self.put(None)
self.process.wait()
self.close()
if self.valid():
request_queue = self.request_queue
assert request_queue is not None
request_queue.put(None)
process = self.process
assert process is not None
process.join()


tuning_process = TuningProcess()
Expand Down Expand Up @@ -227,20 +180,18 @@ class BenchmarkRequest:
def benchmark(
self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None
) -> float:
debug = log.isEnabledFor(logging.DEBUG)
if debug:
if DEBUG:
start_ts = time.time()

mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
log.debug(
"benchmark module key: %s, path: %s",
self.module_cache_key,
self.module_path,
)
if DEBUG:
print(
f"benchmark module key: {self.module_cache_key}, path: {self.module_path}"
)

run = getattr(mod, self.kernel_name).run

if debug:
if DEBUG:
load_elapse = time.time() - start_ts
start_ts = time.time()

Expand All @@ -254,7 +205,7 @@ def benchmark(
assert isinstance(self.output_tensor, TensorMeta)
output_tensor = self.output_tensor.to_tensor()

if debug:
if DEBUG:
create_tensor_elapse = time.time() - start_ts
start_ts = time.time()

Expand All @@ -271,16 +222,12 @@ def worker() -> float:
out = do_bench(worker)
torch.cuda.synchronize() # shake out any CUDA errors

if debug:
if DEBUG:
bench_elapse = time.time() - start_ts
log.debug(
"InChildProcess %s: load %f, create tensor %f, bench %f",
self.module_cache_key,
load_elapse,
create_tensor_elapse,
bench_elapse,
print(
f"InChidProcess {self.module_cache_key}: load {load_elapse}, "
+ f"create tensor {create_tensor_elapse}, bench {bench_elapse}"
)

return out


Expand All @@ -292,15 +239,35 @@ def benchmark_in_sub_process(
"""
assert choice.bmreq is not None
tuning_process.initialize()
assert tuning_process.valid()
process, request_queue, response_queue = (
tuning_process.process,
tuning_process.request_queue,
tuning_process.response_queue,
)
assert (
process is not None and request_queue is not None and response_queue is not None
)

request_queue.put(choice.bmreq)
while True:
try:
timing = response_queue.get(timeout=1.0)
except queue.Empty:
status = process.exitcode
if status is None:
# child process is still running
continue
# child process fail
assert status != 0

warnings.warn(
f"Fail to benchmark choice '{choice}'. It will be ignored. Please debug the root cause in case the choice can bring perf gains." # noqa: B950 line too long
)

tuning_process.put(choice.bmreq)
try:
return tuning_process.get()
except EOFError:
warnings.warn(
f"Failed to benchmark choice '{choice}'. It will be ignored. "
"Please debug the root cause in case the choice can bring perf gains.",
stacklevel=2,
)
# return INF so this choice will be ignored
return float("inf")
tuning_process.clear()

# return INF so this choice will be ignored
return float("inf")

return timing
5 changes: 0 additions & 5 deletions torch/_inductor/autotune_process_entry.py

This file was deleted.

1 change: 1 addition & 0 deletions torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ def autotune(choice):

# do the optional warmup
tuning_process.initialize()
assert tuning_process.valid()

autotune_start_ts = time.time()
timings = self.lookup(
Expand Down

0 comments on commit 2039f30

Please sign in to comment.