Skip to content

Commit

Permalink
feat: parallelize suggestion requests (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
chase-crumbaugh authored Aug 10, 2023
1 parent aad4dca commit d027ce4
Show file tree
Hide file tree
Showing 6 changed files with 440 additions and 274 deletions.
29 changes: 25 additions & 4 deletions cmd/suggest.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package cmd

import (
"errors"
"fmt"
"github.com/manifoldco/promptui"
"github.com/speakeasy-api/openapi-generation/v2/pkg/errors"
"github.com/speakeasy-api/speakeasy/internal/suggestions"
"github.com/speakeasy-api/speakeasy/internal/validation"
"github.com/spf13/cobra"
"golang.org/x/exp/slices"
"strings"
)

Expand All @@ -19,11 +19,15 @@ you must first create an API key via https://app.speakeasyapi.dev and then set t
RunE: suggestFixesOpenAPI,
}

var severities = fmt.Sprintf("%s, %s, or %s", errors.SeverityError, errors.SeverityWarn, errors.SeverityHint)

func suggestInit() {
suggestCmd.Flags().StringP("schema", "s", "", "path to the OpenAPI document")
suggestCmd.Flags().BoolP("auto-approve", "a", false, "auto continue through all prompts")
suggestCmd.Flags().StringP("output-file", "o", "", "output the modified file with suggested fixes applied to the specified path")
suggestCmd.Flags().IntP("max-suggestions", "n", -1, "maximum number of llm suggestions to fetch, the default is no limit")
suggestCmd.Flags().StringP("level", "l", "warn", fmt.Sprintf("%s. The minimum level of severity to request suggestions for", severities))
suggestCmd.Flags().BoolP("serial", "", false, "do not parallelize requesting suggestions")
suggestCmd.Flags().StringP("model", "m", "gpt-4-0613", "model to use when making llm suggestions (gpt-4-0613 recommended)")
_ = suggestCmd.MarkFlagRequired("schema")
rootCmd.AddCommand(suggestCmd)
Expand All @@ -42,6 +46,16 @@ func suggestFixesOpenAPI(cmd *cobra.Command, args []string) error {
return err
}

level, err := cmd.Flags().GetString("level")
if err != nil {
return err
}

severity := errors.Severity(level)
if !slices.Contains([]errors.Severity{errors.SeverityError, errors.SeverityWarn, errors.SeverityHint}, severity) {
return fmt.Errorf("level must be one of %s", severities)
}

outputFile, err := cmd.Flags().GetString("output-file")
if err != nil {
return err
Expand All @@ -58,13 +72,20 @@ func suggestFixesOpenAPI(cmd *cobra.Command, args []string) error {
}

if !strings.HasPrefix(modelName, "gpt-3.5") && !strings.HasPrefix(modelName, "gpt-4") {
return errors.New("only gpt3.5 and gpt4 based models supported")
return fmt.Errorf("only gpt3.5 and gpt4 based models supported")
}

dontParallelize, err := cmd.Flags().GetBool("serial")
if err != nil {
return err
}

suggestionConfig := suggestions.Config{
AutoContinue: autoApprove,
Model: modelName,
OutputFile: outputFile,
Parallelize: !dontParallelize,
Level: severity,
}

maxSuggestion, err := cmd.Flags().GetInt("max-suggestions")
Expand All @@ -76,7 +97,7 @@ func suggestFixesOpenAPI(cmd *cobra.Command, args []string) error {
suggestionConfig.MaxSuggestions = &maxSuggestion
}

if err := validation.ValidateOpenAPI(cmd.Context(), schemaPath, &suggestionConfig, true); err != nil {
if err := suggestions.StartSuggest(cmd.Context(), schemaPath, &suggestionConfig, true); err != nil {
rootCmd.SilenceUsage = true

return err
Expand Down
2 changes: 1 addition & 1 deletion cmd/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func validateOpenAPI(cmd *cobra.Command, args []string) error {
return err
}

if err := validation.ValidateOpenAPI(cmd.Context(), schemaPath, nil, outputHints); err != nil {
if err := validation.ValidateOpenAPI(cmd.Context(), schemaPath, outputHints); err != nil {
rootCmd.SilenceUsage = true

return err
Expand Down
19 changes: 3 additions & 16 deletions internal/suggestions/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,20 @@ import (
)

const uploadTimeout = time.Minute * 2
const suggestionTimeout = time.Minute * 3
const suggestionTimeout = time.Minute * 15 // Very high because of parallelism (the server will go as fast as it can based on OpenAI's rate limits)

const ApiURL = "https://api.prod.speakeasyapi.dev"

var baseURL = ApiURL

type Suggestion struct {
SuggestedFix string `json:"suggested_fix"`
JSONPatch string `json:"json_patch"`
Reasoning string `json:"reasoning"`
}

type suggestionRequest struct {
Error string `json:"error"`
Severity errors.Severity `json:"severity"`
LineNumber int `json:"line_number"`
PreviousSuggestionContext *string `json:"previous_suggestion_context,omitempty"`
}

func Upload(schema []byte, filePath string) (string, string, error) {
func Upload(schema []byte, filePath string, model string) (string, string, error) {
openAIKey, err := GetOpenAIKey()
if err != nil {
return "", "", err
Expand Down Expand Up @@ -75,6 +69,7 @@ func Upload(schema []byte, filePath string) (string, string, error) {
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("x-openai-key", openAIKey)
req.Header.Set("x-api-key", apiKey)
req.Header.Set("x-openai-model", model)

client := &http.Client{
Timeout: uploadTimeout,
Expand Down Expand Up @@ -109,14 +104,8 @@ func GetSuggestion(
severity errors.Severity,
lineNumber int,
fileType string,
model string,
previousSuggestionContext *string,
) (*Suggestion, error) {
openAIKey, err := GetOpenAIKey()
if err != nil {
return nil, err
}

apiKey, err := getSpeakeasyAPIKey()
if err != nil {
return nil, err
Expand All @@ -141,10 +130,8 @@ func GetSuggestion(

req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-session-token", token)
req.Header.Set("x-openai-key", openAIKey)
req.Header.Set("x-api-key", apiKey)
req.Header.Set("x-file-type", fileType)
req.Header.Set("x-openai-model", model)

client := &http.Client{
Timeout: suggestionTimeout,
Expand Down
238 changes: 238 additions & 0 deletions internal/suggestions/suggest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
package suggestions

import (
"context"
goerr "errors"
"fmt"
"github.com/manifoldco/promptui"
"github.com/speakeasy-api/openapi-generation/v2/pkg/errors"
"github.com/speakeasy-api/speakeasy/internal/auth"
"github.com/speakeasy-api/speakeasy/internal/log"
"github.com/speakeasy-api/speakeasy/internal/validation"
"go.uber.org/zap"
"math"
"os"
"os/signal"
"syscall"
)

var ErrNoSuggestionFound = goerr.New("no suggestion found")

const suggestionBatchSize = 5

func StartSuggest(ctx context.Context, schemaPath string, suggestionsConfig *Config, outputHints bool) error {
fmt.Println("Validating OpenAPI spec...")
fmt.Println()

schema, err := os.ReadFile(schemaPath)
if err != nil {
return fmt.Errorf("failed to read schema file %s: %w", schemaPath, err)
}

schema, err = ReformatFile(schema, DetectFileType(schemaPath))
if err != nil {
return fmt.Errorf("failed to reformat schema file %s: %w", schemaPath, err)
}

vErrs, vWarns, vInfo, err := validation.Validate(schema, schemaPath, outputHints)
if err != nil {
return err
}

printValidationSummary(vErrs, vWarns, vInfo)

toSuggestFor := vErrs
switch suggestionsConfig.Level {
case errors.SeverityWarn:
toSuggestFor = append(toSuggestFor, vWarns...)
break
case errors.SeverityHint:
toSuggestFor = append(append(toSuggestFor, vWarns...), vInfo...)
break
}

// Limit the number of errors to MaxSuggestions
if suggestionsConfig.MaxSuggestions != nil && *suggestionsConfig.MaxSuggestions < len(toSuggestFor) {
toSuggestFor = toSuggestFor[:*suggestionsConfig.MaxSuggestions]
}

if len(toSuggestFor) > 0 {
err = Suggest(schema, schemaPath, toSuggestFor, *suggestionsConfig)
if err != nil {
fmt.Println(promptui.Styler(promptui.FGRed, promptui.FGBold)(fmt.Sprintf("cannot fetch llm suggestions: %s", err.Error())))
return err
}

if suggestionsConfig.OutputFile != "" && suggestionsConfig.AutoContinue {
fmt.Println(promptui.Styler(promptui.FGWhite)("Suggestions applied and written to " + suggestionsConfig.OutputFile))
fmt.Println()
}
} else {
fmt.Println(promptui.Styler(promptui.FGGreen, promptui.FGBold)("Congrats! 🎊 Your spec had no issues we could detect."))
}

return nil
}

func Suggest(schema []byte, schemaPath string, errs []error, config Config) error {
suggestionToken := ""
fileType := ""
totalSuggestions := 0

l := log.NewLogger(schemaPath)

// local authentication should just be set in env variable
if os.Getenv("SPEAKEASY_SERVER_URL") != "http://localhost:35290" {
if err := auth.Authenticate(false); err != nil {
return err
}
}

if _, err := GetOpenAIKey(); err != nil {
return err
}

suggestionToken, fileType, err := Upload(schema, schemaPath, config.Model)
if err != nil {
return err
} else {
// Cleanup Memory Usage in LLM
defer func() {
Clear(suggestionToken)
}()

// Handle Signal Exit
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
Clear(suggestionToken)
os.Exit(0)
}()
}

suggest, err := New(suggestionToken, schemaPath, fileType, schema, config)
if err != nil {
return err
}

/**
* Parallelized suggestions
*/
if config.Parallelize {
fmt.Println("Getting suggestions...")
fmt.Println()

suggest.Verbose = false

// Request suggestions in parallel, in batches of suggestionBatchSize
suggestions := make([]*Suggestion, len(errs))
for i := 0; i < len(errs); i += suggestionBatchSize {
end := int(math.Min(float64(i+suggestionBatchSize), float64(len(errs))))
res, err := suggest.FindSuggestions(errs[i:end])
if err != nil {
return err
}

suggestions = append(suggestions, res...)
}

for i, err := range errs {
suggestion := suggestions[i]

printVErr(l, err)
fmt.Println() // Spacing
suggestion.Print()

if suggestion != nil {
fmt.Println(promptui.Styler(promptui.FGGreen, promptui.FGBold)("✓ Suggestion is valid and resolves the error"))
fmt.Println() // Spacing

if suggest.AwaitShouldApply() {
newFile, err := suggest.ApplySuggestion(*suggestion)
if err != nil {
return err
}

err = suggest.CommitSuggestion(newFile)
if err != nil {
return err
}
}
}
}

return nil
}

/**
* Non-parallelized suggestions
*/
for _, validationErr := range errs {
if suggest.ShouldSkip(validationErr) {
continue
}

printVErr(l, validationErr)

_, newFile, err := suggest.GetSuggestionAndRevalidate(validationErr, nil)

if err != nil {
if goerr.Is(err, ErrNoSuggestionFound) {
fmt.Println("Did not find a suggestion for error.")
suggest.Skip(validationErr)
continue
} else {
return err
}
}

if suggest.AwaitShouldApply() {
err := suggest.CommitSuggestion(newFile)
if err != nil {
return err
}
} else {
suggest.Skip(validationErr)
}

totalSuggestions++
}

return nil
}

func printVErr(l *log.Logger, sourceErr error) {
vErr := errors.GetValidationErr(sourceErr)

if vErr != nil {
if vErr.Severity == errors.SeverityError {
l.Error("", zap.Error(sourceErr))
} else if vErr.Severity == errors.SeverityWarn {
l.Warn("", zap.Error(sourceErr))
} else if vErr.Severity == errors.SeverityHint {
l.Info("", zap.Error(sourceErr))
}
}
}

func printValidationSummary(errs []error, warns []error, info []error) {
pluralize := func(s string, n int) string {
if n == 1 {
return s
} else {
return s + "s"
}
}

stringify := func(s string, errs []error) string {
return fmt.Sprintf("%d %s", len(errs), pluralize(s, len(errs)))
}

fmt.Printf(
"Found %s, %s, and %s.\n\n",
promptui.Styler(promptui.FGRed, promptui.FGBold)(stringify("error", errs)),
promptui.Styler(promptui.FGYellow, promptui.FGBold)(stringify("warning", warns)),
promptui.Styler(promptui.FGBlue, promptui.FGBold)(stringify("hint", info)),
)
}
Loading

0 comments on commit d027ce4

Please sign in to comment.