Skip to content

Commit da80226

Browse files
authored
Merge pull request #628 from yangcao77/shield-endpoint
[RHDHPAI-1150] create /v1/shields endpoint
2 parents f74a8be + 5a98473 commit da80226

File tree

7 files changed

+550
-2
lines changed

7 files changed

+550
-2
lines changed

docs/openapi.json

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,40 @@
123123
}
124124
}
125125
},
126+
"/v1/shields": {
127+
"get": {
128+
"tags": [
129+
"shields"
130+
],
131+
"summary": "Shields Endpoint Handler",
132+
"description": "Handle requests to the /shields endpoint.\n\nProcess GET requests to the /shields endpoint, returning a list of available\nshields from the Llama Stack service.\n\nRaises:\n HTTPException: If unable to connect to the Llama Stack server or if\n shield retrieval fails for any reason.\n\nReturns:\n ShieldsResponse: An object containing the list of available shields.",
133+
"operationId": "shields_endpoint_handler_v1_shields_get",
134+
"responses": {
135+
"200": {
136+
"description": "Successful Response",
137+
"content": {
138+
"application/json": {
139+
"schema": {
140+
"$ref": "#/components/schemas/ShieldsResponse"
141+
}
142+
}
143+
},
144+
"shields": [
145+
{
146+
"identifier": "lightspeed_question_validity-shield",
147+
"provider_resource_id": "lightspeed_question_validity-shield",
148+
"provider_id": "lightspeed_question_validity",
149+
"type": "shield",
150+
"params": {}
151+
}
152+
]
153+
},
154+
"500": {
155+
"description": "Connection to Llama Stack is broken"
156+
}
157+
}
158+
}
159+
},
126160
"/v1/query": {
127161
"post": {
128162
"tags": [
@@ -1082,6 +1116,7 @@
10821116
"delete_conversation",
10831117
"feedback",
10841118
"get_models",
1119+
"get_shields",
10851120
"get_metrics",
10861121
"get_config",
10871122
"info",
@@ -2990,6 +3025,34 @@
29903025
"title": "ServiceConfiguration",
29913026
"description": "Service configuration."
29923027
},
3028+
"ShieldsResponse": {
3029+
"properties": {
3030+
"shields": {
3031+
"items": {
3032+
"additionalProperties": true,
3033+
"type": "object"
3034+
},
3035+
"type": "array",
3036+
"title": "Shields",
3037+
"description": "List of shields available",
3038+
"examples": [
3039+
{
3040+
"identifier": "lightspeed_question_validity-shield",
3041+
"params": {},
3042+
"provider_id": "lightspeed_question_validity",
3043+
"provider_resource_id": "lightspeed_question_validity-shield",
3044+
"type": "shield"
3045+
}
3046+
]
3047+
}
3048+
},
3049+
"type": "object",
3050+
"required": [
3051+
"shields"
3052+
],
3053+
"title": "ShieldsResponse",
3054+
"description": "Model representing a response to shields request."
3055+
},
29933056
"StatusResponse": {
29943057
"properties": {
29953058
"functionality": {

src/app/endpoints/shields.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""Handler for REST API call to list available shields."""
2+
3+
import logging
4+
from typing import Annotated, Any
5+
6+
from fastapi import APIRouter, HTTPException, Request, status
7+
from fastapi.params import Depends
8+
from llama_stack_client import APIConnectionError
9+
10+
from authentication import get_auth_dependency
11+
from authentication.interface import AuthTuple
12+
from authorization.middleware import authorize
13+
from client import AsyncLlamaStackClientHolder
14+
from configuration import configuration
15+
from models.config import Action
16+
from models.responses import ShieldsResponse
17+
from utils.endpoints import check_configuration_loaded
18+
19+
logger = logging.getLogger(__name__)
20+
router = APIRouter(tags=["shields"])
21+
22+
23+
shields_responses: dict[int | str, dict[str, Any]] = {
24+
200: {
25+
"shields": [
26+
{
27+
"identifier": "lightspeed_question_validity-shield",
28+
"provider_resource_id": "lightspeed_question_validity-shield",
29+
"provider_id": "lightspeed_question_validity",
30+
"type": "shield",
31+
"params": {},
32+
}
33+
]
34+
},
35+
500: {"description": "Connection to Llama Stack is broken"},
36+
}
37+
38+
39+
@router.get("/shields", responses=shields_responses)
40+
@authorize(Action.GET_SHIELDS)
41+
async def shields_endpoint_handler(
42+
request: Request,
43+
auth: Annotated[AuthTuple, Depends(get_auth_dependency())],
44+
) -> ShieldsResponse:
45+
"""
46+
Handle requests to the /shields endpoint.
47+
48+
Process GET requests to the /shields endpoint, returning a list of available
49+
shields from the Llama Stack service.
50+
51+
Raises:
52+
HTTPException: If unable to connect to the Llama Stack server or if
53+
shield retrieval fails for any reason.
54+
55+
Returns:
56+
ShieldsResponse: An object containing the list of available shields.
57+
"""
58+
# Used only by the middleware
59+
_ = auth
60+
61+
# Nothing interesting in the request
62+
_ = request
63+
64+
check_configuration_loaded(configuration)
65+
66+
llama_stack_configuration = configuration.llama_stack_configuration
67+
logger.info("Llama stack config: %s", llama_stack_configuration)
68+
69+
try:
70+
# try to get Llama Stack client
71+
client = AsyncLlamaStackClientHolder().get_client()
72+
# retrieve shields
73+
shields = await client.shields.list()
74+
s = [dict(s) for s in shields]
75+
return ShieldsResponse(shields=s)
76+
77+
# connection to Llama Stack server
78+
except APIConnectionError as e:
79+
logger.error("Unable to connect to Llama Stack: %s", e)
80+
raise HTTPException(
81+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
82+
detail={
83+
"response": "Unable to connect to Llama Stack",
84+
"cause": str(e),
85+
},
86+
) from e
87+
# any other exception that can occur during shield listing
88+
except Exception as e:
89+
logger.error("Unable to retrieve list of shields: %s", e)
90+
raise HTTPException(
91+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
92+
detail={
93+
"response": "Unable to retrieve list of shields",
94+
"cause": str(e),
95+
},
96+
) from e

src/app/routers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from app.endpoints import (
66
info,
77
models,
8+
shields,
89
root,
910
query,
1011
health,
@@ -27,6 +28,7 @@ def include_routers(app: FastAPI) -> None:
2728
app.include_router(root.router)
2829
app.include_router(info.router, prefix="/v1")
2930
app.include_router(models.router, prefix="/v1")
31+
app.include_router(shields.router, prefix="/v1")
3032
app.include_router(query.router, prefix="/v1")
3133
app.include_router(streaming_query.router, prefix="/v1")
3234
app.include_router(config.router, prefix="/v1")

src/models/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ class Action(str, Enum):
358358
DELETE_CONVERSATION = "delete_conversation"
359359
FEEDBACK = "feedback"
360360
GET_MODELS = "get_models"
361+
GET_SHIELDS = "get_shields"
361362
GET_METRICS = "get_metrics"
362363
GET_CONFIG = "get_config"
363364

src/models/responses.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,24 @@ class ModelsResponse(BaseModel):
3636
)
3737

3838

39+
class ShieldsResponse(BaseModel):
40+
"""Model representing a response to shields request."""
41+
42+
shields: list[dict[str, Any]] = Field(
43+
...,
44+
description="List of shields available",
45+
examples=[
46+
{
47+
"identifier": "lightspeed_question_validity-shield",
48+
"provider_resource_id": "lightspeed_question_validity-shield",
49+
"provider_id": "lightspeed_question_validity",
50+
"type": "shield",
51+
"params": {},
52+
}
53+
],
54+
)
55+
56+
3957
class RAGChunk(BaseModel):
4058
"""Model representing a RAG chunk used in the response."""
4159

0 commit comments

Comments
 (0)