Skip to content

Commit

Permalink
feat: added support for A21 and Amazon Titan models via bedrock api (k…
Browse files Browse the repository at this point in the history
…8sgpt-ai#1101)

* feat: added support for A21 and Amazon Titan models via bedrock api

Signed-off-by: Yomesh Shah <yomesh@gmail.com>

* fix: response type for diffrent models and use of constant for top_P

Signed-off-by: Yomesh Shah <yomesh@gmail.com>

* fix: constant for top_P as int vs string

Signed-off-by: Yomesh Shah <yomesh@gmail.com>

* feat: moved topP and maxTokens to config rather than being constants in the code

Signed-off-by: Yomesh Shah <yomesh@gmail.com>

---------

Signed-off-by: Yomesh Shah <yomesh@gmail.com>
Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
  • Loading branch information
awsyshah and AlexsJones authored Sep 17, 2024
1 parent d30563d commit 4f3ecf0
Showing 1 changed file with 90 additions and 20 deletions.
110 changes: 90 additions & 20 deletions pkg/ai/amazonbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,8 @@ type AmazonBedRockClient struct {
client *bedrockruntime.BedrockRuntime
model string
temperature float32
}

// InvokeModelResponseBody represents the response body structure from the model invocation.
type InvokeModelResponseBody struct {
Completion string `json:"completion"`
Stop_reason string `json:"stop_reason"`
topP float32
maxTokens int
}

// Amazon BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt)
Expand All @@ -52,14 +48,22 @@ const (
ModelAnthropicClaudeV2 = "anthropic.claude-v2"
ModelAnthropicClaudeV1 = "anthropic.claude-v1"
ModelAnthropicClaudeInstantV1 = "anthropic.claude-instant-v1"
ModelA21J2UltraV1 = "ai21.j2-ultra-v1"
ModelA21J2JumboInstruct = "ai21.j2-jumbo-instruct"
ModelAmazonTitanExpressV1 = "amazon.titan-text-express-v1"
)

var BEDROCK_MODELS = []string{
ModelAnthropicClaudeV2,
ModelAnthropicClaudeV1,
ModelAnthropicClaudeInstantV1,
ModelA21J2UltraV1,
ModelA21J2JumboInstruct,
ModelAmazonTitanExpressV1,
}

//const TOPP = 0.9 moved to config

// GetModelOrDefault check config model
func GetModelOrDefault(model string) string {

Expand Down Expand Up @@ -109,21 +113,46 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
a.client = bedrockruntime.New(sess)
a.model = GetModelOrDefault(config.GetModel())
a.temperature = config.GetTemperature()
a.topP = config.GetTopP()
a.maxTokens = config.GetMaxTokens()

return nil
}

// GetCompletion sends a request to the model for generating completion based on the provided prompt.
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) {

// Prepare the input data for the model invocation
request := map[string]interface{}{
"prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
"max_tokens_to_sample": 1024,
"temperature": a.temperature,
"top_p": 0.9,
// Prepare the input data for the model invocation based on the model & the Response Body per model as well.
var request map[string]interface{}
switch a.model {
case ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1:
request = map[string]interface{}{
"prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt),
"max_tokens_to_sample": a.maxTokens,
"temperature": a.temperature,
"top_p": a.topP,
}
case ModelA21J2UltraV1, ModelA21J2JumboInstruct:
request = map[string]interface{}{
"prompt": prompt,
"maxTokens": a.maxTokens,
"temperature": a.temperature,
"topP": a.topP,
}
case ModelAmazonTitanExpressV1:
request = map[string]interface{}{
"inputText": fmt.Sprintf("\n\nUser: %s", prompt),
"textGenerationConfig": map[string]interface{}{
"maxTokenCount": a.maxTokens,
"temperature": a.temperature,
"topP": a.topP,
},
}
default:
return "", fmt.Errorf("model %s not supported", a.model)
}


body, err := json.Marshal(request)
if err != nil {
return "", err
Expand All @@ -142,15 +171,56 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
if err != nil {
return "", err
}
// Parse the response body
output := &InvokeModelResponseBody{}
err = json.Unmarshal(resp.Body, output)
if err != nil {
return "", err
}
return output.Completion, nil

// Response type changes as per model
switch a.model {
case ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1:
type InvokeModelResponseBody struct {
Completion string `json:"completion"`
Stop_reason string `json:"stop_reason"`
}
output := &InvokeModelResponseBody{}
err = json.Unmarshal(resp.Body, output)
if err != nil {
return "", err
}
return output.Completion, nil
case ModelA21J2UltraV1, ModelA21J2JumboInstruct:
type Data struct {
Text string `json:"text"`
}
type Completion struct {
Data Data `json:"data"`
}
type InvokeModelResponseBody struct {
Completions []Completion `json:"completions"`
}
output := &InvokeModelResponseBody{}
err = json.Unmarshal(resp.Body, output)
if err != nil {
return "", err
}
return output.Completions[0].Data.Text, nil
case ModelAmazonTitanExpressV1:
type Result struct {
TokenCount int `json:"tokenCount"`
OutputText string `json:"outputText"`
CompletionReason string `json:"completionReason"`
}
type InvokeModelResponseBody struct {
InputTextTokenCount int `json:"inputTextTokenCount"`
Results []Result `json:"results"`
}
output := &InvokeModelResponseBody{}
err = json.Unmarshal(resp.Body, output)
if err != nil {
return "", err
}
return output.Results[0].OutputText, nil
default:
return "", fmt.Errorf("model %s not supported", a.model)
}
}

// GetName returns the name of the AmazonBedRockClient.
func (a *AmazonBedRockClient) GetName() string {
return amazonbedrockAIClientName
Expand Down

0 comments on commit 4f3ecf0

Please sign in to comment.