Skip to content

Commit fbd417f

Browse files
authored
[GuideLLM Refactor] Type fixes (#417)
## Summary Various type fixes with the goal of not breaking anything. --- - [x] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [x] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`)
2 parents c3fdf88 + 48769c2 commit fbd417f

File tree

9 files changed

+81
-67
lines changed

9 files changed

+81
-67
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ module = [
174174
"transformers.*",
175175
"setuptools.*",
176176
"setuptools_git_versioning.*",
177+
"torchcodec.*"
177178
]
178179
ignore_missing_imports = true
179180

src/guidellm/extras/multimodal.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def encode_video(
230230
else:
231231
raise ValueError(f"Unsupported video type: {type(video)} for {video}")
232232

233-
video_base64 = base64.b64encode(video).decode("utf-8")
233+
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
234234

235235
return {
236236
"type": "video_base64",
@@ -266,8 +266,9 @@ def encode_audio(
266266
"audio_samples",
267267
"audio_seconds",
268268
"audio_bytes",
269+
"file_name",
269270
],
270-
str | int | float | None,
271+
str | int | float | bytes | None,
271272
]:
272273
"""Decode audio (if necessary) and re-encode to specified format."""
273274
samples = _decode_audio(audio, sample_rate=sample_rate, max_duration=max_duration)
@@ -338,10 +339,10 @@ def _decode_audio( # noqa: C901, PLR0912
338339

339340
samples: AudioSamples
340341

342+
data: torch.Tensor | bytes
341343
# HF datasets return AudioDecoder for audio column
342344
if isinstance(audio, AudioDecoder):
343345
samples = audio.get_samples_played_in_range(stop_seconds=max_duration)
344-
345346
elif isinstance(audio, torch.Tensor):
346347
# If float stream assume decoded audio
347348
if torch.is_floating_point(audio):

src/guidellm/mock_server/handlers/chat_completions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async def _handle_non_stream(self, req: ChatCompletionsRequest) -> HTTPResponse:
136136

137137
# Token counts
138138
prompt_text = self.tokenizer.apply_chat_template(req.messages)
139-
prompt_tokens = len(self.tokenizer(prompt_text))
139+
prompt_tokens = len(self.tokenizer(prompt_text)) # type: ignore[arg-type]
140140
max_tokens = req.max_completion_tokens or req.max_tokens or math.inf
141141
completion_tokens_count = min(
142142
sample_number(self.config.output_tokens, self.config.output_tokens_std),
@@ -197,7 +197,7 @@ async def generate_stream(stream_response):
197197

198198
# Token counts
199199
prompt_text = self.tokenizer.apply_chat_template(req.messages)
200-
prompt_tokens = len(self.tokenizer(prompt_text))
200+
prompt_tokens = len(self.tokenizer(prompt_text)) # type: ignore[arg-type]
201201
max_tokens = req.max_completion_tokens or req.max_tokens or math.inf
202202
completion_tokens_count = int(
203203
min(

src/guidellm/mock_server/utils.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,15 @@ def __call__(self, text: str | list[str], **kwargs) -> list[int]: # noqa: ARG00
5858
return self.convert_tokens_to_ids(tokens)
5959
elif isinstance(text, list):
6060
# Handle batch processing
61-
return [self.__call__(t) for t in text]
61+
result = []
62+
for t in text:
63+
result.extend(self.__call__(t))
64+
return result
6265
else:
6366
msg = f"text input must be of type `str` or `list[str]`, got {type(text)}"
6467
raise ValueError(msg)
6568

66-
def tokenize(self, text: TextInput, **_kwargs) -> list[str]:
69+
def tokenize(self, text: TextInput, **_kwargs) -> list[str]: # type: ignore[override]
6770
"""
6871
Tokenize input text into a list of token strings.
6972
@@ -76,7 +79,7 @@ def tokenize(self, text: TextInput, **_kwargs) -> list[str]:
7679
# Split text into tokens: words, spaces, and punctuation
7780
return re.findall(r"\w+|[^\w\s]|\s+", text)
7881

79-
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
82+
def convert_tokens_to_ids(self, tokens: str | list[str]) -> list[int]:
8083
"""
8184
Convert token strings to numeric token IDs.
8285
@@ -87,12 +90,12 @@ def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
8790
:return: Single token ID or list of token IDs
8891
"""
8992
if isinstance(tokens, str):
90-
return hash(tokens) % self.VocabSize
93+
return [hash(tokens) % self.VocabSize]
9194
return [hash(token) % self.VocabSize for token in tokens]
9295

93-
def convert_ids_to_tokens(
94-
self, ids: int | list[int], _skip_special_tokens: bool = False
95-
) -> str | list[str]:
96+
def convert_ids_to_tokens( # type: ignore[override]
97+
self, ids: list[int], _skip_special_tokens: bool = False
98+
) -> list[str]:
9699
"""
97100
Convert numeric token IDs back to token strings.
98101
@@ -102,17 +105,9 @@ def convert_ids_to_tokens(
102105
:param ids: Single token ID or list of token IDs to convert
103106
:return: Single token string or list of token strings
104107
"""
105-
if not ids and not isinstance(ids, list):
106-
return ""
107-
elif not ids:
108+
if not ids:
108109
return [""]
109110

110-
if isinstance(ids, int):
111-
fake = Faker()
112-
fake.seed_instance(ids % self.VocabSize)
113-
114-
return fake.word()
115-
116111
fake = Faker()
117112
fake.seed_instance(sum(ids) % self.VocabSize)
118113

@@ -162,7 +157,7 @@ def _add_tokens(
162157
"""
163158
return 0
164159

165-
def apply_chat_template(
160+
def apply_chat_template( # type: ignore[override]
166161
self,
167162
conversation: list,
168163
tokenize: bool = False, # Changed default to False to match transformers
@@ -193,7 +188,7 @@ def apply_chat_template(
193188
return self.convert_tokens_to_ids(self.tokenize(formatted_text))
194189
return formatted_text
195190

196-
def decode(
191+
def decode( # type: ignore[override]
197192
self,
198193
token_ids: list[int],
199194
skip_special_tokens: bool = True,
@@ -255,7 +250,7 @@ def create_fake_tokens_str(
255250
fake = Faker()
256251
fake.seed_instance(seed)
257252

258-
tokens = []
253+
tokens: list[str] = []
259254

260255
while len(tokens) < num_tokens:
261256
text = fake.text(

src/guidellm/presentation/data_models.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,25 +117,23 @@ def from_benchmarks(cls, benchmarks: list["GenerativeBenchmark"]):
117117
range(len(successful_requests)), min(5, len(successful_requests))
118118
)
119119
sample_prompts = [
120-
successful_requests[i].request_args.replace("\n", " ").replace('"', "'")
121-
if successful_requests[i].request_args is not None
122-
else ""
120+
req.request_args.replace("\n", " ").replace('"', "'")
121+
if (req := successful_requests[i]).request_args else ""
123122
for i in sample_indices
124123
]
125124
sample_outputs = [
126-
successful_requests[i].output.replace("\n", " ").replace('"', "'")
127-
if successful_requests[i].output is not None
128-
else ""
125+
req.output.replace("\n", " ").replace('"', "'")
126+
if (req := successful_requests[i]).output else ""
129127
for i in sample_indices
130128
]
131129

132130
prompt_tokens = [
133-
float(req.prompt_tokens)
131+
float(req.prompt_tokens) if req.prompt_tokens is not None else -1
134132
for bm in benchmarks
135133
for req in bm.requests.successful
136134
]
137135
output_tokens = [
138-
float(req.output_tokens)
136+
float(req.output_tokens) if req.output_tokens is not None else -1
139137
for bm in benchmarks
140138
for req in bm.requests.successful
141139
]

src/guidellm/utils/encoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
HAS_MSGSPEC = True
3434
except ImportError:
35-
MsgspecDecoder = MsgspecEncoder = None
35+
MsgspecDecoder = MsgspecEncoder = None # type: ignore[misc, assignment] # HAS_MSGSPEC will be checked at runtime
3636
HAS_MSGSPEC = False
3737

3838

src/guidellm/utils/imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
try:
44
import orjson as json
55
except ImportError:
6-
import json
6+
import json # type: ignore[no-redef] # Done only after a failure.
77

88

99
__all__ = ["json"]

src/guidellm/utils/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class TokenProposal(RegistryMixin):
6565
:cvar registry_populated: Track whether auto-discovery has completed
6666
"""
6767

68-
registry: ClassVar[dict[str, RegistryObjT] | None] = None
68+
registry: ClassVar[dict[str, RegistryObjT] | None] = None # type: ignore[misc]
6969
registry_auto_discovery: ClassVar[bool] = False
7070
registry_populated: ClassVar[bool] = False
7171

src/guidellm/utils/statistics.py

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -283,40 +283,12 @@ def from_request_times(
283283
)
284284

285285
# First convert to timing events based on type
286-
events: list[tuple[float, float]] = []
287-
288-
if distribution_type == "concurrency":
289-
# For concurrency, each request adds to concurrency at start
290-
# and subtracts at end
291-
for (start, end), weight in zip(requests, weights, strict=False):
292-
events.append((start, weight))
293-
events.append((end, -1 * weight))
294-
elif distribution_type == "rate":
295-
# For rate, each request is added at the end time only
296-
global_start = min(start for start, _ in requests) if requests else 0.0
297-
events.append((global_start, 0.0))
298-
for (_, end), weight in zip(requests, weights, strict=False):
299-
events.append((end, weight))
300-
else:
301-
raise ValueError(
302-
f"Invalid distribution_type '{distribution_type}'. "
303-
"Must be 'concurrency' or 'rate'."
304-
)
305-
306-
# Combine any events within epsilon of each other for stability
307-
sorted_events = sorted(events, key=lambda event: event[0])
308-
flattened_events: list[tuple[float, float]] = (
309-
[sorted_events.pop(0)] if sorted_events else []
286+
events = DistributionSummary._convert_to_timing_events(
287+
requests, distribution_type, weights
310288
)
311-
last_time = flattened_events[0][0] if flattened_events else 0.0
312289

313-
for time, val in sorted_events:
314-
if abs(time - last_time) <= epsilon:
315-
last_val = flattened_events[-1][1]
316-
flattened_events[-1] = (last_time, last_val + val)
317-
else:
318-
last_time = time
319-
flattened_events.append((time, val))
290+
# Combine any events within epsilon of each other for stability
291+
flattened_events = DistributionSummary._combine_events(events, epsilon)
320292

321293
# Convert events to value distribution function
322294
distribution: dict[float, float] = defaultdict(float)
@@ -357,6 +329,53 @@ def from_request_times(
357329
include_cdf=include_cdf,
358330
)
359331

332+
@staticmethod
333+
def _convert_to_timing_events(
334+
requests: list[tuple[float, float]],
335+
distribution_type: Literal["concurrency", "rate"],
336+
weights: list[float],
337+
) -> list[tuple[float, float]]:
338+
events: list[tuple[float, float]] = []
339+
340+
if distribution_type == "concurrency":
341+
# For concurrency, each request adds to concurrency at start
342+
# and subtracts at end
343+
for (start, end), weight in zip(requests, weights, strict=False):
344+
events.append((start, weight))
345+
events.append((end, -1 * weight))
346+
elif distribution_type == "rate":
347+
# For rate, each request is added at the end time only
348+
global_start = min(start for start, _ in requests) if requests else 0.0
349+
events.append((global_start, 0.0))
350+
for (_, end), weight in zip(requests, weights, strict=False):
351+
events.append((end, weight))
352+
else:
353+
raise ValueError(
354+
f"Invalid distribution_type '{distribution_type}'. "
355+
"Must be 'concurrency' or 'rate'."
356+
)
357+
return events
358+
359+
@staticmethod
360+
def _combine_events(
361+
events: list[tuple[float, float]],
362+
epsilon: float,
363+
) -> list[tuple[float, float]]:
364+
sorted_events = sorted(events, key=lambda event: event[0])
365+
flattened_events: list[tuple[float, float]] = (
366+
[sorted_events.pop(0)] if sorted_events else []
367+
)
368+
last_time = flattened_events[0][0] if flattened_events else 0.0
369+
370+
for time, val in sorted_events:
371+
if abs(time - last_time) <= epsilon:
372+
last_val = flattened_events[-1][1]
373+
flattened_events[-1] = (last_time, last_val + val)
374+
else:
375+
last_time = time
376+
flattened_events.append((time, val))
377+
return flattened_events
378+
360379
@staticmethod
361380
def from_iterable_request_times(
362381
requests: list[tuple[float, float]],

0 commit comments

Comments
 (0)