From 53969aa9d68e230a1d9de4a0a0437a36402b16f5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 6 Sep 2024 15:26:10 -0700 Subject: [PATCH] add depyf.optimization and tests (#58) --- .github/workflows/test_pytorch.yml | 1 + depyf/optimization.py | 74 ++++++++++++++++++++++++++++++ tests/test_pytorch/test_wrapper.py | 58 +++++++++++++++++++++++ 3 files changed, 133 insertions(+) create mode 100644 depyf/optimization.py create mode 100644 tests/test_pytorch/test_wrapper.py diff --git a/.github/workflows/test_pytorch.yml b/.github/workflows/test_pytorch.yml index a41bdbba..1940d318 100644 --- a/.github/workflows/test_pytorch.yml +++ b/.github/workflows/test_pytorch.yml @@ -42,6 +42,7 @@ jobs: echo "success" - name: Test with pytest run: | + coverage run --append tests/test_pytorch/test_wrapper.py coverage run --append tests/test_pytorch/test_mp.py coverage run --append tests/test_pytorch/test_no_graph.py coverage run --append tests/test_pytorch/test_irregular.py diff --git a/depyf/optimization.py b/depyf/optimization.py new file mode 100644 index 00000000..c24622a3 --- /dev/null +++ b/depyf/optimization.py @@ -0,0 +1,74 @@ +import os +import sys +from abc import abstractmethod +from contextlib import contextmanager +from types import CodeType +from typing import Callable, List + +import torch + + +class TorchCompileWrapperWithCustomDispacther: + """ + A wrapper class for torch.compile, with a custom dispatch logic. + Subclasses should: + 1. Implement the forward method + 2. Implement the dispatch logic in the __call__ method + It can use `self.compiled_codes` to access the compiled bytecode, + and `with self.dispatch_to_code(index):` to dispatch to + the compiled code. + 3. Implement the `__init__` method to determine how to call + `torch.compile` over the forward method. + """ + + def __init__(self, compiled_callable: Callable, use_custom_dispatcher: bool = True): + self.compiled_callable = compiled_callable + self.original_code_object = self.__class__.forward.__code__ + self.compiled_codes: List[CodeType] = [] + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + + self.use_custom_dispatcher: bool = use_custom_dispatcher + + def __call__(self, *args, **kwargs): + """Implement the dispatch logic here, beyond the torch.compile level. + NOTE: this function can have additional arguments beyond the forward + method, for directly dispatching to the compiled code. + """ + return self.compiled_callable(*args, **kwargs) + + @abstractmethod + def forward(self, *args, **kwargs): + ... + + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): + """Hook to save the compiled bytecode for direct execution.""" + if old_code is not self.original_code_object: + return + frame = sys._getframe() + while True: + frame = frame.f_back + code_name = frame.f_code.co_name + file_name = frame.f_code.co_filename.split(os.path.sep)[-1] + if code_name == "_compile" and file_name == "convert_frame.py": + break + frame = frame.f_locals["frame"] + assert frame.f_code == old_code + + if frame.f_locals["self"] is not self: + return + + self.compiled_codes.append(new_code) + + @contextmanager + def dispatch_to_code(self, index: int): + """Context manager to dispatch to the compiled code. + Why does this work? Because Dynamo guarantees that the compiled + bytecode has exactly the same arguments, cell variables, and free + variables as the original code. Therefore we can directly switch + the code object in the function and call it. + + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa + self.__class__.forward.__code__ = self.compiled_codes[index] + yield + self.__class__.forward.__code__ = self.original_code_object diff --git a/tests/test_pytorch/test_wrapper.py b/tests/test_pytorch/test_wrapper.py new file mode 100644 index 00000000..e8cfd115 --- /dev/null +++ b/tests/test_pytorch/test_wrapper.py @@ -0,0 +1,58 @@ +from typing import Optional + +import torch + +from depyf.optimization import TorchCompileWrapperWithCustomDispacther + + +class MyMod(torch.nn.Module): + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + if cache is not None: + return x + cache + return x * 2 + + +class MyWrapper(TorchCompileWrapperWithCustomDispacther): + + def __init__(self, model): + self.model = model + compiled_callable = torch.compile(self.forward, backend="eager") + super().__init__(compiled_callable) + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + # this is the function to be compiled + return self.model(x, cache) + + def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + # let torch.compile compile twice + if len(self.compiled_codes) == 2: + dispatch_id = 0 if cache is None else 1 + with self.dispatch_to_code(dispatch_id): + return self.forward(x, cache) + else: + return self.compiled_callable(x, cache) + + +mod = MyMod() +wrappers = [] +for i in range(3): + torch._dynamo.reset() + wrapper = MyWrapper(mod) + wrappers.append(wrapper) + x = torch.tensor([1]) + wrapper(x, None) # profile run, compile + # create a cache tensor + cache = torch.tensor([2]) + wrapper(x, cache) # warm up with cache, recompile + + # for new input, dispatch to the compiled code directly + new_x = torch.tensor([3]) + assert wrapper(new_x, + None).item() == 6 # dispatch to the first compiled code + assert wrapper( + new_x, cache).item() == 5 # dispatch to the second compiled code + +for wrapper in wrappers: + # make sure they have independent compiled codes + assert len(wrapper.compiled_codes) == 2