Skip to content

Commit ad8f46e

Browse files
sahdev77Sahdev Gargapascal07
authored
feat(go/genkit): added background action and model support (#3262)
Co-authored-by: Sahdev Garg <sahdevgarg@google.com> Co-authored-by: Alex Pascal <apascal07@gmail.com>
1 parent 72fa4ea commit ad8f46e

File tree

21 files changed

+1659
-149
lines changed

21 files changed

+1659
-149
lines changed

genkit-tools/common/src/types/model.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ export const ModelInfoSchema = z.object({
147147
constrained: z.enum(['none', 'all', 'no-tools']).optional(),
148148
/** Model supports controlling tool choice, e.g. forced tool calling. */
149149
toolChoice: z.boolean().optional(),
150+
/** Model supports long running operations. */
151+
longRunning: z.boolean().optional(),
150152
})
151153
.optional(),
152154
/** At which stage of development this model is.

genkit-tools/genkit-schema.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,9 @@
10271027
},
10281028
"toolChoice": {
10291029
"type": "boolean"
1030+
},
1031+
"longRunning": {
1032+
"type": "boolean"
10301033
}
10311034
},
10321035
"additionalProperties": false

go/ai/background_model.go

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// SPDX-License-Identifier: Apache-2.0
16+
17+
package ai
18+
19+
import (
20+
"context"
21+
"errors"
22+
23+
"github.com/firebase/genkit/go/core"
24+
"github.com/firebase/genkit/go/core/api"
25+
"github.com/firebase/genkit/go/internal/registry"
26+
)
27+
28+
// BackgroundModel represents a model that can run operations in the background.
29+
type BackgroundModel interface {
30+
// Name returns the registry name of the background model.
31+
Name() string
32+
// Register registers the model with the given registry.
33+
Register(r api.Registry)
34+
// Start starts a background operation.
35+
Start(ctx context.Context, req *ModelRequest) (*ModelOperation, error)
36+
// Check checks the status of a background operation.
37+
Check(ctx context.Context, op *ModelOperation) (*ModelOperation, error)
38+
// Cancel cancels a background operation.
39+
Cancel(ctx context.Context, op *ModelOperation) (*ModelOperation, error)
40+
// SupportsCancel returns whether the background action supports cancellation.
41+
SupportsCancel() bool
42+
}
43+
44+
// backgroundModel is the concrete implementation of BackgroundModel interface.
45+
type backgroundModel struct {
46+
core.BackgroundActionDef[*ModelRequest, *ModelResponse]
47+
}
48+
49+
// ModelOperation is a background operation for a model.
50+
type ModelOperation = core.Operation[*ModelResponse]
51+
52+
// StartModelOpFunc starts a background model operation.
53+
type StartModelOpFunc = func(ctx context.Context, req *ModelRequest) (*ModelOperation, error)
54+
55+
// CheckOperationFunc checks the status of a background model operation.
56+
type CheckModelOpFunc = func(ctx context.Context, op *ModelOperation) (*ModelOperation, error)
57+
58+
// CancelOperationFunc cancels a background model operation.
59+
type CancelModelOpFunc = func(ctx context.Context, op *ModelOperation) (*ModelOperation, error)
60+
61+
// BackgroundModelOptions holds configuration for defining a background model
62+
type BackgroundModelOptions struct {
63+
ModelOptions
64+
Cancel CancelModelOpFunc // Function that cancels a background model operation.
65+
Metadata map[string]any // Additional metadata.
66+
}
67+
68+
// LookupBackgroundModel looks up a BackgroundAction registered by [DefineBackgroundModel].
69+
// It returns nil if the background model was not found.
70+
func LookupBackgroundModel(r api.Registry, name string) BackgroundModel {
71+
key := api.KeyFromName(api.ActionTypeBackgroundModel, name)
72+
action := core.LookupBackgroundAction[*ModelRequest, *ModelResponse](r, key)
73+
if action == nil {
74+
return nil
75+
}
76+
return &backgroundModel{*action}
77+
}
78+
79+
// NewBackgroundModel defines a new model that runs in the background.
80+
func NewBackgroundModel(name string, opts *BackgroundModelOptions, startFn StartModelOpFunc, checkFn CheckModelOpFunc) BackgroundModel {
81+
if name == "" {
82+
panic("ai.NewBackgroundModel: name is required")
83+
}
84+
if startFn == nil {
85+
panic("ai.NewBackgroundModel: startFn is required")
86+
}
87+
if checkFn == nil {
88+
panic("ai.NewBackgroundModel: checkFn is required")
89+
}
90+
91+
if opts == nil {
92+
opts = &BackgroundModelOptions{}
93+
}
94+
if opts.Label == "" {
95+
opts.Label = name
96+
}
97+
if opts.Supports == nil {
98+
opts.Supports = &ModelSupports{}
99+
}
100+
101+
metadata := map[string]any{
102+
"type": api.ActionTypeBackgroundModel,
103+
"model": map[string]any{
104+
"label": opts.Label,
105+
"supports": map[string]any{
106+
"media": opts.Supports.Media,
107+
"context": opts.Supports.Context,
108+
"multiturn": opts.Supports.Multiturn,
109+
"systemRole": opts.Supports.SystemRole,
110+
"tools": opts.Supports.Tools,
111+
"toolChoice": opts.Supports.ToolChoice,
112+
"constrained": opts.Supports.Constrained,
113+
"output": opts.Supports.Output,
114+
"contentType": opts.Supports.ContentType,
115+
"longRunning": opts.Supports.LongRunning,
116+
},
117+
"versions": opts.Versions,
118+
"stage": opts.Stage,
119+
"customOptions": opts.ConfigSchema,
120+
},
121+
}
122+
123+
inputSchema := core.InferSchemaMap(ModelRequest{})
124+
if inputSchema != nil && opts.ConfigSchema != nil {
125+
if props, ok := inputSchema["properties"].(map[string]any); ok {
126+
props["config"] = opts.ConfigSchema
127+
}
128+
}
129+
130+
mws := []ModelMiddleware{
131+
simulateSystemPrompt(&opts.ModelOptions, nil),
132+
augmentWithContext(&opts.ModelOptions, nil),
133+
validateSupport(name, &opts.ModelOptions),
134+
addAutomaticTelemetry(),
135+
}
136+
fn := core.ChainMiddleware(mws...)(backgroundModelToModelFn(startFn))
137+
138+
wrappedFn := func(ctx context.Context, req *ModelRequest) (*ModelOperation, error) {
139+
resp, err := fn(ctx, req, nil)
140+
if err != nil {
141+
return nil, err
142+
}
143+
144+
return modelOpFromResponse(resp)
145+
}
146+
147+
return &backgroundModel{*core.NewBackgroundAction(name, api.ActionTypeBackgroundModel, metadata, wrappedFn, checkFn, opts.Cancel)}
148+
}
149+
150+
// DefineBackgroundModel defines and registers a new model that runs in the background.
151+
func DefineBackgroundModel(r *registry.Registry, name string, opts *BackgroundModelOptions, fn StartModelOpFunc, checkFn CheckModelOpFunc) BackgroundModel {
152+
m := NewBackgroundModel(name, opts, fn, checkFn)
153+
m.Register(r)
154+
return m
155+
}
156+
157+
// GenerateOperation generates a model response as a long-running operation based on the provided options.
158+
func GenerateOperation(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*ModelOperation, error) {
159+
resp, err := Generate(ctx, r, opts...)
160+
if err != nil {
161+
return nil, err
162+
}
163+
164+
return modelOpFromResponse(resp)
165+
}
166+
167+
// CheckModelOperation checks the status of a background model operation by looking up the model and calling its Check method.
168+
func CheckModelOperation(ctx context.Context, r api.Registry, op *ModelOperation) (*ModelOperation, error) {
169+
return core.CheckOperation[*ModelRequest](ctx, r, op)
170+
}
171+
172+
// backgroundModelToModelFn wraps a background model start function into a [ModelFunc] for middleware compatibility.
173+
func backgroundModelToModelFn(startFn StartModelOpFunc) ModelFunc {
174+
return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
175+
op, err := startFn(ctx, req)
176+
if err != nil {
177+
return nil, err
178+
}
179+
180+
var opError *OperationError
181+
if op.Error != nil {
182+
opError = &OperationError{Message: op.Error.Error()}
183+
}
184+
185+
metadata := op.Metadata
186+
if metadata == nil {
187+
metadata = make(map[string]any)
188+
}
189+
190+
return &ModelResponse{
191+
Operation: &Operation{
192+
Action: op.Action,
193+
Id: op.ID,
194+
Done: op.Done,
195+
Output: op.Output,
196+
Error: opError,
197+
Metadata: metadata,
198+
},
199+
Request: req,
200+
}, nil
201+
}
202+
}
203+
204+
// modelOpFromResponse extracts a [ModelOperation] from a [ModelResponse].
205+
func modelOpFromResponse(resp *ModelResponse) (*ModelOperation, error) {
206+
if resp.Operation == nil {
207+
return nil, core.NewError(core.FAILED_PRECONDITION, "background model did not return an operation")
208+
}
209+
210+
op := &ModelOperation{
211+
Action: resp.Operation.Action,
212+
ID: resp.Operation.Id,
213+
Done: resp.Operation.Done,
214+
Metadata: resp.Operation.Metadata,
215+
}
216+
217+
if resp.Operation.Error != nil {
218+
op.Error = errors.New(resp.Operation.Error.Message)
219+
}
220+
221+
if resp.Operation.Output != nil {
222+
if modelResp, ok := resp.Operation.Output.(*ModelResponse); ok {
223+
op.Output = modelResp
224+
} else {
225+
return nil, core.NewError(core.INTERNAL, "operation output is not a model response")
226+
}
227+
}
228+
229+
return op, nil
230+
}

go/ai/gen.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ type ModelSupports struct {
223223
Constrained ConstrainedSupport `json:"constrained,omitempty"`
224224
ContentType []string `json:"contentType,omitempty"`
225225
Context bool `json:"context,omitempty"`
226+
LongRunning bool `json:"longRunning,omitempty"`
226227
Media bool `json:"media,omitempty"`
227228
Multiturn bool `json:"multiturn,omitempty"`
228229
Output []string `json:"output,omitempty"`
@@ -257,8 +258,10 @@ type ModelResponse struct {
257258
FinishMessage string `json:"finishMessage,omitempty"`
258259
FinishReason FinishReason `json:"finishReason,omitempty"`
259260
// LatencyMs is the time the request took in milliseconds.
260-
LatencyMs float64 `json:"latencyMs,omitempty"`
261-
Message *Message `json:"message,omitempty"`
261+
LatencyMs float64 `json:"latencyMs,omitempty"`
262+
Message *Message `json:"message,omitempty"`
263+
Operation *Operation `json:"operation,omitempty"`
264+
Raw any `json:"raw,omitempty"`
262265
// Request is the [ModelRequest] struct used to trigger this response.
263266
Request *ModelRequest `json:"request,omitempty"`
264267
// Usage describes how many resources were used by this generation request.
@@ -275,6 +278,19 @@ type ModelResponseChunk struct {
275278
Role Role `json:"role,omitempty"`
276279
}
277280

281+
type Operation struct {
282+
Action string `json:"action,omitempty"`
283+
Done bool `json:"done,omitempty"`
284+
Error *OperationError `json:"error,omitempty"`
285+
Id string `json:"id,omitempty"`
286+
Metadata map[string]any `json:"metadata,omitempty"`
287+
Output any `json:"output,omitempty"`
288+
}
289+
290+
type OperationError struct {
291+
Message string `json:"message,omitempty"`
292+
}
293+
278294
// OutputConfig describes the structure that the model's output
279295
// should conform to. If Format is [OutputFormatJSON], then Schema
280296
// can describe the desired form of the generated JSON.

0 commit comments

Comments
 (0)