Skip to content

Commit

Permalink
Add more device types for the time estimation. (bytedance#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-ningxin authored Jul 4, 2024
1 parent cc8a773 commit 75ad160
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions python/flux/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def _get_flops():
if device_name.find("A100") >= 0 or device_name.find("A800") >= 0:
assert is_fp16
return 312
# https://resources.nvidia.com/en-us-gpu-resources/a10-datasheet-nvidia?lx=CPwSfP&ncid=no-ncid
if device_name == "NVIDIA A10": # No doc from NVIDIA
return 125 if is_fp16 else 250
if device_name == "NVIDIA A30": # No doc from NVIDIA
return 165 if is_fp16 else 330
if device_name == "NVIDIA L20": # No doc from NVIDIA
return 119 if is_fp16 else 239
# https://www.nvidia.com/en-us/data-center/l4/
Expand All @@ -68,9 +73,10 @@ def _get_flops():
if device_name == "NVIDIA L40S":
return 366 if is_fp16 else 733
# https://www.nvidia.com/en-us/data-center/h100/
if device_name == "NVIDIA H100" or device_name == "NVIDIA H800":
if device_name.find("H100") >= 0 or device_name.find("H800") >= 0:
return 989 if is_fp16 else 1979

if device_name.find("H200") >= 0:
return 1979 if is_fp16 else 3958
raise Exception(f"not supported device {device_name}")

flops = M * N * K * 2
Expand Down

0 comments on commit 75ad160

Please sign in to comment.