Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
35 changes: 31 additions & 4 deletions internal/toolinternal/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package toolinternal

import (
"context"
"errors"

"github.com/google/uuid"
"google.golang.org/genai"
Expand All @@ -28,12 +29,18 @@ import (
"google.golang.org/adk/tool"
)

// ErrArtifactServiceNotConfigured is returned when artifact service operations are attempted without configuration.
var ErrArtifactServiceNotConfigured = errors.New("artifact service not configured")

type internalArtifacts struct {
agent.Artifacts
eventActions *session.EventActions
}

func (ia *internalArtifacts) Save(ctx context.Context, name string, data *genai.Part) (*artifact.SaveResponse, error) {
if ia == nil {
return nil, ErrArtifactServiceNotConfigured
}
resp, err := ia.Artifacts.Save(ctx, name, data)
if err != nil {
return resp, err
Expand All @@ -48,6 +55,20 @@ func (ia *internalArtifacts) Save(ctx context.Context, name string, data *genai.
return resp, nil
}

func (ia *internalArtifacts) List(ctx context.Context) (*artifact.ListResponse, error) {
if ia == nil {
return nil, ErrArtifactServiceNotConfigured
}
return ia.Artifacts.List(ctx)
}

func (ia *internalArtifacts) Load(ctx context.Context, name string) (*artifact.LoadResponse, error) {
if ia == nil {
return nil, ErrArtifactServiceNotConfigured
}
return ia.Artifacts.Load(ctx, name)
}

func NewToolContext(ctx agent.InvocationContext, functionCallID string, actions *session.EventActions) tool.Context {
if functionCallID == "" {
functionCallID = uuid.NewString()
Expand All @@ -60,15 +81,21 @@ func NewToolContext(ctx agent.InvocationContext, functionCallID string, actions
}
cbCtx := contextinternal.NewCallbackContextWithDelta(ctx, actions.StateDelta)

// Only create internalArtifacts if the underlying Artifacts service is configured
var artifacts *internalArtifacts
if ctx.Artifacts() != nil {
artifacts = &internalArtifacts{
Artifacts: ctx.Artifacts(),
eventActions: actions,
}
}

return &toolContext{
CallbackContext: cbCtx,
invocationContext: ctx,
functionCallID: functionCallID,
eventActions: actions,
artifacts: &internalArtifacts{
Artifacts: ctx.Artifacts(),
eventActions: actions,
},
artifacts: artifacts,
}
}

Expand Down
43 changes: 43 additions & 0 deletions internal/toolinternal/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package toolinternal

import (
"errors"
"testing"

"google.golang.org/adk/agent"
Expand All @@ -36,3 +37,45 @@ func TestToolContext(t *testing.T) {
t.Errorf("ToolContext(%+T) is unexpectedly an InvocationContext", got)
}
}

func TestInternalArtifacts_NilSafe(t *testing.T) {
// Create invocation context without artifact service
inv := contextinternal.NewInvocationContext(t.Context(), contextinternal.InvocationContextParams{
Artifacts: nil,
})
toolCtx := NewToolContext(inv, "fn1", &session.EventActions{})

artifacts := toolCtx.Artifacts()
// artifacts will be nil when service not configured

tests := []struct {
name string
call func() (any, error)
}{
{
name: "List",
call: func() (any, error) { return artifacts.List(t.Context()) },
},
{
name: "Load",
call: func() (any, error) { return artifacts.Load(t.Context(), "test.txt") },
},
{
name: "Save",
call: func() (any, error) { return artifacts.Save(t.Context(), "test.txt", nil) },
},
}

for _, tt := range tests {
t.Run(tt.name+" returns error", func(t *testing.T) {
_, err := tt.call()
if err == nil {
t.Error("Expected an error, got nil")
return
}
if !errors.Is(err, ErrArtifactServiceNotConfigured) {
t.Errorf("Expected ErrArtifactServiceNotConfigured, got: %v", err)
}
})
}
}
36 changes: 36 additions & 0 deletions runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ func New(cfg Config) (*Runner, error) {
return nil, fmt.Errorf("failed to create agent tree: %w", err)
}

// Validate that required services are configured for tools
if err := validateConfiguration(cfg.Agent, cfg.ArtifactService); err != nil {
return nil, err
}

return &Runner{
appName: cfg.AppName,
rootAgent: cfg.Agent,
Expand Down Expand Up @@ -268,3 +273,34 @@ func findAgent(curAgent agent.Agent, targetName string) agent.Agent {
}
return nil
}

// validateConfiguration checks that required services are available for tools.
func validateConfiguration(rootAgent agent.Agent, artifactService artifact.Service) error {
return walkAgentTree(rootAgent, func(a agent.Agent) error {
llmAgent, ok := a.(llminternal.Agent)
if !ok {
return nil
}

state := llminternal.Reveal(llmAgent)
for _, t := range state.Tools {
if t.Name() == "load_artifacts" && artifactService == nil {
return fmt.Errorf("agent %q uses load_artifacts tool but ArtifactService not configured in runner", a.Name())
}
}
return nil
})
}

// walkAgentTree recursively walks the agent tree and applies fn to each agent.
func walkAgentTree(a agent.Agent, fn func(agent.Agent) error) error {
if err := fn(a); err != nil {
return err
}
for _, sub := range a.SubAgents() {
if err := walkAgentTree(sub, fn); err != nil {
return err
}
}
return nil
}
85 changes: 83 additions & 2 deletions runner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import (
"strings"
"testing"

"google.golang.org/genai"

"google.golang.org/adk/agent"
"google.golang.org/adk/agent/llmagent"
"google.golang.org/adk/artifact"
"google.golang.org/adk/session"
"google.golang.org/adk/tool"
"google.golang.org/adk/tool/loadartifactstool"
"google.golang.org/genai"
)

func TestRunner_findAgentToRun(t *testing.T) {
Expand Down Expand Up @@ -314,6 +315,86 @@ func TestRunner_SaveInputBlobsAsArtifacts(t *testing.T) {
}
}

func TestNew_ValidatesLoadArtifactsToolRequiresArtifactService(t *testing.T) {
t.Parallel()

tests := []struct {
name string
agent agent.Agent
artifactService artifact.Service
wantErr bool
errContains string
}{
{
name: "error when load_artifacts tool present but no artifact service",
agent: must(llmagent.New(llmagent.Config{
Name: "test_agent",
Tools: []tool.Tool{loadartifactstool.New()},
})),
artifactService: nil,
wantErr: true,
errContains: "load_artifacts tool but ArtifactService not configured",
},
{
name: "ok when load_artifacts tool and artifact service both present",
agent: must(llmagent.New(llmagent.Config{
Name: "test_agent",
Tools: []tool.Tool{loadartifactstool.New()},
})),
artifactService: artifact.InMemoryService(),
wantErr: false,
},
{
name: "ok when no load_artifacts tool and no artifact service",
agent: must(llmagent.New(llmagent.Config{
Name: "test_agent",
})),
artifactService: nil,
wantErr: false,
},
{
name: "error when load_artifacts in sub-agent but no artifact service",
agent: must(llmagent.New(llmagent.Config{
Name: "parent_agent",
SubAgents: []agent.Agent{
must(llmagent.New(llmagent.Config{
Name: "child_agent",
Tools: []tool.Tool{loadartifactstool.New()},
})),
},
})),
artifactService: nil,
wantErr: true,
errContains: "child_agent",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := New(Config{
AppName: "testApp",
Agent: tt.agent,
SessionService: session.InMemoryService(),
ArtifactService: tt.artifactService,
})

if tt.wantErr {
if err == nil {
t.Errorf("New() expected error but got nil")
return
}
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("New() error = %v, want error containing %q", err, tt.errContains)
}
} else {
if err != nil {
t.Errorf("New() unexpected error = %v", err)
}
}
})
}
}

// creates agentTree for tests and returns references to the agents
func agentTree(t *testing.T) agentTreeStruct {
t.Helper()
Expand Down
27 changes: 27 additions & 0 deletions tool/loadartifactstool/load_artifacts_tool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package loadartifactstool_test

import (
"errors"
"strings"
"testing"

Expand Down Expand Up @@ -277,6 +278,32 @@ func TestLoadArtifactsTool_ProcessRequest_Artifacts_OtherFunctionCall(t *testing
}
}

func TestLoadArtifactsTool_ProcessRequest_NoArtifactService(t *testing.T) {
loadArtifactsTool := loadartifactstool.New()

// Create tool context WITHOUT artifact service configured
ctx := icontext.NewInvocationContext(t.Context(), icontext.InvocationContextParams{
Artifacts: nil, // No artifact service
})
tc := toolinternal.NewToolContext(ctx, "", nil)

llmRequest := &model.LLMRequest{}

requestProcessor, ok := loadArtifactsTool.(toolinternal.RequestProcessor)
if !ok {
t.Fatal("loadArtifactsTool does not implement RequestProcessor")
}

err := requestProcessor.ProcessRequest(tc, llmRequest)
if err == nil {
t.Fatal("Expected error when artifact service not configured, got nil")
}

if !errors.Is(err, toolinternal.ErrArtifactServiceNotConfigured) {
t.Errorf("Expected ErrArtifactServiceNotConfigured, got: %v", err)
}
}

func createToolContext(t *testing.T) tool.Context {
t.Helper()

Expand Down
2 changes: 2 additions & 0 deletions vendor/cel.dev/expr/.bazelversion
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
7.3.2
# Keep this pinned version in parity with cel-go
2 changes: 2 additions & 0 deletions vendor/cel.dev/expr/.gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.pb.go linguist-generated=true
*.pb.go -diff -merge
2 changes: 2 additions & 0 deletions vendor/cel.dev/expr/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
bazel-*
MODULE.bazel.lock
34 changes: 34 additions & 0 deletions vendor/cel.dev/expr/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")

package(default_visibility = ["//visibility:public"])

licenses(["notice"]) # Apache 2.0

go_library(
name = "expr",
srcs = [
"checked.pb.go",
"eval.pb.go",
"explain.pb.go",
"syntax.pb.go",
"value.pb.go",
],
importpath = "cel.dev/expr",
visibility = ["//visibility:public"],
deps = [
"@org_golang_google_genproto_googleapis_rpc//status:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect",
"@org_golang_google_protobuf//runtime/protoimpl",
"@org_golang_google_protobuf//types/known/anypb",
"@org_golang_google_protobuf//types/known/durationpb",
"@org_golang_google_protobuf//types/known/emptypb",
"@org_golang_google_protobuf//types/known/structpb",
"@org_golang_google_protobuf//types/known/timestamppb",
],
)

alias(
name = "go_default_library",
actual = ":expr",
visibility = ["//visibility:public"],
)
25 changes: 25 additions & 0 deletions vendor/cel.dev/expr/CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Contributor Code of Conduct
## Version 0.1.1 (adapted from 0.3b-angular)

As contributors and maintainers of the Common Expression Language
(CEL) project, we pledge to respect everyone who contributes by
posting issues, updating documentation, submitting pull requests,
providing feedback in comments, and any other activities.

Communication through any of CEL's channels (GitHub, Gitter, IRC,
mailing lists, Google+, Twitter, etc.) must be constructive and never
resort to personal attacks, trolling, public or private harassment,
insults, or other unprofessional conduct.

We promise to extend courtesy and respect to everyone involved in this
project regardless of gender, gender identity, sexual orientation,
disability, age, race, ethnicity, religion, or level of experience. We
expect anyone contributing to the project to do the same.

If any member of the community violates this code of conduct, the
maintainers of the CEL project may take action, removing issues,
comments, and PRs or blocking accounts as deemed appropriate.

If you are subject to or witness unacceptable behavior, or have any
other concerns, please email us at
[cel-conduct@google.com](mailto:cel-conduct@google.com).
Loading