Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion nilai-api/src/nilai_api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

def get_user(credentials: HTTPAuthorizationCredentials = Security(bearer_scheme)):
token = credentials.credentials
print(token)
user = UserManager.check_api_key(token)
if user:
return user
Expand Down
4 changes: 4 additions & 0 deletions nilai-api/src/nilai_api/routers/private.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Fast API and serving
import logging
import os
import asyncio
from base64 import b64encode
from typing import AsyncGenerator, Union

Expand Down Expand Up @@ -168,6 +169,9 @@ async def stream_response() -> AsyncGenerator[str, None]:
async for chunk in response.aiter_lines():
if chunk: # Skip empty lines
yield f"{chunk}\n"
await asyncio.sleep(
0
) # Add an await to return inmediately
except httpx.HTTPStatusError as e:
raise HTTPException(
status_code=e.response.status_code,
Expand Down
23 changes: 15 additions & 8 deletions nilai-models/src/nilai_models/models/llama_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import logging
from typing import Any, Generator
from typing import AsyncGenerator

from fastapi import HTTPException
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -71,15 +72,20 @@ async def chat_completion(
# Streaming response logic
if req.stream:

def generate() -> Generator[str, Any, None]:
async def generate() -> AsyncGenerator[str, None]:
try:
# Create a generator for the streamed output
for output in self.model.create_chat_completion(
prompt, # type: ignore
stream=True,
temperature=req.temperature if req.temperature else 0.2,
max_tokens=req.max_tokens,
):
loop = asyncio.get_event_loop()
output_generator = await loop.run_in_executor(
None,
lambda: self.model.create_chat_completion(
prompt, # type: ignore
stream=True,
temperature=req.temperature if req.temperature else 0.2,
max_tokens=req.max_tokens,
),
)
for output in output_generator:
# Extract delta content from output
choices = output.get("choices", []) # type: ignore
if not choices or "delta" not in choices[0]:
Expand All @@ -92,6 +98,7 @@ def generate() -> Generator[str, Any, None]:
) # Create a ChoiceChunk
completion_chunk = ChatCompletionChunk(choices=[chunk])
yield f"data: {completion_chunk.model_dump_json()}\n\n" # Stream the chunk
await asyncio.sleep(0) # Add an await to return inmediately

yield "data: [DONE]\n\n"
except Exception as e:
Expand Down
Loading