Skip to content

Commit 04f16a5

Browse files
committed
claudetool: make it easier to parameterize patch tool
1 parent bdc6889 commit 04f16a5

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

claudetool/bash.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import (
2525
// PermissionCallback is a function type for checking if a command is allowed to run
2626
type PermissionCallback func(command string) error
2727

28-
// BashTool specifies a llm.Tool for executing shell commands.
28+
// BashTool specifies an llm.Tool for executing shell commands.
2929
type BashTool struct {
3030
// CheckPermission is called before running any command, if set
3131
CheckPermission PermissionCallback

claudetool/patch.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,22 @@ import (
2525
// and returns a new, possibly altered tool output.
2626
type PatchCallback func(input PatchInput, output llm.ToolOut) llm.ToolOut
2727

28-
// Patch creates a patch tool. The callback may be nil.
29-
func Patch(callback PatchCallback) *llm.Tool {
28+
// PatchTool specifies an llm.Tool for patching files.
29+
type PatchTool struct {
30+
Callback PatchCallback // may be nil
31+
}
32+
33+
// Tool returns an llm.Tool based on p.
34+
func (p *PatchTool) Tool() *llm.Tool {
3035
return &llm.Tool{
3136
Name: PatchName,
3237
Description: strings.TrimSpace(PatchDescription),
3338
InputSchema: llm.MustSchema(PatchInputSchema),
3439
Run: func(ctx context.Context, m json.RawMessage) llm.ToolOut {
3540
var input PatchInput
3641
output := patchRun(ctx, m, &input)
37-
if callback != nil {
38-
return callback(input, output)
42+
if p.Callback != nil {
43+
return p.Callback(input, output)
3944
}
4045
return output
4146
},

loop/agent.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,9 @@ func (a *Agent) initConvoWithUsage(usage *conversation.CumulativeUsage) *convers
13901390
Timeouts: a.config.BashTimeouts,
13911391
Pwd: a.workingDir,
13921392
}
1393+
patchTool := &claudetool.PatchTool{
1394+
Callback: a.patchCallback,
1395+
}
13931396

13941397
// Register all tools with the conversation
13951398
// When adding, removing, or modifying tools here, double-check that the termui tool display
@@ -1411,7 +1414,7 @@ func (a *Agent) initConvoWithUsage(usage *conversation.CumulativeUsage) *convers
14111414
convo.Tools = []*llm.Tool{
14121415
bashTool.Tool(),
14131416
claudetool.Keyword,
1414-
claudetool.Patch(a.patchCallback),
1417+
patchTool.Tool(),
14151418
claudetool.Think,
14161419
claudetool.TodoRead,
14171420
claudetool.TodoWrite,

0 commit comments

Comments
 (0)