Skip to content

Commit

Permalink
update check for triton (#2641)
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire authored Oct 24, 2024
1 parent f4e0343 commit eaa4e6f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
8 changes: 5 additions & 3 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def check_env_triton(device: str):
from packaging import version
logger = get_logger('lmdeploy')

msg = (
'Please ensure that your device is functioning properly with <Triton>.\n' # noqa: E501
'You can verify your environment by running '
'`python -m lmdeploy.pytorch.check_env.triton_custom_add`.')
try:
logger.debug('Checking <Triton> environment.')
import torch
Expand All @@ -87,11 +91,9 @@ def check_env_triton(device: str):
'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501
'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501
' or reinstall the driver.')
else:
msg = None
_handle_exception(e, 'Triton', logger, msg)
except Exception as e:
_handle_exception(e, 'Triton', logger)
_handle_exception(e, 'Triton', logger, msg)

if device == 'cuda':
device_cap = torch.cuda.get_device_capability()
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/check_env/triton_custom_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,11 @@ def custom_add(a, b):
grid = (triton.cdiv(size, BLOCK), )
_add_kernel[grid](a, b, c, size, BLOCK=BLOCK)
return c


if __name__ == '__main__':
a = torch.tensor([1, 2], device='cuda')
b = a.new_tensor([3, 4], device='cuda')
c = custom_add(a, b)
torch.testing.assert_close(c, a + b)
print('Done.')
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/kernels/cuda/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,9 +1153,9 @@ def _get_block_d(Lk):
if not is_decoding:
BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq)
if _nv_cap[0] < 8:
BLOCK_M = max(16, min(BLOCK, 8192 // BLOCK_DMODEL))
BLOCK_M = max(16, 8192 // BLOCK_DMODEL)
else:
BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL))
BLOCK_M = max(16, 16384 // BLOCK_DMODEL)
num_warps = 4
num_stages = 2
kv_head = k.shape[h_dim]
Expand Down

0 comments on commit eaa4e6f

Please sign in to comment.