Skip to content

Commit 7ec2018

Browse files
committed
multichat
1 parent d86ef69 commit 7ec2018

File tree

4 files changed

+140
-42
lines changed

4 files changed

+140
-42
lines changed

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,16 @@ Run the docker image:
4141
docker run -d --name ai_container -p 8080:8080 fastapi_bitnet
4242
```
4343

44-
Once it's running navigate to http://127.0.0.1:8080/docs
44+
Once it's running navigate to http://127.0.0.1:8080/docs
45+
46+
## Docker hub repository
47+
48+
You can fetch the dockerfile at: https://hub.docker.com/repository/docker/grctest/fastapi_bitnet/general
49+
50+
## How to add to VSCode!
51+
52+
Run the dockerfile locally using the command above, then navigate to the VSCode Copilot chat window and find the wrench icon "Configure Tools...".
53+
54+
In the tool configuration overview scroll to the bottom and select 'Add more tools...' then '+ Add MCP Server' then 'HTTP'.
55+
56+
Enter into the URL field `http://127.0.0.1:8080/mcp` then your copilot will be able to launch new bitnet server instances and chat with them.

app/lib/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .endpoints import ChatRequest
2+
from typing import List
3+
from pydantic import BaseModel
4+
5+
__all__ = ["ChatRequest", "MultiChatRequest"]
6+
7+
# Re-export for import convenience
8+
class MultiChatRequest(BaseModel):
9+
requests: List[ChatRequest]

app/lib/endpoints.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# --- Braincell Orchestrator (Middleman Proxy) ---
1+
# --- bitnet Orchestrator (Middleman Proxy) ---
22
from pydantic import BaseModel
33

44
from fastapi import FastAPI, HTTPException, Query, Depends
@@ -10,6 +10,11 @@
1010
import time
1111
import httpx
1212

13+
from typing import List
14+
from pydantic import BaseModel, Field
15+
from fastapi import HTTPException
16+
import asyncio
17+
1318
# --- Server Process Management ---
1419
# Each server instance is tracked by a unique (host, port) key
1520
server_processes = {}
@@ -40,12 +45,11 @@ def _max_threads():
4045
return os.cpu_count() or 1
4146

4247
async def initialize_server_endpoint(
43-
model: ModelEnum,
44-
threads: int = Query(os.cpu_count() // 2, gt=0, le=os.cpu_count()),
48+
threads: int = Query(1, gt=0, le=os.cpu_count()),
4549
ctx_size: int = Query(2048, gt=0),
46-
port: int = Query(8081, gt=1023),
50+
port: int = Query(8081, gt=8080, le=65535),
4751
system_prompt: str = Query("You are a helpful assistant.", description="Unique system prompt for this server instance"),
48-
n_predict: int = Query(4096, gt=0, description="Number of tokens to predict for the server instance"),
52+
n_predict: int = Query(256, gt=0, description="Number of tokens to predict for the server instance."),
4953
temperature: float = Query(0.8, gt=0.0, le=2.0, description="Temperature for sampling")
5054
):
5155
"""
@@ -71,7 +75,7 @@ async def initialize_server_endpoint(
7175
raise HTTPException(status_code=429, detail=f"Cannot start server: would oversubscribe CPU threads (in use: {threads_in_use}, requested: {threads}, max: {max_threads})")
7276
command = [
7377
server_path,
74-
'-m', model.value,
78+
'-m', "models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf",
7579
'-c', str(ctx_size),
7680
'-t', str(threads),
7781
'-n', str(n_predict),
@@ -96,7 +100,7 @@ async def initialize_server_endpoint(
96100
raise HTTPException(status_code=500, detail=f"Server failed to start. Stderr: {stderr_output}")
97101
server_processes[key] = proc
98102
server_configs[key] = {
99-
"model": model.value,
103+
"model": "models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf",
100104
"threads": threads,
101105
"ctx_size": ctx_size,
102106
"host": host,
@@ -241,43 +245,70 @@ def get_model_sizes():
241245

242246
class ChatRequest(BaseModel):
243247
message: str
244-
port: int
245-
# Optionally add user/session id, etc.
248+
port: int = 8081
249+
threads: int = 1
250+
ctx_size: int = 2048
251+
n_predict: int = 256
252+
temperature: float = 0.8
246253

247-
def chat_with_braincell(
254+
def chat_with_bitnet(
248255
chat: ChatRequest
249256
):
250257
"""
251-
Middleman endpoint: receives a chat message and forwards it to the specified braincell (llama server instance) by port.
252-
Returns the response from the braincell.
258+
Middleman endpoint: receives a chat message and forwards it to the specified bitnet (llama server instance) by port.
259+
Returns the response from the bitnet.
253260
"""
254261
host = "127.0.0.1"
255262
key = (host, chat.port)
256263
proc = server_processes.get(key)
257264
cfg = server_configs.get(key)
258265
if not (proc and proc.poll() is None and cfg):
259-
raise HTTPException(status_code=503, detail=f"Braincell server not running on {host}:{chat.port}. Initialize it first.")
266+
raise HTTPException(status_code=503, detail=f"bitnet server not running on {host}:{chat.port}. Initialize it first.")
260267
server_url = f"http://{host}:{chat.port}/completion"
261268
payload = {
262-
"prompt": chat.message
269+
"prompt": chat.message,
270+
"threads": chat.threads,
271+
"ctx_size": chat.ctx_size,
272+
"n_predict": chat.n_predict,
273+
"temperature": chat.temperature
263274
}
264275
async def _chat():
265276
async with httpx.AsyncClient() as client:
266277
try:
267-
response = await client.post(server_url, json=payload, timeout=120.0)
278+
response = await client.post(server_url, json=payload, timeout=180.0)
268279
response.raise_for_status()
269280
result_data = response.json()
270281
content = result_data.get("content", result_data)
271282
return {"result": content}
272283
except httpx.TimeoutException:
273-
raise HTTPException(status_code=504, detail="Request to braincell server timed out.")
284+
raise HTTPException(status_code=504, detail="Request to bitnet server timed out.")
274285
except httpx.ConnectError:
275-
raise HTTPException(status_code=503, detail=f"Could not connect to braincell server at {server_url}. Is it running?")
286+
raise HTTPException(status_code=503, detail=f"Could not connect to bitnet server at {server_url}. Is it running?")
276287
except httpx.RequestError as e:
277-
raise HTTPException(status_code=500, detail=f"Error during request to braincell server: {str(e)}")
288+
raise HTTPException(status_code=500, detail=f"Error during request to bitnet server: {str(e)}")
278289
except httpx.HTTPStatusError as e:
279290
error_detail = e.response.text or str(e)
280-
raise HTTPException(status_code=e.response.status_code, detail=f"Braincell server returned error: {error_detail}")
291+
raise HTTPException(status_code=e.response.status_code, detail=f"bitnet server returned error: {error_detail}")
281292
except Exception as e:
282293
raise HTTPException(status_code=500, detail=f"Unexpected error during chat: {str(e)}")
283294
return _chat
295+
296+
class MultiChatRequest(BaseModel):
297+
requests: List[ChatRequest]
298+
299+
async def multichat_with_bitnet(multichat: MultiChatRequest):
300+
async def run_chat(chat_req: ChatRequest):
301+
chat_fn = chat_with_bitnet(chat_req)
302+
return await chat_fn()
303+
results = await asyncio.gather(*(run_chat(req) for req in multichat.requests), return_exceptions=True)
304+
# Format results: if exception, return error message
305+
formatted = []
306+
for res in results:
307+
if isinstance(res, Exception):
308+
if isinstance(res, HTTPException):
309+
formatted.append({"error": res.detail, "status_code": res.status_code})
310+
else:
311+
formatted.append({"error": str(res)})
312+
else:
313+
formatted.append(res)
314+
return {"results": formatted}

app/main.py

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,59 @@
33
from fastapi_mcp import FastApiMCP
44
from lib.models import ModelEnum
55
import lib.endpoints as endpoints
6-
from lib.endpoints import chat_with_braincell, ChatRequest
6+
from lib.endpoints import chat_with_bitnet, ChatRequest, multichat_with_bitnet, MultiChatRequest
7+
import traceback
78

89
app = FastAPI()
910

10-
# Wrap with MCP for Model Context Protocol support
11-
mcp = FastApiMCP(app)
12-
13-
# Mount the MCP server directly to your FastAPI app
14-
mcp.mount()
15-
1611
@app.post("/initialize-server")
1712
async def initialize_server(
18-
model: ModelEnum,
1913
threads: int = Query(os.cpu_count() // 2, gt=0, le=os.cpu_count()),
2014
ctx_size: int = Query(2048, gt=0),
2115
port: int = Query(8081, gt=1023),
2216
system_prompt: str = Query("You are a helpful assistant.", description="Unique system prompt for this server instance"),
2317
n_predict: int = Query(4096, gt=0, description="Number of tokens to predict for the server instance"),
2418
temperature: float = Query(0.8, gt=0.0, le=2.0, description="Temperature for sampling")
2519
):
26-
return await endpoints.initialize_server_endpoint(
27-
model=model,
28-
threads=threads,
29-
ctx_size=ctx_size,
30-
port=port,
31-
system_prompt=system_prompt,
32-
n_predict=n_predict,
33-
temperature=temperature
34-
)
20+
try:
21+
return await endpoints.initialize_server_endpoint(
22+
threads=threads,
23+
ctx_size=ctx_size,
24+
port=port,
25+
system_prompt=system_prompt,
26+
n_predict=n_predict,
27+
temperature=temperature
28+
)
29+
except Exception as e:
30+
print(traceback.format_exc())
31+
raise HTTPException(status_code=500, detail=str(e))
32+
33+
def _max_threads():
34+
return os.cpu_count() or 1
35+
36+
# --- Server Initialization and Shutdown Endpoints ---
37+
def validate_thread_allocation(requests):
38+
max_threads = _max_threads()
39+
total_requested = sum(req["threads"] for req in requests)
40+
for req in requests:
41+
if req["threads"] > max_threads:
42+
raise HTTPException(
43+
status_code=400,
44+
detail=f"Requested {req['threads']} threads for a server, but only {max_threads} are available."
45+
)
46+
if total_requested > max_threads:
47+
raise HTTPException(
48+
status_code=400,
49+
detail=f"Total requested threads ({total_requested}) exceed available threads ({max_threads})."
50+
)
3551

3652
@app.post("/shutdown-server")
3753
async def shutdown_server(port: int = Query(8081, gt=1023)):
38-
return await endpoints.shutdown_server_endpoint(port=port)
54+
try:
55+
return await endpoints.shutdown_server_endpoint(port=port)
56+
except Exception as e:
57+
print(traceback.format_exc())
58+
raise HTTPException(status_code=500, detail=str(e))
3959

4060
@app.get("/server-status")
4161
async def server_status_endpoint(port: int = Query(8081, gt=1023)): # Renamed for clarity
@@ -48,7 +68,11 @@ async def benchmark(
4868
threads: int = Query(2, gt=0, le=os.cpu_count()),
4969
n_prompt: int = Query(32, gt=0)
5070
):
51-
return await endpoints.run_benchmark(model, n_token, threads, n_prompt)
71+
try:
72+
return await endpoints.run_benchmark(model, n_token, threads, n_prompt)
73+
except Exception as e:
74+
print(traceback.format_exc())
75+
raise HTTPException(status_code=500, detail=str(e))
5276

5377
@app.get("/perplexity")
5478
async def perplexity(
@@ -58,13 +82,35 @@ async def perplexity(
5882
ctx_size: int = Query(4, gt=0),
5983
ppl_stride: int = Query(0, ge=0)
6084
):
61-
return await endpoints.run_perplexity(model, prompt, threads, ctx_size, ppl_stride)
85+
try:
86+
return await endpoints.run_perplexity(model, prompt, threads, ctx_size, ppl_stride)
87+
except Exception as e:
88+
print(traceback.format_exc())
89+
raise HTTPException(status_code=500, detail=str(e))
6290

6391
@app.get("/model-sizes")
6492
def model_sizes():
6593
return endpoints.get_model_sizes()
6694

6795
@app.post("/chat")
6896
async def chat(chat: ChatRequest):
69-
chat_fn = chat_with_braincell(chat)
70-
return await chat_fn()
97+
try:
98+
return await chat_with_bitnet(chat)
99+
except Exception as e:
100+
print(traceback.format_exc())
101+
raise HTTPException(status_code=500, detail=str(e))
102+
103+
# Parallel multi-chat endpoint
104+
@app.post("/multichat")
105+
async def multichat(multichat: MultiChatRequest):
106+
try:
107+
return await multichat_with_bitnet(multichat)
108+
except Exception as e:
109+
print(traceback.format_exc())
110+
raise HTTPException(status_code=500, detail=str(e))
111+
112+
# Wrap with MCP for Model Context Protocol support
113+
mcp = FastApiMCP(app)
114+
115+
# Mount the MCP server directly to your FastAPI app
116+
mcp.mount()

0 commit comments

Comments
 (0)