Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/main/presenter/llmProviderPresenter/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ export class LLMProviderPresenter implements ILlmProviderPresenter {
modelId: string,
eventId: string,
temperature: number = 0.6,
maxTokens: number = 4096
maxTokens: number = 4096,
enabledMcpTools?: string[]
): AsyncGenerator<LLMAgentEvent, void, unknown> {
console.log(`[Agent Loop] Starting agent loop for event: ${eventId} with model: ${modelId}`)
if (!this.canStartNewStream()) {
Expand Down Expand Up @@ -371,7 +372,7 @@ export class LLMProviderPresenter implements ILlmProviderPresenter {

try {
console.log(`[Agent Loop] Iteration ${toolCallCount + 1} for event: ${eventId}`)
const mcpTools = await presenter.mcpPresenter.getAllToolDefinitions()
const mcpTools = await presenter.mcpPresenter.getAllToolDefinitions(enabledMcpTools)
// Call the provider's core stream method, expecting LLMCoreStreamEvent
const stream = provider.coreStream(
conversationMessages,
Expand Down Expand Up @@ -591,9 +592,9 @@ export class LLMProviderPresenter implements ILlmProviderPresenter {
toolCallCount++

// Find the tool definition to get server info
const toolDef = (await presenter.mcpPresenter.getAllToolDefinitions()).find(
(t) => t.function.name === toolCall.name
)
const toolDef = (
await presenter.mcpPresenter.getAllToolDefinitions(enabledMcpTools)
).find((t) => t.function.name === toolCall.name)

if (!toolDef) {
console.error(`Tool definition not found for ${toolCall.name}. Skipping execution.`)
Expand Down
5 changes: 2 additions & 3 deletions src/main/presenter/mcpPresenter/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,10 @@ export class McpPresenter implements IMCPPresenter {
// 通知渲染进程服务器已停止
eventBus.send(MCP_EVENTS.SERVER_STOPPED, SendTarget.ALL_WINDOWS, serverName)
}

async getAllToolDefinitions(): Promise<MCPToolDefinition[]> {
async getAllToolDefinitions(enabledMcpTools?: string[]): Promise<MCPToolDefinition[]> {
const enabled = await this.configPresenter.getMcpEnabled()
if (enabled) {
return this.toolManager.getAllToolDefinitions()
return await this.toolManager.getAllToolDefinitions(enabledMcpTools)
}
return []
}
Expand Down
21 changes: 19 additions & 2 deletions src/main/presenter/mcpPresenter/toolManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,17 @@ export class ToolManager {
public async getRunningClients(): Promise<McpClient[]> {
return this.serverManager.getRunningClients()
}

// 获取所有工具定义
public async getAllToolDefinitions(): Promise<MCPToolDefinition[]> {
public async getAllToolDefinitions(enabledTools?: string[]): Promise<MCPToolDefinition[]> {
if (this.cachedToolDefinitions !== null && this.cachedToolDefinitions.length > 0) {
if (enabledTools) {
const enabledSet = new Set(enabledTools)
return this.cachedToolDefinitions.filter((toolDef) => {
const finalName = toolDef.function.name
const originalName = this.toolNameToTargetMap?.get(finalName)?.originalName || finalName
return enabledSet.has(finalName) || enabledSet.has(originalName)
})
}
return this.cachedToolDefinitions
}

Expand Down Expand Up @@ -200,6 +207,16 @@ export class ToolManager {
// 缓存结果并返回
this.cachedToolDefinitions = results
console.info(`Cached ${results.length} final tool definitions and populated target map.`)

if (enabledTools && enabledTools.length > 0) {
const enabledSet = new Set(enabledTools)
return this.cachedToolDefinitions.filter((toolDef) => {
const finalName = toolDef.function.name
const originalName = this.toolNameToTargetMap?.get(finalName)?.originalName || finalName
return enabledSet.has(finalName) || enabledSet.has(originalName)
})
}

return this.cachedToolDefinitions
}

Expand Down
44 changes: 34 additions & 10 deletions src/main/presenter/sqlitePresenter/tables/conversations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ type ConversationRow = {
artifacts: number
is_new: number
is_pinned: number
enabled_mcp_tools: string | null
}

// 解析 JSON 字段
function getJsonField<T>(val: string | null | undefined, fallback: T): T {
try {
return val ? JSON.parse(val) : fallback
} catch {
return fallback
}
}

export class ConversationsTable extends BaseTable {
Expand Down Expand Up @@ -46,7 +56,6 @@ export class ConversationsTable extends BaseTable {
CREATE INDEX idx_conversations_pinned ON conversations(is_pinned);
`
}

getMigrationSQL(version: number): string | null {
if (version === 1) {
return `
Expand All @@ -67,11 +76,17 @@ export class ConversationsTable extends BaseTable {
UPDATE conversations SET artifacts = 0;
`
}
if (version === 3) {
return `
ALTER TABLE conversations ADD COLUMN enabled_mcp_tools TEXT DEFAULT '[]';
`
}

return null
}

getLatestVersion(): number {
return 2
return 3
}

async create(title: string, settings: Partial<CONVERSATION_SETTINGS> = {}): Promise<string> {
Expand All @@ -89,9 +104,10 @@ export class ConversationsTable extends BaseTable {
model_id,
is_new,
artifacts,
is_pinned
is_pinned,
enabled_mcp_tools
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ,?)
`)
const conv_id = nanoid()
const now = Date.now()
Expand All @@ -108,7 +124,8 @@ export class ConversationsTable extends BaseTable {
settings.modelId || 'gpt-4',
1,
settings.artifacts || 0,
0 // Default is_pinned to 0
0, // Default is_pinned to 0
settings.enabledMcpTools ? JSON.stringify(settings.enabledMcpTools) : '[]'
)
return conv_id
}
Expand All @@ -130,7 +147,8 @@ export class ConversationsTable extends BaseTable {
model_id as modelId,
is_new,
artifacts,
is_pinned
is_pinned,
enabled_mcp_tools
FROM conversations
WHERE conv_id = ?
`
Expand All @@ -155,7 +173,8 @@ export class ConversationsTable extends BaseTable {
maxTokens: result.maxTokens,
providerId: result.providerId,
modelId: result.modelId,
artifacts: result.artifacts as 0 | 1
artifacts: result.artifacts as 0 | 1,
enabledMcpTools: getJsonField(result.enabled_mcp_tools, [])
}
}
}
Expand Down Expand Up @@ -208,8 +227,11 @@ export class ConversationsTable extends BaseTable {
updates.push('artifacts = ?')
params.push(data.settings.artifacts)
}
if (data.settings.enabledMcpTools !== undefined) {
updates.push('enabled_mcp_tools = ?')
params.push(JSON.stringify(data.settings.enabledMcpTools))
}
}

if (updates.length > 0 || data.updatedAt) {
updates.push('updated_at = ?')
params.push(data.updatedAt || Date.now())
Expand Down Expand Up @@ -252,7 +274,8 @@ export class ConversationsTable extends BaseTable {
model_id as modelId,
is_new,
artifacts,
is_pinned
is_pinned,
enabled_mcp_tools
FROM conversations
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
Expand All @@ -276,7 +299,8 @@ export class ConversationsTable extends BaseTable {
maxTokens: row.maxTokens,
providerId: row.providerId,
modelId: row.modelId,
artifacts: row.artifacts as 0 | 1
artifacts: row.artifacts as 0 | 1,
enabledMcpTools: getJsonField(row.enabled_mcp_tools, [])
}
}))
}
Expand Down
22 changes: 12 additions & 10 deletions src/main/presenter/threadPresenter/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,6 @@ export class ThreadPresenter implements IThreadPresenter {

return conversation
}

async createConversation(
title: string,
settings: Partial<CONVERSATION_SETTINGS> = {},
Expand Down Expand Up @@ -1467,16 +1466,17 @@ export class ThreadPresenter implements IThreadPresenter {
providerId: currentProviderId,
modelId: currentModelId,
temperature: currentTemperature,
maxTokens: currentMaxTokens
maxTokens: currentMaxTokens,
enabledMcpTools: crrentEnabledMcpTools
} = currentConversation.settings

const stream = this.llmProviderPresenter.startStreamCompletion(
currentProviderId, // 使用最新的设置
finalContent,
currentModelId, // 使用最新的设置
state.message.id,
currentTemperature, // 使用最新的设置
currentMaxTokens // 使用最新的设置
currentMaxTokens, // 使用最新的设置
crrentEnabledMcpTools
)
for await (const event of stream) {
const msg = event.data
Expand Down Expand Up @@ -1574,7 +1574,7 @@ export class ThreadPresenter implements IThreadPresenter {
this.throwIfCancelled(state.message.id)

// 7. 准备提示内容
const { providerId, modelId, temperature, maxTokens } = conversation.settings
const { providerId, modelId, temperature, maxTokens, enabledMcpTools } = conversation.settings
const modelConfig = this.configPresenter.getModelConfig(modelId, providerId)

const { finalContent, promptTokens } = await this.preparePromptContent(
Expand Down Expand Up @@ -1641,7 +1641,8 @@ export class ThreadPresenter implements IThreadPresenter {
modelId,
state.message.id,
temperature,
maxTokens
maxTokens,
enabledMcpTools
)
for await (const event of stream) {
const msg = event.data
Expand Down Expand Up @@ -1789,7 +1790,7 @@ export class ThreadPresenter implements IThreadPresenter {
finalContent: ChatMessage[]
promptTokens: number
}> {
const { systemPrompt, contextLength, artifacts } = conversation.settings
const { systemPrompt, contextLength, artifacts, enabledMcpTools } = conversation.settings

const searchPrompt = searchResults ? generateSearchPrompt(userContent, searchResults) : ''
const enrichedUserMessage =
Expand All @@ -1801,7 +1802,7 @@ export class ThreadPresenter implements IThreadPresenter {
const searchPromptTokens = searchPrompt ? approximateTokenSize(searchPrompt ?? '') : 0
const systemPromptTokens = systemPrompt ? approximateTokenSize(systemPrompt ?? '') : 0
const userMessageTokens = approximateTokenSize(userContent + enrichedUserMessage)
const mcpTools = await presenter.mcpPresenter.getAllToolDefinitions()
const mcpTools = await presenter.mcpPresenter.getAllToolDefinitions(enabledMcpTools)
const mcpToolsTokens = mcpTools.reduce(
(acc, tool) => acc + approximateTokenSize(JSON.stringify(tool)),
0
Expand Down Expand Up @@ -3049,7 +3050,7 @@ export class ThreadPresenter implements IThreadPresenter {
throw new Error(errorMsg)
}

const { providerId, modelId, temperature, maxTokens } = conversation.settings
const { providerId, modelId, temperature, maxTokens, enabledMcpTools } = conversation.settings
const modelConfig = this.configPresenter.getModelConfig(modelId, providerId)

if (!modelConfig) {
Expand Down Expand Up @@ -3102,7 +3103,8 @@ export class ThreadPresenter implements IThreadPresenter {
modelId,
messageId,
temperature,
maxTokens
maxTokens,
enabledMcpTools
)

for await (const event of stream) {
Expand Down
3 changes: 2 additions & 1 deletion src/renderer/src/components/NewThread.vue
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ const handleSend = async (content: UserMessageContent) => {
temperature: temperature.value,
contextLength: contextLength.value,
maxTokens: maxTokens.value,
artifacts: artifacts.value as 0 | 1
artifacts: artifacts.value as 0 | 1,
enabledMcpTools: chatStore.chatConfig.enabledMcpTools
})
console.log('threadId', threadId, activeModel.value)
chatStore.sendMessage(content)
Expand Down
Loading