Skip to content

Batched benchmark script and more detailed benchmark metrics #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
*.eggs/
*.so
build/
log/
archive/
*.csv
.vscode/
23 changes: 23 additions & 0 deletions benchmark/batch_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash

mkdir -p log

MODEL_LOG_NAME="opt-13b"
MODEL="facebook/opt-13b"

for BATCH_SIZE in 8 32 128; do
for INPUT_LEN in 1 32 256 1024; do
for OUTPUT_LEN in 1 16 128; do
for TENSOR_PARALLEL_SIZE in 1 2 4; do
python benchmark_latency.py \
--model $MODEL \
--batch-size $BATCH_SIZE \
--input-len $INPUT_LEN \
--output-len $OUTPUT_LEN \
--tensor-parallel-size $TENSOR_PARALLEL_SIZE \
| tee -a log/model_${MODEL_LOG_NAME}_bs_${BATCH_SIZE}_in_${INPUT_LEN}_out_${OUTPUT_LEN}_tp_${TENSOR_PARALLEL_SIZE}.log
sleep 0.1
done
done
done
done
62 changes: 46 additions & 16 deletions benchmark/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import argparse
import time
from typing import List
Expand All @@ -11,17 +12,24 @@
initialize_ray_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
from cacheflow.profile import set_sync_for_profiling


def main(args: argparse.Namespace):
print(json.dumps(args.__dict__))
set_sync_for_profiling()

# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')

cuda_profiler = False
ray_cluster_address = "local" if cuda_profiler else "auto"

(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
address='local',
address=ray_cluster_address,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))

Expand Down Expand Up @@ -60,32 +68,55 @@ def main(args: argparse.Namespace):
sampling_params = SamplingParams.from_dict(sampling_params_dict)
input_token_ids = [0] * args.input_len

def profile_step(profile=False):
if profile:
def profile_step():
if cuda_profiler:
torch.cuda.cudart().cudaProfilerStart()
server.reset_timer()
for _ in range(args.batch_size):
frontend._add_query(input_token_ids, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
# Prompt step
start_time = time.time()
while True:
server.step()
end_time = time.time()
prompt_latency = end_time - start_time
# Decoding steps
num_decoding_steps = 0
start_time = time.time()
while server.has_unfinished_requests():
server.step()
if not server.has_unfinished_requests():
break
num_decoding_steps += 1
end_time = time.time()
latency = end_time - start_time
if profile:
decoding_latency = end_time - start_time
if cuda_profiler:
torch.cuda.cudart().cudaProfilerStop()
return latency
server_profile_results = server.get_profile_results()
# First controller's first worker
worker_execution_latency = server_profile_results[0][0]["execution_latency"]
worker_communication_latency = server_profile_results[0][0]["communication_latency"]
return (prompt_latency, decoding_latency, num_decoding_steps,
worker_execution_latency, worker_communication_latency)

print("Warm up step")
print("== Warm up step ==")
profile_step()

# Benchmark.
latencies = []
for _ in tqdm(range(3), desc="Profile step"):
latencies.append(profile_step())
print(f'Avg latency: {np.mean(latencies)} seconds')

print("== Profile steps ==")
num_profile_steps = 5
for step in range(num_profile_steps):
(prompt_latency, decoding_latency, num_decoding_steps,
worker_execution_latency, worker_communication_latency) = profile_step()
decoding_latency_per_step = decoding_latency / num_decoding_steps if num_decoding_steps > 0 else 0.0
result = {
"step": step,
"prompt_latency_seconds": prompt_latency,
"decoding_latency_seconds": decoding_latency,
"decoding_latency_per_step_seconds": decoding_latency_per_step,
"num_decoding_steps": num_decoding_steps,
"worker_execution_latency_seconds": worker_execution_latency,
"worker_communication_latency_seconds": worker_communication_latency,
}
print(json.dumps(result))

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
Expand All @@ -95,5 +126,4 @@ def profile_step(profile=False):
parser.add_argument('--batch-size', type=int, default=8)
args = parser.parse_args()
args.max_batch_size = max(args.max_batch_size, args.batch_size * args.input_len)
print(args)
main(args)
45 changes: 45 additions & 0 deletions benchmark/parse_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import csv
import json
import os
from argparse import Namespace
from collections import defaultdict

import numpy as np
import pandas as pd

log_dir = 'log/'
log_files = os.listdir(log_dir)
all_results = []

for log_file in log_files:
file_path = os.path.join(log_dir, log_file)
lines = list(open(file_path).readlines())
profile_arguments = json.loads(lines[0])
results = defaultdict(list)
for line in lines:
if "prompt_latency_seconds" not in line:
continue
result = json.loads(line)
for k, v in result.items():
if k == "step":
continue
results[k].append(v)
final_result = {
"model": profile_arguments["model"],
"batch_size": profile_arguments["batch_size"],
"input_len": profile_arguments["input_len"],
"output_len": profile_arguments["output_len"],
"tensor_parallel_size": profile_arguments["tensor_parallel_size"],
}

for k, v in results.items():
final_result[k + "_mean"] = np.mean(v)
final_result[k + "_std"] = np.std(v)

all_results.append(final_result)

df = pd.DataFrame.from_records(all_results)

print(df)

df.to_csv('parse_result.csv', index=False)
8 changes: 8 additions & 0 deletions cacheflow/master/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ def has_unfinished_requests(self):
return (self.scheduler.waiting or self.scheduler.running or
self.scheduler.swapped)

def reset_timer(self):
for controller in self.controllers:
controller.reset_timer()

def get_profile_results(self):
return [controller.get_profile_results() for controller in
self.controllers]


def initialize_ray_cluster(
address: str = 'auto',
Expand Down
36 changes: 36 additions & 0 deletions cacheflow/parallel_utils/tensor_parallel/mappings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

import time

import torch

from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from cacheflow.profile import (maybe_sync_for_profiling,
add_to_communication_latency)
from .utils import split_tensor_along_last_dim


Expand All @@ -17,9 +21,16 @@ def _reduce(input_):
if get_tensor_model_parallel_world_size()==1:
return input_

maybe_sync_for_profiling()
start_time = time.time()

# All-reduce.
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())

maybe_sync_for_profiling()
end_time = time.time()
add_to_communication_latency(end_time - start_time)

return input_


Expand Down Expand Up @@ -78,8 +89,16 @@ def _gather_along_last_dim(input_):

tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_

maybe_sync_for_profiling()
start_time = time.time()

torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())

maybe_sync_for_profiling()
end_time = time.time()
add_to_communication_latency(end_time - start_time)

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()

Expand All @@ -99,9 +118,17 @@ def _gather_along_first_dim(input_):

output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())

maybe_sync_for_profiling()
start_time = time.time()

torch.distributed._all_gather_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())

maybe_sync_for_profiling()
end_time = time.time()
add_to_communication_latency(end_time - start_time)

return output

def _reduce_scatter_along_first_dim(input_):
Expand All @@ -119,8 +146,17 @@ def _reduce_scatter_along_first_dim(input_):

output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())

maybe_sync_for_profiling()
start_time = time.time()

torch.distributed._reduce_scatter_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())

maybe_sync_for_profiling()
end_time = time.time()
add_to_communication_latency(end_time - start_time)

return output


Expand Down
31 changes: 31 additions & 0 deletions cacheflow/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch

# Global profile option

SYNC_FOR_PROFILING = False

def maybe_sync_for_profiling():
if SYNC_FOR_PROFILING:
torch.cuda.synchronize()

def get_sync_for_profiling():
return SYNC_FOR_PROFILING

def set_sync_for_profiling(new_value: bool = True):
global SYNC_FOR_PROFILING
SYNC_FOR_PROFILING = new_value

# Communication latency

COMMUNICATION_LATENCY = 0.0

def reset_communication_latency():
global COMMUNICATION_LATENCY
COMMUNICATION_LATENCY = 0.0

def add_to_communication_latency(latency):
global COMMUNICATION_LATENCY
COMMUNICATION_LATENCY += latency

def get_communication_latency():
return COMMUNICATION_LATENCY
11 changes: 10 additions & 1 deletion cacheflow/worker/controller.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Dict, List, Union, Tuple
from typing import Dict, List, Union, Tuple, Any

import ray

from cacheflow.master.scheduler import Scheduler
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.worker.worker import Worker
from cacheflow.profile import get_sync_for_profiling


DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
model_path=model_path,
sync_for_profiling=get_sync_for_profiling(),
)
self.workers.append(worker)

Expand Down Expand Up @@ -95,3 +97,10 @@ def execute_stage(
else:
# TODO: Support pipeline parallelism.
assert False

def get_profile_results(self) -> List[Dict[str, Any]]:
return ray.get([worker.get_profile_results.remote()
for worker in self.workers])

def reset_timer(self) -> None:
ray.get([worker.reset_timer.remote() for worker in self.workers])
Loading