From 4b39bac25b185573a650f2335a99030cd4fdde8d Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 18 Sep 2024 23:13:20 +0530 Subject: [PATCH] apply feedback --- src/litserve/callbacks/base.py | 12 ++++++------ src/litserve/server.py | 5 ++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/litserve/callbacks/base.py b/src/litserve/callbacks/base.py index 685f57a9..df3fb368 100644 --- a/src/litserve/callbacks/base.py +++ b/src/litserve/callbacks/base.py @@ -16,8 +16,8 @@ class EventTypes: AFTER_ENCODE_RESPONSE = "on_after_encode_response" BEFORE_PREDICT = "on_before_predict" AFTER_PREDICT = "on_after_predict" - BEFORE_SERVER_REGISTER = "on_before_server_register" - AFTER_SERVER_REGISTER = "on_after_server_register" + ON_SERVER_START = "on_server_start" + ON_SERVER_END = "on_serve_end" class Callback(ABC): @@ -45,11 +45,11 @@ def on_before_predict(self, *args, **kwargs): def on_after_predict(self, *args, **kwargs): """Called after prediction is completed.""" - def on_before_server_register(self, *args, **kwargs): - """Called before LitServer endpoint setup is started.""" + def on_server_start(self, *args, **kwargs): + """Called before server starts.""" - def on_after_server_register(self, *args, **kwargs): - """Called after LitServer endpoint setup is completed.""" + def on_server_end(self, *args, **kwargs): + """Called when server terminates.""" class CallbackRunner: diff --git a/src/litserve/server.py b/src/litserve/server.py index e129832f..289b720a 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -262,6 +262,7 @@ async def lifespan(self, app: FastAPI): yield + self._callback_runner.trigger_event(EventTypes.ON_SERVER_END, litserver=self) task.cancel() logger.debug("Shutting down response queue to buffer task") @@ -294,7 +295,7 @@ async def data_streamer(self, q: deque, data_available: asyncio.Event, send_stat def register_endpoints(self): """Register endpoint routes for the FastAPI app and setup middlewares.""" - self._callback_runner.trigger_event(EventTypes.BEFORE_SERVER_REGISTER, litserver=self) + self._callback_runner.trigger_event(EventTypes.ON_SERVER_START, litserver=self) workers_ready = False @self.app.get("/", dependencies=[Depends(self.setup_auth())]) @@ -380,8 +381,6 @@ async def stream_predict(request: self.request_type) -> self.response_type: elif callable(middleware): self.app.add_middleware(middleware) - self._callback_runner.trigger_event(EventTypes.AFTER_SERVER_REGISTER, litserver=self) - @staticmethod def generate_client_file(): src_path = os.path.join(os.path.dirname(__file__), "python_client.py")