Skip to content

Commit 9c896bf

Browse files
committed
Add model selector
1 parent 3bacf72 commit 9c896bf

File tree

6 files changed

+205
-85
lines changed

6 files changed

+205
-85
lines changed

backend/text_playground/jurassic2.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import boto3
2+
import json
3+
4+
bedrock_runtime = boto3.client(
5+
service_name="bedrock-runtime",
6+
region_name="us-east-1",
7+
)
8+
9+
def invoke(prompt, temperature, max_tokens):
10+
prompt_config = {
11+
"prompt": prompt,
12+
"maxTokens": max_tokens,
13+
"temperature": temperature
14+
}
15+
16+
response = bedrock_runtime.invoke_model(
17+
body=json.dumps(prompt_config),
18+
modelId="ai21.j2-mid-v1"
19+
)
20+
21+
response_body = json.loads(response.get("body").read())
22+
23+
completion = response_body["completions"][0]["data"]["text"]
24+
if completion.startswith("\n"):
25+
completion = completion[1:]
26+
27+
return completion

backend/text_playground/models.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
from pydantic import BaseModel
22

3-
class TextRequest(BaseModel):
3+
class ClaudeRequest(BaseModel):
44
prompt: str
5-
temperature: float
6-
maxTokens: int
5+
# Randomness and diversity
6+
# min: 0, max: 1, default: 0.5
7+
temperature: float = 0.5
8+
# Length
9+
# min: 0, max: 4096, default: 200
10+
maxTokens: int = 200
711

812
class TextResponse(BaseModel):
9-
completion: str
13+
completion: str
14+
15+
class Jurassic2Request(BaseModel):
16+
prompt: str
17+
# Randomness and diversity
18+
# min: 0, max: 1, default: 0.5
19+
temperature: float = 0.5
20+
# Length
21+
# min: 0, max: 8191, default: 200
22+
maxTokens: int = 200

backend/text_playground/routes.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
from fastapi import APIRouter
22
from . import models
3-
from . import services
3+
from . import claude
4+
from . import jurassic2
5+
46

57
router = APIRouter()
68

79

810
@router.post("/foundation-models/model/text/anthropic.claude-v2/invoke")
9-
def invoke(body: models.TextRequest):
10-
completion = services.invoke(body.prompt, body.temperature, body.maxTokens)
11+
def invoke(body: models.ClaudeRequest):
12+
completion = claude.invoke(body.prompt, body.temperature, body.maxTokens)
1113

1214
return models.TextResponse(
1315
completion=completion
1416
)
17+
18+
@router.post("/foundation-models/model/text/ai21.j2-mid-v1/invoke")
19+
def new_route_function(request: models.Jurassic2Request):
20+
completion = jurassic2.invoke(request.prompt, request.temperature, request.maxTokens)
21+
22+
return models.TextResponse(
23+
completion=completion
24+
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import React, {useState} from "react";
2+
3+
export default function ModelSelector({ onModelChange }) {
4+
const defaultModel = {
5+
modelName: "Anthropic Claude V2",
6+
modelId: "anthropic.claude-v2",
7+
};
8+
9+
const [isOpen, setIsOpen] = useState(false);
10+
const [selectedModel, setSelectedModel] = useState(defaultModel);
11+
12+
const models = [
13+
defaultModel,
14+
{
15+
modelName: "AI21 Labs Jurassic-2",
16+
modelId: "ai21.j2-mid-v1",
17+
}
18+
];
19+
20+
const toggleDropdown = () => {
21+
setIsOpen(!isOpen);
22+
};
23+
24+
const selectModel = (item) => {
25+
setSelectedModel(item);
26+
setIsOpen(false);
27+
onModelChange(item);
28+
}
29+
30+
return (
31+
<div className="w-64 ml-4">
32+
<div className="relative w-full">
33+
<button id="dropdown-button"
34+
onClick={toggleDropdown}
35+
className="inline-flex justify-center w-full px-4 py-2 text-sm font-medium text-gray-700 bg-white border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-offset-gray-100 focus:ring-blue-500">
36+
<span className="mr-2">{selectedModel.modelName}</span>
37+
<svg xmlns="http://www.w3.org/2000/svg"
38+
className="w-5 h-5 ml-2 -mr-1"
39+
viewBox="0 0 20 20"
40+
fill="currentColor"
41+
aria-hidden="true">
42+
<path fillRule="evenodd"
43+
d="M6.293 9.293a1 1 0 011.414 0L10 11.586l2.293-2.293a1 1 0 111.414 1.414l-3 3a1 1 0 01-1.414 0l-3-3a1 1 0 010-1.414z"
44+
clipRule="evenodd" />
45+
</svg>
46+
</button>
47+
{isOpen && (
48+
<div className="absolute right-0 mt-2 rounded-md shadow-lg bg-white ring-1 ring-black ring-opacity-5 p-1 text-sm w-64">
49+
{models.map((item, index) => (
50+
<a key={index}
51+
onClick={() => selectModel(item)}
52+
href="#"
53+
className="block px-4 py-2 text-gray-700 hover:bg-gray-100 active:bg-blue-100 cursor-pointer rounded-md">
54+
{item.modelName}
55+
</a>
56+
))}
57+
</div>
58+
)}
59+
</div>
60+
</div>
61+
);
62+
};

frontend/components/text/TextContainer.jsx

Lines changed: 86 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22

33
import React, { useState } from "react";
44
import GlobalConfig from "@/app/app.config"
5+
import ModelSelector from "./ModelSelector";
56

67
export default function TextContainer() {
78
const [inputValue, setInputValue] = useState("");
89
const [isLoading, setIsLoading] = useState(false);
910
const [temperatureValue, setTemperatureValue] = useState(0.8);
1011
const [maxTokensValue, setMaxTokensValue] = useState(300);
1112

13+
const onModelChange = (newModel) => {
14+
console.log('Model changed to:', newModel);
15+
}
16+
1217
const handleInputChange = (e) => {
1318
setInputValue(e.target.value);
1419
};
@@ -110,94 +115,97 @@ export default function TextContainer() {
110115
}
111116
};
112117

113-
return <div className="flex flex-col flex-auto h-full p-6">
114-
<h3 className="text-3xl font-medium text-gray-700">Text Playground (Anthropic Claude V2)</h3>
115-
<div className="flex flex-col flex-shrink-0 rounded-2xl bg-gray-100 p-4 mt-8">
116-
<div className="flex flex-col h-full overflow-x-auto mb-4">
117-
<div className="flex flex-col h-full">
118-
<div className="mb-4 w-full bg-gray-50 rounded-lg border border-gray-200">
119-
<div className="p-0 bg-white rounded-xl">
120-
<textarea id="input" rows="20"
121-
disabled={isLoading}
122-
value={inputValue}
123-
onChange={handleInputChange}
124-
className="block p-4 w-full text-sm text-gray-800 bg-white"
125-
placeholder="Write something..." required>
126-
</textarea>
118+
return (
119+
<div className="flex flex-col flex-auto h-full p-6">
120+
<ModelSelector onModelChange={ onModelChange } />
121+
<h3 className="text-3xl font-medium text-gray-700">Text Playground (Anthropic Claude V2)</h3>
122+
<div className="flex flex-col flex-shrink-0 rounded-2xl bg-gray-100 p-4 mt-8">
123+
<div className="flex flex-col h-full overflow-x-auto mb-4">
124+
<div className="flex flex-col h-full">
125+
<div className="mb-4 w-full bg-gray-50 rounded-lg border border-gray-200">
126+
<div className="p-0 bg-white rounded-xl">
127+
<textarea id="input" rows="20"
128+
disabled={isLoading}
129+
value={inputValue}
130+
onChange={handleInputChange}
131+
className="block p-4 w-full text-sm text-gray-800 bg-white"
132+
placeholder="Write something..." required>
133+
</textarea>
134+
</div>
127135
</div>
128136
</div>
129137
</div>
130-
</div>
131-
<div className="flex flex-row items-center h-16 rounded-xl bg-white w-full px-4">
132138
<div className="flex flex-row items-center h-16 rounded-xl bg-white w-full px-4">
133-
<div className="">
134-
<div className="relative w-full">
135-
<label htmlFor="temperature">
136-
Temperature:
137-
</label>
139+
<div className="flex flex-row items-center h-16 rounded-xl bg-white w-full px-4">
140+
<div className="">
141+
<div className="relative w-full">
142+
<label htmlFor="temperature">
143+
Temperature:
144+
</label>
145+
</div>
138146
</div>
139-
</div>
140-
<div className="ml-4">
141-
<div className="relative w-14">
142-
<input
143-
placeholder="0.8"
144-
id="temperature"
145-
type="text"
146-
value={temperatureValue}
147-
onChange={handleTemperatureValueChange}
148-
onBlur={handleTemperatureValueBlur}
149-
className="flex w-full border rounded-xl focus:outline-none focus:border-indigo-300 pl-4 h-10"
150-
/>
151-
147+
<div className="ml-4">
148+
<div className="relative w-14">
149+
<input
150+
placeholder="0.8"
151+
id="temperature"
152+
type="text"
153+
value={temperatureValue}
154+
onChange={handleTemperatureValueChange}
155+
onBlur={handleTemperatureValueBlur}
156+
className="flex w-full border rounded-xl focus:outline-none focus:border-indigo-300 pl-4 h-10"
157+
/>
158+
159+
</div>
152160
</div>
153-
</div>
154-
<div className="ml-8">
155-
<div className="relative">
156-
<label htmlFor="tokens">
157-
Max. length:
158-
</label>
161+
<div className="ml-8">
162+
<div className="relative">
163+
<label htmlFor="tokens">
164+
Max. length:
165+
</label>
166+
</div>
159167
</div>
160-
</div>
161-
<div className="ml-4">
162-
<div className="relative w-20">
163-
<input
164-
placeholder="300"
165-
id="tokens"
166-
type="text"
167-
value={maxTokensValue}
168-
onChange={handleMaxTokensValueChange}
169-
onBlur={handleMaxTokensValueBlur}
170-
className="flex w-full border rounded-xl focus:outline-none focus:border-indigo-300 pl-4 h-10"
171-
/>
172-
168+
<div className="ml-4">
169+
<div className="relative w-20">
170+
<input
171+
placeholder="300"
172+
id="tokens"
173+
type="text"
174+
value={maxTokensValue}
175+
onChange={handleMaxTokensValueChange}
176+
onBlur={handleMaxTokensValueBlur}
177+
className="flex w-full border rounded-xl focus:outline-none focus:border-indigo-300 pl-4 h-10"
178+
/>
179+
180+
</div>
181+
</div>
182+
<div className="ml-4 ml-auto">
183+
<button
184+
type="button"
185+
disabled={isLoading}
186+
onClick={sendMessage}
187+
className={getButtonClass()}>
188+
<span>Send</span>
189+
<span className="ml-2">
190+
<svg
191+
className="w-4 h-4 transform rotate-45 -mt-px"
192+
fill="none"
193+
stroke="currentColor"
194+
viewBox="0 0 24 24"
195+
xmlns="http://www.w3.org/2000/svg">
196+
<path
197+
strokeLinecap="round"
198+
strokeLinejoin="round"
199+
strokeWidth="2"
200+
d="M12 19l9 2-9-18-9 18 9-2zm0 0v-8">
201+
</path>
202+
</svg>
203+
</span>
204+
</button>
173205
</div>
174-
</div>
175-
<div className="ml-4 ml-auto">
176-
<button
177-
type="button"
178-
disabled={isLoading}
179-
onClick={sendMessage}
180-
className={getButtonClass()}>
181-
<span>Send</span>
182-
<span className="ml-2">
183-
<svg
184-
className="w-4 h-4 transform rotate-45 -mt-px"
185-
fill="none"
186-
stroke="currentColor"
187-
viewBox="0 0 24 24"
188-
xmlns="http://www.w3.org/2000/svg">
189-
<path
190-
strokeLinecap="round"
191-
strokeLinejoin="round"
192-
strokeWidth="2"
193-
d="M12 19l9 2-9-18-9 18 9-2zm0 0v-8">
194-
</path>
195-
</svg>
196-
</span>
197-
</button>
198206
</div>
199207
</div>
200208
</div>
201209
</div>
202-
</div>
210+
)
203211
};

0 commit comments

Comments
 (0)