Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
commit_id=$GITHUB_SHA
fi
commit_msg=$(git show -s --format=%s | cut -c1-70)
python3 benchmark_v2/run_benchmarks.py -b 32 -s 128 -n 256 --branch-name "$BRANCH_NAME" --commit-id "$commit_id" --commit-message "$commit_msg" --model-id "$MODEL_ID" --log-level INFO --push-result-to-dataset "$DATASET_ID"
python3 benchmark_v2/run_benchmarks.py -b 32 -s 128 -n 256 --cross-generate --branch-name "$BRANCH_NAME" --commit-id "$commit_id" --commit-message "$commit_msg" --model-id "$MODEL_ID" --log-level INFO --push-result-to-dataset "$DATASET_ID"
env:
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
PUSH_TO_HUB_TOKEN: ${{ secrets.PUSH_TO_HUB_TOKEN }}
Expand Down
47 changes: 19 additions & 28 deletions benchmark_v2/framework/benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
from typing import Any

from transformers.utils.import_utils import is_flash_attn_2_available


KERNELIZATION_AVAILABLE = False
try:
Expand All @@ -18,6 +20,16 @@
class BenchmarkConfig:
"""Configuration for a single benchmark scenario."""

all_attn_implementations = [
("flash_attention_2", None),
("eager", None),
("sdpa", "math"),
("sdpa", "flash_attention"),
("flex_attention", None),
]

all_compiled_modes = [None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]

def __init__(
self,
warmup_iterations: int = 5,
Expand Down Expand Up @@ -59,6 +71,13 @@ def __init__(
def check_validity(self, skip_validity_check: bool = False) -> None:
if skip_validity_check:
return
# Check FA is installed
if self.attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
logger.warning(
"Flash attention does not support compile mode. Defaulting to SDPA w/ flash attention backend."
)
self.attn_implementation = "sdpa"
self.sdpa_backend = "flash_attention"
# Flash attention does not support compile mode, so we turn it off # FIXME: it would be better to support it
is_fa = self.attn_implementation == "flash_attention_2"
is_fa |= self.attn_implementation == "sdpa" and self.sdpa_backend == "flash_attention"
Expand Down Expand Up @@ -163,34 +182,6 @@ def cross_generate_configs(
return configs


def generate_all_configs(
warmup_iterations: int = 5,
measurement_iterations: int = 20,
batch_size: int = 1,
sequence_length: int = 128,
num_tokens_to_generate: int = 128,
gpu_monitoring: bool = True,
) -> list[BenchmarkConfig]:
all_attn_implementations = [
("flash_attention_2", None),
("eager", None),
("sdpa", "math"),
("sdpa", "flash_attention"),
("flex_attention", None),
]
return cross_generate_configs(
attn_impl_and_sdpa_backend=all_attn_implementations,
compiled_mode=[None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],
kernelized=[False, KERNELIZATION_AVAILABLE],
warmup_iterations=warmup_iterations,
measurement_iterations=measurement_iterations,
batch_size=batch_size,
sequence_length=sequence_length,
num_tokens_to_generate=num_tokens_to_generate,
gpu_monitoring=gpu_monitoring,
)


def generate_main_configs(
warmup_iterations: int = 5,
measurement_iterations: int = 20,
Expand Down
12 changes: 10 additions & 2 deletions benchmark_v2/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
import sys
import uuid

from framework.benchmark_config import BenchmarkConfig, generate_all_configs, generate_main_configs
from framework.benchmark_config import (
KERNELIZATION_AVAILABLE,
BenchmarkConfig,
cross_generate_configs,
generate_main_configs,
)
from framework.benchmark_runner import BenchmarkRunner


Expand Down Expand Up @@ -82,7 +87,10 @@
# If there is only one (batch_size, sequence_length, num_tokens_to_generate), we benchmark across configs
elif len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 1:
if args.cross_generate:
benchmark_configs = generate_all_configs(
benchmark_configs = cross_generate_configs(
attn_impl_and_sdpa_backend=BenchmarkConfig.all_attn_implementations,
compiled_mode=[None, "default"], # usually there is not much to gain by compiling with other modes
kernelized=[False, KERNELIZATION_AVAILABLE],
warmup_iterations=args.warmup,
measurement_iterations=args.iterations,
batch_size=args.batch_size[0],
Expand Down