Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions cmd/llm/chat2.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@ import (
// TYPES

type Chat2Cmd struct {
Model string `arg:"" help:"Model name"`
Token string `env:"TELEGRAM_TOKEN" help:"Telegram token" required:""`
Model string `arg:"" help:"Model name"`
TelegramToken string `env:"TELEGRAM_TOKEN" help:"Telegram token" required:""`
System string `flag:"system" help:"Set the system prompt"`
}

type Server struct {
sync.RWMutex
*telegram.Client

// Model and toolkit
toolkit llm.ToolKit
model llm.Model
toolkit llm.ToolKit
opts []llm.Opt

// Map of active sessions
sessions map[string]llm.Context
Expand All @@ -35,11 +37,15 @@ type Server struct {
////////////////////////////////////////////////////////////////////////////////
// LIFECYCLE

func NewTelegramServer(token string, model llm.Model, toolkit llm.ToolKit, opts ...telegram.Opt) (*Server, error) {
func NewTelegramServer(token string, model llm.Model, system string, toolkit llm.ToolKit, opts ...telegram.Opt) (*Server, error) {
server := new(Server)
server.sessions = make(map[string]llm.Context)
server.model = model
server.toolkit = toolkit
server.opts = []llm.Opt{
llm.WithToolKit(toolkit),
llm.WithSystemPrompt(system),
}

// Create a new telegram client
opts = append(opts, telegram.WithCallback(server.receive))
Expand All @@ -58,12 +64,12 @@ func NewTelegramServer(token string, model llm.Model, toolkit llm.ToolKit, opts

func (cmd *Chat2Cmd) Run(globals *Globals) error {
return run(globals, cmd.Model, func(ctx context.Context, model llm.Model) error {
server, err := NewTelegramServer(cmd.Token, model, globals.toolkit, telegram.WithDebug(globals.Debug))
server, err := NewTelegramServer(cmd.TelegramToken, model, cmd.System, globals.toolkit, telegram.WithDebug(globals.Debug))
if err != nil {
return err
}

log.Printf("Running Telegram bot %q\n", server.Client.Name())
log.Printf("Running Telegram bot %q with model %q\n", server.Client.Name(), model.Name())

var result error
var wg sync.WaitGroup
Expand Down Expand Up @@ -103,7 +109,7 @@ func (telegram *Server) Purge() {
telegram.Lock()
defer telegram.Unlock()
for user, session := range telegram.sessions {
if session.SinceLast() > 10*time.Minute {
if session.SinceLast() > 5*time.Minute {
log.Printf("Purging session for %q\n", user)
delete(telegram.sessions, user)
}
Expand All @@ -116,10 +122,7 @@ func (telegram *Server) session(user string) llm.Context {
if session, exists := telegram.sessions[user]; exists {
return session
}
session := telegram.model.Context(
llm.WithToolKit(telegram.toolkit),
llm.WithSystemPrompt("Please reply to messages in markdown format."),
)
session := telegram.model.Context(telegram.opts...)
telegram.sessions[user] = session
return session
}
Expand All @@ -130,7 +133,6 @@ func (telegram *Server) receive(ctx context.Context, msg telegram.Message) error

// Process the message
text := msg.Text()
text += "\n\nPlease reply in markdown format."
if err := session.FromUser(ctx, text); err != nil {
return err
}
Expand All @@ -144,7 +146,7 @@ func (telegram *Server) receive(ctx context.Context, msg telegram.Message) error
if text := session.Text(0); text != "" {
msg.Reply(ctx, text, false)
} else {
msg.Reply(ctx, "_Gathering information_", true)
msg.Reply(ctx, "Gathering information", true)
}

results, err := telegram.toolkit.Run(ctx, calls...)
Expand Down