diff --git a/.github/workflows/ci-parity.yml b/.github/workflows/ci-parity.yml index 428f0f5a..c527b1b3 100644 --- a/.github/workflows/ci-parity.yml +++ b/.github/workflows/ci-parity.yml @@ -26,8 +26,11 @@ jobs: - name: Install LitServe run: | pip --version - pip install . torchvision jsonargparse uvloop -U -q -r _requirements/test.txt -U -q + pip install . torchvision jsonargparse uvloop tenacity -U -q -r _requirements/test.txt -U -q pip list - - name: Tests + - name: Parity test run: export PYTHONPATH=$PWD && python tests/parity_fastapi/main.py + + - name: Streaming speed test + run: bash tests/perf_test/stream/run_test.sh diff --git a/tests/perf_test/stream/run_test.sh b/tests/perf_test/stream/run_test.sh new file mode 100644 index 00000000..03cc139b --- /dev/null +++ b/tests/perf_test/stream/run_test.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# 1. Test server streams data very fast + +# Function to clean up server process +cleanup() { + pkill -f "python tests/perf_test/stream/stream_speed/server.py" +} + +# Trap script exit to run cleanup +trap cleanup EXIT + +# Start the server in the background and capture its PID +python tests/perf_test/stream/stream_speed/server.py & +SERVER_PID=$! + +echo "Server started with PID $SERVER_PID" + +# Run your benchmark script +echo "Preparing to run benchmark.py..." + +export PYTHONPATH=$PWD && python tests/perf_test/stream/stream_speed/benchmark.py + +# Check if benchmark.py exited successfully +if [ $? -ne 0 ]; then + echo "benchmark.py failed to run successfully." + exit 1 +else + echo "benchmark.py ran successfully." +fi diff --git a/tests/perf_test/stream/stream_speed/benchmark.py b/tests/perf_test/stream/stream_speed/benchmark.py new file mode 100644 index 00000000..ddeb2f2f --- /dev/null +++ b/tests/perf_test/stream/stream_speed/benchmark.py @@ -0,0 +1,58 @@ +"""Consume 10K tokens from the stream endpoint and measure the speed.""" + +import logging +import time + +import requests +from tenacity import retry, stop_after_attempt + +logger = logging.getLogger(__name__) +# Configuration +SERVER_URL = "http://0.0.0.0:8000/predict" +TOTAL_TOKENS = 10000 +EXPECTED_TTFT = 0.005 # time to first token + +# tokens per second +MAX_SPEED = 3600 # 3600 on GitHub CI, 10000 on M3 Pro + +session = requests.Session() + + +def speed_test(): + start = time.time() + resp = session.post(SERVER_URL, stream=True, json={"input": 1}) + num_tokens = 0 + ttft = None # time to first token + for line in resp.iter_lines(): + if not line: + continue + if ttft is None: + ttft = time.time() - start + print(f"Time to first token: {ttft}") + assert ttft < EXPECTED_TTFT, "Expected time to first token to be less than 0.1 seconds" + num_tokens += 1 + end = time.time() + resp.raise_for_status() + assert num_tokens == TOTAL_TOKENS, f"Expected {TOTAL_TOKENS} tokens, got {num_tokens}" + speed = num_tokens / (end - start) + return {"speed": speed, "time": end - start} + + +@retry(stop=stop_after_attempt(10)) +def main(): + for i in range(10): + try: + resp = requests.get("http://localhost:8000/health") + if resp.status_code == 200: + break + except requests.exceptions.ConnectionError as e: + logger.error(f"Error connecting to server: {e}") + time.sleep(10) + data = speed_test() + speed = data["speed"] + print(data) + assert speed >= MAX_SPEED, f"Expected streaming speed to be greater than {MAX_SPEED}, got {speed}" + + +if __name__ == "__main__": + main() diff --git a/tests/perf_test/stream/stream_speed/server.py b/tests/perf_test/stream/stream_speed/server.py new file mode 100644 index 00000000..03e5dc80 --- /dev/null +++ b/tests/perf_test/stream/stream_speed/server.py @@ -0,0 +1,25 @@ +import litserve as ls + + +class SimpleStreamingAPI(ls.LitAPI): + def setup(self, device) -> None: + self.model = None + + def decode_request(self, request): + return request["input"] + + def predict(self, x): + yield from range(10000) + + def encode_response(self, output_stream): + for output in output_stream: + yield {"output": output} + + +if __name__ == "__main__": + api = SimpleStreamingAPI() + server = ls.LitServer( + api, + stream=True, + ) + server.run(port=8000, generate_client_file=False)