Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .env.ci
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,10 @@ GF_SECURITY_ADMIN_PASSWORD = "password"

# WebSearch Settings
BRAVE_SEARCH_API = "Your API here"

# Optional: Override default testnet URLs if needed
NILDB_NILCHAIN_URL=http://rpc.testnet.nilchain-rpc-proxy.nilogy.xyz
NILDB_NILAUTH_URL=https://nilauth.sandbox.app-cluster.sandbox.nilogy.xyz
NILDB_NODES=https://nildb-stg-n1.nillion.network,https://nildb-stg-n2.nillion.network,https://nildb-stg-n3.nillion.network
NILDB_BUILDER_PRIVATE_KEY=0x1234567890abcdef1234567890abcdef12345678
NILDB_COLLECTION=12345678-1234-1234-1234-123456789012
9 changes: 9 additions & 0 deletions .github/workflows/cicd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ jobs:
- name: Run Ruff linting
run: uv run ruff check --exclude packages/verifier/

- name: Create .env for tests
run: |
cp .env.ci .env
# Set dummy secrets for unit tests
sed -i 's/HF_TOKEN=.*/HF_TOKEN=dummy_token/' .env
sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=dummy_api/' .env

- name: Run tests
run: uv run pytest -v tests/unit

Expand Down Expand Up @@ -135,6 +142,8 @@ jobs:
# Copy secret into .env replacing the existing HF_TOKEN
sed -i 's/HF_TOKEN=.*/HF_TOKEN=${{ secrets.HF_TOKEN }}/' .env
sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=${{ secrets.BRAVE_SEARCH_API }}/' .env
sed -i 's/NILDB_BUILDER_PRIVATE_KEY=.*/NILDB_BUILDER_PRIVATE_KEY=${{ secrets.NILDB_BUILDER_PRIVATE_KEY }}/' .env
sed -i 's/NILDB_COLLECTION=.*/NILDB_COLLECTION=${{ secrets.NILDB_COLLECTION }}/' .env

- name: Compose docker-compose.yml
run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -o development-compose.yml
Expand Down
22 changes: 22 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Global pytest configuration."""

import asyncio
import warnings


def pytest_configure(config):
"""Configure pytest to suppress StreamWriter errors."""
# Suppress warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

# Monkey patch StreamWriter.__del__ to suppress exceptions
original_del = asyncio.StreamWriter.__del__

def silent_del(self):
try:
original_del(self)
except Exception:
pass

asyncio.StreamWriter.__del__ = silent_del
6 changes: 4 additions & 2 deletions nilai-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ dependencies = [
"web3>=7.8.0",
"click>=8.1.8",
"nuc-helpers",
"nuc",
"nuc>=0.1.0",
"pyyaml>=6.0.1",
"secretvaults",
]


Expand All @@ -45,4 +46,5 @@ build-backend = "hatchling.build"
nilai-common = { workspace = true }
nuc-helpers = { workspace = true }

nuc = { git = "https://github.com/NillionNetwork/nuc-py.git", rev = "4922b5e9354e611cc31322d681eb29da05be584e" }
# TODO: Remove this once the secretvaults package is released with the fix
secretvaults = { git = "https://github.com/jcabrero/secretvaults-py", rev = "main" }
3 changes: 3 additions & 0 deletions nilai-api/src/nilai_api/auth/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import HTTPException, status
from nilai_api.db.users import UserData
from nuc_helpers.usage import TokenRateLimits, TokenRateLimit
from nuc_helpers.nildb_document import PromptDocument


class AuthenticationError(HTTPException):
Expand All @@ -17,11 +18,13 @@ def __init__(self, detail: str):
class AuthenticationInfo(BaseModel):
user: UserData
token_rate_limit: Optional[TokenRateLimits]
prompt_document: Optional[PromptDocument]


__all__ = [
"AuthenticationError",
"AuthenticationInfo",
"TokenRateLimits",
"TokenRateLimit",
"PromptDocument",
]
6 changes: 6 additions & 0 deletions nilai-api/src/nilai_api/auth/nuc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nilai_common.logger import setup_logger

from nuc_helpers.usage import TokenRateLimits
from nuc_helpers.nildb_document import PromptDocument

logger = setup_logger(__name__)

Expand Down Expand Up @@ -120,3 +121,8 @@ def get_token_rate_limit(nuc_token: str) -> Optional[TokenRateLimits]:
raise AuthenticationError("Token has expired")

return token_rate_limits


def get_token_prompt_document(nuc_token: str) -> Optional[PromptDocument]:
prompt_document = PromptDocument.from_token(nuc_token)
return prompt_document
39 changes: 29 additions & 10 deletions nilai-api/src/nilai_api/auth/strategies.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Callable, Awaitable
from typing import Callable, Awaitable, Optional
from datetime import datetime, timezone

from nilai_api.db.users import UserManager, UserModel, UserData
from nilai_api.auth.jwt import validate_jwt
from nilai_api.auth.nuc import validate_nuc, get_token_rate_limit
from nilai_api.auth.nuc import (
validate_nuc,
get_token_rate_limit,
get_token_prompt_document,
)
from nilai_api.config import DOCS_TOKEN
from nilai_api.auth.common import (
PromptDocument,
TokenRateLimits,
AuthenticationInfo,
AuthenticationError,
Expand Down Expand Up @@ -55,6 +60,7 @@ async def wrapper(token) -> AuthenticationInfo:
return AuthenticationInfo(
user=UserData.from_sqlalchemy(user_model),
token_rate_limit=None,
prompt_document=None,
)
return await function(token)

Expand All @@ -65,21 +71,27 @@ async def wrapper(token) -> AuthenticationInfo:

@allow_token(DOCS_TOKEN)
async def api_key_strategy(api_key: str) -> AuthenticationInfo:
user_model: UserModel | None = await UserManager.check_api_key(api_key)
user_model: Optional[UserModel] = await UserManager.check_api_key(api_key)
if user_model:
return AuthenticationInfo(
user=UserData.from_sqlalchemy(user_model), token_rate_limit=None
user=UserData.from_sqlalchemy(user_model),
token_rate_limit=None,
prompt_document=None,
)
raise AuthenticationError("Missing or invalid API key")


@allow_token(DOCS_TOKEN)
async def jwt_strategy(jwt_creds: str) -> AuthenticationInfo:
result = validate_jwt(jwt_creds)
user_model: UserModel | None = await UserManager.check_api_key(result.user_address)
user_model: Optional[UserModel] = await UserManager.check_api_key(
result.user_address
)
if user_model:
return AuthenticationInfo(
user=UserData.from_sqlalchemy(user_model), token_rate_limit=None
user=UserData.from_sqlalchemy(user_model),
token_rate_limit=None,
prompt_document=None,
)
else:
user_model = UserModel(
Expand All @@ -89,7 +101,9 @@ async def jwt_strategy(jwt_creds: str) -> AuthenticationInfo:
)
await UserManager.insert_user_model(user_model)
return AuthenticationInfo(
user=UserData.from_sqlalchemy(user_model), token_rate_limit=None
user=UserData.from_sqlalchemy(user_model),
token_rate_limit=None,
prompt_document=None,
)


Expand All @@ -99,12 +113,15 @@ async def nuc_strategy(nuc_token) -> AuthenticationInfo:
Validate a NUC token and return the user model
"""
subscription_holder, user = validate_nuc(nuc_token)
token_rate_limits: TokenRateLimits | None = get_token_rate_limit(nuc_token)
user_model: UserModel | None = await UserManager.check_user(user)
token_rate_limits: Optional[TokenRateLimits] = get_token_rate_limit(nuc_token)
prompt_document: Optional[PromptDocument] = get_token_prompt_document(nuc_token)

user_model: Optional[UserModel] = await UserManager.check_user(user)
if user_model:
return AuthenticationInfo(
user=UserData.from_sqlalchemy(user_model),
token_rate_limit=token_rate_limits,
prompt_document=prompt_document,
)

user_model = UserModel(
Expand All @@ -114,7 +131,9 @@ async def nuc_strategy(nuc_token) -> AuthenticationInfo:
)
await UserManager.insert_user_model(user_model)
return AuthenticationInfo(
user=UserData.from_sqlalchemy(user_model), token_rate_limit=token_rate_limits
user=UserData.from_sqlalchemy(user_model),
token_rate_limit=token_rate_limits,
prompt_document=prompt_document,
)


Expand Down
4 changes: 4 additions & 0 deletions nilai-api/src/nilai_api/db/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def from_sqlalchemy(cls, user: UserModel) -> "UserData":
web_search_ratelimit_minute=user.web_search_ratelimit_minute,
)

@property
def is_subscription_owner(self):
return self.userid == self.apikey


class UserManager:
@staticmethod
Expand Down
Empty file.
13 changes: 13 additions & 0 deletions nilai-api/src/nilai_api/handlers/nildb/api_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import BaseModel, ConfigDict
from typing import TypeAlias

PromptDelegationRequest: TypeAlias = str


class PromptDelegationToken(BaseModel):
"""Delegation token model"""

model_config = ConfigDict(validate_assignment=True)

token: str
did: str
38 changes: 38 additions & 0 deletions nilai-api/src/nilai_api/handlers/nildb/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
from typing import Optional
from dotenv import load_dotenv
from pydantic import BaseModel
from pydantic import Field

from secretvaults.common.types import Uuid

load_dotenv()


class NilDBConfig(BaseModel):
NILCHAIN_URL: str = Field(..., description="The URL of the Nilchain")
NILAUTH_URL: str = Field(..., description="The URL of the Nilauth")
NODES: list[str] = Field(..., description="The URLs of the Nildb nodes")
BUILDER_PRIVATE_KEY: str = Field(..., description="The private key of the builder")
COLLECTION: Uuid = Field(..., description="The ID of the collection")


def get_required_env_var(name: str) -> str:
"""Get a required environment variable, raising an error if not set."""
value: Optional[str] = os.getenv(name, None)
if value is None:
raise ValueError(f"Required environment variable {name} is not set")
return value


# Validate environment variables at import time
CONFIG = NilDBConfig(
NILCHAIN_URL=get_required_env_var("NILDB_NILCHAIN_URL"),
NILAUTH_URL=get_required_env_var("NILDB_NILAUTH_URL"),
NODES=get_required_env_var("NILDB_NODES").split(","),
BUILDER_PRIVATE_KEY=get_required_env_var("NILDB_BUILDER_PRIVATE_KEY"),
COLLECTION=Uuid(get_required_env_var("NILDB_COLLECTION")),
)


__all__ = ["CONFIG"]
Loading
Loading