Skip to content

feat(go/genkit) Session handling #2955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,11 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
genOpts.Config = modelRef.Config()
}

if genOpts.Session != nil {
// Set session details in context
ctx = genOpts.Session.SetContext(ctx)
}

actionOpts := &GenerateActionOptions{
Model: modelName,
Messages: messages,
Expand Down
99 changes: 98 additions & 1 deletion go/ai/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type ConfigOption interface {
applyRetriever(*retrieverOptions) error
applyEvaluator(*evaluatorOptions) error
applyIndexer(*indexerOptions) error
applySession(*sessionOptions) error
}

// applyConfig applies the option to the config options.
Expand Down Expand Up @@ -101,6 +102,11 @@ func (o *configOptions) applyIndexer(opts *indexerOptions) error {
return o.applyConfig(&opts.configOptions)
}

// applySession applies the option to the session options.
func (o *configOptions) applySession(opts *sessionOptions) error {
return o.applyConfig(&opts.configOptions)
}

// WithConfig sets the configuration.
func WithConfig(config any) ConfigOption {
return &configOptions{Config: config}
Expand Down Expand Up @@ -261,6 +267,11 @@ func WithToolChoice(toolChoice ToolChoice) CommonGenOption {
return &commonGenOptions{ToolChoice: toolChoice}
}

// WithSession sets the request session.
func WithSession(session Session) ExecutionOption {
return &executionOptions{Session: &session}
}

// promptOptions are options for defining a prompt.
type promptOptions struct {
commonGenOptions
Expand Down Expand Up @@ -529,7 +540,8 @@ func WithCustomConstrainedOutput() OutputOption {

// executionOptions are options for the execution of a prompt or generate request.
type executionOptions struct {
Stream ModelStreamCallback // Function to call with each chunk of the generated response.
Stream ModelStreamCallback // Function to call with each chunk of the generated response.
Session *Session // Session to use for request.
}

// ExecutionOption is an option for the execution of a prompt or generate request. It applies only to Generate() and prompt.Execute().
Expand All @@ -548,6 +560,13 @@ func (o *executionOptions) applyExecution(execOpts *executionOptions) error {
execOpts.Stream = o.Stream
}

if o.Session != nil {
if execOpts.Session != nil {
return errors.New("cannot set session more than once (WithSession)")
}
execOpts.Session = o.Session
}

return nil
}

Expand Down Expand Up @@ -836,3 +855,81 @@ func (o *promptExecutionOptions) applyPromptExecute(pgOpts *promptExecutionOptio
func WithInput(input any) PromptExecuteOption {
return &promptExecutionOptions{Input: input}
}

// sessionOptions are options for configuring the session parameters.
type sessionOptions struct {
configOptions
ID string // The session ID
Data *SessionData // The data for the session
Store SessionStore // The store for the session, defaults to in-memory storage
Schema map[string]any // The schema to use for the data
DefaultState any // The default state
}

// SessionOption is an option for configuring the session parameters.
// It applies only to [Session].
type SessionOption interface {
applySession(*sessionOptions) error
}

// applySession applies the option to the session options.
func (o *sessionOptions) applySession(sessOpts *sessionOptions) error {
if err := o.applyConfig(&sessOpts.configOptions); err != nil {
return err
}

if o.ID != "" {
if sessOpts.ID != "" {
return errors.New("cannot set session id more than once (WithSessionID)")
}
sessOpts.ID = o.ID
}

if o.Data != nil {
if sessOpts.Data != nil {
return errors.New("cannot set session data more than once (WithSessionData)")
}
sessOpts.Data = o.Data
}

if o.Store != nil {
if sessOpts.Store != nil {
return errors.New("cannot set session store more than once (WithSessionStore)")
}
sessOpts.Store = o.Store
}

if o.Schema != nil {
if sessOpts.Schema != nil {
return errors.New("cannot set state type more than once (WithSessionStateType)")
}
sessOpts.Schema = o.Schema
sessOpts.DefaultState = o.DefaultState
}

return nil
}

// WithSessionID sets the session id.
func WithSessionID(id string) SessionOption {
return &sessionOptions{ID: id}
}

// WithSessionData sets the session data.
func WithSessionData(data SessionData) SessionOption {
return &sessionOptions{Data: &data}
}

// WithSessionStore sets a session store for the session.
func WithSessionStore(store SessionStore) SessionOption {
return &sessionOptions{Store: store}
}

// WithStateType uses the struct provided to derive the state schema.
// If passing a struct with values, the struct definition will serve as the schema, the values will serve as the data.
func WithSessionStateType(state any) SessionOption {
return &sessionOptions{
Schema: base.SchemaAsMap(base.InferJSONSchema(state)),
DefaultState: state,
}
}
20 changes: 16 additions & 4 deletions go/ai/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,10 @@ func TestGenerateOptionsComplete(t *testing.T) {
tool := &mockTool{name: "test/tool"}
streamFunc := func(context.Context, *ModelResponseChunk) error { return nil }
doc := DocumentFromText("doc", nil)
session, err := NewSession(context.Background())
if err != nil {
t.Fatal(err.Error())
}
options := []GenerateOption{
WithModel(model),
WithMessages(NewUserTextMessage("message")),
Expand All @@ -384,6 +388,7 @@ func TestGenerateOptionsComplete(t *testing.T) {
WithOutputInstructions(""),
WithCustomConstrainedOutput(),
WithStreaming(streamFunc),
WithSession(*session),
}

for _, opt := range options {
Expand Down Expand Up @@ -419,7 +424,8 @@ func TestGenerateOptionsComplete(t *testing.T) {
CustomConstrained: true,
},
executionOptions: executionOptions{
Stream: streamFunc,
Stream: streamFunc,
Session: &Session{},
},
documentOptions: documentOptions{
Documents: []*Document{doc},
Expand All @@ -430,7 +436,7 @@ func TestGenerateOptionsComplete(t *testing.T) {
cmpopts.IgnoreFields(commonGenOptions{}, "MessagesFn", "Middleware"),
cmpopts.IgnoreFields(promptingOptions{}, "SystemFn", "PromptFn"),
cmpopts.IgnoreFields(executionOptions{}, "Stream"),
cmpopts.IgnoreUnexported(mockModel{}, mockTool{}),
cmpopts.IgnoreUnexported(mockModel{}, mockTool{}, Session{}),
cmp.AllowUnexported(generateOptions{}, commonGenOptions{}, promptingOptions{},
outputOptions{}, executionOptions{}, documentOptions{})); diff != "" {
t.Errorf("Options not applied correctly, diff (-want +got):\n%s", diff)
Expand Down Expand Up @@ -561,6 +567,10 @@ func TestPromptExecuteOptionsComplete(t *testing.T) {
streamFunc := func(context.Context, *ModelResponseChunk) error { return nil }
input := map[string]string{"key": "value"}
doc := DocumentFromText("doc", nil)
session, err := NewSession(context.Background())
if err != nil {
t.Fatal(err.Error())
}

options := []PromptExecuteOption{
WithModel(model),
Expand All @@ -574,6 +584,7 @@ func TestPromptExecuteOptionsComplete(t *testing.T) {
WithDocs(doc),
WithStreaming(streamFunc),
WithInput(input),
WithSession(*session),
}

for _, opt := range options {
Expand All @@ -596,7 +607,8 @@ func TestPromptExecuteOptionsComplete(t *testing.T) {
Middleware: []ModelMiddleware{mw},
},
executionOptions: executionOptions{
Stream: streamFunc,
Stream: streamFunc,
Session: &Session{},
},
documentOptions: documentOptions{
Documents: []*Document{doc},
Expand All @@ -607,7 +619,7 @@ func TestPromptExecuteOptionsComplete(t *testing.T) {
if diff := cmp.Diff(expected, opts,
cmpopts.IgnoreFields(commonGenOptions{}, "MessagesFn", "Middleware"),
cmpopts.IgnoreFields(executionOptions{}, "Stream"),
cmpopts.IgnoreUnexported(mockModel{}, mockTool{}),
cmpopts.IgnoreUnexported(mockModel{}, mockTool{}, Session{}),
cmp.AllowUnexported(promptExecutionOptions{}, commonGenOptions{},
executionOptions{})); diff != "" {
t.Errorf("Options not applied correctly, diff (-want +got):\n%s", diff)
Expand Down
5 changes: 5 additions & 0 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ func (p *Prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod
actionOpts.ReturnToolRequests = *genOpts.ReturnToolRequests
}

if genOpts.Session != nil {
// Set session details in context
ctx = genOpts.Session.SetContext(ctx)
}

return GenerateWithRequest(ctx, p.registry, actionOpts, genOpts.Middleware, genOpts.Stream)
}

Expand Down
Loading
Loading