From a05b7f0b17eff91bffde6d412651a03ba1403a7a Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 18 Sep 2024 14:14:11 +0530 Subject: [PATCH] minor LitServer code clean up (#279) * clean up * update * update --- src/litserve/server.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index d4e76000..e1716106 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -196,7 +196,7 @@ def __init__( device_list = range(devices) self.devices = [self.device_identifiers(accelerator, device) for device in device_list] - self.workers = self.devices * self.workers_per_device + self.inference_workers = self.devices * self.workers_per_device self.setup_server() def launch_inference_worker(self, num_uvicorn_servers: int): @@ -204,23 +204,17 @@ def launch_inference_worker(self, num_uvicorn_servers: int): self.workers_setup_status = manager.dict() self.request_queue = manager.Queue() - self.response_queues = [] - for _ in range(num_uvicorn_servers): - response_queue = manager.Queue() - self.response_queues.append(response_queue) + self.response_queues = [manager.Queue() for _ in range(num_uvicorn_servers)] for spec in self._specs: # Objects of Server class are referenced (not copied) logging.debug(f"shallow copy for Server is created for for spec {spec}") server_copy = copy.copy(self) del server_copy.app - try: - spec.setup(server_copy) - except Exception as e: - raise e + spec.setup(server_copy) process_list = [] - for worker_id, device in enumerate(self.devices * self.workers_per_device): + for worker_id, device in enumerate(self.inference_workers): if len(device) == 1: device = device[0] @@ -258,7 +252,7 @@ async def lifespan(self, app: FastAPI): ) response_queue = self.response_queues[app.response_queue_id] - response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) + response_executor = ThreadPoolExecutor(max_workers=len(self.inference_workers)) future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) task = loop.create_task(future) @@ -405,7 +399,7 @@ def run( **kwargs, ): if generate_client_file: - self.generate_client_file() + LitServer.generate_client_file() port_msg = f"port must be a value from 1024 to 65535 but got {port}" try: @@ -420,13 +414,15 @@ def run( sockets = [config.bind_socket()] if num_api_servers is None: - num_api_servers = len(self.workers) + num_api_servers = len(self.inference_workers) if num_api_servers < 1: raise ValueError("num_api_servers must be greater than 0") if sys.platform == "win32": - print("Windows does not support forking. Using threads api_server_worker_type will be set to 'thread'") + warnings.warn( + "Windows does not support forking. Using threads" " api_server_worker_type will be set to 'thread'" + ) api_server_worker_type = "thread" elif api_server_worker_type is None: api_server_worker_type = "process"