Skip to content

Commit d936cec

Browse files
committed
Added documentation and refactored process_document and request_chunk_summary functions.
1 parent 4425e99 commit d936cec

File tree

3 files changed

+108
-38
lines changed

3 files changed

+108
-38
lines changed

tools/document_summarizer/main.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66
from utils import print_configs, validate_arguments
77

88

9-
# The endpoint we send our requests to (the server is hosted locally with LM studio)
10-
llm_endpoint = "http://127.0.0.1:1234/v1/chat/completions"
11-
12-
139
def main():
1410
# Parse command-line arguments
1511
parser = argparse.ArgumentParser(

tools/document_summarizer/summarize.py

Lines changed: 89 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,45 @@
11
from io import TextIOWrapper
22
import json
3+
import logging
34
import os
45

56
import requests
67

78
from prompt import PromptChain, PromptChainBuilder
8-
from chunk_document import DocumentChunkSummary, chunk_content, get_token_length
9+
from chunk_document import DocumentChunkSummary, chunk_content
910
from config import SummarizerConfig, SummaryOutputHandler
10-
from utils import Timer, print_summary_outcome
11+
from utils import Timer, calculate_summary_statistics, print_summary_outcome
12+
13+
14+
LLM_API_ENDPOINT = "http://127.0.0.1:1234/v1/chat/completions"
15+
16+
17+
def _build_request_body(
18+
prompt_chain: PromptChain,
19+
model: str,
20+
temperature: float,
21+
max_new_tokens: int,
22+
) -> dict:
23+
"""
24+
Creates JSON request body that will be sent to the LLM api.
25+
The schema defined below is specifically tailored to an LM studio server.
26+
"""
27+
return {
28+
"model": model,
29+
"messages": [
30+
{"role": "system", "content": prompt_chain.system_prompt},
31+
{"role": "user", "content": prompt_chain.user_prompt},
32+
],
33+
"temperature": temperature,
34+
"max_tokens": max_new_tokens,
35+
"stream": False, # Use boolean instead of string
36+
}
37+
38+
39+
class LLMAPIError(Exception):
40+
"""Exception for LLM API request related errors"""
41+
42+
pass
1143

1244

1345
def request_chunk_summary(
@@ -16,35 +48,62 @@ def request_chunk_summary(
1648
temperature: float = 0.25,
1749
max_new_tokens: int = -1,
1850
) -> "DocumentChunkSummary":
19-
"""Request a summary for a `DocumentChunk`"""
20-
response = requests.post(
21-
"http://127.0.0.1:1234/v1/chat/completions",
22-
headers={"Content-Type": "application/json"},
23-
data=json.dumps(
24-
{
25-
"model": model,
26-
"messages": [
27-
{"role": "system", "content": prompt_chain.system_prompt},
28-
{"role": "user", "content": prompt_chain.user_prompt},
29-
],
30-
"temperature": temperature,
31-
"max_tokens": max_new_tokens,
32-
"stream": "false",
33-
}
34-
),
35-
).json()
36-
37-
return DocumentChunkSummary(
38-
content=response["choices"][0]["message"]["content"],
39-
original=prompt_chain.original_chunk_content,
40-
)
51+
"""
52+
Request a summary for a DocumentChunk from the LLM API endpoint.
53+
54+
Args:
55+
`prompt_chain`: PromptChain object containing system and user prompts
56+
`model`: Name of the model to use
57+
`temperature`: Float between 0 and 1 controlling randomness
58+
`max_new_tokens`: Maximum number of tokens to generate (-1 for unlimited)
59+
60+
Returns:
61+
DocumentChunkSummary containing the generated summary
62+
63+
Raises:
64+
LLMAPIError: If API request fails
65+
"""
66+
67+
request_body = _build_request_body(prompt_chain, model, temperature, max_new_tokens)
68+
69+
try:
70+
response = requests.post(
71+
LLM_API_ENDPOINT,
72+
headers={"Content-Type": "application/json"},
73+
json=request_body,
74+
)
75+
response.raise_for_status()
76+
77+
response_data = response.json()
78+
summary = response_data["choices"][0]["message"]["content"]
79+
80+
return DocumentChunkSummary(
81+
content=summary,
82+
original=prompt_chain.original_chunk_content,
83+
)
84+
85+
except Exception as e:
86+
logging.error(f"Failed to generate summary: {str(e)}")
87+
raise LLMAPIError(f"Failed to generate summary: {str(e)}")
4188

4289

4390
def process_document(
4491
document: TextIOWrapper,
4592
summarizer_config: SummarizerConfig,
4693
summary_output_handler: SummaryOutputHandler,
4794
):
95+
"""
96+
Process a document by chunking it and generating summaries.
97+
98+
Args:
99+
`document`: The input document to process
100+
`summarizer_config`: Configuration for the summarization process
101+
`summary_output_handler`: Handler for saving summaries
102+
103+
Raises:
104+
`ValueError`: If document is invalid
105+
`RuntimeError`: If processing fails
106+
"""
48107
file_name_without_ext = os.path.splitext(os.path.basename(document.name))[0]
49108

50109
with Timer(
@@ -94,19 +153,15 @@ def process_document(
94153
summaries.append(chunk_summary)
95154

96155
# Compute token and char lengths of the original and summaries, then print them out
97-
total_token_len_chunks = sum(get_token_length(chunk.content) for chunk in chunks)
98-
total_token_len_summaries = sum(
99-
get_token_length(summary.content) for summary in summaries
100-
)
101-
total_char_len_chunks = sum(len(chunk) for chunk in chunks)
102-
total_char_len_summaries = sum(len(summary) for summary in summaries)
156+
stats = calculate_summary_statistics(chunks, summaries)
157+
103158
print("")
104159
print_summary_outcome(
105160
file_name_without_ext,
106-
total_token_len_chunks,
107-
total_token_len_summaries,
108-
total_char_len_chunks,
109-
total_char_len_summaries,
161+
stats["token_len_chunks"],
162+
stats["token_len_summaries"],
163+
stats["char_len_chunks"],
164+
stats["char_len_summaries"],
110165
)
111166

112167
summary_output_handler.save(file_name_without_ext, summaries, chunks)

tools/document_summarizer/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from typing import Any, Dict, List, Optional
66

77
from config import SummarizerConfig, SummaryOutputConfig
8+
from chunk_document import (
9+
DocumentChunk,
10+
DocumentChunkSummary,
11+
get_token_length,
12+
)
813

914

1015
def validate_arguments(args, parser):
@@ -65,6 +70,20 @@ def print_summary_outcome(
6570
)
6671

6772

73+
# Used to log statistics of a document summary
74+
def calculate_summary_statistics(
75+
chunks: List["DocumentChunk"], summaries: List["DocumentChunkSummary"]
76+
):
77+
return {
78+
"token_len_chunks": sum(get_token_length(chunk.content) for chunk in chunks),
79+
"token_len_summaries": sum(
80+
get_token_length(summary.content) for summary in summaries
81+
),
82+
"char_len_chunks": sum(len(chunk) for chunk in chunks),
83+
"char_len_summaries": sum(len(summary) for summary in summaries),
84+
}
85+
86+
6887
@dataclass
6988
class Timer:
7089
task_label: str

0 commit comments

Comments
 (0)