Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ nilAI is a platform designed to run on Confidential VMs with Trusted Execution E
```
- Update the environment variables in `.env`:
- `HUGGINGFACE_API_TOKEN`: Your Hugging Face API token
- Obtain Hugging Face token by requesting access on the specific model's [Hugging Face page](https://huggingface.co/meta-llama/Llama-3.2-1B)
- Obtain token by requesting access on the specific model's Hugging Face page. For example, to request access for the Llama 1B model, you can ask [here](https://huggingface.co/meta-llama/Llama-3.2-1B). Note that for the Llama-8B model, you need to make a separate request.

## Deployment Options

Expand Down Expand Up @@ -115,7 +115,7 @@ docker run -d --name postgres \
2. **Run API Server**
```shell
# Development Environment
fastapi dev nilai-api/src/nilai_api/__main__.py --port 8080
uv run fastapi dev nilai-api/src/nilai_api/__main__.py --port 8080

# Production Environment
uv run fastapi run nilai-api/src/nilai_api/__main__.py --port 8080
Expand Down
4 changes: 2 additions & 2 deletions docker/compose/docker-compose.llama-3b-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ services:
condition: service_healthy
command: >
--model meta-llama/Llama-3.2-3B-Instruct
--gpu-memory-utilization 0.085
--max-model-len 4300
--gpu-memory-utilization 0.5
--max-model-len 30000
--tensor-parallel-size 1
--enable-auto-tool-choice
--tool-call-parser llama3_json
Expand Down
4 changes: 2 additions & 2 deletions nilai-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ dependencies = [
"sqlalchemy>=2.0.36",
"uvicorn>=0.32.1",
"httpx>=0.27.2",
"nilrag>=0.1.2",
"nilql>=0.0.0a3",
"nilrag>=0.1.10",
"nilql>=0.0.0a12",
"openai>=1.59.9",
"pg8000>=1.31.2",
"prometheus_fastapi_instrumentator>=7.0.2",
Expand Down
32 changes: 22 additions & 10 deletions nilai-api/src/nilai_api/handlers/nilrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
group_shares_by_id,
)
from sentence_transformers import SentenceTransformer
from typing import Union
from typing import Union, Any, Dict, List

logger = logging.getLogger(__name__)

Expand All @@ -29,8 +29,6 @@ def generate_embeddings_huggingface(

Args:
chunks_or_query (str or list): Text string(s) to generate embeddings for
model_name (str, optional): Name of the HuggingFace model to use.
Defaults to 'sentence-transformers/all-MiniLM-L6-v2'.

Returns:
numpy.ndarray: Array of embeddings for the input text
Expand All @@ -39,7 +37,7 @@ def generate_embeddings_huggingface(
return embeddings


def handle_nilrag(req: ChatRequest):
async def handle_nilrag(req: ChatRequest):
"""
Endpoint to process a client query.
1. Initialization: Secret share keys and NilDB instance.
Expand Down Expand Up @@ -74,8 +72,12 @@ def handle_nilrag(req: ChatRequest):

# Initialize secret keys
num_parties = len(nilDB.nodes)
additive_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"sum": True})
xor_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"store": True})
additive_key = nilql.ClusterKey.generate(
{"nodes": [{}] * num_parties}, {"sum": True}
)
xor_key = nilql.ClusterKey.generate(
{"nodes": [{}] * num_parties}, {"store": True}
)

# Step 2: Secret share query
logger.debug("Secret sharing query and sending to NilDB...")
Expand All @@ -95,7 +97,9 @@ def handle_nilrag(req: ChatRequest):

# Step 3: Ask NilDB to compute the differences
logger.debug("Requesting computation from NilDB...")
difference_shares = nilDB.diff_query_execute(nilql_query_embedding)
difference_shares: List[List[Dict[str, Any]]] = await nilDB.diff_query_execute(
nilql_query_embedding
)

# Step 4: Compute distances and sort
logger.debug("Compute distances and sort...")
Expand All @@ -106,7 +110,7 @@ def handle_nilrag(req: ChatRequest):
)
# 4.2 Transpose the lists for each _id
difference_shares_by_id = {
id: np.array(differences).T.tolist()
id: list(map(list, zip(*differences)))
for id, differences in difference_shares_by_id.items()
}
# 4.3 Decrypt and compute distances
Expand All @@ -124,11 +128,16 @@ def handle_nilrag(req: ChatRequest):

# Step 5: Query the top k
logger.debug("Query top k chunks...")
top_k = 2
top_k = req.nilrag.get("num_chunks", 2)
if not isinstance(top_k, int):
raise HTTPException(
status_code=400,
detail="num_chunks must be an integer as it represents the number of chunks to be retrieved.",
)
top_k_ids = [item["_id"] for item in sorted_ids[:top_k]]

# 5.1 Query top k
chunk_shares = nilDB.chunk_query_execute(top_k_ids)
chunk_shares = await nilDB.chunk_query_execute(top_k_ids)

# 5.2 Group chunk shares by ID
chunk_shares_by_id = group_shares_by_id(
Expand Down Expand Up @@ -166,6 +175,9 @@ def handle_nilrag(req: ChatRequest):

logger.debug(f"System message updated with relevant context:\n {req.messages}")

except HTTPException as e:
raise e

except Exception as e:
logger.error("An error occurred within nilrag: %s", str(e))
raise HTTPException(
Expand Down
11 changes: 1 addition & 10 deletions nilai-api/src/nilai_api/routers/private.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,6 @@ async def get_models(user: UserModel = Depends(get_user)) -> List[ModelMetadata]
"""
logger.info(f"Retrieving models for user {user.userid} from pid {os.getpid()}")
return [endpoint.metadata for endpoint in (await state.models).values()]
# result = [Model(
# id = endpoint.metadata.id,
# created = 0,
# object = "model",
# owned_by = endpoint.metadata.author,
# data = endpoint.metadata.dict(),
# ) for endpoint in (await state.models).values()]

# return result[0]


async def chat_completion_concurrent_rate_limit(request: Request) -> Tuple[int, str]:
Expand Down Expand Up @@ -196,7 +187,7 @@ async def chat_completion(
)

if req.nilrag:
handle_nilrag(req)
await handle_nilrag(req)

if req.stream:
client = AsyncOpenAI(base_url=model_url, api_key="<not-needed>")
Expand Down
Loading
Loading