Skip to content

Commit 272eee1

Browse files
committed
feat: simplify Prompt function signature
1 parent fb1a428 commit 272eee1

File tree

1 file changed

+51
-26
lines changed

1 file changed

+51
-26
lines changed

lua/gp/init.lua

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,16 +1144,7 @@ M.prepare_commands = function()
11441144
template = M.config.template_prepend
11451145
end
11461146
end
1147-
M.Prompt(
1148-
params,
1149-
target,
1150-
agent.cmd_prefix,
1151-
agent.model,
1152-
template,
1153-
agent.system_prompt,
1154-
whisper,
1155-
agent.provider
1156-
)
1147+
M.Prompt(params, target, agent, template, agent.cmd_prefix, whisper)
11571148
end
11581149

11591150
M.cmd[command] = function(params)
@@ -1179,12 +1170,9 @@ M.call_hook = function(name, params)
11791170
end
11801171

11811172
---@param messages table
1182-
---@param model string | table | nil
1183-
---@param default_model string | table
1173+
---@param model string | table
11841174
---@param provider string | nil
1185-
M.prepare_payload = function(messages, model, default_model, provider)
1186-
model = model or default_model
1187-
1175+
M.prepare_payload = function(messages, model, provider)
11881176
if type(model) == "string" then
11891177
return {
11901178
model = model,
@@ -1724,17 +1712,25 @@ M.not_chat = function(buf, file_name)
17241712
end
17251713

17261714
local lines = vim.api.nvim_buf_get_lines(buf, 0, -1, false)
1727-
if #lines < 4 then
1715+
if #lines < 7 then
17281716
return "file too short"
17291717
end
17301718

17311719
if not lines[1]:match("^# ") then
17321720
return "missing topic header"
17331721
end
17341722

1735-
if not (lines[3]:match("^- file: ") or lines[4]:match("^- file: ")) then
1723+
local header_found = nil
1724+
for i = 1, 6 do
1725+
if lines[i]:match("^- file: ") then
1726+
header_found = true
1727+
break
1728+
end
1729+
end
1730+
if not header_found then
17361731
return "missing file header"
17371732
end
1733+
17381734
return nil
17391735
end
17401736

@@ -2287,6 +2283,10 @@ M.chat_respond = function(params)
22872283
agent_name = agent_name .. " & custom role"
22882284
end
22892285

2286+
if headers.model and not headers.provider then
2287+
headers.provider = "openai"
2288+
end
2289+
22902290
local agent_prefix = config.chat_assistant_prefix[1]
22912291
local agent_suffix = config.chat_assistant_prefix[2]
22922292
if type(M.config.chat_assistant_prefix) == "string" then
@@ -2352,8 +2352,8 @@ M.chat_respond = function(params)
23522352
-- call the model and write response
23532353
M.query(
23542354
buf,
2355-
agent.provider,
2356-
M.prepare_payload(messages, headers.model, agent.model, agent.provider),
2355+
headers.provider or agent.provider,
2356+
M.prepare_payload(messages, headers.model or agent.model, headers.provider or agent.provider),
23572357
M.create_handler(buf, win, M._H.last_content_line(buf), true, "", not M.config.chat_free_cursor),
23582358
vim.schedule_wrap(function(qid)
23592359
local qt = M.get_query(qid)
@@ -2395,8 +2395,8 @@ M.chat_respond = function(params)
23952395
-- call the model
23962396
M.query(
23972397
nil,
2398-
agent.provider,
2399-
M.prepare_payload(messages, nil, agent.model, agent.provider),
2398+
headers.provider or agent.provider,
2399+
M.prepare_payload(messages, headers.model or agent.model, headers.provider or agent.provider),
24002400
topic_handler,
24012401
vim.schedule_wrap(function()
24022402
-- get topic from invisible buffer
@@ -2898,7 +2898,33 @@ M.cmd.Context = function(params)
28982898
M._H.feedkeys("G", "xn")
28992899
end
29002900

2901-
M.Prompt = function(params, target, prompt, model, template, system_template, whisper, provider)
2901+
local exampleHook = [[
2902+
UnitTests = function(gp, params)
2903+
local template = "I have the following code from {{filename}}:\n\n"
2904+
.. "```{{filetype}}\n{{selection}}\n```\n\n"
2905+
.. "Please respond by writing table driven unit tests for the code above."
2906+
local agent = gp.get_command_agent()
2907+
gp.Prompt(params, gp.Target.vnew, agent, template)
2908+
end,
2909+
]]
2910+
2911+
---@param params table
2912+
---@param target integer | function | table
2913+
---@param agent table # obtained from get_command_agent or get_chat_agent
2914+
---@param template string # te
2915+
---@param prompt string | nil # nil for non interactive commads
2916+
---@param whisper string | nil # predefined input (e.g. obtained from Whisper)
2917+
M.Prompt = function(params, target, agent, template, prompt, whisper)
2918+
if not agent or not type(agent) == "table" or not agent.provider then
2919+
M.warning(
2920+
"The `gp.Prompt` method signature has changed.\n"
2921+
.. "Please update your hook functions as demonstrated in the example below::\n\n"
2922+
.. exampleHook
2923+
.. "\nFor more information, refer to the 'Extend Functionality' section in the documentation."
2924+
)
2925+
return
2926+
end
2927+
29022928
-- enew, new, vnew, tabnew should be resolved into table
29032929
if type(target) == "function" then
29042930
target = target()
@@ -3060,7 +3086,7 @@ M.Prompt = function(params, target, prompt, model, template, system_template, wh
30603086
local filetype = M._H.get_filetype(buf)
30613087
local filename = vim.api.nvim_buf_get_name(buf)
30623088

3063-
local sys_prompt = M.template_render(system_template, command, selection, filetype, filename)
3089+
local sys_prompt = M.template_render(agent.system_template, command, selection, filetype, filename)
30643090
sys_prompt = sys_prompt or ""
30653091
table.insert(messages, { role = "system", content = sys_prompt })
30663092

@@ -3163,11 +3189,10 @@ M.Prompt = function(params, target, prompt, model, template, system_template, wh
31633189
end
31643190

31653191
-- call the model and write the response
3166-
local agent = M.get_command_agent()
31673192
M.query(
31683193
buf,
3169-
provider or agent.provider,
3170-
M.prepare_payload(messages, model, agent.model, agent.provider),
3194+
agent.provider,
3195+
M.prepare_payload(messages, agent.model, agent.provider),
31713196
handler,
31723197
vim.schedule_wrap(function(qid)
31733198
on_exit(qid)

0 commit comments

Comments
 (0)