Skip to content

Resolve m types #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Track mcp.json.
- Rerun autogen tool on entitycore.
- Prompt engineering on tools.
- mtype resolving now use embedddings.

### Removed
- Knowledge graph tools and utils.
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ docker exec -it neuroagent-backend-1 alembic -x url=postgresql://postgres:pwd@po
docker exec -it neuroagent-minio-1 mc alias set myminio http://minio:9000 minioadmin minioadmin && docker exec -it neuroagent-minio-1 mc mb myminio/neuroagent
```

To enable the brain region resolving tool, retrieve your bearer token and make sure to run the following script:
To enable the brain region and m-type resolving tools, retrieve your bearer token and make sure to run the following scripts:
```bash
python backend/src/neuroagent/scripts/embed_hierarchies.py $token -e https://staging.openbraininstitute.org/api/entitycore/ -u http://localhost:9000 -b neuroagent -a minioadmin -s minioadmin
```
```bash
python backend/src/neuroagent/scripts/embed_mtypes.py $token -e https://staging.openbraininstitute.org/api/entitycore/ -u http://localhost:9000 -b neuroagent -a minioadmin -s minioadmin -p 1000
```
which stores a json file in your minio/s3 instance.

4. Access the application at `http://localhost:3000`
Expand Down
22 changes: 21 additions & 1 deletion backend/src/neuroagent/app/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Any, Sequence

import botocore
import yaml
from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field
Expand All @@ -29,7 +30,7 @@
ToolCallPartVercel,
ToolCallVercel,
)
from neuroagent.schemas import EmbeddedBrainRegions
from neuroagent.schemas import EmbeddedBrainRegions, EmbeddedMTypes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -237,6 +238,25 @@ def get_br_embeddings(
return output


def get_mtype_embeddings(
s3_client: Any, bucket_name: str, folder: str
) -> EmbeddedMTypes | None:
"""Retrieve mtype embeddings from s3."""
try:
resp = s3_client.get_object(
Bucket=bucket_name, Key="shared/mtypes_embeddings.json"
)
except botocore.exceptions.ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code in ("NoSuchKey", "404", "NotFound"):
return None
raise

payload = resp["Body"].read().decode("utf‑8")
data = json.loads(payload)
return EmbeddedMTypes(**data)


def format_messages_output(
db_messages: Sequence[Messages],
tool_hil_mapping: dict[str, bool],
Expand Down
1 change: 1 addition & 0 deletions backend/src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def get_context_variables(
"entitycore_url": settings.tools.entitycore.url,
"httpx_client": httpx_client,
"literature_search_url": settings.tools.literature.url,
"mtype_embeddings": request.app.state.mtype_embeddings,
"obi_one_url": settings.tools.obi_one.url,
"openai_client": openai_client,
"project_id": thread.project_id,
Expand Down
7 changes: 7 additions & 0 deletions backend/src/neuroagent/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from neuroagent import __version__
from neuroagent.app.app_utils import (
get_br_embeddings,
get_mtype_embeddings,
get_semantic_router,
setup_engine,
)
Expand Down Expand Up @@ -127,7 +128,13 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncContextManager[None]: # type:
bucket_name=app_settings.storage.bucket_name,
folder="shared",
)
mtype_embeddings = get_mtype_embeddings(
s3_client=s3_client,
bucket_name=app_settings.storage.bucket_name,
folder="shared",
)
fastapi_app.state.br_embeddings = br_embeddings
fastapi_app.state.mtype_embeddings = mtype_embeddings

async with aclosing(
AsyncAccountingSessionFactory(
Expand Down
14 changes: 14 additions & 0 deletions backend/src/neuroagent/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,17 @@ class EmbeddedBrainRegions(BaseModel):

regions: list[EmbeddedBrainRegion]
hierarchy_id: str


class EmbeddedMType(BaseModel):
"""MType embedding schema."""

id: str
pref_label: str
pref_label_embedding: list[float] | None = None


class EmbeddedMTypes(BaseModel):
"""Schema for dumping."""

mtypes: list[EmbeddedMType]
173 changes: 173 additions & 0 deletions backend/src/neuroagent/scripts/embed_mtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""Using the get-all endpoint from entitycore, embed all the m-types for resolving."""

import argparse
import asyncio
import logging
import os

import boto3
from dotenv import load_dotenv
from httpx import AsyncClient
from openai import AsyncOpenAI

from neuroagent.schemas import EmbeddedMType, EmbeddedMTypes

logging.basicConfig(
format="[%(levelname)s] %(asctime)s %(name)s %(message)s", level=logging.INFO
)

logger = logging.getLogger(__name__)


def get_parser() -> argparse.ArgumentParser:
"""Get parser for command line arguments."""
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"token",
type=str,
help="Bearer token for the entity core call.",
)
parser.add_argument(
"--entity-core-url",
"-e",
required=False,
default=None,
help="URL of the entity core API. Read from env if not specified.",
)
parser.add_argument(
"--page-size",
"-p",
required=False,
default=1000,
type=int,
help="page size to get all the m-types.",
)
(
parser.add_argument(
"--s3-url",
"-u",
type=str,
required=False,
default=None,
help="URL of the s3 bucket. Read from env if not specified.",
),
)
(
parser.add_argument(
"--s3-bucket-name",
"-b",
type=str,
required=False,
default=None,
help="Name of the s3 bucket. Read from env if not specified.",
),
)
(
parser.add_argument(
"--s3-access-key",
"-a",
type=str,
required=False,
default=None,
help="Access key of the s3 bucket. Read from env if not specified.",
),
)
(
parser.add_argument(
"--s3-secret-key",
"-s",
type=str,
required=False,
default=None,
help="Secret key of the s3 bucket. Read from env if not specified.",
),
)

return parser


async def push_mtype_embeddings_to_s3(
s3_url: str | None,
entity_core_url: str | None,
s3_access_key: str | None,
s3_secret_key: str | None,
s3_bucket_name: str,
token: str,
page_size: int,
) -> None:
"""Compute and push m-type embeddings to s3."""
httpx_client = AsyncClient(timeout=None)
logger.info("Getting list of all m-types from Entity-Core.")

response = await httpx_client.get(
f"{(entity_core_url or os.getenv('NEUROAGENT_TOOLS__ENTITYCORE__URL')).rstrip('/')}/mtype", # type: ignore
params={"page_size": page_size},
headers={"Authorization": f"Bearer {token}"},
)
if response.status_code != 200:
raise ValueError(
f"Entity core returned a non 200 status code. Could not update the brain region embeddings. Error: {response.text}"
)

m_types_response = response.json()
if m_types_response["pagination"]["total_items"] > page_size:
raise ValueError(
"Not all m-types were retreived, please increase the page size."
)

Comment on lines +115 to +119
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we paginate instead of raising an error. For instance, we can remove the -p argument from the argparse, set manually a page_size and do something along the lines of:

retrieved_items = []
page = 1
while m_types_response["pagination"]["total_items"] > len(retrieved_items):
    response = await httpx_client.get(..., params={"page_size": page_size, "page": page}
    retrieved_items.extend(...)
    page += 1

m_types = EmbeddedMTypes(
mtypes=[
EmbeddedMType(id=m_type["id"], pref_label=m_type["pref_label"])
for m_type in m_types_response["data"]
]
)
# Gather the names
pref_labels = [m_types.pref_label for m_types in m_types.mtypes]
Comment on lines +126 to +127
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we ever get an answer from Daniela about what field scientists will use to search for an mtype ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, I guess there are more important things to work on for CNS


# Embed them
logger.info("Embedding the m_types pref_labels.")
openai_client = AsyncOpenAI(api_key=os.getenv("NEUROAGENT_OPENAI__TOKEN"))
m_types_embeddings = await openai_client.embeddings.create(
input=pref_labels, model="text-embedding-3-small"
)

# Set the embeddings in the original class
for m_types_class, pref_label_embedding in zip(
m_types.mtypes, m_types_embeddings.data
):
m_types_class.pref_label_embedding = pref_label_embedding.embedding

# Put the result in the s3 bucket
logger.info(
f"Saving the results in s3 bucket: {s3_url or os.getenv('NEUROAGENT_STORAGE__ENDPOINT_URL')} at location: {'shared/mtypes_embeddings.json'}"
)
s3_client = boto3.client(
"s3",
endpoint_url=s3_url or os.getenv("NEUROAGENT_STORAGE__ENDPOINT_URL"),
aws_access_key_id=s3_access_key or os.getenv("NEUROAGENT_STORAGE__ACCESS_KEY"),
aws_secret_access_key=s3_secret_key
or os.getenv("NEUROAGENT_STORAGE__SECRET_KEY"),
aws_session_token=None,
config=boto3.session.Config(signature_version="s3v4"),
)

s3_client.put_object(
Bucket=s3_bucket_name or os.getenv("NEUROAGENT_STORAGE__BUCKET_NAME"),
Key="shared/mtypes_embeddings.json",
Body=m_types.model_dump_json(),
ContentType="application/json",
)


async def main() -> None:
"""Run main logic."""
parser = get_parser()
args = parser.parse_args()
await push_mtype_embeddings_to_s3(**vars(args))


if __name__ == "__main__":
load_dotenv()
asyncio.run(main())
Loading