Skip to content

Commit

Permalink
minor LitServer code clean up (#279)
Browse files Browse the repository at this point in the history
* clean up

* update

* update
  • Loading branch information
aniketmaurya authored Sep 18, 2024
1 parent f475369 commit a05b7f0
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,31 +196,25 @@ def __init__(
device_list = range(devices)
self.devices = [self.device_identifiers(accelerator, device) for device in device_list]

self.workers = self.devices * self.workers_per_device
self.inference_workers = self.devices * self.workers_per_device
self.setup_server()

def launch_inference_worker(self, num_uvicorn_servers: int):
manager = mp.Manager()
self.workers_setup_status = manager.dict()
self.request_queue = manager.Queue()

self.response_queues = []
for _ in range(num_uvicorn_servers):
response_queue = manager.Queue()
self.response_queues.append(response_queue)
self.response_queues = [manager.Queue() for _ in range(num_uvicorn_servers)]

for spec in self._specs:
# Objects of Server class are referenced (not copied)
logging.debug(f"shallow copy for Server is created for for spec {spec}")
server_copy = copy.copy(self)
del server_copy.app
try:
spec.setup(server_copy)
except Exception as e:
raise e
spec.setup(server_copy)

process_list = []
for worker_id, device in enumerate(self.devices * self.workers_per_device):
for worker_id, device in enumerate(self.inference_workers):
if len(device) == 1:
device = device[0]

Expand Down Expand Up @@ -258,7 +252,7 @@ async def lifespan(self, app: FastAPI):
)

response_queue = self.response_queues[app.response_queue_id]
response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device))
response_executor = ThreadPoolExecutor(max_workers=len(self.inference_workers))
future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor)
task = loop.create_task(future)

Expand Down Expand Up @@ -405,7 +399,7 @@ def run(
**kwargs,
):
if generate_client_file:
self.generate_client_file()
LitServer.generate_client_file()

port_msg = f"port must be a value from 1024 to 65535 but got {port}"
try:
Expand All @@ -420,13 +414,15 @@ def run(
sockets = [config.bind_socket()]

if num_api_servers is None:
num_api_servers = len(self.workers)
num_api_servers = len(self.inference_workers)

if num_api_servers < 1:
raise ValueError("num_api_servers must be greater than 0")

if sys.platform == "win32":
print("Windows does not support forking. Using threads api_server_worker_type will be set to 'thread'")
warnings.warn(
"Windows does not support forking. Using threads" " api_server_worker_type will be set to 'thread'"
)
api_server_worker_type = "thread"
elif api_server_worker_type is None:
api_server_worker_type = "process"
Expand Down

0 comments on commit a05b7f0

Please sign in to comment.