Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.ci
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ GF_SECURITY_ADMIN_PASSWORD = "password"
# WebSearch Settings
BRAVE_SEARCH_API = "Your API here"

# Optional: Override default testnet URLs if needed
# NilDB Configuration (Required)
NILDB_NILCHAIN_URL=http://rpc.testnet.nilchain-rpc-proxy.nilogy.xyz
NILDB_NILAUTH_URL=https://nilauth.sandbox.app-cluster.sandbox.nilogy.xyz
NILDB_NODES=https://nildb-stg-n1.nillion.network,https://nildb-stg-n2.nillion.network,https://nildb-stg-n3.nillion.network
Expand Down
7 changes: 5 additions & 2 deletions .github/workflows/cicd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ jobs:
sed -i 's/HF_TOKEN=.*/HF_TOKEN=dummy_token/' .env
sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=dummy_api/' .env

- name: pyright
run: uv run pyright

- name: Run tests
run: uv run pytest -v tests/unit

- name: pyright
run: uv run pyright


start-runner:
name: Start self-hosted EC2 runner
Expand Down Expand Up @@ -252,6 +254,7 @@ jobs:
run: |
set -e
export ENVIRONMENT=ci
export AUTH_STRATEGY=nuc
uv run pytest -v tests/e2e

- name: Run E2E tests for API Key
Expand Down
12 changes: 6 additions & 6 deletions nilai-api/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from nilai_api.db import Base
from nilai_api.db.users import UserModel
from nilai_api.db.logs import QueryLog
import nilai_api.config as nilai_config
from nilai_api.config import CONFIG as nilai_config

# If we don't use the models, they remain unused, and the migration fails
# This is a workaround to ensure the models are loaded
Expand Down Expand Up @@ -93,12 +93,12 @@ def run_migrations_online() -> None:


load_dotenv()
db_host = nilai_config.DB_HOST
db_host = nilai_config.database.host
if db_host:
db_port = nilai_config.DB_PORT
db_user = nilai_config.DB_USER
db_pass = nilai_config.DB_PASS
db_name = nilai_config.DB_NAME
db_port = nilai_config.database.port
db_user = nilai_config.database.user
db_pass = nilai_config.database.password
db_name = nilai_config.database.db
config.set_main_option(
"sqlalchemy.url",
f"postgresql+asyncpg://{db_user}:{db_pass}@{db_host}:{db_port}/{db_name}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@

from alembic import op
import sqlalchemy as sa
from nilai_api.config import (
USER_RATE_LIMIT_MINUTE,
USER_RATE_LIMIT_HOUR,
USER_RATE_LIMIT_DAY,
)
from nilai_api.config import CONFIG


# revision identifiers, used by Alembic.
Expand All @@ -39,15 +35,21 @@ def upgrade() -> None:
),
sa.Column("last_activity", sa.DateTime, nullable=True),
sa.Column(
"ratelimit_day", sa.Integer, default=USER_RATE_LIMIT_DAY, nullable=True
"ratelimit_day",
sa.Integer,
default=CONFIG.rate_limiting.user_rate_limit_day,
nullable=True,
),
sa.Column(
"ratelimit_hour", sa.Integer, default=USER_RATE_LIMIT_HOUR, nullable=True
"ratelimit_hour",
sa.Integer,
default=CONFIG.rate_limiting.user_rate_limit_hour,
nullable=True,
),
sa.Column(
"ratelimit_minute",
sa.Integer,
default=USER_RATE_LIMIT_MINUTE,
default=CONFIG.rate_limiting.user_rate_limit_minute,
nullable=True,
),
)
Expand Down
2 changes: 1 addition & 1 deletion nilai-api/src/nilai_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@asynccontextmanager
async def lifespan(app: FastAPI):
client, rate_limit_command = await setup_redis_conn(config.REDIS_URL)
client, rate_limit_command = await setup_redis_conn(config.CONFIG.redis.url)

yield {"redis": client, "redis_rate_limit_command": rate_limit_command}

Expand Down
4 changes: 2 additions & 2 deletions nilai-api/src/nilai_api/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from logging import getLogger

from nilai_api import config
from nilai_api.config import CONFIG
from nilai_api.db.users import UserManager
from nilai_api.auth.strategies import AuthenticationStrategy

Expand All @@ -25,7 +25,7 @@ async def get_auth_info(
credentials: HTTPAuthorizationCredentials = Security(bearer_scheme),
) -> AuthenticationInfo:
try:
strategy_name: str = config.AUTH_STRATEGY.upper()
strategy_name: str = CONFIG.auth.auth_strategy.upper()

try:
strategy = AuthenticationStrategy[strategy_name]
Expand Down
4 changes: 2 additions & 2 deletions nilai-api/src/nilai_api/auth/nuc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from nuc.nilauth import NilauthClient
from nuc.token import Did, NucToken, Command
from functools import lru_cache
from nilai_api.config import NILAUTH_TRUSTED_ROOT_ISSUERS
from nilai_api.config import CONFIG
from nilai_api.state import state
from nilai_api.auth.common import AuthenticationError

Expand All @@ -32,7 +32,7 @@ def get_validator() -> NucTokenValidator:
try:
nilauth_public_keys = [
Did(NilauthClient(key).about().public_key.serialize())
for key in NILAUTH_TRUSTED_ROOT_ISSUERS
for key in CONFIG.auth.nilauth_trusted_root_issuers
]
except Exception as e:
logger.error(f"Error getting validator: {e}")
Expand Down
8 changes: 4 additions & 4 deletions nilai-api/src/nilai_api/auth/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
get_token_rate_limit,
get_token_prompt_document,
)
from nilai_api.config import DOCS_TOKEN
from nilai_api.config import CONFIG
from nilai_api.auth.common import (
PromptDocument,
TokenRateLimits,
Expand Down Expand Up @@ -69,7 +69,7 @@ async def wrapper(token) -> AuthenticationInfo:
return decorator


@allow_token(DOCS_TOKEN)
@allow_token(CONFIG.docs.token)
async def api_key_strategy(api_key: str) -> AuthenticationInfo:
user_model: Optional[UserModel] = await UserManager.check_api_key(api_key)
if user_model:
Expand All @@ -81,7 +81,7 @@ async def api_key_strategy(api_key: str) -> AuthenticationInfo:
raise AuthenticationError("Missing or invalid API key")


@allow_token(DOCS_TOKEN)
@allow_token(CONFIG.docs.token)
async def jwt_strategy(jwt_creds: str) -> AuthenticationInfo:
result = validate_jwt(jwt_creds)
user_model: Optional[UserModel] = await UserManager.check_api_key(
Expand All @@ -107,7 +107,7 @@ async def jwt_strategy(jwt_creds: str) -> AuthenticationInfo:
)


@allow_token(DOCS_TOKEN)
@allow_token(CONFIG.docs.token)
async def nuc_strategy(nuc_token) -> AuthenticationInfo:
"""
Validate a NUC token and return the user model
Expand Down
143 changes: 62 additions & 81 deletions nilai-api/src/nilai_api/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,64 @@
import os
from typing import List, Dict, Any, Optional
import yaml
from dotenv import load_dotenv
from dataclasses import dataclass

load_dotenv()

ENVIRONMENT: str = os.getenv("ENVIRONMENT", "testnet")

ETCD_HOST: str = os.getenv("ETCD_HOST", "localhost")
ETCD_PORT: int = int(os.getenv("ETCD_PORT", 2379))


REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379")

DOCS_TOKEN: str | None = os.getenv("DOCS_TOKEN", None)

DB_USER: str = os.getenv("POSTGRES_USER", "postgres")
DB_PASS: str = os.getenv("POSTGRES_PASSWORD", "")
DB_HOST: str = os.getenv("POSTGRES_HOST", "localhost")
DB_PORT: int = int(os.getenv("POSTGRES_PORT", 5432))
DB_NAME: str = os.getenv("POSTGRES_DB", "nilai_users")


NILAUTH_TRUSTED_ROOT_ISSUERS: List[str] = os.getenv(
"NILAUTH_TRUSTED_ROOT_ISSUERS", ""
).split(",")

AUTH_STRATEGY: str = os.getenv("AUTH_STRATEGY", "api_key")


# Web Search API configuration
@dataclass
class WebSearchSettings:
api_key: Optional[str] = None
api_path: str = "https://api.search.brave.com/res/v1/web/search"
count: int = 3
lang: str = "en"
country: str = "us"
timeout: float = 20.0
max_concurrent_requests: int = 20
rps: int = 20


WEB_SEARCH_SETTINGS = WebSearchSettings(api_key=os.getenv("BRAVE_SEARCH_API"))

# Default values
USER_RATE_LIMIT_MINUTE: Optional[int] = 100
USER_RATE_LIMIT_HOUR: Optional[int] = 1000
USER_RATE_LIMIT_DAY: Optional[int] = 10000
WEB_SEARCH_RATE_LIMIT_MINUTE: Optional[int] = 1
WEB_SEARCH_RATE_LIMIT_HOUR: Optional[int] = 3
WEB_SEARCH_RATE_LIMIT_DAY: Optional[int] = 72
MODEL_CONCURRENT_RATE_LIMIT: Dict[str, int] = {}


def load_config_from_yaml(config_path: str) -> Dict[str, Any]:
if os.path.exists(config_path):
with open(config_path, "r") as f:
return yaml.safe_load(f)
return {}


config_file: str = "config.yaml"
config_path = os.path.join(os.path.dirname(__file__), config_file)

if not os.path.exists(config_path):
config_file = "config.yaml"
config_path = os.path.join(os.path.dirname(__file__), config_file)

config_data = load_config_from_yaml(config_path)

# Overwrite with values from yaml
if config_data:
USER_RATE_LIMIT_MINUTE = config_data.get(
"user_rate_limit_minute", USER_RATE_LIMIT_MINUTE
# Import all configuration models
import json
from .environment import EnvironmentConfig
from .database import DatabaseConfig, EtcdConfig, RedisConfig
from .auth import AuthConfig, DocsConfig
from .nildb import NilDBConfig
from .web_search import WebSearchSettings
from .rate_limiting import RateLimitingConfig
from .utils import create_config_model, CONFIG_DATA
from pydantic import BaseModel
import logging


class NilAIConfig(BaseModel):
"""Centralized configuration container for the Nilai API."""

environment: EnvironmentConfig = create_config_model(
EnvironmentConfig, "", CONFIG_DATA
)
database: DatabaseConfig = create_config_model(
DatabaseConfig, "database", CONFIG_DATA, "POSTGRES_"
)
etcd: EtcdConfig = create_config_model(EtcdConfig, "etcd", CONFIG_DATA, "ETCD_")
redis: RedisConfig = create_config_model(
RedisConfig, "redis", CONFIG_DATA, "REDIS_"
)
USER_RATE_LIMIT_HOUR = config_data.get("user_rate_limit_hour", USER_RATE_LIMIT_HOUR)
USER_RATE_LIMIT_DAY = config_data.get("user_rate_limit_day", USER_RATE_LIMIT_DAY)
MODEL_CONCURRENT_RATE_LIMIT = config_data.get(
"model_concurrent_rate_limit", MODEL_CONCURRENT_RATE_LIMIT
auth: AuthConfig = create_config_model(AuthConfig, "auth", CONFIG_DATA)
docs: DocsConfig = create_config_model(DocsConfig, "docs", CONFIG_DATA, "DOCS_")
web_search: WebSearchSettings = create_config_model(
WebSearchSettings, "web_search", CONFIG_DATA, "WEB_SEARCH_"
)
rate_limiting: RateLimitingConfig = create_config_model(
RateLimitingConfig, "rate_limiting", CONFIG_DATA
)
nildb: NilDBConfig = create_config_model(
NilDBConfig, "nildb", CONFIG_DATA, "NILDB_"
)

def prettify(self):
"""Print the config in a pretty format removing passwords and other sensitive information"""
config_dict = self.model_dump()
keywords = ["pass", "token", "key"]
for key, value in config_dict.items():
if isinstance(value, str):
for keyword in keywords:
print(key, keyword, keyword in key)
if keyword in key and value is not None:
config_dict[key] = "***************"
if isinstance(value, dict):
for k, v in value.items():
for keyword in keywords:
if keyword in k and v is not None:
value[k] = "***************"
return json.dumps(config_dict, indent=4)


# Global config instance
CONFIG = NilAIConfig()
__all__ = [
# Main config object
"CONFIG"
]

logging.info(CONFIG.prettify())
18 changes: 18 additions & 0 deletions nilai-api/src/nilai_api/config/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import List, Optional, Literal
from pydantic import BaseModel, Field


class AuthConfig(BaseModel):
auth_strategy: Literal["api_key", "jwt", "nuc"] = Field(
description="Authentication strategy"
)
nilauth_trusted_root_issuers: List[str] = Field(
description="Trusted root issuers for nilauth"
)
auth_token: Optional[str] = Field(
default=None, description="Auth token for testing"
)


class DocsConfig(BaseModel):
token: Optional[str] = Field(default=None, description="Documentation access token")
Loading
Loading