Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions README-generate-answers.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The default configuration file is [./src/generate_answers/eval_config.yaml](/src
lightspeed-core service configuration. `display_name` is a nice short model name.
- `models_to_evaluate` -- list of model names (`display_name`) for answers generation.

Example:
```yaml
lightspeed_url: "http://localhost:8080"
models:
Expand All @@ -35,7 +36,7 @@ models_to_evaluate:
## Running
`pdm run generate_answers -h`

```shell
```
Usage: generate_answers [OPTIONS]

Generate answers from LLMs by connection to LightSpeed core service.
Expand All @@ -54,4 +55,10 @@ Options:
-f, --force-overwrite Overwrite the output file if it exists
-v, --verbose Increase the logging level to DEBUG
-h, --help Show this message and exit.
```
```

## Results
The results are stored in dataframe in JSON format. The file can be read by `pandas.read_json`.
The columns are:
- `id`, `question` -- from the input file
- `<model_name>_answers` -- for each configured model
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"road-core @ git+https://github.com/road-core/service.git",
"matplotlib>=3.10.1",
"ragas>=0.2.15",
"tenacity>=9.1.2",
]
requires-python = ">=3.11.1,<=3.12.8"
readme = "README.md"
Expand Down
5 changes: 4 additions & 1 deletion src/generate_answers/generate_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ def main( # pylint: disable=R0913,R0917,R0914

qna_df = read_func(input_filename)

# Remove empty questions
qna_df = qna_df[qna_df[_QUESTION_COL].notna() & (qna_df[_QUESTION_COL] != "")]

# Generate the answers
# Parallelize this? pytorch Dataset?
for model, ls_client in evaluators:
Expand All @@ -177,7 +180,7 @@ def main( # pylint: disable=R0913,R0917,R0914

generate_answer_func = partial(ls_client.get_answer, skip_cache=False)

qna_df[output_column] = qna_df[_QUESTION_COL].progress_apply(
qna_df[output_column] = qna_df[_QUESTION_COL].progress_apply( # type: ignore
generate_answer_func
)

Expand Down
8 changes: 7 additions & 1 deletion src/generate_answers/ls_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from diskcache import Cache
from httpx import Client
from tenacity import retry, stop_after_attempt, wait_exponential

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -45,6 +46,11 @@ def _get_cached_answer(self, query: str) -> str | None:
key = self._get_cache_key(query)
return cast(str | None, self.cache.get(key))

# Wait 2^x * 1 second between each retry starting with 4 seconds,
# then up to 100 seconds, then 100 seconds afterwards
@retry(
stop=stop_after_attempt(10), wait=wait_exponential(multiplier=1, min=4, max=100)
)
def get_answer(self, query: str, skip_cache: bool = False) -> str:
"""Get LLM answer for query."""
if not skip_cache:
Expand All @@ -69,7 +75,7 @@ def get_answer(self, query: str, skip_cache: bool = False) -> str:
"Status: %d, query='%s', response='%s'",
response.status_code,
query,
json.dumps(response.json()),
response.text,
)
raise RuntimeError(response)

Expand Down
Loading