Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions go/ai/_test_data/prompts/example.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
model: test-model
maxTurns: 5
description: A test prompt
toolChoice: required
returnToolRequests: true
input:
schema:
type: object
properties:
name:
type: string
default:
name: world
output:
format: text
schema:
type: string
---
Hello, {{name}}!
24 changes: 18 additions & 6 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"errors"
"fmt"
"io/fs"
"log/slog"
"maps"
"os"
Expand Down Expand Up @@ -505,12 +506,17 @@ func LoadPromptDir(r api.Registry, dir string, namespace string) {
return
}

loadPromptDir(r, path, namespace)
loadPromptDir(r, os.DirFS(dir), ".", namespace)
}

// LoadPromptFS loads prompts and partials from the given filesystem for the given namespace.
func LoadPromptFS(r api.Registry, fsys fs.FS, dir string, namespace string) {
loadPromptDir(r, fsys, dir, namespace)
}

// loadPromptDir recursively loads prompts and partials from the directory.
func loadPromptDir(r api.Registry, dir string, namespace string) {
entries, err := os.ReadDir(dir)
func loadPromptDir(r api.Registry, fsys fs.FS, dir, namespace string) {
entries, err := fs.ReadDir(fsys, dir)
if err != nil {
panic(fmt.Errorf("failed to read prompt directory structure: %w", err))
}
Expand All @@ -519,7 +525,7 @@ func loadPromptDir(r api.Registry, dir string, namespace string) {
filename := entry.Name()
path := filepath.Join(dir, filename)
if entry.IsDir() {
loadPromptDir(r, path, namespace)
loadPromptDir(r, fsys, path, namespace)
} else if strings.HasSuffix(filename, ".prompt") {
if strings.HasPrefix(filename, "_") {
partialName := strings.TrimSuffix(filename[1:], ".prompt")
Expand All @@ -531,19 +537,25 @@ func loadPromptDir(r api.Registry, dir string, namespace string) {
r.RegisterPartial(partialName, string(source))
slog.Debug("Registered Dotprompt partial", "name", partialName, "file", path)
} else {
LoadPrompt(r, dir, filename, namespace)
loadPrompt(r, fsys, dir, filename, namespace)
}
}
}
}

// LoadPrompt loads a single prompt into the registry.
func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
dir, rest := filepath.Split(dir)
return loadPrompt(r, os.DirFS(dir), rest, filename, namespace)
}

// loadPrompt uses provided fsys to load a single prompt into the registry.
func loadPrompt(r api.Registry, fsys fs.FS, dir, filename, namespace string) Prompt {
name := strings.TrimSuffix(filename, ".prompt")
name, variant, _ := strings.Cut(name, ".")

sourceFile := filepath.Join(dir, filename)
source, err := os.ReadFile(sourceFile)
source, err := fs.ReadFile(fsys, sourceFile)
if err != nil {
slog.Error("Failed to read prompt file", "file", sourceFile, "error", err)
return nil
Expand Down
13 changes: 13 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package ai

import (
"context"
"embed"
"fmt"
"os"
"path/filepath"
Expand All @@ -29,6 +30,9 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
)

//go:embed _test_data/prompts
var embededPrompts embed.FS

type InputOutput struct {
Text string `json:"text"`
}
Expand Down Expand Up @@ -877,6 +881,15 @@ func assertResponse(t *testing.T, resp *ModelResponse, want string) {
}
}

func TestLoadPrompt_FromFS(t *testing.T) {
reg := registry.New()
LoadPromptFS(reg, embededPrompts, "_test_data/prompts", "test-namespace")
prompt := LookupPrompt(reg, "test-namespace/example")
if prompt == nil {
t.Fatalf("Prompt was not registered")
}
}

func TestLoadPrompt(t *testing.T) {
// Create a temporary directory for testing
tempDir := t.TempDir()
Expand Down
21 changes: 20 additions & 1 deletion go/genkit/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"context"
"errors"
"fmt"
"io/fs"
"log/slog"
"os"
"os/signal"
Expand All @@ -46,6 +47,7 @@ type Genkit struct {
type genkitOptions struct {
DefaultModel string // Default model to use if no other model is specified.
PromptDir string // Directory where dotprompts are stored. Will be loaded automatically on initialization.
PromptFS fs.FS // Filesystem that will be used for PromptDir lookup.
Plugins []api.Plugin // Plugin to initialize automatically.
}

Expand All @@ -69,6 +71,13 @@ func (o *genkitOptions) apply(gOpts *genkitOptions) error {
gOpts.PromptDir = o.PromptDir
}

if o.PromptFS != nil {
if gOpts.PromptFS != nil {
return errors.New("cannot set prompt filesystem more than once (WithPromptFS)")
}
gOpts.PromptFS = o.PromptFS
}

if len(o.Plugins) > 0 {
if gOpts.Plugins != nil {
return errors.New("cannot set plugins more than once (WithPlugins)")
Expand Down Expand Up @@ -106,6 +115,12 @@ func WithPromptDir(dir string) GenkitOption {
return &genkitOptions{PromptDir: dir}
}

// WithPromptFS is a more generic version of `WithPromptDir` and accepts a filesytem
// instead of directory path
func WithPromptFS(fsys fs.FS) GenkitOption {
return &genkitOptions{PromptFS: fsys}
}

// Init creates and initializes a new [Genkit] instance with the provided options.
// It sets up the registry, initializes plugins ([WithPlugins]), loads prompts
// ([WithPromptDir]), and configures other settings like the default model
Expand Down Expand Up @@ -184,7 +199,11 @@ func Init(ctx context.Context, opts ...GenkitOption) *Genkit {

ai.ConfigureFormats(r)
ai.DefineGenerateAction(ctx, r)
ai.LoadPromptDir(r, gOpts.PromptDir, "")
if gOpts.PromptFS == nil {
ai.LoadPromptDir(r, gOpts.PromptDir, "")
} else {
ai.LoadPromptFS(r, gOpts.PromptFS, gOpts.PromptDir, "")
}

r.RegisterValue(api.DefaultModelKey, gOpts.DefaultModel)
r.RegisterValue(api.PromptDirKey, gOpts.PromptDir)
Expand Down
6 changes: 6 additions & 0 deletions go/samples/prompts-dir/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package main

import (
"context"
"embed"
"errors"

// Import Genkit and the Google AI plugin
Expand All @@ -14,12 +15,17 @@ import (
"github.com/firebase/genkit/go/plugins/googlegenai"
)

//go:embed prompts
var prompts embed.FS

func main() {
ctx := context.Background()

g := genkit.Init(ctx,
genkit.WithPlugins(&googlegenai.GoogleAI{}),
genkit.WithPromptDir("prompts"),
// Without it OS's filesystem will be used
genkit.WithPromptFS(prompts),
)

type greetingStyle struct {
Expand Down