1- import { UIMessage , appendResponseMessages , createDataStreamResponse , smoothStream , streamText } from 'ai' ;
1+ import { appendResponseMessages , createDataStreamResponse , DataStreamWriter , UIMessage } from 'ai' ;
22import { notFound } from 'next/navigation' ;
33import { NextRequest } from 'next/server' ;
44import { generateTitleFromUserMessage } from '~/app/(main)/projects/[project]/chats/actions' ;
5- import { generateUUID } from '~/components/chat/utils' ;
6- import { getChatSystemPrompt } from '~/lib/ai/agent' ;
7- import { getLanguageModel } from '~/lib/ai/providers' ;
8- import { getTools } from '~/lib/ai/tools' ;
5+ import { AugmentedLanguageModel } from '~/lib/ai/model' ;
6+ import {
7+ artifactsPrompt ,
8+ awsCloudProviderPrompt ,
9+ chatSystemPrompt ,
10+ commonSystemPrompt ,
11+ gcpCloudProviderPrompt
12+ } from '~/lib/ai/prompts' ;
13+ import { getLanguageModel , getProviderRegistry } from '~/lib/ai/providers' ;
14+ import {
15+ AWSDBClusterTools ,
16+ CommonDBClusterTools ,
17+ commonToolset ,
18+ GCPDBClusterTools ,
19+ getDBSQLTools ,
20+ getPlaybookToolset
21+ } from '~/lib/ai/tools' ;
22+ import { getArtifactTools } from '~/lib/ai/tools/artifacts' ;
23+ import { mcpToolset } from '~/lib/ai/tools/user-mcp' ;
924import { deleteChatById , getChatById , getChatsByProject , saveChat } from '~/lib/db/chats' ;
1025import { getConnection } from '~/lib/db/connections' ;
11- import { getUserSessionDBAccess } from '~/lib/db/db' ;
26+ import { DBAccess , getUserSessionDBAccess } from '~/lib/db/db' ;
1227import { getProjectById } from '~/lib/db/projects' ;
13- import { getTargetDbPool } from '~/lib/targetdb/db' ;
28+ import { Connection , Project } from '~/lib/db/schema' ;
29+ import { getTargetDbPool , Pool } from '~/lib/targetdb/db' ;
1430import { requireUserSession } from '~/utils/route' ;
1531
1632export const maxDuration = 60 ;
@@ -33,6 +49,66 @@ export async function GET(request: NextRequest) {
3349 return Response . json ( { chats } ) ;
3450}
3551
52+ type ChatModelDeps = {
53+ dbAccess : DBAccess ;
54+ project : Project ;
55+ connection : Connection ;
56+ targetDb : Pool ;
57+ useArtifacts : boolean ;
58+ userId : string ;
59+ dataStream ?: DataStreamWriter ;
60+ } ;
61+
62+ const chatModel = new AugmentedLanguageModel < ChatModelDeps > ( {
63+ providerRegistry : getProviderRegistry ,
64+ baseModel : 'chat' ,
65+ metadata : {
66+ tags : [ 'chat' ]
67+ } ,
68+ systemPrompt : [ commonSystemPrompt , chatSystemPrompt ] ,
69+ toolsSets : [
70+ { tools : mcpToolset . listMCPTools } ,
71+ { tools : commonToolset } ,
72+ { tools : async ( { targetDb } ) => getDBSQLTools ( targetDb ) . toolset ( ) } ,
73+
74+ // Playbook support
75+ { tools : async ( { dbAccess, project } ) => getPlaybookToolset ( dbAccess , project . id ) } ,
76+
77+ // Common cloud provider DB support
78+ {
79+ tools : async ( { dbAccess, connection } ) =>
80+ new CommonDBClusterTools ( dbAccess , ( ) => Promise . resolve ( connection ) ) . toolset ( )
81+ }
82+ ]
83+ } ) ;
84+
85+ // AWS cloud provider support
86+ chatModel . addSystemPrompts ( ( { project } ) => ( project . cloudProvider === 'aws' ? awsCloudProviderPrompt : '' ) ) ;
87+ chatModel . addToolSet ( {
88+ active : ( deps ?: ChatModelDeps ) => deps ?. project . cloudProvider === 'aws' ,
89+ tools : async ( { dbAccess, connection } ) => {
90+ return new AWSDBClusterTools ( dbAccess , ( ) => Promise . resolve ( connection ) ) . toolset ( ) ;
91+ }
92+ } ) ;
93+
94+ // GCP cloud provider support
95+ chatModel . addSystemPrompts ( ( { project } ) => ( project . cloudProvider === 'gcp' ? gcpCloudProviderPrompt : '' ) ) ;
96+ chatModel . addToolSet ( {
97+ active : ( deps ?: ChatModelDeps ) => deps ?. project . cloudProvider === 'gcp' ,
98+ tools : async ( { dbAccess, connection } ) => {
99+ return new GCPDBClusterTools ( dbAccess , ( ) => Promise . resolve ( connection ) ) . toolset ( ) ;
100+ }
101+ } ) ;
102+
103+ // Artifacts support
104+ chatModel . addSystemPrompts ( ( { useArtifacts } ) => ( useArtifacts ? artifactsPrompt : '' ) ) ;
105+ chatModel . addToolSet ( {
106+ active : ( deps ?: ChatModelDeps ) => ! ! deps ?. useArtifacts && ! ! deps ?. dataStream ,
107+ tools : async ( { dbAccess, userId, project, dataStream } ) => {
108+ return getArtifactTools ( { dbAccess, userId, projectId : project . id , dataStream : dataStream ! } ) ;
109+ }
110+ } ) ;
111+
36112export async function POST ( request : Request ) {
37113 try {
38114 const { id, messages, connectionId, model : modelId , useArtifacts } = await request . json ( ) ;
@@ -59,34 +135,32 @@ export async function POST(request: Request) {
59135 if ( ! chat ) notFound ( ) ;
60136
61137 const targetDb = getTargetDbPool ( connection . connectionString ) ;
62- const context = getChatSystemPrompt ( { cloudProvider : project . cloudProvider , useArtifacts } ) ;
138+ // const context = getChatSystemPrompt({ cloudProvider: project.cloudProvider, useArtifacts });
63139 const model = await getLanguageModel ( modelId ) ;
64140
65141 return createDataStreamResponse ( {
66142 execute : async ( dataStream ) => {
67- const tools = await getTools ( { project, connection, targetDb, useArtifacts, userId, dataStream } ) ;
68-
69- const result = streamText ( {
143+ const result = await chatModel . streamText ( {
70144 model : model . instance ( ) ,
71- system : context ,
145+ deps : {
146+ dbAccess : await getUserSessionDBAccess ( ) ,
147+ project,
148+ connection,
149+ targetDb,
150+ useArtifacts,
151+ userId,
152+ dataStream
153+ } ,
72154 messages,
73155 maxSteps : 20 ,
74- toolCallStreaming : true ,
75- experimental_transform : smoothStream ( { chunking : 'word' } ) ,
76- experimental_generateMessageId : generateUUID ,
77- experimental_telemetry : {
78- isEnabled : true ,
79- metadata : {
80- projectId : connection . projectId ,
81- connectionId : connectionId ,
82- sessionId : id ,
83- model : model . info ( ) . id ,
84- userId,
85- cloudProvider : project . cloudProvider ,
86- tags : [ 'chat' ]
87- }
156+ metadata : {
157+ projectId : connection . projectId ,
158+ connectionId : connectionId ,
159+ sessionId : id ,
160+ model : model . info ( ) . id ,
161+ userId,
162+ cloudProvider : project . cloudProvider
88163 } ,
89- tools,
90164 onFinish : async ( { response } ) => {
91165 try {
92166 const assistantId = getTrailingMessageId ( {
0 commit comments