Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/key authentication #792

Merged
merged 3 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions gui/pages/Dashboard/Settings/Settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import React, {useState, useEffect, useRef} from 'react';
import {ToastContainer, toast} from 'react-toastify';
import 'react-toastify/dist/ReactToastify.css';
import agentStyles from "@/pages/Content/Agents/Agents.module.css";
import {getOrganisationConfig, updateOrganisationConfig} from "@/pages/api/DashboardService";
import {getOrganisationConfig, updateOrganisationConfig,validateLLMApiKey} from "@/pages/api/DashboardService";
import {EventBus} from "@/utils/eventBus";
import {removeTab, setLocalStorageValue} from "@/utils/utils";
import Image from "next/image";
Expand Down Expand Up @@ -83,8 +83,15 @@ export default function Settings({organisationId}) {
return
}

updateKey("model_api_key", modelApiKey);
updateKey("model_source", source);
validateLLMApiKey(source, modelApiKey)
.then((response) => {
if (response.data.status==="success") {
updateKey("model_api_key", modelApiKey);
updateKey("model_source", source);
} else {
toast.error("Invalid API key", {autoClose: 1800});
}
})
};

const handleTemperatureChange = (event) => {
Expand Down
3 changes: 3 additions & 0 deletions gui/pages/api/DashboardService.js
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ export const validateAccessToken = () => {
return api.get(`/validate-access-token`);
}

export const validateLLMApiKey = (model_source, model_api_key) => {
return api.post(`/validate-llm-api-key`,{model_source, model_api_key});
}
export const checkEnvironment = () => {
return api.get(`/configs/get/env`);
}
Expand Down
18 changes: 18 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from superagi.controllers.analytics import router as analytics_router
from superagi.helper.tool_helper import register_toolkits
from superagi.lib.logger import logger
from superagi.llms.google_palm import GooglePalm
from superagi.llms.openai import OpenAi
from superagi.helper.auth import get_current_user
from superagi.models.agent_workflow import AgentWorkflow
Expand All @@ -53,6 +54,7 @@
from superagi.models.toolkit import Toolkit
from superagi.models.oauth_tokens import OauthTokens
from superagi.models.types.login_request import LoginRequest
from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest
from superagi.models.user import User

app = FastAPI()
Expand Down Expand Up @@ -426,6 +428,22 @@ async def root(Authorize: AuthJWT = Depends()):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")


@app.post("/validate-llm-api-key")
async def validate_llm_api_key(request: ValidateAPIKeyRequest, Authorize: AuthJWT = Depends()):
"""API to validate LLM API Key"""
source = request.model_source
api_key = request.model_api_key
valid_api_key = False
if source == "OpenAi":
valid_api_key = OpenAi(api_key=api_key).verify_access_key()
elif source == "Google Palm":
valid_api_key = GooglePalm(api_key=api_key).verify_access_key()
if valid_api_key:
return {"message": "Valid API Key", "status": "success"}
else:
return {"message": "Invalid API Key", "status": "failed"}


@app.get("/validate-open-ai-key/{open_ai_key}")
async def root(open_ai_key: str, Authorize: AuthJWT = Depends()):
"""API to validate Open AI Key"""
Expand Down
6 changes: 5 additions & 1 deletion superagi/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@ def get_api_key(self):

@abstractmethod
def get_model(self):
pass
pass

@abstractmethod
def verify_access_key(self):
pass
14 changes: 14 additions & 0 deletions superagi/llms/google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,17 @@ def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT
except Exception as exception:
logger.info("Google palm Exception:", exception)
return {"error": exception}

def verify_access_key(self):
"""
Verify the access key is valid.

Returns:
bool: True if the access key is valid, False otherwise.
"""
try:
models = palm.list_models()
return True
except Exception as exception:
logger.info("Google palm Exception:", exception)
return False
14 changes: 14 additions & 0 deletions superagi/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,17 @@ def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT
except Exception as exception:
logger.info("OpenAi Exception:", exception)
return {"error": exception}

def verify_access_key(self):
"""
Verify the access key is valid.

Returns:
bool: True if the access key is valid, False otherwise.
"""
try:
models = openai.Model.list()
return True
except Exception as exception:
logger.info("OpenAi Exception:", exception)
return False
6 changes: 6 additions & 0 deletions superagi/models/types/validate_llm_api_key_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pydantic import BaseModel


class ValidateAPIKeyRequest(BaseModel):
model_source: str
model_api_key: str
8 changes: 8 additions & 0 deletions tests/unit_tests/llms/test_google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,11 @@ def test_chat_completion(mock_palm):
top_p=palm_instance.top_p,
max_output_tokens=int(max_tokens)
)


def test_verify_access_key():
model = 'models/text-bison-001'
api_key = 'test_key'
palm_instance = GooglePalm(api_key, model=model)
result = palm_instance.verify_access_key()
assert result is False
9 changes: 9 additions & 0 deletions tests/unit_tests/llms/test_open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import MagicMock, patch
from superagi.llms.openai import OpenAi


@patch('superagi.llms.openai.openai')
def test_chat_completion(mock_openai):
# Arrange
Expand Down Expand Up @@ -30,3 +31,11 @@ def test_chat_completion(mock_openai):
frequency_penalty=openai_instance.frequency_penalty,
presence_penalty=openai_instance.presence_penalty
)


def test_verify_access_key():
model = 'gpt-4'
api_key = 'test_key'
openai_instance = OpenAi(api_key, model=model)
result = openai_instance.verify_access_key()
assert result is False