Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/cheshire-cat-ai/core int…
Browse files Browse the repository at this point in the history
…o develop
  • Loading branch information
pieroit committed Nov 25, 2024
2 parents 3a3e1be + b55773d commit 2fdf311
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 30 deletions.
12 changes: 6 additions & 6 deletions core/cat/auth/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ async def __call__(
# get protocol from Starlette request
protocol = connection.scope.get('type')
# extract credentials (user_id, token_or_key) from connection
user_id, credential = await self.extract_credentials(connection)
user_id, credential = self.extract_credentials(connection)
auth_handlers = [
# try to get user from local idp
connection.app.state.ccat.core_auth_handler,
# try to get user from auth_handler
connection.app.state.ccat.custom_auth_handler,
]
for ah in auth_handlers:
user: AuthUserInfo = await ah.authorize_user_from_credential(
user: AuthUserInfo = ah.authorize_user_from_credential(
protocol, credential, self.resource, self.permission, user_id=user_id
)
if user:
Expand All @@ -59,7 +59,7 @@ async def __call__(
self.not_allowed(connection)

@abstractmethod
async def extract_credentials(self, connection: Request | WebSocket) -> Tuple[str] | None:
def extract_credentials(self, connection: Request | WebSocket) -> Tuple[str] | None:
pass

@abstractmethod
Expand All @@ -73,7 +73,7 @@ def not_allowed(self, connection: Request | WebSocket):

class HTTPAuth(ConnectionAuth):

async def extract_credentials(self, connection: Request) -> Tuple[str, str] | None:
def extract_credentials(self, connection: Request) -> Tuple[str, str] | None:
"""
Extract user_id and token/key from headers
"""
Expand Down Expand Up @@ -121,7 +121,7 @@ def not_allowed(self, connection: Request):

class WebSocketAuth(ConnectionAuth):

async def extract_credentials(self, connection: WebSocket) -> Tuple[str, str] | None:
def extract_credentials(self, connection: WebSocket) -> Tuple[str, str] | None:
"""
Extract user_id from WebSocket path params
Extract token from WebSocket query string
Expand Down Expand Up @@ -166,7 +166,7 @@ def not_allowed(self, connection: WebSocket):

class CoreFrontendAuth(HTTPAuth):

async def extract_credentials(self, connection: Request) -> Tuple[str, str] | None:
def extract_credentials(self, connection: Request) -> Tuple[str, str] | None:
"""
Extract user_id from cookie
"""
Expand Down
20 changes: 10 additions & 10 deletions core/cat/factory/custom_auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BaseAuthHandler(ABC): # TODOAUTH: pydantic model?
MUST be implemented by subclasses.
"""

async def authorize_user_from_credential(
def authorize_user_from_credential(
self,
protocol: Literal["http", "websocket"],
credential: str,
Expand All @@ -32,17 +32,17 @@ async def authorize_user_from_credential(
) -> AuthUserInfo | None:
if is_jwt(credential):
# JSON Web Token auth
return await self.authorize_user_from_jwt(
return self.authorize_user_from_jwt(
credential, auth_resource, auth_permission
)
else:
# API_KEY auth
return await self.authorize_user_from_key(
return self.authorize_user_from_key(
protocol, user_id, credential, auth_resource, auth_permission
)

@abstractmethod
async def authorize_user_from_jwt(
def authorize_user_from_jwt(
self,
token: str,
auth_resource: AuthResource,
Expand All @@ -52,7 +52,7 @@ async def authorize_user_from_jwt(
pass

@abstractmethod
async def authorize_user_from_key(
def authorize_user_from_key(
self,
protocol: Literal["http", "websocket"],
user_id: str,
Expand All @@ -67,7 +67,7 @@ async def authorize_user_from_key(
# Core auth handler, verify token on local idp
class CoreAuthHandler(BaseAuthHandler):

async def authorize_user_from_jwt(
def authorize_user_from_jwt(
self, token: str, auth_resource: AuthResource, auth_permission: AuthPermission
) -> AuthUserInfo | None:
try:
Expand Down Expand Up @@ -98,7 +98,7 @@ async def authorize_user_from_jwt(
# do not pass
return None

async def authorize_user_from_key(
def authorize_user_from_key(
self,
protocol: Literal["http", "websocket"],
user_id: str,
Expand Down Expand Up @@ -147,7 +147,7 @@ def _authorize_websocket_key(self, user_id: str, api_key: str, ws_key: str) -> A
# No match -> deny access
return None

async def issue_jwt(self, username: str, password: str) -> str | None:
def issue_jwt(self, username: str, password: str) -> str | None:
# authenticate local user credentials and return a JWT token

# brutal search over users, which are stored in a simple dictionary.
Expand Down Expand Up @@ -178,10 +178,10 @@ async def issue_jwt(self, username: str, password: str) -> str | None:

# Default Auth, always deny auth by default (only core auth decides).
class CoreOnlyAuthHandler(BaseAuthHandler):
async def authorize_user_from_jwt(*args, **kwargs) -> AuthUserInfo | None:
def authorize_user_from_jwt(*args, **kwargs) -> AuthUserInfo | None:
return None

async def authorize_user_from_key(*args, **kwargs) -> AuthUserInfo | None:
def authorize_user_from_key(*args, **kwargs) -> AuthUserInfo | None:
return None


4 changes: 2 additions & 2 deletions core/cat/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def core_login_token(request: Request, response: Response):

# use username and password to authenticate user from local identity provider and get token
auth_handler = request.app.state.ccat.core_auth_handler
access_token = await auth_handler.issue_jwt(
access_token = auth_handler.issue_jwt(
form_data["username"], form_data["password"]
)

Expand Down Expand Up @@ -95,7 +95,7 @@ async def auth_token(request: Request, credentials: UserCredentials):

# use username and password to authenticate user from local identity provider and get token
auth_handler = request.app.state.ccat.core_auth_handler
access_token = await auth_handler.issue_jwt(
access_token = auth_handler.issue_jwt(
credentials.username, credentials.password
)

Expand Down
93 changes: 90 additions & 3 deletions core/cat/routes/memory/points.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Dict, List
from pydantic import BaseModel
from fastapi import Query, Request, APIRouter, HTTPException, Depends
from fastapi import Query, Body, Request, APIRouter, HTTPException, Depends
import time

from cat.auth.connection import HTTPAuth
from cat.auth.permissions import AuthPermission, AuthResource
from cat.memory.vector_memory import VectorMemory
from cat.looking_glass.stray_cat import StrayCat

from cat.log import log

class MemoryPointBase(BaseModel):
content: str
Expand All @@ -24,14 +24,15 @@ class MemoryPoint(MemoryPointBase):


# GET memories from recall
@router.get("/recall")
@router.get("/recall", deprecated=True)
async def recall_memory_points_from_text(
request: Request,
text: str = Query(description="Find memories similar to this text."),
k: int = Query(default=100, description="How many memories to return."),
stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)),
) -> Dict:
"""Search k memories similar to given text."""
log.warning("Deprecated: This endpoint will be removed in the next major version.")

# Embed the query to plot it in the Memory page
query_embedding = stray.embedder.embed_query(text)
Expand Down Expand Up @@ -76,6 +77,92 @@ async def recall_memory_points_from_text(
},
}

# POST memories from recall
@router.post("/recall")
async def recall_memory_points(
request: Request,
text: str = Body(description="Find memories similar to this text."),
k: int = Body(default=100, description="How many memories to return."),
metadata: Dict = Body(default={},
description="Flat dictionary where each key-value pair represents a filter."
"The memory points returned will match the specified metadata criteria."
),
stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)),
) -> Dict:
"""Search k memories similar to given text with specified metadata criteria.
Example
----------
```
collection = "episodic"
content = "MIAO!"
metadata = {"custom_key": "custom_value"}
req_json = {
"content": content,
"metadata": metadata,
}
# create a point
res = requests.post(
f"http://localhost:1865/memory/collections/{collection}/points", json=req_json
)
# recall with metadata
req_json = {
"text": "CAT",
"metadata":{"custom_key":"custom_value"}
}
res = requests.post(
f"http://localhost:1865/memory/recall", json=req_json
)
json = res.json()
print(json)
```
"""

# Embed the query to plot it in the Memory page
query_embedding = stray.embedder.embed_query(text)
query = {
"text": text,
"vector": query_embedding,
}

# Loop over collections and retrieve nearby memories
collections = list(
stray.memory.vectors.collections.keys()
)
recalled = {}
for c in collections:
# only episodic collection has users
user_id = stray.user_id
if c == "episodic":
metadata["source"] = user_id
else:
metadata.pop("source", None)

memories = stray.memory.vectors.collections[c].recall_memories_from_embedding(
query_embedding, k=k, metadata=metadata
)

recalled[c] = []
for metadata_memories, score, vector, id in memories:
memory_dict = dict(metadata_memories)
memory_dict.pop("lc_kwargs", None) # langchain stuff, not needed
memory_dict["id"] = id
memory_dict["score"] = float(score)
memory_dict["vector"] = vector
recalled[c].append(memory_dict)

return {
"query": query,
"vectors": {
"embedder": str(
stray.embedder.__class__.__name__
), # TODO: should be the config class name
"collections": recalled,
},
}

# CREATE a point in memory
@router.post("/collections/{collection_id}/points", response_model=MemoryPoint)
async def create_memory_point(
Expand Down
8 changes: 3 additions & 5 deletions core/tests/routes/auth/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def test_refuse_issue_jwt(client):
assert json["detail"]["error"] == "Invalid Credentials"


@pytest.mark.asyncio # to test async functions
async def test_issue_jwt(client):
def test_issue_jwt(client):
creds = {
"username": "admin",
"password": "admin"
Expand All @@ -49,7 +48,7 @@ async def test_issue_jwt(client):

# is the JWT correct for core auth handler?
auth_handler = client.app.state.ccat.core_auth_handler
user_info = await auth_handler.authorize_user_from_jwt(
user_info = auth_handler.authorize_user_from_jwt(
received_token, AuthResource.LLM, AuthPermission.WRITE
)
assert len(user_info.id) == 36 and len(user_info.id.split("-")) == 5 # uuid4
Expand All @@ -70,8 +69,7 @@ async def test_issue_jwt(client):
assert False


@pytest.mark.asyncio
async def test_issue_jwt_for_new_user(client):
def test_issue_jwt_for_new_user(client):

# create new user
creds = {
Expand Down
51 changes: 47 additions & 4 deletions core/tests/routes/memory/test_memory_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# search on default startup memory
def test_memory_recall_default_success(client):
params = {"text": "Red Queen"}
response = client.get("/memory/recall/", params=params)
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200

Expand All @@ -30,7 +30,7 @@ def test_memory_recall_default_success(client):

# search without query should throw error
def test_memory_recall_without_query_error(client):
response = client.get("/memory/recall")
response = client.post("/memory/recall")
assert response.status_code == 400


Expand All @@ -42,7 +42,7 @@ def test_memory_recall_success(client):

# recall
params = {"text": "Red Queen"}
response = client.get("/memory/recall/", params=params)
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
Expand All @@ -58,8 +58,51 @@ def test_memory_recall_with_k_success(client):
# recall at max k memories
max_k = 2
params = {"k": max_k, "text": "Red Queen"}
response = client.get("/memory/recall/", params=params)
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
assert len(episodic_memories) == max_k # only 2 of 6 memories recalled

# search with query and metadata
def test_memory_recall_with_metadata(client):
messages = [
{
"content": "MIAO_1",
"metadata": {"key_1":"v1","key_2":"v2"},
},
{
"content": "MIAO_2",
"metadata": {"key_1":"v1","key_2":"v3"},
},
{
"content": "MIAO_3",
"metadata": {},
}
]

# insert a new points with metadata
for req_json in messages:
client.post(
"/memory/collections/episodic/points", json=req_json
)

# recall with metadata
params = {"text": "MIAO", "metadata":{"key_1":"v1"}}
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
assert len(episodic_memories) == 2

# recall with metadata multiple keys in metadata
params = {"text": "MIAO", "metadata":{"key_1":"v1","key_2":"v2"}}
response = client.post("/memory/recall/", json=params)
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
assert len(episodic_memories) == 1
assert episodic_memories[0]["page_content"] == "MIAO_1"



0 comments on commit 2fdf311

Please sign in to comment.