Skip to content

Commit

Permalink
Add Locust testing scripts and configurations (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
milank94 authored Dec 2, 2024
1 parent 7712e0e commit 8054abb
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 1 deletion.
40 changes: 39 additions & 1 deletion locust/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,39 @@
# Locust server stress testing
# Locust Load Testing

## Setup Environment

In your Python environment with model dependencies installed, install the Locust utility:

```bash
pip install -r requirements.txt
```

## Load Test

Run a test using a dataset of variable input prompts generated by [generate_prompts](./../utils/prompt_generation.py#250).

Command to run the load test:

```bash
locust --config locust_config.conf
```

## Test Configuration

You can configure Locust tests using `.conf` files. The key configurations to modify are:

- **`users`**: Maximum number of concurrent users.
- **`spawn-rate`**: Number of users spawned per second.
- **`run-time`**: Total duration of the test (e.g., `300s`, `20m`, `3h`, `1h30m`, etc.).

For more details, see the [Locust configuration guide](https://docs.locust.io/en/2.25.0/configuration.html).

### Example

To run a test with 32 users, all launched simultaneously, for a duration of 3 minutes, set the parameters:

```bash
users = 32
spawn-rate = 32
run-time = 3m
```
49 changes: 49 additions & 0 deletions locust/data_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

import os
import sys
from typing import List, Union

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from types import SimpleNamespace

from utils.prompt_generation import generate_prompts


class DataReader:
"""
Reads sample data from a dataset file using an iterator.
"""
def __init__(self) -> None:
# Create custom args
self.args = SimpleNamespace(
tokenizer_model="meta-llama/Llama-3.1-70B-Instruct",
dataset="random",
max_prompt_length=128,
input_seq_len=128,
num_prompts=32,
distribution="fixed",
template=None,
save_path=None,
)

# Generate prompts
self.prompts, self.prompt_lengths = generate_prompts(self.args)

# Initialize data iterator
self.data = iter(self.prompts)

def __iter__(self):
"""Allow DataReader to be used as an iterator."""
return self

def __next__(self) -> Union[str, List[str]]:
"""Return the next prompt from the dataset. Reset iterator when exhausted."""
try:
return next(self.data)
except StopIteration:
# Reset the iterator if all data has been consumed
self.data = iter(self.prompts)
return next(self.data)
6 changes: 6 additions & 0 deletions locust/locust_config.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
locustfile = locustfile.py
headless = true
host = http://localhost:7000
users = 32
spawn-rate = 1
run-time = 3m
67 changes: 67 additions & 0 deletions locust/locustfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

import json
import os

import jwt

from data_reader import DataReader
from locust import FastHttpUser, events, task

# Constants for timeouts and API configuration
NETWORK_TIMEOUT = 300.0
CONNECTION_TIMEOUT = 300.0
API_ENDPOINT = "/v1/completions"
DEFAULT_PARAMS = {
"model": "meta-llama/Llama-3.1-70B-Instruct",
"temperature": 1.0,
"top_k": 10,
"top_p": 0.9,
}

# Global variable to store data iterator
data_iter = None

def get_authorization():
authorization = os.getenv("AUTHORIZATION", None)
if authorization is None:
jwt_secret = os.getenv("JWT_SECRET", None)
if jwt_secret is None:
raise ValueError(
"Neither AUTHORIZATION or JWT_SECRET environment variables are set."
)
json_payload = json.loads('{"team_id": "tenstorrent", "token_id":"debug-test"}')
encoded_jwt = jwt.encode(json_payload, jwt_secret, algorithm="HS256")
authorization = f"{encoded_jwt}"
return authorization


# Event listener to load custom data before tests start
@events.test_start.add_listener
def load_custom_data(**kwargs):
global data_iter
data_iter = DataReader()


class ServeUser(FastHttpUser):
# Set test parameters
network_timeout = NETWORK_TIMEOUT
connection_timeout = CONNECTION_TIMEOUT
headers = {"Authorization": f"Bearer {get_authorization()}"}

def post_request(self, prompt: str, max_tokens: int):
"""Helper method to send a POST request to the API with the given prompt and token limit."""
json_data = {
"prompt": prompt,
**DEFAULT_PARAMS, # Merge default parameters
"max_tokens": max_tokens,
}
response = self.client.post(API_ENDPOINT, json=json_data, headers=self.headers)
return response

@task
def dataset_test(self):
"""Test using generated prompts from a data iterator."""
prompt = next(data_iter)
self.post_request(prompt, max_tokens=128)
1 change: 1 addition & 0 deletions locust/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
locust==2.25.0

0 comments on commit 8054abb

Please sign in to comment.