Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

re: transfer oneflow_compile from onediff to oneflow #10408

Merged
merged 20 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions python/oneflow/framework/infer_compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os

import oneflow as flow
from oneflow.framework.args_tree import ArgsTree

from .transform.custom_transform import register
from .utils.patch_for_compiler import *
from .with_fx_graph import fx_node_tranform
from .with_fx_interpreter import OneFlowInterpreter
from .with_oneflow_compile import compile_from_torch


def oneflow_backend(gm, example_inputs, *args, **kwargs):
with_interp = os.getenv(
"ONEDIFF_INFER_COMPILER_USE_INTERPRETER", "False"
).lower() in ("true", "1", "t",)
if not with_interp:
transformed_fn = fx_node_tranform(gm)

def wrapped_forward(*args, **kwargs):
def input_fn(value):
if isinstance(value, torch.Tensor):
return flow.utils.tensor.from_torch(value.contiguous())
else:
return value

args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor)
out = args_tree.map_leaf(input_fn)
args = out[0]
if with_interp:
output = OneFlowInterpreter(gm, garbage_collect_values=False).run(
*args, **kwargs
)
else:
output = transformed_fn(*args, **kwargs)
if isinstance(output, tuple):
return tuple(flow.utils.tensor.to_torch(i) for i in output)
return flow.utils.tensor.to_torch(output)

return wrapped_forward
17 changes: 17 additions & 0 deletions python/oneflow/framework/infer_compiler/import_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
""" Tools for importing modules and packages"""
from .importer import LazyMocker, import_module_from_path
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import inspect
from types import FunctionType
from typing import Union


class MockEntityNameFormatter:
def __init__(self, prefix: str = "mock_", suffix: str = "_oflow"):
self.prefix = prefix
self.suffix = suffix

def _format_pkg_name(self, pkg_name: str) -> str:
if pkg_name.startswith(self.prefix) and pkg_name.endswith(self.suffix):
return pkg_name
return self.prefix + pkg_name + self.suffix

def _reverse_pkg_name(self, pkg_name: str) -> str:
assert pkg_name.startswith(self.prefix) and pkg_name.endswith(
self.suffix
), f"Package name must start with {self.prefix} and end with {self.suffix}, but got {pkg_name}"
return pkg_name[len(self.prefix) : -len(self.suffix)]

def _format_full_class_name(self, obj: Union[str, type, FunctionType]):
if isinstance(obj, type):
obj = f"{obj.__module__}.{obj.__qualname__}"

elif isinstance(obj, FunctionType):
module = inspect.getmodule(obj)
obj = f"{module.__name__}.{obj.__qualname__}"

assert isinstance(obj, str), f"obj must be str, but got {type(obj)}"

if "." in obj:
pkg_name, cls_name = obj.split(".", 1)
return f"{self._format_pkg_name(pkg_name)}.{cls_name}"
else:
return self._format_pkg_name(obj)

def format(self, entity: Union[str, type, FunctionType]) -> str:
return self._format_full_class_name(entity)

def unformat(self, mock_entity_name: str) -> str:
if "." in mock_entity_name:
pkg_name, cls_name = mock_entity_name.split(".", 1)
return f"{self._reverse_pkg_name(pkg_name)}.{cls_name}"
else: # mock_entity_name is a pkg_name
return self._reverse_pkg_name(mock_entity_name)
137 changes: 137 additions & 0 deletions python/oneflow/framework/infer_compiler/import_tools/importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import importlib
import os
import sys
from pathlib import Path
from types import FunctionType, ModuleType
from typing import Optional, Union

from oneflow.mock_torch import DynamicMockModule

from .format_utils import MockEntityNameFormatter

if sys.version_info < (3, 8):
try:
from importlib_metadata import requires
except ImportError:
import subprocess

subprocess.check_call("pip install importlib_metadata", shell=True)
subprocess.check_call("pip install packaging", shell=True)
else:
from importlib.metadata import requires

__all__ = ["import_module_from_path", "LazyMocker", "is_need_mock"]


def is_need_mock(cls) -> bool:
assert isinstance(cls, (type, str))
main_pkg = cls.__module__.split(".")[0]
try:
pkgs = requires(main_pkg)
except Exception as e:
return True
if pkgs:
for pkg in pkgs:
pkg = pkg.split(" ")[0]
if pkg == "torch":
return True
return False
return True


def import_module_from_path(module_path: Union[str, Path]) -> ModuleType:
if isinstance(module_path, Path):
module_path = str(module_path)
module_name = os.path.basename(module_path)
if os.path.isfile(module_path):
sp = os.path.splitext(module_path)
module_name = sp[0]

if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module_dir = os.path.split(module_path)[0]
else:
module_spec = importlib.util.spec_from_file_location(
module_name, os.path.join(module_path, "__init__.py")
)
module_dir = module_path

module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
module_spec.loader.exec_module(module)
return module


class LazyMocker:
def __init__(self, prefix: str, suffix: str, tmp_dir: Optional[Union[str, Path]]):
self.prefix = prefix
self.suffix = suffix
self.tmp_dir = tmp_dir
self.mocked_packages = set()
self.cleanup_list = []

def mock_package(self, package: str):
pass

def cleanup(self):
pass

def get_mock_entity_name(self, entity: Union[str, type, FunctionType]):
formatter = MockEntityNameFormatter(prefix=self.prefix, suffix=self.suffix)
full_obj_name = formatter.format(entity)
return full_obj_name

def mock_entity(self, entity: Union[str, type, FunctionType]):
"""Mock the entity and return the mocked entity

Example:
>>> mocker = LazyMocker(prefix="mock_", suffix="_of", tmp_dir="tmp")
>>> mocker.mock_entity("models.DemoModel")
<class 'mock_models_of.DemoModel'>
>>> cls_obj = models.DemoModel
>>> mocker.mock_entity(cls_obj)
<class 'mock_models_of.DemoModel'>
"""
return self.load_entity_with_mock(entity)

def add_mocked_package(self, package: str):
if package in self.mocked_packages:
return

self.mocked_packages.add(package)
package = sys.modules.get(package, None)

# TODO remove code below
# fix the mock error in https://github.com/siliconflow/oneflow/blob/main/python/oneflow/mock_torch/mock_importer.py#L105-L118
if package and getattr(package, "__file__", None) is not None:
pkg_path = Path(package.__file__).parents[1]
if pkg_path not in sys.path:
sys.path.append(str(pkg_path))

def load_entity_with_mock(self, entity: Union[str, type, FunctionType]):
formatter = MockEntityNameFormatter(prefix=self.prefix, suffix=self.suffix)
full_obj_name = formatter.format(entity)
attrs = full_obj_name.split(".")

# add package path to sys.path to avoid mock error
self.add_mocked_package(attrs[0])

mock_pkg = DynamicMockModule.from_package(attrs[0], verbose=False)
for name in attrs[1:]:
mock_pkg = getattr(mock_pkg, name)
return mock_pkg
26 changes: 26 additions & 0 deletions python/oneflow/framework/infer_compiler/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Module to convert PyTorch code to OneFlow."""
from .builtin_transform import (
ProxySubmodule,
default_converter,
get_attr,
map_args,
proxy_class,
torch2oflow,
)
from .custom_transform import register
from .manager import transform_mgr
Loading