Skip to content

Commit

Permalink
Merge pull request weaviate#4906 from weaviate/octoai_header
Browse files Browse the repository at this point in the history
Octoai header fixes + image generation
  • Loading branch information
dirkkul authored May 13, 2024
2 parents f7e440e + bfc69dd commit fc01e49
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 50 deletions.
79 changes: 65 additions & 14 deletions modules/generative-octoai/clients/octoai.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (v *octoai) GenerateAllResults(ctx context.Context, textProperties []map[st
func (v *octoai) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) {
settings := config.NewClassSettings(cfg)

octoAIUrl, err := v.getOctoAIUrl(ctx, settings.BaseURL())
octoAIUrl, isImage, err := v.getOctoAIUrl(ctx, settings.BaseURL())
if err != nil {
return nil, errors.Wrap(err, "join OctoAI API host and path")
}
Expand All @@ -78,11 +78,29 @@ func (v *octoai) Generate(ctx context.Context, cfg moduletools.ClassConfig, prom
{"role": "user", "content": prompt},
}

input := generateInput{
Messages: octoAIPrompt,
Model: settings.Model(),
MaxTokens: settings.MaxTokens(),
Temperature: settings.Temperature(),
var input interface{}
if !isImage {
input = generateInputText{
Messages: octoAIPrompt,
Model: settings.Model(),
MaxTokens: settings.MaxTokens(),
Temperature: settings.Temperature(),
}
} else {
input = generateInputImage{
Prompt: prompt,
NegativePrompt: "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, signature, cut off, draft",
Sampler: "DDIM",
CfgScale: 11,
Height: 1024,
Width: 1024,
Seed: 0,
Steps: 20,
NumImages: 1,
HighNoiseFrac: 0.7,
Strength: 0.92,
UseRefiner: true,
}
}

body, err := json.Marshal(input)
Expand Down Expand Up @@ -125,19 +143,29 @@ func (v *octoai) Generate(ctx context.Context, cfg moduletools.ClassConfig, prom
return nil, errors.Errorf("connection to OctoAI API failed with status: %d", res.StatusCode)
}

textResponse := resBody.Choices[0].Message.Content
var textResponse string
if isImage {
textResponse = resBody.Images[0].Image
} else {
textResponse = resBody.Choices[0].Message.Content
}

return &generativemodels.GenerateResponse{
Result: &textResponse,
}, nil
}

func (v *octoai) getOctoAIUrl(ctx context.Context, baseURL string) (string, error) {
func (v *octoai) getOctoAIUrl(ctx context.Context, baseURL string) (string, bool, error) {
passedBaseURL := baseURL
if headerBaseURL := v.getValueFromContext(ctx, "X-Octoai-Baseurl"); headerBaseURL != "" {
passedBaseURL = headerBaseURL
}
return url.JoinPath(passedBaseURL, "/v1/chat/completions")
if strings.Contains(passedBaseURL, "image") {
urlTmp, err := url.JoinPath(passedBaseURL, "/generate/sdxl")
return urlTmp, true, err
}
urlTmp, err := url.JoinPath(passedBaseURL, "/v1/chat/completions")
return urlTmp, false, err
}

func (v *octoai) generatePromptForTask(textProperties []map[string]string, task string) (string, error) {
Expand Down Expand Up @@ -177,24 +205,40 @@ func (v *octoai) getValueFromContext(ctx context.Context, key string) string {
}

func (v *octoai) getApiKey(ctx context.Context) (string, error) {
if apiKey := v.getValueFromContext(ctx, "X-OctoAI-Api-Key"); apiKey != "" {
return apiKey, nil
}
if v.apiKey != "" {
return v.apiKey, nil
}
if apiKey := modulecomponents.GetValueFromContext(ctx, "X-OctoAI-Api-Key"); apiKey != "" {
return apiKey, nil
}
return "", errors.New("no api key found " +
"neither in request header: X-OctoAI-Api-Key " +
"nor in environment variable under OCTOAI_APIKEY")
}

type generateInput struct {
type generateInputText struct {
Model string `json:"model"`
Messages []map[string]string `json:"messages"`
MaxTokens int `json:"max_tokens"`
Temperature int `json:"temperature"`
}

type generateInputImage struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt"`
Sampler string `json:"sampler"`
CfgScale int `json:"cfg_scale"`
Height int `json:"height"`
Width int `json:"width"`
Seed int `json:"seed"`
Steps int `json:"steps"`
NumImages int `json:"num_images"`
HighNoiseFrac float64 `json:"high_noise_frac"`
Strength float64 `json:"strength"`
UseRefiner bool `json:"use_refiner"`
// StylePreset string `json:"style_preset"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Expand All @@ -205,9 +249,16 @@ type Choice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}

type Image struct {
Image string `json:"image_b64"`
}

type generateResponse struct {
Choices []Choice
Error *octoaiApiError `json:"error,omitempty"`
Images []Image

Error *octoaiApiError `json:"error,omitempty"`
}

type octoaiApiError struct {
Expand Down
14 changes: 12 additions & 2 deletions modules/generative-octoai/clients/octoai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,23 @@ func TestGetAnswer(t *testing.T) {
}
})
}
t.Run("when X-Octoai-BaseURL header is passed", func(t *testing.T) {
t.Run("when X-Octoai-BaseURL header is passed for text", func(t *testing.T) {
c := New("apiKey", 5*time.Second, nullLogger())
baseUrl := "https://text.octoai.run"
buildURL, err := c.getOctoAIUrl(context.Background(), baseUrl)
buildURL, isImage, err := c.getOctoAIUrl(context.Background(), baseUrl)
assert.Equal(t, nil, err)
assert.Equal(t, false, isImage)
assert.Equal(t, "https://text.octoai.run/v1/chat/completions", buildURL)
})

t.Run("when X-Octoai-BaseURL header is passed for image", func(t *testing.T) {
c := New("apiKey", 5*time.Second, nullLogger())
baseUrl := "https://image.octoai.run"
buildURL, isImage, err := c.getOctoAIUrl(context.Background(), baseUrl)
assert.Equal(t, nil, err)
assert.Equal(t, true, isImage)
assert.Equal(t, "https://image.octoai.run/generate/sdxl", buildURL)
})
}

type testAnswerHandler struct {
Expand Down
47 changes: 13 additions & 34 deletions modules/text2vec-octoai/clients/octoai.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ func buildUrl(config ent.VectorizationConfig) (string, error) {
}

type vectorizer struct {
octoAIApiKey string
httpClient *http.Client
buildUrlFn func(config ent.VectorizationConfig) (string, error)
logger logrus.FieldLogger
apiKey string
httpClient *http.Client
buildUrlFn func(config ent.VectorizationConfig) (string, error)
logger logrus.FieldLogger
}

func New(octoAIApiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
return &vectorizer{
octoAIApiKey: octoAIApiKey,
apiKey: apiKey,
httpClient: &http.Client{
Timeout: timeout,
},
Expand Down Expand Up @@ -182,36 +182,15 @@ func (v *vectorizer) getApiKeyHeaderAndValue(apiKey string) (string, string) {
}

func (v *vectorizer) getApiKey(ctx context.Context) (string, error) {
var apiKey, envVar string

apiKey = "X-OctoAI-Api-Key"
envVar = "OCTOAI_APIKEY"
if len(v.octoAIApiKey) > 0 {
return v.octoAIApiKey, nil
}

return v.getApiKeyFromContext(ctx, apiKey, envVar)
}

func (v *vectorizer) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) {
if apiKeyValue := v.getValueFromContext(ctx, apiKey); apiKeyValue != "" {
return apiKeyValue, nil
if v.apiKey != "" {
return v.apiKey, nil
}
return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar)
}

func (v *vectorizer) getValueFromContext(ctx context.Context, key string) string {
if value := ctx.Value(key); value != nil {
if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
return keyHeader[0]
}
if apiKey := modulecomponents.GetValueFromContext(ctx, "X-OctoAI-Api-Key"); apiKey != "" {
return apiKey, nil
}
// try getting header from GRPC if not successful
if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
return apiKey[0]
}

return ""
return "", errors.New("no api key found " +
"neither in request header: X-OctoAI-Api-Key " +
"nor in environment variable under OCTOAI_APIKEY")
}

func (v *vectorizer) GetApiKeyHash(ctx context.Context, config moduletools.ClassConfig) [32]byte {
Expand Down

0 comments on commit fc01e49

Please sign in to comment.