Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ ENV/
env.bak/
venv.bak/

*.sqlite
.ruff_cache/
# Spyder project settings
.spyderproject
.spyproject
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# NilAI

Copy the `.env.sample` to `.env` to and replace the value of the `HUGGINGFACE_API_TOKEN` with the appropriate value. It is required to download Llama3.2 1B.

```shell
docker compose up --build web
Expand Down
1 change: 1 addition & 0 deletions db/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.sqlite
3 changes: 3 additions & 0 deletions db/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# DB

This directory is meant to host the db data.
6 changes: 3 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

services:
web:
nilai:
build:
context: .
dockerfile: docker/Dockerfile
ports:
- "12345:12345"
volumes:
- hugging_face_models:/root/.cache/huggingface
- ${PWD}/db/:/app/db/ # sqlite database for users
- hugging_face_models:/root/.cache/huggingface # cache models

volumes:
hugging_face_models:
9 changes: 6 additions & 3 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
FROM python:3.12-slim

COPY . /app
COPY --link nilai /app/nilai
COPY pyproject.toml uv.lock .env /app/

WORKDIR /app

Expand All @@ -9,5 +10,7 @@ RUN uv sync

EXPOSE 12345

ENTRYPOINT ["uv", "run", "fastapi", "run", "nilai/server.py"]
CMD ["--host", "0.0.0.0", "--port", "12345"]
# ENTRYPOINT ["uv", "run", "fastapi", "run", "nilai/main.py"]
# CMD ["--host", "0.0.0.0", "--port", "12345"]

CMD ["uv", "run", "fastapi", "run", "nilai/main.py", "--host", "0.0.0.0", "--port", "12345"]
11 changes: 11 additions & 0 deletions docker/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

```shell
docker build -t nillion/nilai:latest -f docker/Dockerfile .


docker run \
-p 12345:12345 \
-v hugging_face_models:/root/.cache/huggingface \
-v $(pwd)/users.sqlite:/app/users.sqlite \
nillion/nilai:latest
```
17 changes: 17 additions & 0 deletions nilai/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from fastapi import HTTPException, Security, status
from fastapi.security import APIKeyHeader

from nilai.db import UserManager

UserManager.initialize_db()

api_key_header = APIKeyHeader(name="X-API-Key")


def get_user(api_key_header: str = Security(api_key_header)):
user = UserManager.check_api_key(api_key_header)
if user:
return user
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing or invalid API key"
)
260 changes: 260 additions & 0 deletions nilai/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import logging
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional

import sqlalchemy
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import QueuePool

# Configure logging

logger = logging.getLogger(__name__)


# Database configuration with better defaults and connection pooling
class DatabaseConfig:
# Use environment variables in a real-world scenario
DATABASE_URL = "sqlite:///db/users.sqlite"
POOL_SIZE = 5
MAX_OVERFLOW = 10
POOL_TIMEOUT = 30
POOL_RECYCLE = 3600 # Reconnect after 1 hour


# Create base and engine with improved configuration
Base = sqlalchemy.orm.declarative_base()
engine = create_engine(
DatabaseConfig.DATABASE_URL,
poolclass=QueuePool,
pool_size=DatabaseConfig.POOL_SIZE,
max_overflow=DatabaseConfig.MAX_OVERFLOW,
pool_timeout=DatabaseConfig.POOL_TIMEOUT,
pool_recycle=DatabaseConfig.POOL_RECYCLE,
echo=False, # Set to True for SQL logging during development
)

# Create session factory with improved settings
SessionLocal = sessionmaker(
bind=engine,
autocommit=False, # Changed to False for more explicit transaction control
autoflush=False, # More control over when to flush
expire_on_commit=False, # Keep objects usable after session closes
)


# Enhanced User Model with additional constraints and validation
class User(Base):
__tablename__ = "users"

userid = Column(String(36), primary_key=True, index=True)
name = Column(String(100), nullable=False)
apikey = Column(String(36), unique=True, nullable=False, index=True)
input_tokens = Column(Integer, default=0, nullable=False)
generated_tokens = Column(Integer, default=0, nullable=False)

def __repr__(self):
return f"<User(userid={self.userid}, name={self.name})>"


@dataclass
class UserData:
userid: str
name: str
apikey: str
input_tokens: int
generated_tokens: int


# Context manager for database sessions
@contextmanager
def get_db_session() -> "Generator[Session, Any, Any]":
"""Provide a transactional scope for database operations."""
session = SessionLocal()
try:
yield session
session.commit()
except SQLAlchemyError as e:
session.rollback()
logger.error(f"Database error: {e}")
raise
finally:
session.close()


class UserManager:
@staticmethod
def initialize_db() -> bool:
"""
Create database tables only if they do not already exist.

Returns:
bool: True if tables were created, False if tables already existed
"""
try:
# Create an inspector to check existing tables
inspector = sqlalchemy.inspect(engine)

# Check if the 'users' table already exists
if not inspector.has_table("users"):
# Create all tables that do not exist
Base.metadata.create_all(bind=engine)
logger.info("Database tables created successfully.")
return True
else:
logger.info("Database tables already exist. Skipping creation.")
return False
except SQLAlchemyError as e:
logger.error(f"Error checking or creating database tables: {e}")
raise

@staticmethod
def generate_user_id() -> str:
"""Generate a unique user ID."""
return str(uuid.uuid4())

@staticmethod
def generate_api_key() -> str:
"""Generate a unique API key."""
return str(uuid.uuid4())

@staticmethod
def insert_user(name: str) -> Dict[str, str]:
"""
Insert a new user into the database.

Args:
name (str): Name of the user

Returns:
Dict containing userid and apikey
"""
userid = UserManager.generate_user_id()
apikey = UserManager.generate_api_key()

try:
with get_db_session() as session:
user = User(userid=userid, name=name, apikey=apikey)
session.add(user)
logger.info(f"User {name} added successfully.")
return {"userid": userid, "apikey": apikey}
except SQLAlchemyError as e:
logger.error(f"Error inserting user: {e}")
raise

@staticmethod
def check_api_key(api_key: str) -> Optional[str]:
"""
Validate an API key.

Args:
api_key (str): API key to validate

Returns:
User's name if API key is valid, None otherwise
"""
try:
with get_db_session() as session:
user = session.query(User).filter(User.apikey == api_key).first()
return user.name if user else None # type: ignore
except SQLAlchemyError as e:
logger.error(f"Error checking API key: {e}")
return None

@staticmethod
def update_token_usage(userid: str, input_tokens: int, generated_tokens: int):
"""
Update token usage for a specific user.

Args:
userid (str): User's unique ID
input_tokens (int): Number of input tokens
generated_tokens (int): Number of generated tokens
"""
try:
with get_db_session() as session:
user = session.query(User).filter(User.userid == userid).first()
if user:
user.input_tokens += input_tokens # type: ignore
user.generated_tokens += generated_tokens # type: ignore
logger.info(f"Updated token usage for user {userid}")
else:
logger.warning(f"User {userid} not found")
except SQLAlchemyError as e:
logger.error(f"Error updating token usage: {e}")

@staticmethod
def get_all_users() -> Optional[List[UserData]]:
"""
Retrieve all users from the database.

Returns:
Dict of users or None if no users found
"""
try:
with get_db_session() as session:
users = session.query(User).all()
return [
UserData(
userid=user.userid, # type: ignore
name=user.name, # type: ignore
apikey=user.apikey, # type: ignore
input_tokens=user.input_tokens, # type: ignore
generated_tokens=user.generated_tokens, # type: ignore
)
for user in users
]
except SQLAlchemyError as e:
logger.error(f"Error retrieving all users: {e}")
return None

@staticmethod
def get_user_token_usage(userid: str) -> Optional[Dict[str, int]]:
"""
Retrieve total token usage for a user.

Args:
userid (str): User's unique ID

Returns:
Dict of token usage or None if user not found
"""
try:
with get_db_session() as session:
user = session.query(User).filter(User.userid == userid).first()
if user:
return {
"input_tokens": user.input_tokens,
"generated_tokens": user.generated_tokens,
} # type: ignore
return None
except SQLAlchemyError as e:
logger.error(f"Error retrieving token usage: {e}")
return None


# Example Usage
if __name__ == "__main__":
# Initialize the database
UserManager.initialize_db()

print(UserManager.get_all_users())

# Add some users
bob = UserManager.insert_user("Bob")
alice = UserManager.insert_user("Alice")

print(f"Bob's details: {bob}")
print(f"Alice's details: {alice}")

# Check API key
user_name = UserManager.check_api_key(bob["apikey"])
print(f"API key validation: {user_name}")

# Update and retrieve token usage
UserManager.update_token_usage(bob["userid"], input_tokens=50, generated_tokens=20)
usage = UserManager.get_user_token_usage(bob["userid"])
print(f"Bob's token usage: {usage}")
Loading