Skip to content

Commit 6084d41

Browse files
tjohnson31415z103cb
authored andcommitted
format: make mypy happy (#24)
`format.sh` now has mypy checks after pulling in upstream changes. This PR makes the mypy suggested modifications to our code. --------- Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
1 parent 6100f4b commit 6084d41

File tree

4 files changed

+20
-21
lines changed

4 files changed

+20
-21
lines changed

vllm/entrypoints/grpc/grpc_server.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm import (AsyncLLMEngine, CompletionOutput, RequestOutput,
1515
SamplingParams)
1616
from vllm.config import ModelConfig
17-
from vllm.entrypoints.grpc.pb import generation_pb2_grpc
17+
from vllm.entrypoints.grpc.pb import generation_pb2_grpc # type: ignore
1818
# yapf: disable
1919
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
2020
BatchedGenerationResponse,
@@ -54,15 +54,15 @@ async def _handle_exception(e: Exception, func, *args, **kwargs):
5454
if not isinstance(e, AbortError):
5555
if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check
5656
context = kwargs.get("context", None) or args[-1]
57-
logger.exception(f"{func.__name__} caused GPU OOM error")
57+
logger.exception("%s caused GPU OOM error", func.__name__)
5858
service_metrics.count_request_failure(FailureReasonLabel.OOM)
5959
await context.abort(StatusCode.RESOURCE_EXHAUSTED, str(e))
6060
else:
6161
if "generate" in func.__name__.lower():
6262
service_metrics.count_request_failure(FailureReasonLabel.GENERATE)
6363
else:
6464
service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN)
65-
logger.exception(f"{func.__name__} failed")
65+
logger.exception("%s failed", func.__name__)
6666
raise e
6767

6868

@@ -298,7 +298,7 @@ def _convert_output(self,
298298
text=output.text[text_start_offset:],
299299
generated_token_count=len(output.token_ids),
300300
stop_reason=stop_reason,
301-
stop_sequence=stop_sequence,
301+
stop_sequence=stop_sequence if stop_sequence else '',
302302
)
303303

304304
if resp_options.generated_tokens:
@@ -416,7 +416,8 @@ async def _validate_and_convert_params(
416416

417417
@staticmethod
418418
def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
419-
time_limit_reached: bool) -> Tuple['StopReason', str]:
419+
time_limit_reached: bool
420+
) -> Tuple[StopReason.ValueType, Optional[str]]:
420421
finish_reason = output.finish_reason
421422
stop_sequence = None
422423
if finish_reason is None:
@@ -436,20 +437,20 @@ def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
436437
stop_sequence = stop_str_or_tok
437438
else:
438439
logger.warning(
439-
f"Unexpected stop_reason type: {type(stop_str_or_tok)}"
440+
"Unexpected stop_reason type: %s", type(stop_str_or_tok)
440441
)
441442
elif finish_reason == "abort":
442443
stop_reason = StopReason.CANCELLED
443444
else:
444-
logger.warning(f"Unrecognized finish_reason: {finish_reason}")
445+
logger.warning("Unrecognized finish_reason: %s", finish_reason)
445446
stop_reason = StopReason.CANCELLED
446447

447448
return stop_reason, stop_sequence
448449

449450
def _convert_tokens(
450451
self,
451-
token_ids: list[int],
452-
logprobs_list: Optional[list[Dict[int, Logprob]]],
452+
token_ids: List[int],
453+
logprobs_list: Optional[List[Dict[int, Logprob]]],
453454
include_logprobs: bool,
454455
include_ranks: bool,
455456
top_n_tokens: int,
@@ -502,7 +503,7 @@ async def _validate_prompt_and_tokenize(
502503
# "max_length": truncate_input_tokens} \
503504
# if truncate_input_tokens is not None else {
504505
# "truncation": True, "max_length": max_model_len + 1}
505-
tokenize_kwargs = {}
506+
tokenize_kwargs: Dict[str, Any] = {}
506507

507508
input_ids = await self.tokenizer_group.encode_async(
508509
prompt, **tokenize_kwargs)
@@ -664,6 +665,6 @@ async def start_grpc_server(engine: AsyncLLMEngine,
664665
server.add_insecure_port(listen_on)
665666

666667
await server.start()
667-
logger.info(f"gRPC Server started at {listen_on}")
668+
logger.info("gRPC Server started at %s", listen_on)
668669

669670
return server

vllm/entrypoints/openai/api_server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import argparse
21
import asyncio
32
import importlib
43
import inspect

vllm/tgis_utils/args.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,10 @@ def postprocess_tgis_args(args: argparse.Namespace) -> argparse.Namespace:
129129
if args.max_batch_size is not None:
130130
# Existing MAX_BATCH_SIZE settings in TGIS configs may not necessarily
131131
# be best for vLLM so we'll just log a warning for now
132-
logger.warn(
133-
f"max_batch_size is set to {args.max_batch_size} but will be "
134-
f"ignored for now. max_num_seqs can be used if this is still "
135-
f"needed.")
132+
logger.warning(
133+
"max_batch_size is set to %d but will be ignored for now."
134+
"max_num_seqs can be used if this is still needed.",
135+
args.max_batch_size)
136136

137137
if args.tls_cert_path:
138138
args.ssl_certfile = args.tls_cert_path

vllm/tgis_utils/logs.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,10 @@ def log_response(inputs: List[str], params: Parameters, prefix_id: str,
4545
level = logging.WARN
4646
else:
4747
level = logging.INFO
48-
logger.log(
49-
level, f"{span_str}: {kind_log} generated "
50-
f"{response.generated_token_count} tokens before "
51-
f"{stop_reason_str}, output {output_len} chars: "
52-
f"{short_output}")
48+
logger.log(level,
49+
"%s: %s generated %d tokens before %s, output %d chars: %s",
50+
span_str, kind_log, response.generated_token_count,
51+
stop_reason_str, output_len, short_output)
5352

5453

5554
def _truncate(text: str, len_: int) -> bytes:

0 commit comments

Comments
 (0)