Skip to content

Commit

Permalink
[Target][Device] Auto detect target and create device from str in tor…
Browse files Browse the repository at this point in the history
…ch style (apache#15714)

- Target auto detection: `Target.auto_detect()`.
- Target created from device: `Target.from_device("cuda")` or
  `Target.from_device(tvm.cuda())`
- create device from str: `tvm.device("cuda:0")` or tvm.device("cuda",
  0)
  • Loading branch information
LeshengJin authored Sep 13, 2023
1 parent a25843b commit d1ede36
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 2 deletions.
19 changes: 17 additions & 2 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,26 @@ def device(dev_type, dev_id=0):
assert tvm.device("cpu", 1) == tvm.cpu(1)
assert tvm.device("cuda", 0) == tvm.cuda(0)
"""
if not isinstance(dev_id, int):
raise ValueError(f"Invalid device id: {dev_id}")

if isinstance(dev_type, string_types):
dev_type = dev_type.split()[0]
if dev_type.count(":") == 0:
pass
elif dev_type.count(":") == 1:
# It will override the dev_id passed by the user.
dev_type, dev_id = dev_type.split(":")
if not dev_id.isdigit():
raise ValueError(f"Invalid device id: {dev_id}")
dev_id = int(dev_id)
else:
raise ValueError(f"Invalid device string: {dev_type}")

if dev_type not in Device.STR2MASK:
raise ValueError(f"Unknown device type {dev_type}")
dev_type = Device.STR2MASK[dev_type]
raise ValueError(f"Unknown device type: {dev_type}")

return Device(Device.STR2MASK[dev_type], dev_id)
return Device(dev_type, dev_id)


Expand Down
114 changes: 114 additions & 0 deletions python/tvm/target/detect_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Detect target."""
from typing import Union

from . import Target
from .._ffi import get_global_func
from .._ffi.runtime_ctypes import Device
from ..runtime.ndarray import device


def _detect_metal(dev: Device) -> Target:
return Target(
{
"kind": "metal",
"max_shared_memory_per_block": 32768,
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
}
)


def _detect_cuda(dev: Device) -> Target:
return Target(
{
"kind": "cuda",
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
"arch": "sm_" + dev.compute_version.replace(".", ""),
}
)


def _detect_rocm(dev: Device) -> Target:
return Target(
{
"kind": "rocm",
"mtriple": "amdgcn-and-amdhsa-hcc",
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
}
)


def _detect_vulkan(dev: Device) -> Target:
f_get_target_property = get_global_func("device_api.vulkan.get_target_property")
return Target(
{
"kind": "vulkan",
"max_threads_per_block": dev.max_threads_per_block,
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"thread_warp_size": dev.warp_size,
"supports_float16": f_get_target_property(dev, "supports_float16"),
"supports_int16": f_get_target_property(dev, "supports_int16"),
"supports_int8": f_get_target_property(dev, "supports_int8"),
"supports_16bit_buffer": f_get_target_property(dev, "supports_16bit_buffer"),
}
)


def detect_target_from_device(dev: Union[str, Device]) -> Target:
"""Detects Target associated with the given device. If the device does not exist,
there will be an Error.
Parameters
----------
dev : Union[str, Device]
The device to detect the target for.
Supported device types: ["cuda", "metal", "rocm", "vulkan"]
Returns
-------
target : Target
The detected target.
"""
if isinstance(dev, str):
dev = device(dev)
device_type = Device.MASK2STR[dev.device_type]
if device_type not in SUPPORT_DEVICE:
raise ValueError(
f"Auto detection for device `{device_type}` is not supported. "
f"Currently only supports: {SUPPORT_DEVICE.keys()}"
)
if not dev.exist:
raise ValueError(
f"Cannot detect device `{dev}`. Please make sure the device and its driver "
"is installed properly, and TVM is compiled with the driver"
)
target = SUPPORT_DEVICE[device_type](dev)
return target


SUPPORT_DEVICE = {
"cuda": _detect_cuda,
"metal": _detect_metal,
"vulkan": _detect_vulkan,
"rocm": _detect_rocm,
}
24 changes: 24 additions & 0 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
import json
import re
import warnings
from typing import Union

import tvm._ffi
from tvm._ffi import register_func as _register_func
from tvm._ffi.runtime_ctypes import Device
from tvm.runtime import Object, convert
from tvm.runtime.container import String
from tvm.ir.container import Map, Array
Expand Down Expand Up @@ -148,6 +150,28 @@ def export(self):
def with_host(self, host=None):
return _ffi_api.WithHost(self, Target(host))

@staticmethod
def from_device(device: Union[str, Device]) -> "Target":
"""Detects Target associated with the given device. If the device does not exist,
there will be an Error.
Parameters
----------
dev : Union[str, Device]
The device to detect the target for.
Supported device types: ["cuda", "metal", "rocm", "vulkan", "opencl", "cpu"]
Returns
-------
target : Target
The detected target.
"""
from .detect_target import ( # pylint: disable=import-outside-toplevel
detect_target_from_device,
)

return detect_target_from_device(device)

@staticmethod
def current(allow_none=True):
"""Returns the current target.
Expand Down
71 changes: 71 additions & 0 deletions tests/python/unittest/test_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

import tvm
import tvm.testing
from tvm._ffi.runtime_ctypes import Device


@pytest.mark.parametrize(
"dev_str, expected_device_type, expect_device_id",
[
("cpu", Device.kDLCPU, 0),
("cuda", Device.kDLCUDA, 0),
("cuda:0", Device.kDLCUDA, 0),
("cuda:3", Device.kDLCUDA, 3),
("metal:2", Device.kDLMetal, 2),
],
)
def test_device(dev_str, expected_device_type, expect_device_id):
dev = tvm.device(dev_str)
assert dev.device_type == expected_device_type
assert dev.device_id == expect_device_id


@pytest.mark.parametrize(
"dev_type, dev_id, expected_device_type, expect_device_id",
[
("cpu", 0, Device.kDLCPU, 0),
("cuda", 0, Device.kDLCUDA, 0),
(Device.kDLCUDA, 0, Device.kDLCUDA, 0),
("cuda", 3, Device.kDLCUDA, 3),
(Device.kDLMetal, 2, Device.kDLMetal, 2),
],
)
def test_device_with_dev_id(dev_type, dev_id, expected_device_type, expect_device_id):
dev = tvm.device(dev_type=dev_type, dev_id=dev_id)
assert dev.device_type == expected_device_type
assert dev.device_id == expect_device_id


@pytest.mark.parametrize(
"dev_type, dev_id",
[
("cpu:0:0", None),
("cpu:?", None),
("cpu:", None),
(Device.kDLCUDA, "?"),
],
)
def test_deive_error(dev_type, dev_id):
with pytest.raises(ValueError):
dev = tvm.device(dev_type=dev_type, dev_id=dev_id)


if __name__ == "__main__":
tvm.testing.main()
45 changes: 45 additions & 0 deletions tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,5 +488,50 @@ def test_target_features():
assert not target_with_features.features.is_missing


@tvm.testing.requires_cuda
@pytest.mark.parametrize("input_device", ["cuda", tvm.cuda()])
def test_target_from_device_cuda(input_device):
target = Target.from_device(input_device)

dev = tvm.cuda()
assert target.kind.name == "cuda"
assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
assert target.max_shared_memory_per_block == dev.max_shared_memory_per_block
assert target.thread_warp_size == dev.warp_size
assert target.arch == "sm_" + dev.compute_version.replace(".", "")


@tvm.testing.requires_rocm
@pytest.mark.parametrize("input_device", ["rocm", tvm.rocm()])
def test_target_from_device_rocm(input_device):
target = Target.from_device(input_device)

dev = tvm.rocm()
assert target.kind.name == "rocm"
assert target.attrs["mtriple"] == "amdgcn-and-amdhsa-hcc"
assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
assert target.max_shared_memory_per_block == dev.max_shared_memory_per_block
assert target.thread_warp_size == dev.warp_size


@tvm.testing.requires_vulkan
@pytest.mark.parametrize("input_device", ["vulkan", tvm.vulkan()])
def test_target_from_device_rocm(input_device):
target = Target.from_device(input_device)

f_get_target_property = tvm.get_global_func("device_api.vulkan.get_target_property")
dev = tvm.vulkan()
assert target.kind.name == "vulkan"
assert target.attrs["max_threads_per_block"] == dev.max_threads_per_block
assert target.max_shared_memory_per_block == dev.max_shared_memory_per_block
assert target.thread_warp_size == dev.warp_size
assert target.attrs["supports_float16"] == f_get_target_property(dev, "supports_float16")
assert target.attrs["supports_int16"] == f_get_target_property(dev, "supports_int16")
assert target.attrs["supports_int8"] == f_get_target_property(dev, "supports_int8")
assert target.attrs["supports_16bit_buffer"] == f_get_target_property(
dev, "supports_16bit_buffer"
)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit d1ede36

Please sign in to comment.