Skip to content

Commit 7157c93

Browse files
add metal to list of choices (#8282)
1 parent 1f0f8f1 commit 7157c93

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

python/tvm/driver/tvmc/runner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020
import json
2121
import logging
22-
from typing import Optional, Dict, List, Union
22+
from typing import Dict, List, Optional, Union
2323

2424
import numpy as np
2525
import tvm
@@ -30,12 +30,11 @@
3030
from tvm.relay.param_dict import load_param_dict
3131

3232
from . import common
33-
from .model import TVMCPackage, TVMCResult
3433
from .common import TVMCException
3534
from .main import register_parser
35+
from .model import TVMCPackage, TVMCResult
3636
from .result_utils import get_top_results
3737

38-
3938
# pylint: disable=invalid-name
4039
logger = logging.getLogger("TVMC")
4140

@@ -51,7 +50,7 @@ def add_run_parser(subparsers):
5150
# like 'webgpu', etc (@leandron)
5251
parser.add_argument(
5352
"--device",
54-
choices=["cpu", "cuda", "cl"],
53+
choices=["cpu", "cuda", "cl", "metal"],
5554
default="cpu",
5655
help="target device to run the compiled module. Defaults to 'cpu'",
5756
)
@@ -391,6 +390,8 @@ def run_module(
391390
dev = session.cuda()
392391
elif device == "cl":
393392
dev = session.cl()
393+
elif device == "metal":
394+
dev = session.metal()
394395
else:
395396
assert device == "cpu"
396397
dev = session.cpu()

0 commit comments

Comments
 (0)