Skip to content

remove hardcoded device="cuda" to support more device #2503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def main(args: argparse.Namespace):
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
)

sampling_params = SamplingParams(
Expand Down Expand Up @@ -135,5 +136,11 @@ def run_to_completion(profile_dir: Optional[str] = None):
default=None,
help=('path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'))
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
args = parser.parse_args()
main(args)
10 changes: 9 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def run_vllm(
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
device: str,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
Expand All @@ -85,6 +86,7 @@ def run_vllm(
max_model_len=max_model_len,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
)

# Add the requests to the engine.
Expand Down Expand Up @@ -209,7 +211,7 @@ def main(args: argparse.Namespace):
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype)
args.kv_cache_dtype, args.device)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -294,6 +296,12 @@ def main(args: argparse.Namespace):
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
27 changes: 18 additions & 9 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,32 @@ def main(
dtype: torch.dtype,
seed: int,
do_profile: bool,
device: str = "cuda",
kv_cache_dtype: Optional[str] = None,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)

scale = float(1.0 / (head_size**0.5))
query = torch.empty(num_seqs,
num_query_heads,
head_size,
dtype=dtype,
device="cuda")
device=device)
query.uniform_(-scale, scale)

assert num_query_heads % num_kv_heads == 0
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device="cuda")
device=device)

context_lens = [context_len for _ in range(num_seqs)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)

# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
Expand All @@ -59,12 +61,17 @@ def main(
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)

# Create the KV cache.
key_caches, value_caches = create_kv_caches_with_random(
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
dtype)
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
block_size,
1,
num_kv_heads,
head_size,
kv_cache_dtype,
dtype,
device=device)
key_cache, value_cache = key_caches[0], value_caches[0]

# Prepare for the paged attention kernel.
Expand All @@ -84,7 +91,7 @@ def main(
)
max_logits = torch.empty_like(exp_sums)

def run_benchmark(num_iters: int, profile: bool = False) -> float:
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize()
if profile:
torch.cuda.cudart().cudaProfilerStart()
Expand Down Expand Up @@ -135,6 +142,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:

# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=3, profile=False)

# Benchmark.
Expand Down Expand Up @@ -175,6 +183,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
args = parser.parse_args()
print(args)

Expand Down
37 changes: 21 additions & 16 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,29 @@
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
layer = SiluAndMul()
out = layer(x)
ref_out = layer._forward(x)
Expand All @@ -37,19 +40,20 @@ def test_silu_and_mul(
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gelu_new(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = NewGELU()
out = layer(x)
ref_out = layer._forward(x)
Expand All @@ -60,18 +64,19 @@ def test_gelu_new(
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_gelu_fast(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = FastGELU()
out = layer(x)
ref_out = layer._forward(x)
Expand Down
51 changes: 23 additions & 28 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]


def ref_masked_attention(
Expand Down Expand Up @@ -91,7 +93,7 @@ def ref_single_query_cached_kv_attention(
alibi_bias = None
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len, device=query.device).int()
position_ids = torch.arange(context_len).int()
alibi_bias = (position_ids - context_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
Expand All @@ -110,7 +112,7 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_paged_attention(
kv_cache_factory,
version: str,
Expand All @@ -122,33 +124,28 @@ def test_paged_attention(
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs,
num_query_heads,
head_size,
dtype=dtype,
device=gpu_id)
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)

assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device=gpu_id)
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)

context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
context_lens = torch.tensor(context_lens, dtype=torch.int)

# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
Expand All @@ -159,13 +156,13 @@ def test_paged_attention(
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
block_tables = torch.tensor(block_tables, dtype=torch.int)

# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
gpu_id)
device)
key_cache, value_cache = key_caches[0], value_caches[0]

# Call the paged attention kernel.
Expand Down Expand Up @@ -193,12 +190,10 @@ def test_paged_attention(
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
Expand Down Expand Up @@ -229,14 +224,14 @@ def test_paged_attention(
block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=gpu_id)
device=device)
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache

value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=gpu_id)
device=device)
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache

Expand Down Expand Up @@ -283,7 +278,7 @@ def ref_multi_query_kv_attention(
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device=query.device)
attn_mask = attn_mask.to(dtype=dtype)

ref_output = ref_masked_attention(
query[start_idx:end_idx],
Expand All @@ -303,20 +298,21 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
Expand All @@ -329,8 +325,7 @@ def test_multi_query_kv_attention(
qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype,
device=gpu_id)
dtype=dtype)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
Expand Down
Loading