Skip to content

Commit 421b385

Browse files
committed
Experiment: model with dependency injection
1 parent 726af8b commit 421b385

File tree

8 files changed

+546
-89
lines changed

8 files changed

+546
-89
lines changed

apps/dbagent/src/app/(main)/projects/[project]/chats/actions.ts

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
'use server';
22

3-
import { generateText, Message } from 'ai';
4-
import { getModelInstance } from '~/lib/ai/agent';
3+
import { Message } from 'ai';
4+
import { AugmentedLanguageModel } from '~/lib/ai/model';
5+
import { getProviderRegistry } from '~/lib/ai/providers';
56
import { deleteMessagesByChatIdAfterTimestamp, getMessageById } from '~/lib/db/chats';
67
import { getUserSessionDBAccess } from '~/lib/db/db';
78

8-
export async function generateTitleFromUserMessage({ message }: { message: Message }) {
9-
const { text: title } = await generateText({
10-
model: await getModelInstance('title'),
11-
experimental_telemetry: {
12-
isEnabled: true,
13-
metadata: {
14-
tags: ['internal', 'chat', 'title']
15-
}
16-
},
17-
system: `\n
9+
const titleModel = new AugmentedLanguageModel({
10+
providerRegistry: getProviderRegistry,
11+
baseModel: 'title',
12+
systemPrompt: `\n
1813
- you will generate a short title based on the first message a user begins a conversation with
1914
- ensure it is not more than 80 characters long
2015
- the title should be a summary of the user's message
2116
- do not use quotes or colons`,
17+
metadata: {
18+
tags: ['internal', 'chat', 'title']
19+
}
20+
});
21+
22+
export async function generateTitleFromUserMessage({ message }: { message: Message }) {
23+
const { text: title } = await titleModel.generateText({
2224
prompt: JSON.stringify(message)
2325
});
24-
2526
return title;
2627
}
2728

apps/dbagent/src/app/api/chat/route.ts

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,32 @@
1-
import { UIMessage, appendResponseMessages, createDataStreamResponse, smoothStream, streamText } from 'ai';
1+
import { appendResponseMessages, createDataStreamResponse, DataStreamWriter, UIMessage } from 'ai';
22
import { notFound } from 'next/navigation';
33
import { NextRequest } from 'next/server';
44
import { generateTitleFromUserMessage } from '~/app/(main)/projects/[project]/chats/actions';
5-
import { generateUUID } from '~/components/chat/utils';
6-
import { getChatSystemPrompt } from '~/lib/ai/agent';
7-
import { getLanguageModel } from '~/lib/ai/providers';
8-
import { getTools } from '~/lib/ai/tools';
5+
import { AugmentedLanguageModel } from '~/lib/ai/model';
6+
import {
7+
artifactsPrompt,
8+
awsCloudProviderPrompt,
9+
chatSystemPrompt,
10+
commonSystemPrompt,
11+
gcpCloudProviderPrompt
12+
} from '~/lib/ai/prompts';
13+
import { getLanguageModel, getProviderRegistry } from '~/lib/ai/providers';
14+
import {
15+
AWSDBClusterTools,
16+
CommonDBClusterTools,
17+
commonToolset,
18+
GCPDBClusterTools,
19+
getDBSQLTools,
20+
getPlaybookToolset
21+
} from '~/lib/ai/tools';
22+
import { getArtifactTools } from '~/lib/ai/tools/artifacts';
23+
import { mcpToolset } from '~/lib/ai/tools/user-mcp';
924
import { deleteChatById, getChatById, getChatsByProject, saveChat } from '~/lib/db/chats';
1025
import { getConnection } from '~/lib/db/connections';
11-
import { getUserSessionDBAccess } from '~/lib/db/db';
26+
import { DBAccess, getUserSessionDBAccess } from '~/lib/db/db';
1227
import { getProjectById } from '~/lib/db/projects';
13-
import { getTargetDbPool } from '~/lib/targetdb/db';
28+
import { Connection, Project } from '~/lib/db/schema';
29+
import { getTargetDbPool, Pool } from '~/lib/targetdb/db';
1430
import { requireUserSession } from '~/utils/route';
1531

1632
export const maxDuration = 60;
@@ -33,6 +49,66 @@ export async function GET(request: NextRequest) {
3349
return Response.json({ chats });
3450
}
3551

52+
type ChatModelDeps = {
53+
dbAccess: DBAccess;
54+
project: Project;
55+
connection: Connection;
56+
targetDb: Pool;
57+
useArtifacts: boolean;
58+
userId: string;
59+
dataStream?: DataStreamWriter;
60+
};
61+
62+
const chatModel = new AugmentedLanguageModel<ChatModelDeps>({
63+
providerRegistry: getProviderRegistry,
64+
baseModel: 'chat',
65+
metadata: {
66+
tags: ['chat']
67+
},
68+
systemPrompt: [commonSystemPrompt, chatSystemPrompt],
69+
toolsSets: [
70+
{ tools: mcpToolset.listMCPTools },
71+
{ tools: commonToolset },
72+
{ tools: async ({ targetDb }) => getDBSQLTools(targetDb).toolset() },
73+
74+
// Playbook support
75+
{ tools: async ({ dbAccess, project }) => getPlaybookToolset(dbAccess, project.id) },
76+
77+
// Common cloud provider DB support
78+
{
79+
tools: async ({ dbAccess, connection }) =>
80+
new CommonDBClusterTools(dbAccess, () => Promise.resolve(connection)).toolset()
81+
}
82+
]
83+
});
84+
85+
// AWS cloud provider support
86+
chatModel.addSystemPrompts(({ project }) => (project.cloudProvider === 'aws' ? awsCloudProviderPrompt : ''));
87+
chatModel.addToolSet({
88+
active: (deps?: ChatModelDeps) => deps?.project.cloudProvider === 'aws',
89+
tools: async ({ dbAccess, connection }) => {
90+
return new AWSDBClusterTools(dbAccess, () => Promise.resolve(connection)).toolset();
91+
}
92+
});
93+
94+
// GCP cloud provider support
95+
chatModel.addSystemPrompts(({ project }) => (project.cloudProvider === 'gcp' ? gcpCloudProviderPrompt : ''));
96+
chatModel.addToolSet({
97+
active: (deps?: ChatModelDeps) => deps?.project.cloudProvider === 'gcp',
98+
tools: async ({ dbAccess, connection }) => {
99+
return new GCPDBClusterTools(dbAccess, () => Promise.resolve(connection)).toolset();
100+
}
101+
});
102+
103+
// Artifacts support
104+
chatModel.addSystemPrompts(({ useArtifacts }) => (useArtifacts ? artifactsPrompt : ''));
105+
chatModel.addToolSet({
106+
active: (deps?: ChatModelDeps) => !!deps?.useArtifacts && !!deps?.dataStream,
107+
tools: async ({ dbAccess, userId, project, dataStream }) => {
108+
return getArtifactTools({ dbAccess, userId, projectId: project.id, dataStream: dataStream! });
109+
}
110+
});
111+
36112
export async function POST(request: Request) {
37113
try {
38114
const { id, messages, connectionId, model: modelId, useArtifacts } = await request.json();
@@ -59,34 +135,32 @@ export async function POST(request: Request) {
59135
if (!chat) notFound();
60136

61137
const targetDb = getTargetDbPool(connection.connectionString);
62-
const context = getChatSystemPrompt({ cloudProvider: project.cloudProvider, useArtifacts });
138+
// const context = getChatSystemPrompt({ cloudProvider: project.cloudProvider, useArtifacts });
63139
const model = await getLanguageModel(modelId);
64140

65141
return createDataStreamResponse({
66142
execute: async (dataStream) => {
67-
const tools = await getTools({ project, connection, targetDb, useArtifacts, userId, dataStream });
68-
69-
const result = streamText({
143+
const result = await chatModel.streamText({
70144
model: model.instance(),
71-
system: context,
145+
deps: {
146+
dbAccess: await getUserSessionDBAccess(),
147+
project,
148+
connection,
149+
targetDb,
150+
useArtifacts,
151+
userId,
152+
dataStream
153+
},
72154
messages,
73155
maxSteps: 20,
74-
toolCallStreaming: true,
75-
experimental_transform: smoothStream({ chunking: 'word' }),
76-
experimental_generateMessageId: generateUUID,
77-
experimental_telemetry: {
78-
isEnabled: true,
79-
metadata: {
80-
projectId: connection.projectId,
81-
connectionId: connectionId,
82-
sessionId: id,
83-
model: model.info().id,
84-
userId,
85-
cloudProvider: project.cloudProvider,
86-
tags: ['chat']
87-
}
156+
metadata: {
157+
projectId: connection.projectId,
158+
connectionId: connectionId,
159+
sessionId: id,
160+
model: model.info().id,
161+
userId,
162+
cloudProvider: project.cloudProvider
88163
},
89-
tools,
90164
onFinish: async ({ response }) => {
91165
try {
92166
const assistantId = getTrailingMessageId({

apps/dbagent/src/components/chat/artifacts/text/server.ts

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,36 @@
1-
import { smoothStream, streamText } from 'ai';
2-
import { getModelInstance } from '~/lib/ai/agent';
1+
import { AugmentedLanguageModel } from '~/lib/ai/model';
32
import { updateDocumentPrompt } from '~/lib/ai/prompts';
3+
import { getProviderRegistry } from '~/lib/ai/providers';
44
import { createDocumentHandler } from '../server';
55

6+
const titleModel = new AugmentedLanguageModel({
7+
providerRegistry: getProviderRegistry,
8+
baseModel: 'title',
9+
systemPrompt: 'Write about the given topic. Markdown is supported. Use headings wherever appropriate.',
10+
metadata: {
11+
tags: ['artifact', 'text', 'create']
12+
}
13+
});
14+
15+
type DocumentUpdateDeps = {
16+
content: string | null;
17+
};
18+
19+
const documentUpdateModel = new AugmentedLanguageModel<DocumentUpdateDeps>({
20+
providerRegistry: getProviderRegistry,
21+
baseModel: 'chat',
22+
systemPrompt: ({ content }) => updateDocumentPrompt(content, 'text'),
23+
metadata: {
24+
tags: ['artifact', 'text', 'update']
25+
}
26+
});
27+
628
export const textDocumentHandler = createDocumentHandler<'text'>({
729
kind: 'text',
830
onCreateDocument: async ({ title, dataStream }) => {
931
let draftContent = '';
1032

11-
const { fullStream } = streamText({
12-
model: await getModelInstance('chat'),
13-
system: 'Write about the given topic. Markdown is supported. Use headings wherever appropriate.',
14-
experimental_transform: smoothStream({ chunking: 'word' }),
15-
experimental_telemetry: {
16-
isEnabled: true,
17-
metadata: {
18-
tags: ['artifact', 'text', 'create']
19-
}
20-
},
21-
prompt: title
22-
});
23-
33+
const { fullStream } = await titleModel.streamText({ prompt: title });
2434
for await (const delta of fullStream) {
2535
const { type } = delta;
2636

@@ -41,16 +51,8 @@ export const textDocumentHandler = createDocumentHandler<'text'>({
4151
onUpdateDocument: async ({ document, description, dataStream }) => {
4252
let draftContent = '';
4353

44-
const { fullStream } = streamText({
45-
model: await getModelInstance('chat'),
46-
system: updateDocumentPrompt(document.content, 'text'),
47-
experimental_transform: smoothStream({ chunking: 'word' }),
48-
experimental_telemetry: {
49-
isEnabled: true,
50-
metadata: {
51-
tags: ['artifact', 'text', 'update']
52-
}
53-
},
54+
const { fullStream } = await documentUpdateModel.streamText({
55+
deps: { content: document.content },
5456
prompt: description,
5557
experimental_providerMetadata: {
5658
openai: {

apps/dbagent/src/components/monitoring/actions.ts

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
'use server';
22

3-
import { openai } from '@ai-sdk/openai';
4-
import { generateText } from 'ai';
53
import { auth } from '~/auth';
4+
import { AugmentedLanguageModel } from '~/lib/ai/model';
5+
import { getProviderRegistry } from '~/lib/ai/providers';
66
import { getUserDBAccess, getUserSessionDBAccess } from '~/lib/db/db';
77
import { getScheduleRuns } from '~/lib/db/schedule-runs';
88
import {
@@ -17,21 +17,19 @@ import { Schedule, ScheduleInsert, ScheduleRun } from '~/lib/db/schema';
1717
import { scheduleGetNextRun, utcToLocalDate } from '~/lib/monitoring/scheduler';
1818
import { listPlaybooks } from '~/lib/tools/playbooks';
1919

20+
const utilModel = new AugmentedLanguageModel({
21+
providerRegistry: getProviderRegistry,
22+
baseModel: 'openai:gpt-4o',
23+
metadata: {
24+
tags: ['internal', 'monitoring', 'cron']
25+
}
26+
});
27+
2028
export async function generateCronExpression(description: string): Promise<string> {
2129
const prompt = `Generate a cron expression for the following schedule description: "${description}".
2230
Return strictly the cron expression, no quotes or anything else.`;
2331

24-
const { text } = await generateText({
25-
model: openai('gpt-4o'),
26-
prompt: prompt,
27-
experimental_telemetry: {
28-
isEnabled: true,
29-
metadata: {
30-
tags: ['internal', 'monitoring', 'cron']
31-
}
32-
}
33-
});
34-
32+
const { text } = await utilModel.generateText({ prompt });
3533
return text.trim();
3634
}
3735

apps/dbagent/src/lib/ai/agent.ts

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
import { LanguageModel } from 'ai';
22
import { CloudProvider } from '../db/schema';
3-
import { artifactsPrompt, chatSystemPrompt, commonSystemPrompt, monitoringSystemPrompt } from './prompts';
3+
import {
4+
artifactsPrompt,
5+
awsCloudProviderPrompt,
6+
chatSystemPrompt,
7+
commonSystemPrompt,
8+
gcpCloudProviderPrompt,
9+
monitoringSystemPrompt
10+
} from './prompts';
411
import { getLanguageModel, getLanguageModelWithFallback, ModelWithFallback } from './providers';
512

613
function getCloudProviderPrompt(cloudProvider: string): string {
714
switch (cloudProvider) {
815
case 'aws':
9-
return `All Postgres instances in this project are hosted on AWS, either RDS or Aurora.
10-
When recommending actions, only recommend actions that can be performed on RDS or Aurora.
11-
If you need to know more about the instance, you can use the getClusterInfo tool.
12-
If you want to recommend changes to the instance, provide instructions specific to RDS or Aurora.
13-
`;
16+
return awsCloudProviderPrompt;
1417
case 'gcp':
15-
return `All Postgres instances in this project are GCP Cloud SQL instances.
16-
When recommending actions, only recommend actions that can be performed on GCP Cloud SQL.
17-
If you need to know more about the instance, you can use the getInstanceInfo tool.
18-
If you want to recommend changes to the instance, provide instructions specific to GCP Cloud SQL.
19-
`;
18+
return gcpCloudProviderPrompt;
2019
default:
2120
return '';
2221
}

0 commit comments

Comments
 (0)