Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed: load pickle files asynchronously in bento service #538

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions konfuzio_sdk/bento/extraction/rfextractionai_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,38 @@ class ExtractionService:
model_ref = bentoml.models.get(ai_model_name)

def __init__(self):
"""Load the extraction model into memory."""
self.extraction_model = bentoml.picklable_model.load_model(self.model_ref)
"""Initialize the extraction service."""
print(f'Initializing service for model {self.model_ref}')
self.extraction_model = None
self.executor = ThreadPoolExecutor()
self.model_load_task = asyncio.create_task(self.load_model())

async def load_model(self):
"""Asynchronously load the extraction model into memory using the executor."""
print(f'Loading model {self.model_ref}')
loop = asyncio.get_event_loop()
self.extraction_model = await loop.run_in_executor(
self.executor, bentoml.picklable_model.load_model, self.model_ref
)
print(f'Model {self.model_ref} loaded')

async def get_model(self):
"""Ensure the model is loaded before returning it."""
await self.model_load_task
if self.extraction_model is None:
raise RuntimeError('Model failed to load')
return self.extraction_model

@bentoml.api(input_spec=ExtractRequest20240117)
@handle_exceptions
async def extract(self, ctx: bentoml.Context, **request: t.Any) -> ExtractResponse20240117:
"""Send a call to the Extraction AI and process the response."""
# Even though the request is already validated against the pydantic schema, we need to get it back as an
# instance of the pydantic model to be able to pass it to the prepare_request function.
# Ensure the model is loaded
extraction_model = await self.get_model()

# The rest of the method remains the same
request = ExtractRequest20240117(**request)
project = self.extraction_model.project
project = extraction_model.project
# Add credentials from the request headers to the Project object, but only if the SDK version supports this.
# Older SDK versions do not have the credentials attribute on Project.
if hasattr(project, 'credentials'):
Expand All @@ -47,10 +67,10 @@ async def extract(self, ctx: bentoml.Context, **request: t.Any) -> ExtractRespon
document = prepare_request(
request=request,
project=project,
konfuzio_sdk_version=getattr(self.extraction_model, 'konfuzio_sdk_version', None),
konfuzio_sdk_version=getattr(extraction_model, 'konfuzio_sdk_version', None),
)
# Run the extraction in a separate thread, otherwise the API server will block
result = await asyncio.get_event_loop().run_in_executor(self.executor, self.extraction_model.extract, document)
result = await asyncio.get_event_loop().run_in_executor(self.executor, extraction_model.extract, document)
annotations_result = process_response(result)
# Remove the Document and its copies from the Project to avoid memory leaks
project._documents = [d for d in project._documents if d.id_ != document.id_ and d.copy_of_id != document.id_]
Expand Down
Loading