Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions llms/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC
model.SetTopP(float32(opts.TopP))
model.SetTopK(int32(opts.TopK))
model.StopSequences = opts.StopWords
model.SafetySettings = []*genai.SafetySetting{
{
Category: genai.HarmCategoryDangerousContent,
Threshold: genai.HarmBlockThreshold(g.opts.HarmThreshold),
},
{
Category: genai.HarmCategoryHarassment,
Threshold: genai.HarmBlockThreshold(g.opts.HarmThreshold),
},
{
Category: genai.HarmCategoryHateSpeech,
Threshold: genai.HarmBlockThreshold(g.opts.HarmThreshold),
},
{
Category: genai.HarmCategorySexuallyExplicit,
Threshold: genai.HarmBlockThreshold(g.opts.HarmThreshold),
},
}

var response *llms.ContentResponse
var err error
Expand Down
25 changes: 25 additions & 0 deletions llms/googleai/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type Options struct {
DefaultTemperature float64
DefaultTopK int
DefaultTopP float64
HarmThreshold HarmBlockThreshold
}

func DefaultOptions() Options {
Expand All @@ -26,6 +27,7 @@ func DefaultOptions() Options {
DefaultTemperature: 0.5,
DefaultTopK: 3,
DefaultTopP: 0.95,
HarmThreshold: HarmBlockOnlyHigh,
}
}

Expand Down Expand Up @@ -70,3 +72,26 @@ func WithDefaultEmbeddingModel(defaultEmbeddingModel string) Option {
opts.DefaultEmbeddingModel = defaultEmbeddingModel
}
}

// WithHarmThreshold sets the safety/harm setting for the model, potentially
// limiting any harmful content it may generate.
func WithHarmThreshold(ht HarmBlockThreshold) Option {
return func(opts *Options) {
opts.HarmThreshold = ht
}
}

type HarmBlockThreshold int32

const (
// HarmBlockUnspecified means threshold is unspecified.
HarmBlockUnspecified HarmBlockThreshold = 0
// HarmBlockLowAndAbove means content with NEGLIGIBLE will be allowed.
HarmBlockLowAndAbove HarmBlockThreshold = 1
// HarmBlockMediumAndAbove means content with NEGLIGIBLE and LOW will be allowed.
HarmBlockMediumAndAbove HarmBlockThreshold = 2
// HarmBlockOnlyHigh means content with NEGLIGIBLE, LOW, and MEDIUM will be allowed.
HarmBlockOnlyHigh HarmBlockThreshold = 3
// HarmBlockNone means all content will be allowed.
HarmBlockNone HarmBlockThreshold = 4
)
65 changes: 37 additions & 28 deletions llms/googleai/shared_test/shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@ import (
"github.com/tmc/langchaingo/schema"
)

func newGoogleAIClient(t *testing.T) *googleai.GoogleAI {
func newGoogleAIClient(t *testing.T, opts ...googleai.Option) *googleai.GoogleAI {
t.Helper()

genaiKey := os.Getenv("GENAI_API_KEY")
if genaiKey == "" {
t.Skip("GENAI_API_KEY not set")
return nil
}
llm, err := googleai.New(context.Background(), googleai.WithAPIKey(genaiKey))

opts = append(opts, googleai.WithAPIKey(genaiKey))
llm, err := googleai.New(context.Background(), opts...)
require.NoError(t, err)
return llm
}

func newVertexClient(t *testing.T) *vertex.Vertex {
func newVertexClient(t *testing.T, opts ...googleai.Option) *vertex.Vertex {
t.Helper()

project := os.Getenv("VERTEX_PROJECT")
Expand All @@ -46,10 +48,10 @@ func newVertexClient(t *testing.T) *vertex.Vertex {
location = "us-central1"
}

llm, err := vertex.New(
context.Background(),
opts = append(opts,
googleai.WithCloudProject(project),
googleai.WithCloudLocation(location))
llm, err := vertex.New(context.Background(), opts...)
require.NoError(t, err)
return llm
}
Expand All @@ -62,31 +64,38 @@ func funcName(f any) string {
return parts[len(parts)-1]
}

type testFunc func(*testing.T, llms.Model)

// testFuncs is a list of all test functions in this file to run with both
// client types.
var testFuncs = []testFunc{
testMultiContentText,
testGenerateFromSinglePrompt,
testMultiContentTextChatSequence,
testMultiContentImageLink,
testMultiContentImageBinary,
testEmbeddings,
testCandidateCountSetting,
testMaxTokensSetting,
testWithStreaming,
// testConfigs is a list of all test functions in this file to run with both
// client types, and their client configurations.
type testConfig struct {
testFunc func(*testing.T, llms.Model)
opts []googleai.Option
}

var testConfigs = []testConfig{
{testMultiContentText, nil},
{testGenerateFromSinglePrompt, nil},
{testMultiContentTextChatSequence, nil},
{testMultiContentImageLink, nil},
{testMultiContentImageBinary, nil},
{testEmbeddings, nil},
{testCandidateCountSetting, nil},
{testMaxTokensSetting, nil},
{
testMultiContentText,
[]googleai.Option{googleai.WithHarmThreshold(googleai.HarmBlockMediumAndAbove)},
},
{testWithStreaming, nil},
}

func TestShared(t *testing.T) {
for _, f := range testFuncs {
t.Run(fmt.Sprintf("%s-googleai", funcName(f)), func(t *testing.T) {
llm := newGoogleAIClient(t)
f(t, llm)
for _, c := range testConfigs {
t.Run(fmt.Sprintf("%s-googleai", funcName(c.testFunc)), func(t *testing.T) {
llm := newGoogleAIClient(t, c.opts...)
c.testFunc(t, llm)
})
t.Run(fmt.Sprintf("%s-vertex", funcName(f)), func(t *testing.T) {
llm := newVertexClient(t)
f(t, llm)
t.Run(fmt.Sprintf("%s-vertex", funcName(c.testFunc)), func(t *testing.T) {
llm := newVertexClient(t, c.opts...)
c.testFunc(t, llm)
})
}
}
Expand All @@ -111,7 +120,7 @@ func testMultiContentText(t *testing.T, llm llms.Model) {

assert.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "(?i)dog|canid|canine", c1.Content)
assert.Regexp(t, "(?i)dog|carnivo|canid|canine", c1.Content)
}

func testMultiContentTextUsingTextParts(t *testing.T, llm llms.Model) {
Expand Down Expand Up @@ -309,7 +318,7 @@ func testMaxTokensSetting(t *testing.T, llm llms.Model) {
// a stop reason that max of tokens was reached.
{
rsp, err := llm.GenerateContent(context.Background(), content,
llms.WithMaxTokens(16))
llms.WithMaxTokens(64))
require.NoError(t, err)

assert.NotEmpty(t, rsp.Choices)
Expand Down
23 changes: 20 additions & 3 deletions llms/googleai/vertex/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"context"
"errors"
"fmt"
"log"
"strings"

"cloud.google.com/go/vertexai/genai"
Expand Down Expand Up @@ -61,6 +60,24 @@ func (g *Vertex) GenerateContent(ctx context.Context, messages []llms.MessageCon
model.SetTopP(float32(opts.TopP))
model.SetTopK(float32(opts.TopK))
model.StopSequences = opts.StopWords
model.SafetySettings = []*genai.SafetySetting{
{
Category: genai.HarmCategoryDangerousContent,
Threshold: genai.HarmBlockThreshold(g.opts.HarmThreshold),
},
{
Category: genai.HarmCategoryHarassment,
Threshold: genai.HarmBlockThreshold(g.opts.HarmThreshold),
},
{
Category: genai.HarmCategoryHateSpeech,
Threshold: genai.HarmBlockThreshold(g.opts.HarmThreshold),
},
{
Category: genai.HarmCategorySexuallyExplicit,
Threshold: genai.HarmBlockThreshold(g.opts.HarmThreshold),
},
}

var response *llms.ContentResponse
var err error
Expand Down Expand Up @@ -163,7 +180,7 @@ func convertContent(content llms.MessageContent) (*genai.Content, error) {
c.Role = RoleUser
case schema.ChatMessageTypeGeneric:
c.Role = RoleUser
case schema.ChatMessageTypeFunction:
case schema.ChatMessageTypeFunction, schema.ChatMessageTypeTool:
fallthrough
default:
return nil, fmt.Errorf("role %v not supported", content.Role)
Expand Down Expand Up @@ -251,7 +268,7 @@ DoStream:
break DoStream
}
if err != nil {
log.Fatal(err)
return nil, fmt.Errorf("error in stream mode: %w", err)
}

if len(resp.Candidates) != 1 {
Expand Down