Skip to content

Commit 58ffba4

Browse files
committed
feat: agent callbacks for adk
1 parent 997abc9 commit 58ffba4

File tree

4 files changed

+641
-226
lines changed

4 files changed

+641
-226
lines changed

adk/agent_middleware.go

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
package adk
2+
3+
import (
4+
"context"
5+
"runtime/debug"
6+
"sync/atomic"
7+
8+
"github.com/cloudwego/eino/components/tool"
9+
"github.com/cloudwego/eino/compose"
10+
"github.com/cloudwego/eino/internal/safe"
11+
)
12+
13+
// TODO(n3ko): comment
14+
15+
// AgentMiddleware provides hooks to customize agent behavior at various stages of execution.
16+
type AgentMiddleware struct {
17+
// AdditionalInstruction adds supplementary text to the agent's system instruction.
18+
// This instruction is concatenated with the base instruction before each chat model call.
19+
AdditionalInstruction string
20+
21+
// AdditionalTools adds supplementary tools to the agent's available toolset.
22+
// These tools are combined with the tools configured for the agent.
23+
AdditionalTools []tool.BaseTool
24+
25+
// BeforeChatModel is called before each ChatModel invocation, allowing modification of the agent state.
26+
BeforeChatModel func(context.Context, *ChatModelAgentState) error
27+
28+
// AfterChatModel is called after each ChatModel invocation, allowing modification of the agent state.
29+
AfterChatModel func(context.Context, *ChatModelAgentState) error
30+
31+
// WrapToolCall wraps tool calls with custom middleware logic.
32+
// Each middleware contains Invokable and/or Streamable functions for tool calls.
33+
WrapToolCall compose.ToolMiddleware
34+
35+
BeforeAgent func(ctx context.Context, arc *AgentContext) (nextContext context.Context, err error)
36+
37+
OnEvents func(ctx context.Context, arc *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent])
38+
}
39+
40+
type AgentMiddlewareChecker interface {
41+
IsAgentMiddlewareEnabled() bool
42+
}
43+
44+
// ChatModelAgentState represents the state of a chat model agent during conversation.
45+
type ChatModelAgentState struct {
46+
// Messages contains all messages in the current conversation session.
47+
Messages []Message
48+
}
49+
50+
type EntranceType string
51+
52+
const (
53+
EntranceTypeRun EntranceType = "Run"
54+
EntranceTypeResume EntranceType = "Resume"
55+
)
56+
57+
type AgentContext struct {
58+
AgentInput *AgentInput
59+
ResumeInfo *ResumeInfo
60+
AgentRunOptions []AgentRunOption
61+
62+
// internal properties, read only
63+
agentName string
64+
isRootAgent bool
65+
entrance EntranceType
66+
}
67+
68+
func (a *AgentContext) AgentName() string {
69+
return a.agentName
70+
}
71+
72+
func (a *AgentContext) IsRootAgent() bool {
73+
return a.isRootAgent
74+
}
75+
76+
func (a *AgentContext) EntranceType() EntranceType {
77+
return a.entrance
78+
}
79+
80+
type (
81+
runnerPassedMiddlewaresCtxKey struct{}
82+
runnerPassedMiddlewaresInfo struct {
83+
middlewares []AgentMiddleware
84+
isRootAgent int32
85+
}
86+
)
87+
88+
func isRootAgent(ctx context.Context) bool {
89+
if v, ok := ctx.Value(runnerPassedMiddlewaresCtxKey{}).(*runnerPassedMiddlewaresInfo); ok && v != nil {
90+
val := atomic.SwapInt32(&v.isRootAgent, 1)
91+
return val == 0
92+
}
93+
return false
94+
}
95+
96+
func getRunnerPassedAgentMWs(ctx context.Context) []AgentMiddleware {
97+
if v, ok := ctx.Value(runnerPassedMiddlewaresCtxKey{}).(*runnerPassedMiddlewaresInfo); ok && v != nil {
98+
return v.middlewares
99+
}
100+
return nil
101+
}
102+
103+
func isAgentMiddlewareEnabled(a Agent) bool {
104+
if c, ok := a.(AgentMiddlewareChecker); ok && c.IsAgentMiddlewareEnabled() {
105+
return true
106+
}
107+
return false
108+
}
109+
110+
type agentMWRunner struct {
111+
beforeAgentFns []func(ctx context.Context, arc *AgentContext) (nextContext context.Context, err error)
112+
onEventsFns []func(ctx context.Context, arc *AgentContext, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent])
113+
}
114+
115+
func (a *agentMWRunner) execBeforeAgents(ctx context.Context, ac *AgentContext) (context.Context, *AsyncIterator[*AgentEvent]) {
116+
var err error
117+
for i, beforeAgent := range a.beforeAgentFns {
118+
if beforeAgent == nil {
119+
continue
120+
}
121+
ctx, err = beforeAgent(ctx, ac)
122+
if err != nil {
123+
iter, gen := NewAsyncIteratorPair[*AgentEvent]()
124+
gen.Send(&AgentEvent{Err: err})
125+
gen.Close()
126+
return ctx, a.execOnEventsFromIndex(ctx, ac, i-1, iter)
127+
}
128+
}
129+
return ctx, nil
130+
}
131+
132+
func (a *agentMWRunner) execOnEvents(ctx context.Context, ac *AgentContext, iter *AsyncIterator[*AgentEvent]) *AsyncIterator[*AgentEvent] {
133+
return a.execOnEventsFromIndex(ctx, ac, len(a.onEventsFns)-1, iter)
134+
}
135+
136+
func (a *agentMWRunner) execOnEventsFromIndex(ctx context.Context, ac *AgentContext, fromIdx int, iter *AsyncIterator[*AgentEvent]) *AsyncIterator[*AgentEvent] {
137+
for idx := fromIdx; idx >= 0; idx-- {
138+
onEvents := a.onEventsFns[idx]
139+
if onEvents == nil {
140+
continue
141+
}
142+
i, g := NewAsyncIteratorPair[*AgentEvent]()
143+
go func() {
144+
defer func() {
145+
panicErr := recover()
146+
if panicErr != nil {
147+
e := safe.NewPanicErr(panicErr, debug.Stack())
148+
g.Send(&AgentEvent{Err: e})
149+
}
150+
g.Close()
151+
}()
152+
onEvents(ctx, ac, iter, g)
153+
}()
154+
iter = i
155+
}
156+
return iter
157+
}
158+
159+
func NewAsyncIteratorPairWithConversion(
160+
ctx context.Context,
161+
iter *AsyncIterator[*AgentEvent],
162+
fn func(ctx context.Context, srcIter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]),
163+
) *AsyncIterator[*AgentEvent] {
164+
i, g := NewAsyncIteratorPair[*AgentEvent]()
165+
go func() {
166+
defer func() {
167+
panicErr := recover()
168+
if panicErr != nil {
169+
e := safe.NewPanicErr(panicErr, debug.Stack())
170+
g.Send(&AgentEvent{Err: e})
171+
}
172+
g.Close()
173+
}()
174+
175+
fn(ctx, iter, g)
176+
}()
177+
178+
return i
179+
}
180+
181+
func BypassIterator(ctx context.Context, srcIter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) {
182+
defer gen.Close()
183+
for {
184+
event, ok := srcIter.Next()
185+
if !ok {
186+
break
187+
}
188+
gen.Send(event)
189+
}
190+
}
191+
192+
type OnEventFn[T any] func(ctx context.Context, input T, event *AgentEvent) (stop bool, err error)
193+
194+
func NewOnEventProcessor[T any](onEvent OnEventFn[T]) func(ctx context.Context, input T, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) {
195+
return func(ctx context.Context, input T, iter *AsyncIterator[*AgentEvent], gen *AsyncGenerator[*AgentEvent]) {
196+
go func() {
197+
defer func() {
198+
panicErr := recover()
199+
if panicErr != nil {
200+
e := safe.NewPanicErr(panicErr, debug.Stack())
201+
gen.Send(&AgentEvent{Err: e})
202+
}
203+
gen.Close()
204+
}()
205+
206+
for {
207+
event, ok := iter.Next()
208+
if !ok {
209+
break
210+
}
211+
212+
breakIter, err := onEvent(ctx, input, event)
213+
if err != nil {
214+
gen.Send(&AgentEvent{Err: err})
215+
}
216+
if breakIter {
217+
break
218+
}
219+
}
220+
}()
221+
}
222+
}
223+
224+
func iterWithEvents(events ...*AgentEvent) *AsyncIterator[*AgentEvent] {
225+
iter, gen := NewAsyncIteratorPair[*AgentEvent]()
226+
for _, event := range events {
227+
gen.Send(event)
228+
}
229+
gen.Close()
230+
return iter
231+
}

0 commit comments

Comments
 (0)