-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Util] Support debug debug_compare #2142
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
"""Debug compiled models with TVM instrument""" | ||
|
||
import os | ||
from pathlib import Path | ||
from typing import Dict, List, Set, Tuple | ||
|
||
import tvm | ||
from tvm import rpc, runtime | ||
from tvm.relax.testing.lib_comparator import LibCompareVMInstrument | ||
|
||
from mlc_llm.help import HELP | ||
from mlc_llm.support.argparse import ArgumentParser | ||
from mlc_llm.testing.debug_chat import DebugChat | ||
|
||
|
||
def _print_as_table(sorted_list): | ||
print("=" * 100) | ||
print( | ||
"Name".ljust(50) | ||
+ "Time (ms)".ljust(12) | ||
+ "Count".ljust(8) | ||
+ "Total time (ms)".ljust(18) | ||
+ "Percentage (%)" | ||
) | ||
total_time = sum(record[1][0] * record[1][1] for record in sorted_list) * 1000 | ||
for record in sorted_list: | ||
time = record[1][0] * 1000 | ||
weighted_time = time * record[1][1] | ||
percentage = weighted_time / total_time * 100 | ||
print( | ||
record[0].ljust(50) | ||
+ f"{time:.4f}".ljust(12) | ||
+ str(record[1][1]).ljust(8) | ||
+ f"{weighted_time:.4f}".ljust(18) | ||
+ f"{percentage:.2f}" | ||
) | ||
print(f"Total time: {total_time:.4f} ms") | ||
|
||
|
||
class LibCompare(LibCompareVMInstrument): | ||
"""The default debug instrument to use if users don't specify | ||
a customized one. | ||
|
||
This debug instrument will dump the arguments and output of each | ||
VM Call instruction into a .npz file. It will also alert the user | ||
if any function outputs are NaN or INF. | ||
|
||
Parameters | ||
---------- | ||
mod: runtime.Module | ||
The module of interest to be validated. | ||
|
||
device: runtime.Device | ||
The device to run the target module on. | ||
|
||
time_eval: bool | ||
Whether to time evaluate the functions. | ||
|
||
rtol: float | ||
rtol used in validation | ||
|
||
atol: float | ||
atol used in validation | ||
""" | ||
|
||
def __init__( # pylint: disable=too-many-arguments, unused-argument | ||
self, | ||
mod: runtime.Module, | ||
device: runtime.Device, | ||
debug_dir: Path, | ||
time_eval: bool = True, | ||
rtol: float = 1e-2, | ||
atol: float = 1, | ||
skip_rounds: int = 0, | ||
): | ||
super().__init__(mod, device, True, rtol, atol) | ||
self.time_eval = time_eval | ||
self.time_eval_results: Dict[str, Tuple[float, int]] = {} | ||
self.visited: Set[str] = set([]) | ||
self.skip_rounds = skip_rounds | ||
self.counter = 0 | ||
|
||
def reset(self, debug_dir: Path): # pylint: disable=unused-argument | ||
"""Reset the state of the Instrument class | ||
|
||
Note | ||
---- | ||
`debug_dir` is not used in this class. | ||
|
||
Parameters | ||
---------- | ||
debug_out : Path | ||
the directory to dump the .npz files | ||
""" | ||
_print_as_table( | ||
sorted( | ||
self.time_eval_results.items(), | ||
key=lambda x: -(x[1][0] * x[1][1]), | ||
) | ||
) | ||
self.time_eval_results = {} | ||
self.visited = set([]) | ||
self.counter = 0 | ||
|
||
def skip_instrument(self, func, name, before_run, ret_val, *args): | ||
if name.startswith("shape_func"): | ||
return True | ||
if self.counter < self.skip_rounds: | ||
self.counter += 1 | ||
print(f"[{self.counter}] Skip validating {name}..") | ||
return True | ||
if name in self.visited: | ||
if self.time_eval and name in self.time_eval_results: | ||
record = self.time_eval_results[name] | ||
self.time_eval_results[name] = (record[0], record[1] + 1) | ||
return True | ||
self.visited.add(name) | ||
return False | ||
|
||
def compare( | ||
self, | ||
name: str, | ||
ref_args: List[tvm.nd.NDArray], | ||
new_args: List[tvm.nd.NDArray], | ||
ret_indices: List[int], | ||
): | ||
super().compare(name, ref_args, new_args, ret_indices) | ||
|
||
if self.time_eval and name not in self.time_eval_results: | ||
res = self.mod.time_evaluator( | ||
name, self.device, number=20, repeat=3 # , cache_flush_bytes=256 * 10**6 | ||
)(*new_args) | ||
self.time_eval_results[name] = (res.mean, 1) | ||
print(f"Time-eval result {name} on {self.device}:\n {res}") | ||
|
||
|
||
def get_instrument(args): | ||
"""Get the debug instrument from the CLI arguments""" | ||
if args.cmp_device is None: | ||
assert args.cmp_lib_path is None, "cmp_lib_path must be None if cmp_device is None" | ||
args.cmp_device = args.device | ||
args.cmp_lib_path = args.model_lib_path | ||
|
||
if args.cmp_device == "iphone": | ||
assert args.cmp_lib_path.endswith(".dylib"), "Require a dylib file for iPhone" | ||
proxy_host = os.environ.get("TVM_RPC_PROXY_HOST", "127.0.0.1") | ||
proxy_port = int(os.environ.get("TVM_RPC_PROXY_PORT", "9090")) | ||
sess = rpc.connect(proxy_host, proxy_port, "iphone") | ||
sess.upload(args.cmp_lib_path) | ||
lib = sess.load_module(os.path.basename(args.cmp_lib_path)) | ||
cmp_device = sess.metal() | ||
elif args.cmp_device == "android": | ||
assert args.cmp_lib_path.endswith(".so"), "Require a so file for Android" | ||
tracker_host = os.environ.get("TVM_TRACKER_HOST", "0.0.0.0") | ||
tracker_port = int(os.environ.get("TVM_TRACKER_PORT", "9190")) | ||
tracker = rpc.connect_tracker(tracker_host, tracker_port) | ||
sess = tracker.request("android") | ||
sess.upload(args.cmp_lib_path) | ||
lib = sess.load_module(os.path.basename(args.cmp_lib_path)) | ||
cmp_device = sess.cl(0) | ||
else: | ||
lib = tvm.runtime.load_module( | ||
os.path.join( | ||
args.artifact_path, | ||
f"{args.model}-{args.quantization.name}-{args.cmp_device}.so", | ||
) | ||
) | ||
cmp_device = tvm.device(args.cmp_device) | ||
|
||
return LibCompare( | ||
lib, | ||
cmp_device, | ||
time_eval=args.time_eval, | ||
debug_dir=Path(args.debug_dir), | ||
) | ||
|
||
|
||
def main(): | ||
"""The main function to start a DebugChat CLI""" | ||
|
||
parser = ArgumentParser("MLC LLM Chat Debug Tool") | ||
parser.add_argument( | ||
"prompt", | ||
type=str, | ||
help="The user input prompt.", | ||
) | ||
parser.add_argument( | ||
"--generate-len", type=int, help="Number of output tokens to generate.", required=True | ||
) | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
help="An MLC model directory that contains `mlc-chat-config.json`", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--model-lib-path", | ||
type=str, | ||
help="The full path to the model library file to use (e.g. a ``.so`` file).", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--debug-dir", | ||
type=str, | ||
help="The output folder to store the dumped debug files.", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--device", | ||
type=str, | ||
default="auto", | ||
help=HELP["device_compile"] + ' (default: "%(default)s")', | ||
) | ||
parser.add_argument( | ||
"--cmp-device", | ||
type=str, | ||
default="none", | ||
) | ||
parser.add_argument( | ||
"--cmp-lib-path", | ||
type=str, | ||
default="none", | ||
) | ||
parser.add_argument( | ||
"--time-eval", | ||
action="store_true", | ||
help="Whether to time evaluate the functions.", | ||
) | ||
parsed = parser.parse_args() | ||
instrument = get_instrument(parsed) | ||
debug_chat = DebugChat( | ||
model=parsed.model, | ||
model_lib_path=parsed.model_lib_path, | ||
debug_dir=Path(parsed.debug_dir), | ||
device=parsed.device, | ||
debug_instrument=instrument, | ||
) | ||
debug_chat.generate(parsed.prompt, parsed.generate_len) | ||
# Only print decode for now | ||
_print_as_table( | ||
sorted( | ||
instrument.time_eval_results.items(), | ||
key=lambda x: -(x[1][0] * x[1][1]), | ||
) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little confused where this
artifact_path
comes from. Is this supposed to be themodel
folder? Or shall we just use thecmp_lib_path
instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was historically the path of the dist, ia gree having cmp-lib-path now make smore sense