Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 68 additions & 25 deletions graph_net/torch/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,56 @@
from contextlib import contextmanager
import time

try:
import torch_tensorrt
except ImportError:
torch_tensorrt = None

def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:

class GraphCompilerBackend:
def __call__(self, model):
raise NotImplementedError()

def synchronize(self):
raise NotImplementedError()


class InductorBackend(GraphCompilerBackend):
def __call__(self, model):
return torch.compile(model, backend="inductor")

def synchronize(self):
torch.cuda.synchronize()


class TensorRTBackend(GraphCompilerBackend):
def __call__(self, model):
return torch.compile(model, backend="tensorrt")

def synchronize(self):
torch.cuda.synchronize()


registry_backend = {
"inductor": InductorBackend(),
"tensorrt": TensorRTBackend(),
"default": InductorBackend(),
}


def load_class_from_file(
args: argparse.Namespace, class_name: str
) -> Type[torch.nn.Module]:
file_path = f"{args.model_path}/model.py"
file = Path(file_path).resolve()
module_name = file.stem

with open(file_path, "r", encoding="utf-8") as f:
original_code = f.read()
import_stmt = "import torch"
modified_code = f"{import_stmt}\n{original_code}"
if args.device == "cuda":
modified_code = original_code.replace("cpu", "cuda")
else:
modified_code = original_code
spec = importlib.util.spec_from_loader(module_name, loader=None)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
Expand All @@ -32,27 +73,23 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
return model_class


def get_compiler(args):
assert args.compiler == "default"
return torch.compile


def get_synchronizer_func(args):
assert args.compiler == "default"
return torch.cuda.synchronize
def get_compiler_backend(args) -> GraphCompilerBackend:
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
return registry_backend[args.compiler]


def get_model(args):
model_class = load_class_from_file(
f"{args.model_path}/model.py", class_name="GraphModule"
)
return model_class()
model_class = load_class_from_file(args, class_name="GraphModule")
return model_class().to(torch.device(args.device))


def get_input_dict(args):
inputs_params = utils.load_converted_from_text(f"{args.model_path}")
params = inputs_params["weight_info"]
return {k: utils.replay_tensor(v) for k, v in params.items()}
return {
k: utils.replay_tensor(v).to(torch.device(args.device))
for k, v in params.items()
}


@dataclass
Expand All @@ -71,15 +108,14 @@ def naive_timer(duration_box, get_synchronizer_func):


def test_single_model(args):
compiler = get_compiler(args)
synchronizer_func = get_synchronizer_func(args)
compiler = get_compiler_backend(args)
input_dict = get_input_dict(args)
model = get_model(args)
compiled_model = compiler(model)

# eager
eager_duration_box = DurationBox(-1)
with naive_timer(eager_duration_box, synchronizer_func):
with naive_timer(eager_duration_box, compiler.synchronize):
expected_out = model(**input_dict)

# warmup
Expand All @@ -88,7 +124,7 @@ def test_single_model(args):

# compiled
compiled_duration_box = DurationBox(-1)
with naive_timer(compiled_duration_box, synchronizer_func):
with naive_timer(compiled_duration_box, compiler.synchronize):
compiled_out = compiled_model(**input_dict)

def print_cmp(key, func, **kwargs):
Expand Down Expand Up @@ -157,11 +193,11 @@ def test_multi_models(args):
cmd = "".join(
[
sys.executable,
"-m graph_net.torch.test_compiler",
f"--model-path {model_path}",
f"--compiler {args.compiler}",
f"--warmup {args.warmup}",
f"--log-prompt {args.log_prompt}",
" -m graph_net.torch.test_compiler",
f" --model-path {model_path}",
f" --compiler {args.compiler}",
f" --warmup {args.warmup}",
f" --log-prompt {args.log_prompt}",
]
)
cmd_ret = os.system(cmd)
Expand Down Expand Up @@ -212,6 +248,13 @@ def main(args):
default="default",
help="Path to customized compiler python file",
)
parser.add_argument(
"--device",
type=str,
required=False,
default="cpu",
help="Device for testing the compiler",
)
parser.add_argument(
"--warmup", type=int, required=False, default=5, help="Number of warmup steps"
)
Expand Down