diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
index 62e9c47f..e504ec17 100644
--- a/.github/workflows/lint.yaml
+++ b/.github/workflows/lint.yaml
@@ -23,7 +23,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
- go-version: "1.21"
+ go-version: "1.22"
check-latest: true
- name: Install
@@ -56,7 +56,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
- go-version: "1.21"
+ go-version: "1.22"
check-latest: true
- name: Build
run: go build -v ./...
@@ -70,12 +70,9 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
- go-version: "1.21"
+ go-version: "1.22"
check-latest: true
- - name: Install staticcheck
- run: go install honnef.co/go/tools/cmd/staticcheck@latest
-
- name: Install nilaway
run: go install go.uber.org/nilaway/cmd/nilaway@latest
@@ -85,8 +82,6 @@ jobs:
version: latest
args: --timeout 5m
- - name: Staticcheck
- run: staticcheck ./...
# TODO: Ignore the issue in https://github.com/modelgateway/Glide/issues/32
# - name: Nilaway
# run: nilaway ./...
@@ -100,7 +95,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
- go-version: "1.21"
+ go-version: "1.22"
check-latest: true
- name: Test
@@ -126,7 +121,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
- go-version: "1.21"
+ go-version: "1.22"
check-latest: true
- name: Generate OpenAPI Schema
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index f97309b8..69f01d09 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -23,7 +23,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
- go-version: 1.21
+ go-version: 1.22
- name: Checkout
uses: actions/checkout@v4
@@ -69,9 +69,6 @@ jobs:
- name: login into Github Container Registry
run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u $ --password-stdin
- - name: login into Github Container Registry
- run: echo "${{ secrets.DOCKER_HUB_TOKEN }}" | docker login -u einstack --password-stdin
-
- name: login into Docker
run: echo "${{ secrets.DOCKER_HUB_TOKEN }}" | docker login -u einstack --password-stdin
diff --git a/.github/workflows/vuln.yaml b/.github/workflows/vuln.yaml
index 9e363c8e..aa26598c 100644
--- a/.github/workflows/vuln.yaml
+++ b/.github/workflows/vuln.yaml
@@ -27,20 +27,22 @@ jobs:
- name: Install Go
uses: actions/setup-go@v4
with:
- go-version: '1.21.5'
+ go-version: '1.22.1'
check-latest: true
- name: Checkout
uses: actions/checkout@v3
- - name: Install govulncheck
- run: go install golang.org/x/vuln/cmd/govulncheck@latest
+# TODO: enable in https://github.com/EinStack/glide/issues/169
+# - name: Install govulncheck
+# run: go install golang.org/x/vuln/cmd/govulncheck@latest
- name: Install gosec
run: go install github.com/securego/gosec/v2/cmd/gosec@latest
- - name: Govulncheck
- run: govulncheck -test ./...
+# TODO: enable in https://github.com/EinStack/glide/issues/169
+# - name: Govulncheck
+# run: govulncheck -test ./...
- name: Govulncheck
run: gosec ./...
diff --git a/.go-version b/.go-version
index d2ab029d..71f7f51d 100644
--- a/.go-version
+++ b/.go-version
@@ -1 +1 @@
-1.21
+1.22
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a84dbc27..f6c2c138 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,75 +1,138 @@
# Changelog
-The changelog consists of three categories:
-- **Features** - a new functionality that brings value to users
-- **Improvements** - bugfixes, performance and other types of improvements to existing functionality
+The changelog consists of eight categories:
+- **Added** - new functionality that brings value to users
+- **Changed** - changes in existing functionality, performance and other types of improvements
+- **Fixed** - bugfixes
+- **Deprecated** - soon-to-be removed user-facing features
+- **Removed** - earlier deprecated, now removed user-facing features
+- **Security** - fixing CVEs in the gateway or dependencies
- **Miscellaneous** - all other updates like build, release, CLI, etc.
-## 0.0.2-rc.2, 0.0.2 (Feb 22nd, 2024)
+See [keepachangelog.com](https://keepachangelog.com/en/1.1.0/) for more information.
-### Features
+## [Unreleased]
-- β¨ #142: [Lang Chat Router] Ollama Support (@mkrueger12)
-- β¨ #131: [Lang Chat Router] AWS Bedrock Support (@mkrueger12)
+### Added
+
+TBU
+
+### Changed
+
+TBU
+
+### Fixed
+
+TBU
+
+### Deprecated
+
+TBU
+
+### Removed
+
+TBU
+
+### Security
+
+TBU
### Miscellaneous
-- π· #155 Fixing the dockerhub authorization step in the release workflow (@roma-glushko)
-- β»οΈ #151: Moved specific provider schemas closer to provider's packages (@roma-glushko)
+TBU
+
+## [0.0.3-rc.1] (Apr 7th, 2024)
-## 0.0.2-rc.1 (Feb 12th, 2024)
+Bringing support for streaming chat in Glide.
-### Features
+### Added
-- β¨#117 Allow to load dotenv files (@roma-glushko)
+- β¨Streaming Chat Workflow #149 #163 #161 (@roma-glushko)
+- β¨Streaming Support for Azure OpenAI #173 (@mkrueger12)
+- β¨Cohere Streaming Chat Support #171 (@mkrueger12)
+- β¨Start counting token usage in Anthropic Chat #183 (@roma-glushko)
+- β¨Handle unauthorized error in health tracker #170 (@roma-glushko)
-### Improvements
+### Fixed
-- β¨π·#91 Support for Windows (@roma-glushko)
-- π· #139 Build Glide for OpenBSD and ppc65le, s390x, riscv64 architectures (@roma-glushko)
+- π Fix Anthropic API key header #183 (@roma-glushko)
+
+### Security
+
+- π Update crypto lib, golang, fiber #148 (@roma-glushko)
### Miscellaneous
-- π· #92 Release binaries to Snapcraft (@roma-glushko)
-- π· #123 publish images to DockerHub (@roma-glushko)
-- π§ #136 Migrated all API to Fiber (@roma-glushko)
-- π· #139 Create a image tag with pure version (without distro suffix) (@roma-glushko)
+- π Update README.md to fix helm chart location #167 (@arjunnair22)
+- π§ Updated .go-version (@roma-glushko)
+- β
Covered the telemetry by tests #146 (@roma-glushko)
+- π Separate and list all supported capabilities per provider #190 (@roma-glushko)
-## 0.0.1 (Jan 31st, 2024)
+## [0.0.2-rc.2], [0.0.2] (Feb 22nd, 2024)
-### Features
+### Added
-- β¨ #81: Allow to chat message based for specific models (@mkrueger12)
+- β¨ [Lang Chat Router] Ollama Support #142 (@mkrueger12)
+- β¨ [Lang Chat Router] AWS Bedrock Support #131 (@mkrueger12)
-### Improvements
+### Miscellaneous
-- π§ #78: Normalize response latency by response token count (@roma-glushko)
-- π #112 added the CLI banner info (@roma-glushko)
+- π· Fixing the dockerhub authorization step in the release workflow #155 (@roma-glushko)
+- β»οΈ Moved specific provider schemas closer to provider's packages #151 (@roma-glushko)
+
+## [0.0.2-rc.1] (Feb 12th, 2024)
+
+### Added
+
+- β¨ Allow to load dotenv files #117 (@roma-glushko)
+
+### Changed
+
+- β¨π· Support for Windows #91 (@roma-glushko)
+- π· Build Glide for OpenBSD and ppc65le, s390x, riscv64 architectures #139 (@roma-glushko)
+
+### Miscellaneous
+
+- π· Release binaries to Snapcraft #92 (@roma-glushko)
+- π· Publish images to DockerHub #123 (@roma-glushko)
+- π§ Migrated all API to Fiber #136 (@roma-glushko)
+- π· Create a image tag with pure version (without distro suffix) #139 (@roma-glushko)
+
+## [0.0.1] (Jan 31st, 2024)
+
+### Added
+
+- β¨Allow to chat message based for specific models #81 (@mkrueger12)
+
+### Changed
+
+- π§ Normalize response latency by response token count #78 (@roma-glushko)
+- π Added the CLI banner info #112 (@roma-glushko)
### Miscellaneous
- π #114 Make links actual across the project (@roma-glushko)
-## 0.0.1-rc.2 (Jan 22nd, 2024)
+## [0.0.1-rc.2] (Jan 22nd, 2024)
-### Improvements
+### Added
- βοΈ [config] Added validation for config file content #40 (@roma-glushko)
- βοΈ [config] Allowed to pass HTTP server configs from config file #41 (@roma-glushko)
- π· [build] Allowed building Homebrew taps for release candidates #99 (@roma-glushko)
-## 0.0.1-rc.1 (Jan 21st, 2024)
+## [0.0.1-rc.1] (Jan 21st, 2024)
-### Features
+### Added
- β¨ [providers] Support for OpenAI Chat API #3 (@mkrueger12)
- β¨ [API] Unified Chat API #54 (@mkrueger12)
- β¨ [providers] Support for Cohere Chat API #5 (@mkrueger12)
- β¨ [providers] Support for Azure OpenAI Chat API #4 (@mkrueger12)
- β¨ [providers] Support for OctoML Chat API #58 (@mkrueger12)
- β¨ [routing] The Routing Mechanism, Adaptive Health Tracking, and Fallbacks #42 #43 #51 (@roma-glushko)
-- β¨ [routing] Support for round robin routing strategy #44 (@roma-glushko)
+- β¨ [routing] Support for round-robin routing strategy #44 (@roma-glushko)
- β¨ [routing] Support for the least latency routing strategy #46 (@roma-glushko)
-- β¨ [routing] Support for weighted round robin routing strategy #45 (@roma-glushko)
+- β¨ [routing] Support for weighted round-robin routing strategy #45 (@roma-glushko)
- β¨ [providers] Support for Anthropic Chat API #60 (@mkrueger12)
- β¨ [docs] OpenAPI specifications #22 (@roma-glushko)
@@ -80,5 +143,14 @@ The changelog consists of three categories:
- π§ [chores] Inited Glide's CLI #12 (@roma-glushko)
- π· [chores] Setup CI workflows #8 (@roma-glushko)
- βοΈ [config] Inited configs #11 (@roma-glushko)
-- π§ [chores] Automatic coverage reports #39 (@roma-glushko)
+- π§ [chores] Automatic coverage reports #39 (@roma-glushko)
- π· [build] Setup release workflows #9 (@roma-glushko)
+
+[unreleased]: https://github.com/olivierlacan/keep-a-changelog/compare/0.0.3-rc.1...HEAD
+[0.0.3-rc.1]: https://github.com/EinStack/glide/compare/0.0.2..0.0.3-rc.1
+[0.0.2]: https://github.com/EinStack/glide/compare/0.0.2-rc.1..0.0.2
+[0.0.2-rc.2]: https://github.com/EinStack/glide/compare/0.0.2-rc.1..0.0.2-rc.2
+[0.0.2-rc.1]: https://github.com/EinStack/glide/compare/0.0.1..0.0.2-rc.1
+[0.0.1]: https://github.com/EinStack/glide/compare/0.0.1-rc.2..0.0.1
+[0.0.1-rc.2]: https://github.com/EinStack/glide/compare/0.0.1-rc.1..0.0.1-rc.2
+[0.0.1-rc.1]: https://github.com/EinStack/glide/releases/tag/0.0.1-rc.1
diff --git a/Makefile b/Makefile
index ecf858bb..4cebf96b 100644
--- a/Makefile
+++ b/Makefile
@@ -33,7 +33,7 @@ lint: install-checkers ## Lint the source code
vuln: install-checkers ## Check for vulnerabilities
@echo "π Checking for vulnerabilities"
- @$(CHECKER_BIN)/govulncheck -test ./...
+ @#$(CHECKER_BIN)/govulncheck -test ./... enable in https://github.com/EinStack/glide/issues/169
@$(CHECKER_BIN)/gosec -quiet -exclude=G104 ./...
run: ## Run Glide
diff --git a/README.md b/README.md
index c3597081..797f7409 100644
--- a/README.md
+++ b/README.md
@@ -42,16 +42,16 @@ Check out our [documentation](https://glide.einstack.ai)!
### Large Language Models
-| | Provider | Support Status |
-|-----------------------------------------------------|---------------|-----------------|
-| | Anthropic | π Supported |
-| | Azure OpenAI | π Supported |
-| | AWS Bedrock (Titan) | π Supported |
-| | Cohere | π Supported |
-| | Google Gemini | ποΈ Coming Soon |
-| | OctoML | π Supported |
-| | Ollama | π Supported |
-| | OpenAI | π Supported |
+| Provider | Supported Capabilities |
+|-----------------------------------------------------------------------|-------------------------------------------|
+| OpenAI | β
Chat
β
Streaming Chat |
+| Anthropic | β
Chat
ποΈ Streaming Chat (coming soon) |
+| Azure OpenAI | β
Chat
ποΈ Streaming Chat (coming soon) |
+| AWS Bedrock (Titan) | β
Chat |
+| Cohere | β
Chat
ποΈ Streaming Chat (coming soon) |
+| Google Gemini | ποΈ Chat (coming soon) |
+| OctoML | β
Chat |
+| Ollama | β
Chat |
## Get Started
@@ -183,7 +183,7 @@ docker pull ghcr.io/einstack/glide:latest-redhat
Add the EinStack repository:
```bash
-helm repo add einstack https://einstack.github.io/helm-charts
+helm repo add einstack https://einstack.github.io/charts
helm repo update
```
diff --git "a/docs/api/Language API/\360\237\222\254 Chat Stream.bru" "b/docs/api/Language API/\360\237\222\254 Chat Stream.bru"
new file mode 100644
index 00000000..e544c061
--- /dev/null
+++ "b/docs/api/Language API/\360\237\222\254 Chat Stream.bru"
@@ -0,0 +1,11 @@
+meta {
+ name: π¬ Chat Stream
+ type: http
+ seq: 2
+}
+
+get {
+ url: {{base_url}}/language/default/chatStream
+ body: none
+ auth: none
+}
diff --git a/docs/api/[Lang] Chat.bru "b/docs/api/Language API/\360\237\222\254 Chat.bru"
similarity index 71%
rename from docs/api/[Lang] Chat.bru
rename to "docs/api/Language API/\360\237\222\254 Chat.bru"
index d3a31a71..6ea21147 100644
--- a/docs/api/[Lang] Chat.bru
+++ "b/docs/api/Language API/\360\237\222\254 Chat.bru"
@@ -1,11 +1,11 @@
meta {
- name: [Lang] Chat
+ name: π¬ Chat
type: http
- seq: 2
+ seq: 1
}
post {
- url: {{base_url}}/v1/language/myrouter/chat/
+ url: {{base_url}}/language/default/chat/
body: json
auth: none
}
diff --git a/docs/api/[Lang] Router List.bru "b/docs/api/Language API/\360\237\224\247 Router List.bru"
similarity index 76%
rename from docs/api/[Lang] Router List.bru
rename to "docs/api/Language API/\360\237\224\247 Router List.bru"
index 81ccec75..0545245f 100644
--- a/docs/api/[Lang] Router List.bru
+++ "b/docs/api/Language API/\360\237\224\247 Router List.bru"
@@ -1,11 +1,11 @@
meta {
- name: [Lang] Router List
+ name: π§ Router List
type: http
seq: 3
}
get {
- url: {{base_url}}/v1/language/
+ url: {{base_url}}/language/
body: json
auth: none
}
diff --git a/docs/api/environments/Development.bru b/docs/api/environments/Development.bru
index 732c80cf..a8abb8bf 100644
--- a/docs/api/environments/Development.bru
+++ b/docs/api/environments/Development.bru
@@ -1,3 +1,3 @@
vars {
- base_url: http://127.0.0.1:9099
+ base_url: http://127.0.0.1:9099/v1
}
diff --git a/docs/api/Health.bru "b/docs/api/\360\237\224\247 Health.bru"
similarity index 57%
rename from docs/api/Health.bru
rename to "docs/api/\360\237\224\247 Health.bru"
index 0486a046..df4d4d10 100644
--- a/docs/api/Health.bru
+++ "b/docs/api/\360\237\224\247 Health.bru"
@@ -1,11 +1,11 @@
meta {
- name: Health
+ name: π§ Health
type: http
seq: 1
}
get {
- url: {{base_url}}/health
+ url: {{base_url}}/health/
body: none
auth: none
}
diff --git a/docs/docs.go b/docs/docs.go
index 96ddabc5..d575a50a 100644
--- a/docs/docs.go
+++ b/docs/docs.go
@@ -72,7 +72,7 @@ const docTemplate = `{
},
"/v1/language/{router}/chat": {
"post": {
- "description": "Talk to different LLMs Chat API via unified endpoint",
+ "description": "Talk to different LLM Chat APIs via unified endpoint",
"consumes": [
"application/json"
],
@@ -123,17 +123,85 @@ const docTemplate = `{
}
}
}
+ },
+ "/v1/language/{router}/chatStream": {
+ "get": {
+ "description": "Talk to different LLM Stream Chat APIs via a unified websocket endpoint",
+ "consumes": [
+ "application/json"
+ ],
+ "tags": [
+ "Language"
+ ],
+ "summary": "Language Chat",
+ "operationId": "glide-language-chat-stream",
+ "parameters": [
+ {
+ "type": "string",
+ "description": "Router ID",
+ "name": "router",
+ "in": "path",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Websocket Connection Type",
+ "name": "Connection",
+ "in": "header",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Upgrade header",
+ "name": "Upgrade",
+ "in": "header",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Websocket Security Token",
+ "name": "Sec-WebSocket-Key",
+ "in": "header",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Websocket Security Token",
+ "name": "Sec-WebSocket-Version",
+ "in": "header",
+ "required": true
+ }
+ ],
+ "responses": {
+ "101": {
+ "description": "Switching Protocols"
+ },
+ "404": {
+ "description": "Not Found",
+ "schema": {
+ "$ref": "#/definitions/http.ErrorSchema"
+ }
+ },
+ "426": {
+ "description": "Upgrade Required"
+ }
+ }
+ }
}
},
"definitions": {
"anthropic.Config": {
"type": "object",
"required": [
+ "apiVersion",
"baseUrl",
"chatEndpoint",
"model"
],
"properties": {
+ "apiVersion": {
+ "type": "string"
+ },
"baseUrl": {
"type": "string"
},
@@ -377,7 +445,6 @@ const docTemplate = `{
"type": "boolean"
},
"stream": {
- "description": "unsupported right now",
"type": "boolean"
},
"temperature": {
@@ -742,6 +809,10 @@ const docTemplate = `{
},
"schemas.ChatMessage": {
"type": "object",
+ "required": [
+ "content",
+ "role"
+ ],
"properties": {
"content": {
"description": "The content of the message.",
@@ -759,6 +830,9 @@ const docTemplate = `{
},
"schemas.ChatRequest": {
"type": "object",
+ "required": [
+ "message"
+ ],
"properties": {
"message": {
"$ref": "#/definitions/schemas.ChatMessage"
@@ -790,7 +864,7 @@ const docTemplate = `{
"type": "string"
},
"modelResponse": {
- "$ref": "#/definitions/schemas.ProviderResponse"
+ "$ref": "#/definitions/schemas.ModelResponse"
},
"model_id": {
"type": "string"
@@ -803,18 +877,7 @@ const docTemplate = `{
}
}
},
- "schemas.OverrideChatRequest": {
- "type": "object",
- "properties": {
- "message": {
- "$ref": "#/definitions/schemas.ChatMessage"
- },
- "model_id": {
- "type": "string"
- }
- }
- },
- "schemas.ProviderResponse": {
+ "schemas.ModelResponse": {
"type": "object",
"properties": {
"message": {
@@ -831,17 +894,32 @@ const docTemplate = `{
}
}
},
+ "schemas.OverrideChatRequest": {
+ "type": "object",
+ "required": [
+ "message",
+ "model_id"
+ ],
+ "properties": {
+ "message": {
+ "$ref": "#/definitions/schemas.ChatMessage"
+ },
+ "model_id": {
+ "type": "string"
+ }
+ }
+ },
"schemas.TokenUsage": {
"type": "object",
"properties": {
"promptTokens": {
- "type": "number"
+ "type": "integer"
},
"responseTokens": {
- "type": "number"
+ "type": "integer"
},
"totalTokens": {
- "type": "number"
+ "type": "integer"
}
}
}
diff --git a/docs/swagger.json b/docs/swagger.json
index aee257b6..ce439bc8 100644
--- a/docs/swagger.json
+++ b/docs/swagger.json
@@ -69,7 +69,7 @@
},
"/v1/language/{router}/chat": {
"post": {
- "description": "Talk to different LLMs Chat API via unified endpoint",
+ "description": "Talk to different LLM Chat APIs via unified endpoint",
"consumes": [
"application/json"
],
@@ -120,17 +120,85 @@
}
}
}
+ },
+ "/v1/language/{router}/chatStream": {
+ "get": {
+ "description": "Talk to different LLM Stream Chat APIs via a unified websocket endpoint",
+ "consumes": [
+ "application/json"
+ ],
+ "tags": [
+ "Language"
+ ],
+ "summary": "Language Chat",
+ "operationId": "glide-language-chat-stream",
+ "parameters": [
+ {
+ "type": "string",
+ "description": "Router ID",
+ "name": "router",
+ "in": "path",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Websocket Connection Type",
+ "name": "Connection",
+ "in": "header",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Upgrade header",
+ "name": "Upgrade",
+ "in": "header",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Websocket Security Token",
+ "name": "Sec-WebSocket-Key",
+ "in": "header",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Websocket Security Token",
+ "name": "Sec-WebSocket-Version",
+ "in": "header",
+ "required": true
+ }
+ ],
+ "responses": {
+ "101": {
+ "description": "Switching Protocols"
+ },
+ "404": {
+ "description": "Not Found",
+ "schema": {
+ "$ref": "#/definitions/http.ErrorSchema"
+ }
+ },
+ "426": {
+ "description": "Upgrade Required"
+ }
+ }
+ }
}
},
"definitions": {
"anthropic.Config": {
"type": "object",
"required": [
+ "apiVersion",
"baseUrl",
"chatEndpoint",
"model"
],
"properties": {
+ "apiVersion": {
+ "type": "string"
+ },
"baseUrl": {
"type": "string"
},
@@ -374,7 +442,6 @@
"type": "boolean"
},
"stream": {
- "description": "unsupported right now",
"type": "boolean"
},
"temperature": {
@@ -739,6 +806,10 @@
},
"schemas.ChatMessage": {
"type": "object",
+ "required": [
+ "content",
+ "role"
+ ],
"properties": {
"content": {
"description": "The content of the message.",
@@ -756,6 +827,9 @@
},
"schemas.ChatRequest": {
"type": "object",
+ "required": [
+ "message"
+ ],
"properties": {
"message": {
"$ref": "#/definitions/schemas.ChatMessage"
@@ -787,7 +861,7 @@
"type": "string"
},
"modelResponse": {
- "$ref": "#/definitions/schemas.ProviderResponse"
+ "$ref": "#/definitions/schemas.ModelResponse"
},
"model_id": {
"type": "string"
@@ -800,18 +874,7 @@
}
}
},
- "schemas.OverrideChatRequest": {
- "type": "object",
- "properties": {
- "message": {
- "$ref": "#/definitions/schemas.ChatMessage"
- },
- "model_id": {
- "type": "string"
- }
- }
- },
- "schemas.ProviderResponse": {
+ "schemas.ModelResponse": {
"type": "object",
"properties": {
"message": {
@@ -828,17 +891,32 @@
}
}
},
+ "schemas.OverrideChatRequest": {
+ "type": "object",
+ "required": [
+ "message",
+ "model_id"
+ ],
+ "properties": {
+ "message": {
+ "$ref": "#/definitions/schemas.ChatMessage"
+ },
+ "model_id": {
+ "type": "string"
+ }
+ }
+ },
"schemas.TokenUsage": {
"type": "object",
"properties": {
"promptTokens": {
- "type": "number"
+ "type": "integer"
},
"responseTokens": {
- "type": "number"
+ "type": "integer"
},
"totalTokens": {
- "type": "number"
+ "type": "integer"
}
}
}
diff --git a/docs/swagger.yaml b/docs/swagger.yaml
index d5fb088f..00e800eb 100644
--- a/docs/swagger.yaml
+++ b/docs/swagger.yaml
@@ -2,6 +2,8 @@ basePath: /
definitions:
anthropic.Config:
properties:
+ apiVersion:
+ type: string
baseUrl:
type: string
chatEndpoint:
@@ -11,6 +13,7 @@ definitions:
model:
type: string
required:
+ - apiVersion
- baseUrl
- chatEndpoint
- model
@@ -171,7 +174,6 @@ definitions:
search_queries_only:
type: boolean
stream:
- description: unsupported right now
type: boolean
temperature:
type: number
@@ -428,6 +430,9 @@ definitions:
description: The role of the author of this message. One of system, user,
or assistant.
type: string
+ required:
+ - content
+ - role
type: object
schemas.ChatRequest:
properties:
@@ -439,6 +444,8 @@ definitions:
type: array
override:
$ref: '#/definitions/schemas.OverrideChatRequest'
+ required:
+ - message
type: object
schemas.ChatResponse:
properties:
@@ -453,20 +460,13 @@ definitions:
model_id:
type: string
modelResponse:
- $ref: '#/definitions/schemas.ProviderResponse'
+ $ref: '#/definitions/schemas.ModelResponse'
provider:
type: string
router:
type: string
type: object
- schemas.OverrideChatRequest:
- properties:
- message:
- $ref: '#/definitions/schemas.ChatMessage'
- model_id:
- type: string
- type: object
- schemas.ProviderResponse:
+ schemas.ModelResponse:
properties:
message:
$ref: '#/definitions/schemas.ChatMessage'
@@ -477,14 +477,24 @@ definitions:
tokenCount:
$ref: '#/definitions/schemas.TokenUsage'
type: object
+ schemas.OverrideChatRequest:
+ properties:
+ message:
+ $ref: '#/definitions/schemas.ChatMessage'
+ model_id:
+ type: string
+ required:
+ - message
+ - model_id
+ type: object
schemas.TokenUsage:
properties:
promptTokens:
- type: number
+ type: integer
responseTokens:
- type: number
+ type: integer
totalTokens:
- type: number
+ type: integer
type: object
externalDocs:
description: Documentation
@@ -538,7 +548,7 @@ paths:
post:
consumes:
- application/json
- description: Talk to different LLMs Chat API via unified endpoint
+ description: Talk to different LLM Chat APIs via unified endpoint
operationId: glide-language-chat
parameters:
- description: Router ID
@@ -570,6 +580,51 @@ paths:
summary: Language Chat
tags:
- Language
+ /v1/language/{router}/chatStream:
+ get:
+ consumes:
+ - application/json
+ description: Talk to different LLM Stream Chat APIs via a unified websocket
+ endpoint
+ operationId: glide-language-chat-stream
+ parameters:
+ - description: Router ID
+ in: path
+ name: router
+ required: true
+ type: string
+ - description: Websocket Connection Type
+ in: header
+ name: Connection
+ required: true
+ type: string
+ - description: Upgrade header
+ in: header
+ name: Upgrade
+ required: true
+ type: string
+ - description: Websocket Security Token
+ in: header
+ name: Sec-WebSocket-Key
+ required: true
+ type: string
+ - description: Websocket Security Token
+ in: header
+ name: Sec-WebSocket-Version
+ required: true
+ type: string
+ responses:
+ "101":
+ description: Switching Protocols
+ "404":
+ description: Not Found
+ schema:
+ $ref: '#/definitions/http.ErrorSchema'
+ "426":
+ description: Upgrade Required
+ summary: Language Chat
+ tags:
+ - Language
schemes:
- http
swagger: "2.0"
diff --git a/go.mod b/go.mod
index 73a99abe..c90e943c 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module glide
-go 1.21.5
+go 1.22.1
require (
github.com/aws/aws-sdk-go-v2 v1.24.1
@@ -10,9 +10,11 @@ require (
github.com/go-playground/validator/v10 v10.17.0
github.com/gofiber/contrib/fiberzap/v2 v2.1.2
github.com/gofiber/contrib/swagger v1.1.1
- github.com/gofiber/fiber/v2 v2.52.0
+ github.com/gofiber/contrib/websocket v1.3.0
+ github.com/gofiber/fiber/v2 v2.52.2
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
+ github.com/r3labs/sse/v2 v2.10.0
github.com/spf13/cobra v1.8.0
github.com/stretchr/testify v1.8.4
github.com/swaggo/swag v1.16.2
@@ -24,7 +26,7 @@ require (
require (
github.com/KyleBanks/depth v1.2.1 // indirect
- github.com/andybalholm/brotli v1.0.5 // indirect
+ github.com/andybalholm/brotli v1.1.0 // indirect
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.4 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 // indirect
@@ -38,6 +40,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 // indirect
github.com/aws/smithy-go v1.19.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/fasthttp/websocket v1.5.7 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/go-openapi/analysis v0.21.4 // indirect
github.com/go-openapi/errors v0.20.3 // indirect
@@ -53,7 +56,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
- github.com/klauspost/compress v1.17.0 // indirect
+ github.com/klauspost/compress v1.17.6 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
@@ -62,17 +65,19 @@ require (
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
- github.com/rivo/uniseg v0.2.0 // indirect
+ github.com/rivo/uniseg v0.4.7 // indirect
+ github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
- github.com/valyala/fasthttp v1.51.0 // indirect
+ github.com/valyala/fasthttp v1.52.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect
go.mongodb.org/mongo-driver v1.11.3 // indirect
- golang.org/x/crypto v0.17.0 // indirect
- golang.org/x/net v0.19.0 // indirect
- golang.org/x/sys v0.15.0 // indirect
+ golang.org/x/crypto v0.21.0 // indirect
+ golang.org/x/net v0.21.0 // indirect
+ golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.16.1 // indirect
+ gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
diff --git a/go.sum b/go.sum
index a64851f6..c59f3570 100644
--- a/go.sum
+++ b/go.sum
@@ -3,8 +3,8 @@ github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc
github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE=
github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
-github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs=
-github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
+github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
+github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so=
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
@@ -43,6 +43,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/fasthttp/websocket v1.5.7 h1:0a6o2OfeATvtGgoMKleURhLT6JqWPg7fYfWnH4KHau4=
+github.com/fasthttp/websocket v1.5.7/go.mod h1:bC4fxSono9czeXHQUVKxsC0sNjbm7lPJR04GDFqClfU=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/go-openapi/analysis v0.21.2/go.mod h1:HZwRk4RRisyG8vx2Oe6aqeSQcoxRp47Xkp3+K6q+LdY=
@@ -119,8 +121,10 @@ github.com/gofiber/contrib/fiberzap/v2 v2.1.2 h1:7Z1BqS1sYK9e9jTwqPcWx9qQt46PI8o
github.com/gofiber/contrib/fiberzap/v2 v2.1.2/go.mod h1:ulCCQOdDYABGsOQfbndASmCsCN86hsC96iKoOTNYfy8=
github.com/gofiber/contrib/swagger v1.1.1 h1:on+D2fbXkvm0H0lur1rx69mpxLdX1wIH/FrTRZ99b9Y=
github.com/gofiber/contrib/swagger v1.1.1/go.mod h1:pa9awsFSz/3BbSnyTe/drNZaiFfnhC4hk3m9BVet7Co=
-github.com/gofiber/fiber/v2 v2.52.0 h1:S+qXi7y+/Pgvqq4DrSmREGiFwtB7Bu6+QFLuIHYw/UE=
-github.com/gofiber/fiber/v2 v2.52.0/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ=
+github.com/gofiber/contrib/websocket v1.3.0 h1:XADFAGorer1VJ1bqC4UkCjqS37kwRTV0415+050NrMk=
+github.com/gofiber/contrib/websocket v1.3.0/go.mod h1:xguaOzn2ZZ759LavtosEP+rcxIgBEE/rdumPINhR+Xo=
+github.com/gofiber/fiber/v2 v2.52.2 h1:b0rYH6b06Df+4NyrbdptQL8ifuxw/Tf2DgfkZkDaxEo=
+github.com/gofiber/fiber/v2 v2.52.2/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
@@ -139,8 +143,8 @@ github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFF
github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4=
github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA=
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
-github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM=
-github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
+github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
+github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@@ -180,14 +184,19 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
+github.com/r3labs/sse/v2 v2.10.0 h1:hFEkLLFY4LDifoHdiCN/LlGBAdVJYsANaLqNYa1l/v0=
+github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEmkNJ7I=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
+github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
+github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
+github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk=
+github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g=
github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
@@ -217,8 +226,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
-github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
-github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g=
+github.com/valyala/fasthttp v1.52.0 h1:wqBQpxH71XW0e2g+Og4dzQM8pk34aFYlA1Ga8db7gU0=
+github.com/valyala/fasthttp v1.52.0/go.mod h1:hf5C4QnVMkNXMspnsUlfM3WitlgYflyhHYoKol/szxQ=
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
@@ -243,16 +252,17 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
-golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
-golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
+golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
+golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
-golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
-golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
+golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
+golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -271,8 +281,8 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
-golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
+golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -288,6 +298,8 @@ golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgw
golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA=
golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y=
+gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
diff --git a/images/alpine.Dockerfile b/images/alpine.Dockerfile
index 1454c0a7..26b20357 100644
--- a/images/alpine.Dockerfile
+++ b/images/alpine.Dockerfile
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1
-FROM golang:1.21-alpine as build
+FROM golang:1.22-alpine as build
ARG VERSION
ARG COMMIT
diff --git a/images/distroless.Dockerfile b/images/distroless.Dockerfile
index 34776aca..23167563 100644
--- a/images/distroless.Dockerfile
+++ b/images/distroless.Dockerfile
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1
-FROM golang:1.21-alpine as build
+FROM golang:1.22-alpine as build
ARG VERSION
ARG COMMIT
diff --git a/images/redhat.Dockerfile b/images/redhat.Dockerfile
index f55c9534..c3a16248 100644
--- a/images/redhat.Dockerfile
+++ b/images/redhat.Dockerfile
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1
-FROM golang:1.21-alpine as build
+FROM golang:1.22-alpine as build
ARG VERSION
ARG COMMIT
diff --git a/images/ubuntu.Dockerfile b/images/ubuntu.Dockerfile
index 7db2cb62..fe238b79 100644
--- a/images/ubuntu.Dockerfile
+++ b/images/ubuntu.Dockerfile
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1
-FROM golang:1.21-alpine as build
+FROM golang:1.22-alpine as build
ARG VERSION
ARG COMMIT
diff --git a/pkg/api/http/config.go b/pkg/api/http/config.go
index 8e261098..e9098266 100644
--- a/pkg/api/http/config.go
+++ b/pkg/api/http/config.go
@@ -41,7 +41,7 @@ func (cfg *ServerConfig) ToServer() *fiber.App {
// More configs are listed on https://docs.gofiber.io/api/fiber
// TODO: Consider alternative JSON marshallers that provides better performance over the standard marshaller
serverConfig := fiber.Config{
- AppName: "glide",
+ AppName: "Glide",
DisableDefaultDate: true,
ServerHeader: fmt.Sprintf("glide/%v", version.Version),
StreamRequestBody: true,
diff --git a/pkg/api/http/config_test.go b/pkg/api/http/config_test.go
new file mode 100644
index 00000000..8d8667a6
--- /dev/null
+++ b/pkg/api/http/config_test.go
@@ -0,0 +1,14 @@
+package http
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestHTTPConfig_DefaultConfig(t *testing.T) {
+ config := DefaultServerConfig()
+
+ require.NotNil(t, config.Address())
+ require.NotNil(t, config.ToServer())
+}
diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go
index c97f542b..5e325ad0 100644
--- a/pkg/api/http/handlers.go
+++ b/pkg/api/http/handlers.go
@@ -1,10 +1,15 @@
package http
import (
+ "context"
"errors"
+ "sync"
- "github.com/gofiber/fiber/v2"
+ "glide/pkg/telemetry"
+ "go.uber.org/zap"
+ "github.com/gofiber/contrib/websocket"
+ "github.com/gofiber/fiber/v2"
"glide/pkg/api/schemas"
"glide/pkg/routers"
)
@@ -18,7 +23,7 @@ type Handler = func(c *fiber.Ctx) error
//
// @id glide-language-chat
// @Summary Language Chat
-// @Description Talk to different LLMs Chat API via unified endpoint
+// @Description Talk to different LLM Chat APIs via unified endpoint
// @tags Language
// @Param router path string true "Router ID"
// @Param payload body schemas.ChatRequest true "Request Data"
@@ -30,6 +35,12 @@ type Handler = func(c *fiber.Ctx) error
// @Router /v1/language/{router}/chat [POST]
func LangChatHandler(routerManager *routers.RouterManager) Handler {
return func(c *fiber.Ctx) error {
+ if !c.Is("json") {
+ return c.Status(fiber.StatusBadRequest).JSON(ErrorSchema{
+ Message: "Glide accepts only JSON payloads",
+ })
+ }
+
// Unmarshal request body
var req *schemas.ChatRequest
@@ -65,6 +76,106 @@ func LangChatHandler(routerManager *routers.RouterManager) Handler {
}
}
+func LangStreamRouterValidator(routerManager *routers.RouterManager) Handler {
+ return func(c *fiber.Ctx) error {
+ if websocket.IsWebSocketUpgrade(c) {
+ routerID := c.Params("router")
+
+ _, err := routerManager.GetLangRouter(routerID)
+ if err != nil {
+ return c.Status(fiber.StatusNotFound).JSON(ErrorSchema{
+ Message: err.Error(),
+ })
+ }
+
+ return c.Next()
+ }
+
+ return fiber.ErrUpgradeRequired
+ }
+}
+
+// LangStreamChatHandler
+//
+// @id glide-language-chat-stream
+// @Summary Language Chat
+// @Description Talk to different LLM Stream Chat APIs via a unified websocket endpoint
+// @tags Language
+// @Param router path string true "Router ID"
+// @Param Connection header string true "Websocket Connection Type"
+// @Param Upgrade header string true "Upgrade header"
+// @Param Sec-WebSocket-Key header string true "Websocket Security Token"
+// @Param Sec-WebSocket-Version header string true "Websocket Security Token"
+// @Accept json
+// @Success 101
+// @Failure 426
+// @Failure 404 {object} http.ErrorSchema
+// @Router /v1/language/{router}/chatStream [GET]
+func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.RouterManager) Handler {
+ // TODO: expose websocket connection configs https://github.com/gofiber/contrib/tree/main/websocket
+ return websocket.New(func(c *websocket.Conn) {
+ routerID := c.Params("router")
+ // websocket.Conn bindings https://pkg.go.dev/github.com/fasthttp/websocket?tab=doc#pkg-index
+
+ var (
+ err error
+ wg sync.WaitGroup
+ )
+
+ chunkResultC := make(chan *schemas.ChatStreamResult)
+
+ router, _ := routerManager.GetLangRouter(routerID)
+
+ defer close(chunkResultC)
+ defer c.Conn.Close()
+
+ wg.Add(1)
+
+ go func() {
+ defer wg.Done()
+
+ for chunkResult := range chunkResultC {
+ if chunkResult.Error() != nil {
+ if err = c.WriteJSON(chunkResult.Error()); err != nil {
+ break
+ }
+
+ continue
+ }
+
+ if err = c.WriteJSON(chunkResult.Chunk()); err != nil {
+ break
+ }
+ }
+ }()
+
+ for {
+ var chatRequest schemas.ChatStreamRequest
+
+ if err = c.ReadJSON(&chatRequest); err != nil {
+ if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
+ tel.L().Warn("Streaming Chat connection is closed", zap.Error(err), zap.String("routerID", routerID))
+ }
+
+ tel.L().Debug("Streaming chat connection is closed by client", zap.Error(err), zap.String("routerID", routerID))
+
+ break
+ }
+
+ // TODO: handle termination gracefully
+ wg.Add(1)
+
+ go func(chatRequest schemas.ChatStreamRequest) {
+ defer wg.Done()
+
+ router.ChatStream(context.Background(), &chatRequest, chunkResultC)
+ }(chatRequest)
+ }
+
+ wg.Wait()
+ })
+}
+
// LangRoutersHandler
//
// @id glide-language-routers
diff --git a/pkg/api/http/server.go b/pkg/api/http/server.go
index d44beac1..e49c8fec 100644
--- a/pkg/api/http/server.go
+++ b/pkg/api/http/server.go
@@ -41,7 +41,7 @@ func (srv *Server) Run() error {
Title: "Glide API Docs",
BasePath: "/v1/",
Path: "swagger",
- FilePath: "./docs/swagger.json",
+ FilePath: "./docs/swagger.yaml",
}))
srv.server.Use(fiberzap.New(fiberzap.Config{
@@ -53,6 +53,9 @@ func (srv *Server) Run() error {
v1.Get("/language/", LangRoutersHandler(srv.routerManager))
v1.Post("/language/:router/chat/", LangChatHandler(srv.routerManager))
+ v1.Use("/language/:router/chatStream", LangStreamRouterValidator(srv.routerManager))
+ v1.Get("/language/:router/chatStream", LangStreamChatHandler(srv.telemetry, srv.routerManager))
+
v1.Get("/health/", HealthHandler)
srv.server.Use(NotFoundHandler)
diff --git a/pkg/api/schemas/chat.go b/pkg/api/schemas/chat.go
new file mode 100644
index 00000000..4be88692
--- /dev/null
+++ b/pkg/api/schemas/chat.go
@@ -0,0 +1,60 @@
+package schemas
+
+// ChatRequest defines Glide's Chat Request Schema unified across all language models
+type ChatRequest struct {
+ Message ChatMessage `json:"message" validate:"required"`
+ MessageHistory []ChatMessage `json:"messageHistory"`
+ Override *OverrideChatRequest `json:"override,omitempty"`
+}
+
+type OverrideChatRequest struct {
+ Model string `json:"model_id" validate:"required"`
+ Message ChatMessage `json:"message" validate:"required"`
+}
+
+func NewChatFromStr(message string) *ChatRequest {
+ return &ChatRequest{
+ Message: ChatMessage{
+ "user",
+ message,
+ "glide",
+ },
+ }
+}
+
+// ChatResponse defines Glide's Chat Response Schema unified across all language models
+type ChatResponse struct {
+ ID string `json:"id,omitempty"`
+ Created int `json:"created,omitempty"`
+ Provider string `json:"provider,omitempty"`
+ RouterID string `json:"router,omitempty"`
+ ModelID string `json:"model_id,omitempty"`
+ ModelName string `json:"model,omitempty"`
+ Cached bool `json:"cached,omitempty"`
+ ModelResponse ModelResponse `json:"modelResponse,omitempty"`
+}
+
+// ModelResponse is the unified response from the provider.
+
+type ModelResponse struct {
+ SystemID map[string]string `json:"responseId,omitempty"`
+ Message ChatMessage `json:"message"`
+ TokenUsage TokenUsage `json:"tokenCount"`
+}
+
+type TokenUsage struct {
+ PromptTokens int `json:"promptTokens"`
+ ResponseTokens int `json:"responseTokens"`
+ TotalTokens int `json:"totalTokens"`
+}
+
+// ChatMessage is a message in a chat request.
+type ChatMessage struct {
+ // The role of the author of this message. One of system, user, or assistant.
+ Role string `json:"role" validate:"required"`
+ // The content of the message.
+ Content string `json:"content" validate:"required"`
+ // The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores,
+ // with a maximum length of 64 characters.
+ Name string `json:"name,omitempty"`
+}
diff --git a/pkg/api/schemas/chat_stream.go b/pkg/api/schemas/chat_stream.go
new file mode 100644
index 00000000..d77310c7
--- /dev/null
+++ b/pkg/api/schemas/chat_stream.go
@@ -0,0 +1,80 @@
+package schemas
+
+type (
+ Metadata = map[string]any
+ FinishReason = string
+)
+
+var Complete FinishReason = "complete"
+
+// ChatStreamRequest defines a message that requests a new streaming chat
+type ChatStreamRequest struct {
+ ID string `json:"id" validate:"required"`
+ Message ChatMessage `json:"message" validate:"required"`
+ MessageHistory []ChatMessage `json:"messageHistory" validate:"required"`
+ Override *OverrideChatRequest `json:"overrideMessage,omitempty"`
+ Metadata *Metadata `json:"metadata,omitempty"`
+}
+
+func NewChatStreamFromStr(message string) *ChatStreamRequest {
+ return &ChatStreamRequest{
+ Message: ChatMessage{
+ "user",
+ message,
+ "glide",
+ },
+ }
+}
+
+type ModelChunkResponse struct {
+ Metadata *Metadata `json:"metadata,omitempty"`
+ Message ChatMessage `json:"message"`
+ FinishReason *FinishReason `json:"finishReason,omitempty"`
+}
+
+// ChatStreamChunk defines a message for a chunk of streaming chat response
+type ChatStreamChunk struct {
+ ID string `json:"id"`
+ CreatedAt int `json:"createdAt"`
+ Provider string `json:"providerId"`
+ RouterID string `json:"routerId"`
+ ModelID string `json:"modelId"`
+ Cached bool `json:"cached"`
+ ModelName string `json:"modelName"`
+ Metadata *Metadata `json:"metadata,omitempty"`
+ ModelResponse ModelChunkResponse `json:"modelResponse"`
+}
+
+type ChatStreamError struct {
+ ID string `json:"id"`
+ ErrCode string `json:"errCode"`
+ Message string `json:"message"`
+ Metadata *Metadata `json:"metadata,omitempty"`
+}
+
+type ChatStreamResult struct {
+ chunk *ChatStreamChunk
+ err *ChatStreamError
+}
+
+func (r *ChatStreamResult) Chunk() *ChatStreamChunk {
+ return r.chunk
+}
+
+func (r *ChatStreamResult) Error() *ChatStreamError {
+ return r.err
+}
+
+func NewChatStreamResult(chunk *ChatStreamChunk) *ChatStreamResult {
+ return &ChatStreamResult{
+ chunk: chunk,
+ err: nil,
+ }
+}
+
+func NewChatStreamErrorResult(err *ChatStreamError) *ChatStreamResult {
+ return &ChatStreamResult{
+ chunk: nil,
+ err: err,
+ }
+}
diff --git a/pkg/api/schemas/language.go b/pkg/api/schemas/language.go
deleted file mode 100644
index 7e2a2cdc..00000000
--- a/pkg/api/schemas/language.go
+++ /dev/null
@@ -1,60 +0,0 @@
-package schemas
-
-// ChatRequest defines Glide's Chat Request Schema unified across all language models
-type ChatRequest struct {
- Message ChatMessage `json:"message"`
- MessageHistory []ChatMessage `json:"messageHistory"`
- Override OverrideChatRequest `json:"override,omitempty"`
-}
-
-type OverrideChatRequest struct {
- Model string `json:"model_id"`
- Message ChatMessage `json:"message"`
-}
-
-func NewChatFromStr(message string) *ChatRequest {
- return &ChatRequest{
- Message: ChatMessage{
- "human",
- message,
- "roma",
- },
- }
-}
-
-// ChatResponse defines Glide's Chat Response Schema unified across all language models
-type ChatResponse struct {
- ID string `json:"id,omitempty"`
- Created int `json:"created,omitempty"`
- Provider string `json:"provider,omitempty"`
- RouterID string `json:"router,omitempty"`
- ModelID string `json:"model_id,omitempty"`
- Model string `json:"model,omitempty"`
- Cached bool `json:"cached,omitempty"`
- ModelResponse ProviderResponse `json:"modelResponse,omitempty"`
-}
-
-// ProviderResponse is the unified response from the provider.
-
-type ProviderResponse struct {
- SystemID map[string]string `json:"responseId,omitempty"`
- Message ChatMessage `json:"message"`
- TokenUsage TokenUsage `json:"tokenCount"`
-}
-
-type TokenUsage struct {
- PromptTokens float64 `json:"promptTokens"`
- ResponseTokens float64 `json:"responseTokens"`
- TotalTokens float64 `json:"totalTokens"`
-}
-
-// ChatMessage is a message in a chat request.
-type ChatMessage struct {
- // The role of the author of this message. One of system, user, or assistant.
- Role string `json:"role"`
- // The content of the message.
- Content string `json:"content"`
- // The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores,
- // with a maximum length of 64 characters.
- Name string `json:"name,omitempty"`
-}
diff --git a/pkg/cmd/cli.go b/pkg/cmd/cli.go
index 154b83ca..60068884 100644
--- a/pkg/cmd/cli.go
+++ b/pkg/cmd/cli.go
@@ -49,7 +49,7 @@ func NewCLI() *cobra.Command {
if err != nil {
log.Println("β οΈfailed to load dotenv file: ", err) // don't have an inited logger at this moment
} else {
- log.Printf("π§dot env file loaded (%v)", dotEnvFile)
+ log.Printf("π§dot env file is loaded (%v)", dotEnvFile)
}
_, err = configProvider.Load(cfgFile)
diff --git a/pkg/config/expander_test.go b/pkg/config/expander_test.go
index f5d2930f..5686ad4b 100644
--- a/pkg/config/expander_test.go
+++ b/pkg/config/expander_test.go
@@ -62,3 +62,24 @@ func TestExpander_EnvVarExpanded(t *testing.T) {
assert.Equal(t, topP, cfg.Params[0].Value)
assert.Equal(t, fmt.Sprintf("$%v", budget), cfg.Params[1].Value)
}
+
+func TestExpander_FileContentExpanded(t *testing.T) {
+ content, err := os.ReadFile(filepath.Clean(filepath.Join(".", "testdata", "expander.file.yaml")))
+ require.NoError(t, err)
+
+ expander := Expander{}
+ updatedContent := string(expander.Expand(content))
+
+ require.NotContains(t, updatedContent, "${file:")
+ require.Contains(t, updatedContent, "sk-fakeapi-token")
+}
+
+func TestExpander_FileDoesntExist(t *testing.T) {
+ content, err := os.ReadFile(filepath.Clean(filepath.Join(".", "testdata", "expander.file.notfound.yaml")))
+ require.NoError(t, err)
+
+ expander := Expander{}
+ updatedContent := string(expander.Expand(content))
+
+ require.NotContains(t, updatedContent, "${file:")
+}
diff --git a/pkg/config/testdata/expander.file.notfound.yaml b/pkg/config/testdata/expander.file.notfound.yaml
new file mode 100644
index 00000000..4a61af87
--- /dev/null
+++ b/pkg/config/testdata/expander.file.notfound.yaml
@@ -0,0 +1,6 @@
+name: "OpenAI"
+api_key: "${file:./testdata/doesntexist}"
+
+params:
+ - name: budget
+ value: "$$${file:./testdata/doesntexist}"
diff --git a/pkg/config/testdata/expander.file.yaml b/pkg/config/testdata/expander.file.yaml
new file mode 100644
index 00000000..30dc7dc8
--- /dev/null
+++ b/pkg/config/testdata/expander.file.yaml
@@ -0,0 +1,6 @@
+name: "OpenAI"
+api_key: "${file:./testdata/openai_key}"
+
+params:
+ - name: budget
+ value: "$$${file:./testdata/openai_key}"
diff --git a/pkg/config/testdata/openai_key b/pkg/config/testdata/openai_key
new file mode 100644
index 00000000..607f1d9a
--- /dev/null
+++ b/pkg/config/testdata/openai_key
@@ -0,0 +1 @@
+sk-fakeapi-token
diff --git a/pkg/gateway.go b/pkg/gateway.go
index a27eddbd..6d2a6092 100644
--- a/pkg/gateway.go
+++ b/pkg/gateway.go
@@ -25,8 +25,8 @@ import (
type Gateway struct {
// configProvider holds all configurations
configProvider *config.Provider
- // telemetry holds logger, meter, and tracer
- telemetry *telemetry.Telemetry
+ // tel holds logger, meter, and tracer
+ tel *telemetry.Telemetry
// serverManager controls API over different protocols
serverManager *api.ServerManager
// signalChannel is used to receive termination signals from the OS.
@@ -43,8 +43,8 @@ func NewGateway(configProvider *config.Provider) (*Gateway, error) {
return nil, err
}
- tel.Logger.Info("π¦Glide is starting up", zap.String("version", version.FullVersion))
- tel.Logger.Debug("β
config loaded successfully:\n" + configProvider.GetStr())
+ tel.L().Info("π¦Glide is starting up", zap.String("version", version.FullVersion))
+ tel.L().Debug("β
config loaded successfully:\n" + configProvider.GetStr())
routerManager, err := routers.NewManager(&cfg.Routers, tel)
if err != nil {
@@ -58,7 +58,7 @@ func NewGateway(configProvider *config.Provider) (*Gateway, error) {
return &Gateway{
configProvider: configProvider,
- telemetry: tel,
+ tel: tel,
serverManager: serverManager,
signalC: make(chan os.Signal, 3), // equal to number of signal types we expect to receive
shutdownC: make(chan struct{}),
@@ -78,13 +78,13 @@ LOOP:
select {
// TODO: Watch for config updates
case sig := <-gw.signalC:
- gw.telemetry.Logger.Info("received signal from os", zap.String("signal", sig.String()))
+ gw.tel.L().Info("received signal from os", zap.String("signal", sig.String()))
break LOOP
case <-gw.shutdownC:
- gw.telemetry.Logger.Info("received shutdown request")
+ gw.tel.L().Info("received shutdown request")
break LOOP
case <-ctx.Done():
- gw.telemetry.Logger.Info("context done, terminating process")
+ gw.tel.L().Info("context done, terminating process")
// Call shutdown with background context as the passed in context has been canceled
return gw.shutdown(context.Background()) //nolint:contextcheck
}
diff --git a/pkg/providers/anthropic/chat.go b/pkg/providers/anthropic/chat.go
index 5a8d8ee3..dc8f857d 100644
--- a/pkg/providers/anthropic/chat.go
+++ b/pkg/providers/anthropic/chat.go
@@ -9,8 +9,6 @@ import (
"net/http"
"time"
- "glide/pkg/providers/clients"
-
"glide/pkg/api/schemas"
"go.uber.org/zap"
)
@@ -63,6 +61,8 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessa
}
// Chat sends a chat request to the specified anthropic model.
+//
+// Ref: https://docs.anthropic.com/claude/reference/messages_post
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
// Create a new chat request
chatRequest := c.createChatRequestSchema(request)
@@ -72,10 +72,6 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem
return nil, err
}
- if len(chatResponse.ModelResponse.Message.Content) == 0 {
- return nil, ErrEmptyResponse
- }
-
return chatResponse, nil
}
@@ -99,12 +95,13 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
return nil, fmt.Errorf("unable to create anthropic chat request: %w", err)
}
- req.Header.Set("Authorization", "Bearer "+string(c.config.APIKey))
+ req.Header.Set("x-api-key", string(c.config.APIKey)) // must be in lower case
+ req.Header.Set("anthropic-version", c.apiVersion)
req.Header.Set("Content-Type", "application/json")
// TODO: this could leak information from messages which may not be a desired thing to have
- c.telemetry.Logger.Debug(
- "anthropic chat request",
+ c.tel.L().Debug(
+ "Anthropic chat request",
zap.String("chat_url", c.chatURL),
zap.Any("payload", payload),
)
@@ -117,71 +114,49 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- c.telemetry.Logger.Error("failed to read anthropic chat response", zap.Error(err))
- }
-
- c.telemetry.Logger.Error(
- "anthropic chat request failed",
- zap.Int("status_code", resp.StatusCode),
- zap.String("response", string(bodyBytes)),
- zap.Any("headers", resp.Header),
- )
-
- if resp.StatusCode == http.StatusTooManyRequests {
- // Read the value of the "Retry-After" header to get the cooldown delay
- retryAfter := resp.Header.Get("Retry-After")
-
- // Parse the value to get the duration
- cooldownDelay, err := time.ParseDuration(retryAfter)
- if err != nil {
- return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err)
- }
-
- return nil, clients.NewRateLimitError(&cooldownDelay)
- }
-
- // Server & client errors result in the same error to keep gateway resilient
- return nil, clients.ErrProviderUnavailable
+ return nil, c.errMapper.Map(resp)
}
// Read the response body into a byte slice
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
- c.telemetry.Logger.Error("failed to read anthropic chat response", zap.Error(err))
+ c.tel.L().Error("Failed to read anthropic chat response", zap.Error(err))
return nil, err
}
// Parse the response JSON
- var anthropicCompletion ChatCompletion
+ var anthropicResponse ChatCompletion
- err = json.Unmarshal(bodyBytes, &anthropicCompletion)
+ err = json.Unmarshal(bodyBytes, &anthropicResponse)
if err != nil {
- c.telemetry.Logger.Error("failed to parse anthropic chat response", zap.Error(err))
+ c.tel.L().Error("Failed to parse anthropic chat response", zap.Error(err))
return nil, err
}
+ if len(anthropicResponse.Content) == 0 {
+ return nil, ErrEmptyResponse
+ }
+
+ completion := anthropicResponse.Content[0]
+ usage := anthropicResponse.Usage
+
// Map response to ChatResponse schema
response := schemas.ChatResponse{
- ID: anthropicCompletion.ID,
- Created: int(time.Now().UTC().Unix()), // not provided by anthropic
- Provider: providerName,
- Model: anthropicCompletion.Model,
- Cached: false,
- ModelResponse: schemas.ProviderResponse{
- SystemID: map[string]string{
- "system_fingerprint": anthropicCompletion.ID,
- },
+ ID: anthropicResponse.ID,
+ Created: int(time.Now().UTC().Unix()), // not provided by anthropic
+ Provider: providerName,
+ ModelName: anthropicResponse.Model,
+ Cached: false,
+ ModelResponse: schemas.ModelResponse{
+ SystemID: map[string]string{},
Message: schemas.ChatMessage{
- Role: anthropicCompletion.Content[0].Type,
- Content: anthropicCompletion.Content[0].Text,
- Name: "",
+ Role: completion.Type,
+ Content: completion.Text,
},
TokenUsage: schemas.TokenUsage{
- PromptTokens: 0, // Anthropic doesn't send prompt tokens
- ResponseTokens: 0,
- TotalTokens: 0,
+ PromptTokens: usage.InputTokens,
+ ResponseTokens: usage.OutputTokens,
+ TotalTokens: usage.InputTokens + usage.OutputTokens,
},
},
}
diff --git a/pkg/providers/anthropic/chat_stream.go b/pkg/providers/anthropic/chat_stream.go
new file mode 100644
index 00000000..d5a31bc0
--- /dev/null
+++ b/pkg/providers/anthropic/chat_stream.go
@@ -0,0 +1,16 @@
+package anthropic
+
+import (
+ "context"
+
+ "glide/pkg/api/schemas"
+ "glide/pkg/providers/clients"
+)
+
+func (c *Client) SupportChatStream() bool {
+ return false
+}
+
+func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
+ return nil, clients.ErrChatStreamNotImplemented
+}
diff --git a/pkg/providers/anthropic/client.go b/pkg/providers/anthropic/client.go
index c7131455..1af21262 100644
--- a/pkg/providers/anthropic/client.go
+++ b/pkg/providers/anthropic/client.go
@@ -22,10 +22,12 @@ var (
type Client struct {
baseURL string
chatURL string
+ apiVersion string
chatRequestTemplate *ChatRequest
+ errMapper *ErrorMapper
config *Config
httpClient *http.Client
- telemetry *telemetry.Telemetry
+ tel *telemetry.Telemetry
}
// NewClient creates a new OpenAI client for the OpenAI API.
@@ -38,8 +40,10 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
c := &Client{
baseURL: providerConfig.BaseURL,
chatURL: chatURL,
+ apiVersion: providerConfig.APIVersion,
config: providerConfig,
chatRequestTemplate: NewChatRequestFromConfig(providerConfig),
+ errMapper: NewErrorMapper(tel),
httpClient: &http.Client{
Timeout: *clientConfig.Timeout,
// TODO: use values from the config
@@ -48,7 +52,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
MaxIdleConnsPerHost: 2,
},
},
- telemetry: tel,
+ tel: tel,
}
return c, nil
diff --git a/pkg/providers/anthropic/config.go b/pkg/providers/anthropic/config.go
index beb98734..0de765ac 100644
--- a/pkg/providers/anthropic/config.go
+++ b/pkg/providers/anthropic/config.go
@@ -14,7 +14,6 @@ type Params struct {
MaxTokens int `yaml:"max_tokens,omitempty" json:"max_tokens"`
StopSequences []string `yaml:"stop,omitempty" json:"stop"`
Metadata *string `yaml:"metadata,omitempty" json:"metadata"`
- // Stream bool `json:"stream,omitempty"` // TODO: we are not supporting this at the moment
}
func DefaultParams() Params {
@@ -38,6 +37,7 @@ func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error {
type Config struct {
BaseURL string `yaml:"baseUrl" json:"baseUrl" validate:"required"`
+ APIVersion string `yaml:"apiVersion" json:"apiVersion" validate:"required"`
ChatEndpoint string `yaml:"chatEndpoint" json:"chatEndpoint" validate:"required"`
Model string `yaml:"model" json:"model" validate:"required"`
APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"`
@@ -50,6 +50,7 @@ func DefaultConfig() *Config {
return &Config{
BaseURL: "https://api.anthropic.com/v1",
+ APIVersion: "2023-06-01",
ChatEndpoint: "/messages",
Model: "claude-instant-1.2",
DefaultParams: &defaultParams,
diff --git a/pkg/providers/anthropic/errors.go b/pkg/providers/anthropic/errors.go
new file mode 100644
index 00000000..44016560
--- /dev/null
+++ b/pkg/providers/anthropic/errors.go
@@ -0,0 +1,56 @@
+package anthropic
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "glide/pkg/providers/clients"
+ "glide/pkg/telemetry"
+ "go.uber.org/zap"
+)
+
+type ErrorMapper struct {
+ tel *telemetry.Telemetry
+}
+
+func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper {
+ return &ErrorMapper{
+ tel: tel,
+ }
+}
+
+func (m *ErrorMapper) Map(resp *http.Response) error {
+ bodyBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ m.tel.Logger.Error("failed to read anthropic chat response", zap.Error(err))
+ }
+
+ m.tel.Logger.Error(
+ "anthropic chat request failed",
+ zap.Int("status_code", resp.StatusCode),
+ zap.String("response", string(bodyBytes)),
+ zap.Any("headers", resp.Header),
+ )
+
+ if resp.StatusCode == http.StatusTooManyRequests {
+ // Read the value of the "Retry-After" header to get the cooldown delay
+ retryAfter := resp.Header.Get("Retry-After")
+
+ // Parse the value to get the duration
+ cooldownDelay, err := time.ParseDuration(retryAfter)
+ if err != nil {
+ return fmt.Errorf("failed to parse cooldown delay from headers: %w", err)
+ }
+
+ return clients.NewRateLimitError(&cooldownDelay)
+ }
+
+ if resp.StatusCode == http.StatusUnauthorized {
+ return clients.ErrUnauthorized
+ }
+
+ // Server & client errors result in the same error to keep gateway resilient
+ return clients.ErrProviderUnavailable
+}
diff --git a/pkg/providers/anthropic/schamas.go b/pkg/providers/anthropic/schamas.go
index 69b00248..2f915a0c 100644
--- a/pkg/providers/anthropic/schamas.go
+++ b/pkg/providers/anthropic/schamas.go
@@ -1,6 +1,16 @@
package anthropic
-// Anthropic Chat Response
+type Content struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+}
+
+type Usage struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+}
+
+// ChatCompletion is an Anthropic Chat Response
type ChatCompletion struct {
ID string `json:"id"`
Type string `json:"type"`
@@ -9,9 +19,5 @@ type ChatCompletion struct {
Content []Content `json:"content"`
StopReason string `json:"stop_reason"`
StopSequence string `json:"stop_sequence"`
-}
-
-type Content struct {
- Type string `json:"type"`
- Text string `json:"text"`
+ Usage Usage `json:"usage"`
}
diff --git a/pkg/providers/anthropic/testdata/chat.success.json b/pkg/providers/anthropic/testdata/chat.success.json
index eaf0f6c9..f4921bd4 100644
--- a/pkg/providers/anthropic/testdata/chat.success.json
+++ b/pkg/providers/anthropic/testdata/chat.success.json
@@ -1,7 +1,7 @@
{
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"type": "message",
- "model": "claude-2.1",
+ "model": "claude-instant-1.2",
"role": "assistant",
"content": [
{
@@ -10,5 +10,9 @@
}
],
"stop_reason": "end_turn",
- "stop_sequence": null
-}
\ No newline at end of file
+ "stop_sequence": null,
+ "usage":{
+ "input_tokens": 24,
+ "output_tokens": 13
+ }
+}
diff --git a/pkg/providers/azureopenai/chat.go b/pkg/providers/azureopenai/chat.go
index f961587c..b3563eb2 100644
--- a/pkg/providers/azureopenai/chat.go
+++ b/pkg/providers/azureopenai/chat.go
@@ -7,40 +7,13 @@ import (
"fmt"
"io"
"net/http"
- "time"
"glide/pkg/api/schemas"
"glide/pkg/providers/openai"
- "glide/pkg/providers/clients"
-
"go.uber.org/zap"
)
-type ChatMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
-}
-
-// ChatRequest is an Azure openai-specific request schema
-type ChatRequest struct {
- Messages []ChatMessage `json:"messages"`
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- N int `json:"n,omitempty"`
- StopWords []string `json:"stop,omitempty"`
- Stream bool `json:"stream,omitempty"`
- FrequencyPenalty int `json:"frequency_penalty,omitempty"`
- PresencePenalty int `json:"presence_penalty,omitempty"`
- LogitBias *map[int]float64 `json:"logit_bias,omitempty"`
- User *string `json:"user,omitempty"`
- Seed *int `json:"seed,omitempty"`
- Tools []string `json:"tools,omitempty"`
- ToolChoice interface{} `json:"tool_choice,omitempty"`
- ResponseFormat interface{} `json:"response_format,omitempty"`
-}
-
// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
return &ChatRequest{
@@ -49,7 +22,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
MaxTokens: cfg.DefaultParams.MaxTokens,
N: cfg.DefaultParams.N,
StopWords: cfg.DefaultParams.StopWords,
- Stream: false, // unsupported right now
+ Stream: false,
FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty,
PresencePenalty: cfg.DefaultParams.PresencePenalty,
LogitBias: cfg.DefaultParams.LogitBias,
@@ -61,23 +34,10 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
}
}
-func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage {
- messages := make([]ChatMessage, 0, len(request.MessageHistory)+1)
-
- // Add items from messageHistory first and the new chat message last
- for _, message := range request.MessageHistory {
- messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content})
- }
-
- messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})
-
- return messages
-}
-
// Chat sends a chat request to the specified azure openai model.
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
// Create a new chat request
- chatRequest := c.createChatRequestSchema(request)
+ chatRequest := c.createRequestSchema(request)
chatResponse, err := c.doChatRequest(ctx, chatRequest)
if err != nil {
@@ -91,12 +51,21 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem
return chatResponse, nil
}
-func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest {
+// createRequestSchema creates a new ChatRequest object based on the given request.
+func (c *Client) createRequestSchema(request *schemas.ChatRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
- chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
- chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request)
+ chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template
+
+ chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1)
+
+ // Add items from messageHistory first and the new chat message last
+ for _, message := range request.MessageHistory {
+ chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content})
+ }
+
+ chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})
- return chatRequest
+ return &chatRequest
}
func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) {
@@ -115,7 +84,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
req.Header.Set("Content-Type", "application/json")
// TODO: this could leak information from messages which may not be a desired thing to have
- c.telemetry.Logger.Debug(
+ c.tel.Logger.Debug(
"azure openai chat request",
zap.String("chat_url", c.chatURL),
zap.Any("payload", payload),
@@ -129,39 +98,13 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- c.telemetry.Logger.Error("failed to read azure openai chat response", zap.Error(err))
- }
-
- c.telemetry.Logger.Error(
- "azure openai chat request failed",
- zap.Int("status_code", resp.StatusCode),
- zap.String("response", string(bodyBytes)),
- zap.Any("headers", resp.Header),
- )
-
- if resp.StatusCode == http.StatusTooManyRequests {
- // Read the value of the "Retry-After" header to get the cooldown delay
- retryAfter := resp.Header.Get("Retry-After")
-
- // Parse the value to get the duration
- cooldownDelay, err := time.ParseDuration(retryAfter)
- if err != nil {
- return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err)
- }
-
- return nil, clients.NewRateLimitError(&cooldownDelay)
- }
-
- // Server & client errors result in the same error to keep gateway resilient
- return nil, clients.ErrProviderUnavailable
+ return nil, c.errMapper.Map(resp)
}
// Read the response body into a byte slice
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
- c.telemetry.Logger.Error("failed to read azure openai chat response", zap.Error(err))
+ c.tel.Logger.Error("failed to read azure openai chat response", zap.Error(err))
return nil, err
}
@@ -170,7 +113,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
err = json.Unmarshal(bodyBytes, &openAICompletion)
if err != nil {
- c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err))
+ c.tel.Logger.Error("failed to parse openai chat response", zap.Error(err))
return nil, err
}
@@ -178,19 +121,18 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
// Map response to UnifiedChatResponse schema
response := schemas.ChatResponse{
- ID: openAICompletion.ID,
- Created: openAICompletion.Created,
- Provider: providerName,
- Model: openAICompletion.Model,
- Cached: false,
- ModelResponse: schemas.ProviderResponse{
+ ID: openAICompletion.ID,
+ Created: openAICompletion.Created,
+ Provider: providerName,
+ ModelName: openAICompletion.ModelName,
+ Cached: false,
+ ModelResponse: schemas.ModelResponse{
SystemID: map[string]string{
"system_fingerprint": openAICompletion.SystemFingerprint,
},
Message: schemas.ChatMessage{
Role: openAICompletion.Choices[0].Message.Role,
Content: openAICompletion.Choices[0].Message.Content,
- Name: "",
},
TokenUsage: schemas.TokenUsage{
PromptTokens: openAICompletion.Usage.PromptTokens,
diff --git a/pkg/providers/azureopenai/chat_stream.go b/pkg/providers/azureopenai/chat_stream.go
new file mode 100644
index 00000000..a3a64bff
--- /dev/null
+++ b/pkg/providers/azureopenai/chat_stream.go
@@ -0,0 +1,231 @@
+package azureopenai
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+
+ "github.com/r3labs/sse/v2"
+ "glide/pkg/providers/clients"
+ "glide/pkg/telemetry"
+
+ "go.uber.org/zap"
+
+ "glide/pkg/api/schemas"
+)
+
+var (
+ StopReason = "stop"
+ streamDoneMarker = []byte("[DONE]")
+)
+
+// ChatStream represents chat stream for a specific request
+type ChatStream struct {
+ tel *telemetry.Telemetry
+ client *http.Client
+ req *http.Request
+ reqID string
+ reqMetadata *schemas.Metadata
+ resp *http.Response
+ reader *sse.EventStreamReader
+ errMapper *ErrorMapper
+}
+
+func NewChatStream(
+ tel *telemetry.Telemetry,
+ client *http.Client,
+ req *http.Request,
+ reqID string,
+ reqMetadata *schemas.Metadata,
+ errMapper *ErrorMapper,
+) *ChatStream {
+ return &ChatStream{
+ tel: tel,
+ client: client,
+ req: req,
+ reqID: reqID,
+ reqMetadata: reqMetadata,
+ errMapper: errMapper,
+ }
+}
+
+// Open initializes and opens a ChatStream.
+func (s *ChatStream) Open() error {
+ resp, err := s.client.Do(s.req) //nolint:bodyclose
+ if err != nil {
+ return err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return s.errMapper.Map(resp)
+ }
+
+ s.resp = resp
+ s.reader = sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize?
+
+ return nil
+}
+
+// Recv receives a chat stream chunk from the ChatStream and returns a ChatStreamChunk object.
+func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
+ var completionChunk ChatCompletionChunk
+
+ for {
+ rawEvent, err := s.reader.ReadEvent()
+ if err != nil {
+ s.tel.L().Warn(
+ "Chat stream is unexpectedly disconnected",
+ zap.String("provider", providerName),
+ zap.Error(err),
+ )
+
+ // if err is io.EOF, this still means that the stream is interrupted unexpectedly
+ // because the normal stream termination is done via finding out streamDoneMarker
+
+ return nil, clients.ErrProviderUnavailable
+ }
+
+ s.tel.L().Debug(
+ "Raw chat stream chunk",
+ zap.String("provider", providerName),
+ zap.ByteString("rawChunk", rawEvent),
+ )
+
+ event, err := clients.ParseSSEvent(rawEvent)
+
+ if bytes.Equal(event.Data, streamDoneMarker) {
+ s.tel.L().Info(
+ "EOF: [DONE] marker found in chat stream",
+ zap.String("provider", providerName),
+ )
+
+ return nil, io.EOF
+ }
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse chat stream message: %v", err)
+ }
+
+ if !event.HasContent() {
+ s.tel.L().Debug(
+ "Received an empty message in chat stream, skipping it",
+ zap.String("provider", providerName),
+ zap.Any("msg", event),
+ )
+
+ continue
+ }
+
+ err = json.Unmarshal(event.Data, &completionChunk)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal AzureOpenAI chat stream chunk: %v", err)
+ }
+
+ responseChunk := completionChunk.Choices[0]
+
+ var finishReason *schemas.FinishReason
+
+ if responseChunk.FinishReason == StopReason {
+ finishReason = &schemas.Complete
+ }
+
+ // TODO: use objectpool here
+ return &schemas.ChatStreamChunk{
+ ID: s.reqID,
+ Provider: providerName,
+ Cached: false,
+ ModelName: completionChunk.ModelName,
+ Metadata: s.reqMetadata,
+ ModelResponse: schemas.ModelChunkResponse{
+ Metadata: &schemas.Metadata{
+ "response_id": completionChunk.ID,
+ "system_fingerprint": completionChunk.SystemFingerprint,
+ },
+ Message: schemas.ChatMessage{
+ Role: responseChunk.Delta.Role,
+ Content: responseChunk.Delta.Content,
+ },
+ FinishReason: finishReason,
+ },
+ }, nil
+ }
+}
+
+func (s *ChatStream) Close() error {
+ if s.resp != nil {
+ return s.resp.Body.Close()
+ }
+
+ return nil
+}
+
+func (c *Client) SupportChatStream() bool {
+ return true
+}
+
+func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) {
+ // Create a new chat request
+ httpRequest, err := c.makeStreamReq(ctx, req)
+ if err != nil {
+ return nil, err
+ }
+
+ return NewChatStream(
+ c.tel,
+ c.httpClient,
+ httpRequest,
+ req.ID,
+ req.Metadata,
+ c.errMapper,
+ ), nil
+}
+
+func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest {
+ // TODO: consider using objectpool to optimize memory allocation
+ chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template
+
+ chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1)
+
+ // Add items from messageHistory first and the new chat message last
+ for _, message := range request.MessageHistory {
+ chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content})
+ }
+
+ chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})
+
+ return &chatRequest
+}
+
+func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) {
+ chatRequest := c.createRequestFromStream(req)
+
+ chatRequest.Stream = true
+
+ rawPayload, err := json.Marshal(chatRequest)
+ if err != nil {
+ return nil, fmt.Errorf("unable to marshal AzureOpenAI chat stream request payload: %w", err)
+ }
+
+ request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload))
+ if err != nil {
+ return nil, fmt.Errorf("unable to create AzureOpenAI stream chat request: %w", err)
+ }
+
+ request.Header.Set("Content-Type", "application/json")
+ request.Header.Set("api-key", string(c.config.APIKey))
+ request.Header.Set("Cache-Control", "no-cache")
+ request.Header.Set("Accept", "text/event-stream")
+ request.Header.Set("Connection", "keep-alive")
+
+ // TODO: this could leak information from messages which may not be a desired thing to have
+ c.tel.L().Debug(
+ "Stream chat request",
+ zap.String("chatURL", c.chatURL),
+ zap.Any("payload", chatRequest),
+ )
+
+ return request, nil
+}
diff --git a/pkg/providers/azureopenai/chat_stream_test.go b/pkg/providers/azureopenai/chat_stream_test.go
new file mode 100644
index 00000000..080ffb24
--- /dev/null
+++ b/pkg/providers/azureopenai/chat_stream_test.go
@@ -0,0 +1,157 @@
+package azureopenai
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "glide/pkg/api/schemas"
+
+ "glide/pkg/providers/clients"
+ "glide/pkg/telemetry"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAzureOpenAIClient_ChatStreamSupported(t *testing.T) {
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
+
+ require.True(t, client.SupportChatStream())
+}
+
+func TestAzureOpenAIClient_ChatStreamRequest(t *testing.T) {
+ tests := map[string]string{
+ "success stream": "./testdata/chat_stream.success.txt",
+ }
+
+ for name, streamFile := range tests {
+ t.Run(name, func(t *testing.T) {
+ AzureOpenAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ rawPayload, _ := io.ReadAll(r.Body)
+
+ var data interface{}
+ // Parse the JSON body
+ err := json.Unmarshal(rawPayload, &data)
+ if err != nil {
+ t.Errorf("error decoding payload (%q): %v", string(rawPayload), err)
+ }
+
+ chatResponse, err := os.ReadFile(filepath.Clean(streamFile))
+ if err != nil {
+ t.Errorf("error reading azureopenai chat mock response: %v", err)
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ _, err = w.Write(chatResponse)
+ if err != nil {
+ t.Errorf("error on sending chat response: %v", err)
+ }
+ })
+
+ AzureopenAIServer := httptest.NewServer(AzureOpenAIMock)
+ defer AzureopenAIServer.Close()
+
+ ctx := context.Background()
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ providerCfg.BaseURL = AzureopenAIServer.URL
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
+
+ req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?")
+ stream, err := client.ChatStream(ctx, req)
+ require.NoError(t, err)
+
+ err = stream.Open()
+ require.NoError(t, err)
+
+ for {
+ chunk, err := stream.Recv()
+
+ if err == io.EOF {
+ return
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, chunk)
+ }
+ })
+ }
+}
+
+func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) {
+ tests := map[string]string{
+ "success stream, but no last done message": "./testdata/chat_stream.nodone.txt",
+ "success stream, but with empty event": "./testdata/chat_stream.empty.txt",
+ }
+
+ for name, streamFile := range tests {
+ t.Run(name, func(t *testing.T) {
+ openAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ rawPayload, _ := io.ReadAll(r.Body)
+
+ var data interface{}
+ // Parse the JSON body
+ err := json.Unmarshal(rawPayload, &data)
+ if err != nil {
+ t.Errorf("error decoding payload (%q): %v", string(rawPayload), err)
+ }
+
+ chatResponse, err := os.ReadFile(filepath.Clean(streamFile))
+ if err != nil {
+ t.Errorf("error reading openai chat mock response: %v", err)
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ _, err = w.Write(chatResponse)
+ if err != nil {
+ t.Errorf("error on sending chat response: %v", err)
+ }
+ })
+
+ openAIServer := httptest.NewServer(openAIMock)
+ defer openAIServer.Close()
+
+ ctx := context.Background()
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ providerCfg.BaseURL = openAIServer.URL
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
+
+ req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?")
+
+ stream, err := client.ChatStream(ctx, req)
+ require.NoError(t, err)
+
+ err = stream.Open()
+ require.NoError(t, err)
+
+ for {
+ chunk, err := stream.Recv()
+ if err != nil {
+ require.ErrorIs(t, err, clients.ErrProviderUnavailable)
+ return
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, chunk)
+ }
+ })
+ }
+}
diff --git a/pkg/providers/azureopenai/client.go b/pkg/providers/azureopenai/client.go
index cd05ba90..9a15aeb8 100644
--- a/pkg/providers/azureopenai/client.go
+++ b/pkg/providers/azureopenai/client.go
@@ -23,9 +23,10 @@ type Client struct {
baseURL string // The name of your Azure OpenAI Resource (e.g https://glide-test.openai.azure.com/)
chatURL string
chatRequestTemplate *ChatRequest
+ errMapper *ErrorMapper
config *Config
httpClient *http.Client
- telemetry *telemetry.Telemetry
+ tel *telemetry.Telemetry
}
// NewClient creates a new Azure OpenAI client for the OpenAI API.
@@ -42,6 +43,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
chatURL: chatURL,
config: providerConfig,
chatRequestTemplate: NewChatRequestFromConfig(providerConfig),
+ errMapper: NewErrorMapper(tel),
httpClient: &http.Client{
// TODO: use values from the config
Timeout: *clientConfig.Timeout,
@@ -50,7 +52,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
MaxIdleConnsPerHost: 2,
},
},
- telemetry: tel,
+ tel: tel,
}
return c, nil
diff --git a/pkg/providers/azureopenai/client_test.go b/pkg/providers/azureopenai/client_test.go
index 8f5de037..be98c316 100644
--- a/pkg/providers/azureopenai/client_test.go
+++ b/pkg/providers/azureopenai/client_test.go
@@ -106,21 +106,19 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) {
defer mockServer.Close()
- // Create a new client with the mock server URL
- client := &Client{
- httpClient: http.DefaultClient,
- chatURL: mockServer.URL,
- config: &Config{APIKey: "dummy_key"},
- telemetry: telemetry.NewTelemetryMock(),
- }
+ ctx := context.Background()
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ providerCfg.BaseURL = mockServer.URL
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
// Create a chat request payload
- payload := &ChatRequest{
- Messages: []ChatMessage{{Role: "human", Content: "Hello"}},
- }
+ payload := schemas.NewChatFromStr("What's the dealio?")
- // Call the doChatRequest function
- _, err := client.doChatRequest(context.Background(), payload)
+ _, err = client.Chat(ctx, payload)
require.Error(t, err)
require.Contains(t, err.Error(), "provider is not available")
diff --git a/pkg/providers/azureopenai/errors.go b/pkg/providers/azureopenai/errors.go
new file mode 100644
index 00000000..0e55c0b0
--- /dev/null
+++ b/pkg/providers/azureopenai/errors.go
@@ -0,0 +1,56 @@
+package azureopenai
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "glide/pkg/providers/clients"
+ "glide/pkg/telemetry"
+ "go.uber.org/zap"
+)
+
+type ErrorMapper struct {
+ tel *telemetry.Telemetry
+}
+
+func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper {
+ return &ErrorMapper{
+ tel: tel,
+ }
+}
+
+func (m *ErrorMapper) Map(resp *http.Response) error {
+ bodyBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ m.tel.L().Error("failed to read azure openai chat response", zap.Error(err))
+ }
+
+ m.tel.L().Error(
+ "azure openai chat request failed",
+ zap.Int("status_code", resp.StatusCode),
+ zap.String("response", string(bodyBytes)),
+ zap.Any("headers", resp.Header),
+ )
+
+ if resp.StatusCode == http.StatusTooManyRequests {
+ // Read the value of the "Retry-After" header to get the cooldown delay
+ retryAfter := resp.Header.Get("Retry-After")
+
+ // Parse the value to get the duration
+ cooldownDelay, err := time.ParseDuration(retryAfter)
+ if err != nil {
+ return fmt.Errorf("failed to parse cooldown delay from headers: %w", err)
+ }
+
+ return clients.NewRateLimitError(&cooldownDelay)
+ }
+
+ if resp.StatusCode == http.StatusUnauthorized {
+ return clients.ErrUnauthorized
+ }
+
+ // Server & client errors result in the same error to keep gateway resilient
+ return clients.ErrProviderUnavailable
+}
diff --git a/pkg/providers/azureopenai/schemas.go b/pkg/providers/azureopenai/schemas.go
new file mode 100644
index 00000000..993bb8d7
--- /dev/null
+++ b/pkg/providers/azureopenai/schemas.go
@@ -0,0 +1,67 @@
+package azureopenai
+
+type ChatMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+// ChatRequest is an Azure openai-specific request schema
+type ChatRequest struct {
+ Messages []ChatMessage `json:"messages"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ N int `json:"n,omitempty"`
+ StopWords []string `json:"stop,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ FrequencyPenalty int `json:"frequency_penalty,omitempty"`
+ PresencePenalty int `json:"presence_penalty,omitempty"`
+ LogitBias *map[int]float64 `json:"logit_bias,omitempty"`
+ User *string `json:"user,omitempty"`
+ Seed *int `json:"seed,omitempty"`
+ Tools []string `json:"tools,omitempty"`
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
+ ResponseFormat interface{} `json:"response_format,omitempty"`
+}
+
+// ChatCompletion
+// Ref: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
+type ChatCompletion struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int `json:"created"`
+ ModelName string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint"`
+ Choices []Choice `json:"choices"`
+ Usage Usage `json:"usage"`
+}
+
+type Choice struct {
+ Index int `json:"index"`
+ Message ChatMessage `json:"message"`
+ Logprobs interface{} `json:"logprobs"`
+ FinishReason string `json:"finish_reason"`
+}
+
+type Usage struct {
+ PromptTokens float64 `json:"prompt_tokens"`
+ CompletionTokens float64 `json:"completion_tokens"`
+ TotalTokens float64 `json:"total_tokens"`
+}
+
+// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming
+// Ref: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
+type ChatCompletionChunk struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int `json:"created"`
+ ModelName string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint"`
+ Choices []StreamChoice `json:"choices"`
+}
+
+type StreamChoice struct {
+ Index int `json:"index"`
+ Delta ChatMessage `json:"delta"`
+ FinishReason string `json:"finish_reason"`
+}
diff --git a/pkg/providers/azureopenai/testdata/chat_stream.empty.txt b/pkg/providers/azureopenai/testdata/chat_stream.empty.txt
new file mode 100644
index 00000000..a04fb787
--- /dev/null
+++ b/pkg/providers/azureopenai/testdata/chat_stream.empty.txt
@@ -0,0 +1,6 @@
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
+
+data:
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
+
diff --git a/pkg/providers/azureopenai/testdata/chat_stream.nodone.txt b/pkg/providers/azureopenai/testdata/chat_stream.nodone.txt
new file mode 100644
index 00000000..63152c29
--- /dev/null
+++ b/pkg/providers/azureopenai/testdata/chat_stream.nodone.txt
@@ -0,0 +1,22 @@
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"The"},"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" capital"},"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" of"},"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" the"},"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" United"},"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" Kingdom"},"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" is"},"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" London"},"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"."},"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
+
diff --git a/pkg/providers/azureopenai/testdata/chat_stream.success.txt b/pkg/providers/azureopenai/testdata/chat_stream.success.txt
new file mode 100644
index 00000000..1e673eaf
--- /dev/null
+++ b/pkg/providers/azureopenai/testdata/chat_stream.success.txt
@@ -0,0 +1,24 @@
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"The"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" capital"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" of"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" the"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" United"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" Kingdom"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" London"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
+
+data: [DONE]
+
diff --git a/pkg/providers/bedrock/chat.go b/pkg/providers/bedrock/chat.go
index 14feb9bc..bb17cb11 100644
--- a/pkg/providers/bedrock/chat.go
+++ b/pkg/providers/bedrock/chat.go
@@ -99,16 +99,17 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
err = json.Unmarshal(result.Body, &bedrockCompletion)
if err != nil {
c.telemetry.Logger.Error("failed to parse bedrock chat response", zap.Error(err))
+
return nil, err
}
response := schemas.ChatResponse{
- ID: uuid.NewString(),
- Created: int(time.Now().Unix()),
- Provider: "aws-bedrock",
- Model: c.config.Model,
- Cached: false,
- ModelResponse: schemas.ProviderResponse{
+ ID: uuid.NewString(),
+ Created: int(time.Now().Unix()),
+ Provider: "aws-bedrock",
+ ModelName: c.config.Model,
+ Cached: false,
+ ModelResponse: schemas.ModelResponse{
SystemID: map[string]string{
"system_fingerprint": "none",
},
@@ -118,9 +119,9 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
Name: "",
},
TokenUsage: schemas.TokenUsage{
- PromptTokens: float64(bedrockCompletion.Results[0].TokenCount),
+ PromptTokens: bedrockCompletion.Results[0].TokenCount,
ResponseTokens: -1,
- TotalTokens: float64(bedrockCompletion.Results[0].TokenCount),
+ TotalTokens: bedrockCompletion.Results[0].TokenCount,
},
},
}
diff --git a/pkg/providers/bedrock/chat_stream.go b/pkg/providers/bedrock/chat_stream.go
new file mode 100644
index 00000000..35918a0c
--- /dev/null
+++ b/pkg/providers/bedrock/chat_stream.go
@@ -0,0 +1,16 @@
+package bedrock
+
+import (
+ "context"
+
+ "glide/pkg/api/schemas"
+ "glide/pkg/providers/clients"
+)
+
+func (c *Client) SupportChatStream() bool {
+ return false
+}
+
+func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
+ return nil, clients.ErrChatStreamNotImplemented
+}
diff --git a/pkg/providers/clients/config_test.go b/pkg/providers/clients/config_test.go
new file mode 100644
index 00000000..0e725201
--- /dev/null
+++ b/pkg/providers/clients/config_test.go
@@ -0,0 +1,13 @@
+package clients
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestClientConfig_DefaultConfig(t *testing.T) {
+ config := DefaultClientConfig()
+
+ require.NotEmpty(t, config.Timeout)
+}
diff --git a/pkg/providers/clients/errors.go b/pkg/providers/clients/errors.go
index 8c704a3f..deaf00bc 100644
--- a/pkg/providers/clients/errors.go
+++ b/pkg/providers/clients/errors.go
@@ -6,7 +6,11 @@ import (
"time"
)
-var ErrProviderUnavailable = errors.New("provider is not available")
+var (
+ ErrProviderUnavailable = errors.New("provider is not available")
+ ErrUnauthorized = errors.New("API key is wrong or not set")
+ ErrChatStreamNotImplemented = errors.New("streaming chat API is not implemented for provider")
+)
type RateLimitError struct {
untilReset time.Duration
diff --git a/pkg/providers/clients/errors_test.go b/pkg/providers/clients/errors_test.go
new file mode 100644
index 00000000..5a570a13
--- /dev/null
+++ b/pkg/providers/clients/errors_test.go
@@ -0,0 +1,16 @@
+package clients
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestRateLimitError(t *testing.T) {
+ duration := 5 * time.Minute
+ err := NewRateLimitError(&duration)
+
+ require.Equal(t, duration, err.UntilReset())
+ require.Contains(t, err.Error(), "rate limit reached")
+}
diff --git a/pkg/providers/clients/sse.go b/pkg/providers/clients/sse.go
new file mode 100644
index 00000000..5619876f
--- /dev/null
+++ b/pkg/providers/clients/sse.go
@@ -0,0 +1,82 @@
+package clients
+
+import (
+ "bytes"
+ "errors"
+)
+
+// Taken from https://github.com/r3labs/sse/blob/master/client.go#L322
+
+var (
+ headerID = []byte("id:")
+ headerData = []byte("data:")
+ headerEvent = []byte("event:")
+ headerRetry = []byte("retry:")
+)
+
+// Event holds all the event source fields
+type Event struct {
+ ID []byte
+ Data []byte
+ Event []byte
+ Retry []byte
+ Comment []byte
+}
+
+func (e *Event) HasContent() bool {
+ return len(e.ID) > 0 || len(e.Data) > 0 || len(e.Event) > 0 || len(e.Retry) > 0
+}
+
+func ParseSSEvent(msg []byte) (event *Event, err error) {
+ var e Event
+
+ if len(msg) < 1 {
+ return nil, errors.New("event message was empty")
+ }
+
+ // Normalize the crlf to lf to make it easier to split the lines.
+ // Split the line by "\n" or "\r", per the spec.
+ for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) {
+ switch {
+ case bytes.HasPrefix(line, headerID):
+ e.ID = append([]byte(nil), trimHeader(len(headerID), line)...)
+ case bytes.HasPrefix(line, headerData):
+ // The spec allows for multiple data fields per event, concatenated them with "\n".
+ e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...)
+ // The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body.
+ case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))):
+ e.Data = append(e.Data, byte('\n'))
+ case bytes.HasPrefix(line, headerEvent):
+ e.Event = append([]byte(nil), trimHeader(len(headerEvent), line)...)
+ case bytes.HasPrefix(line, headerRetry):
+ e.Retry = append([]byte(nil), trimHeader(len(headerRetry), line)...)
+ default:
+ // Ignore any garbage that doesn't match what we're looking for.
+ } //nolint:wsl
+ }
+
+ // Trim the last "\n" per the spec.
+ e.Data = bytes.TrimSuffix(e.Data, []byte("\n"))
+
+ return &e, err
+}
+
+func trimHeader(size int, data []byte) []byte {
+ if data == nil || len(data) < size {
+ return data
+ }
+
+ data = data[size:]
+
+ // Remove optional leading whitespace
+ if len(data) > 0 && data[0] == 32 {
+ data = data[1:]
+ }
+
+ // Remove trailing new line
+ if len(data) > 0 && data[len(data)-1] == 10 {
+ data = data[:len(data)-1]
+ }
+
+ return data
+}
diff --git a/pkg/providers/clients/sse_test.go b/pkg/providers/clients/sse_test.go
new file mode 100644
index 00000000..c749737e
--- /dev/null
+++ b/pkg/providers/clients/sse_test.go
@@ -0,0 +1,27 @@
+package clients
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseSSEvent_ValidEvents(t *testing.T) {
+ tests := []struct {
+ name string
+ rawMsg string
+ data string
+ }{
+ {"data only", "data: {\"id\":\"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg\"}\n", "{\"id\":\"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg\"}"},
+ {"empty data", "data:", ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ event, err := ParseSSEvent([]byte(tt.rawMsg))
+
+ require.NoError(t, err)
+ require.Equal(t, []byte(tt.data), event.Data)
+ })
+ }
+}
diff --git a/pkg/providers/clients/stream.go b/pkg/providers/clients/stream.go
new file mode 100644
index 00000000..ff150e8f
--- /dev/null
+++ b/pkg/providers/clients/stream.go
@@ -0,0 +1,31 @@
+package clients
+
+import (
+ "glide/pkg/api/schemas"
+)
+
+type ChatStream interface {
+ Open() error
+ Recv() (*schemas.ChatStreamChunk, error)
+ Close() error
+}
+
+type ChatStreamResult struct {
+ chunk *schemas.ChatStreamChunk
+ err error
+}
+
+func (r *ChatStreamResult) Chunk() *schemas.ChatStreamChunk {
+ return r.chunk
+}
+
+func (r *ChatStreamResult) Error() error {
+ return r.err
+}
+
+func NewChatStreamResult(chunk *schemas.ChatStreamChunk, err error) *ChatStreamResult {
+ return &ChatStreamResult{
+ chunk: chunk,
+ err: err,
+ }
+}
diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go
index 165b67bd..86573041 100644
--- a/pkg/providers/cohere/chat.go
+++ b/pkg/providers/cohere/chat.go
@@ -15,40 +15,6 @@ import (
"go.uber.org/zap"
)
-type ChatMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
-}
-
-type ChatHistory struct {
- Role string `json:"role"`
- Message string `json:"message"`
- User string `json:"user,omitempty"`
-}
-
-// ChatRequest is a request to complete a chat completion..
-type ChatRequest struct {
- Model string `json:"model"`
- Message string `json:"message"`
- Temperature float64 `json:"temperature,omitempty"`
- PreambleOverride string `json:"preamble_override,omitempty"`
- ChatHistory []ChatHistory `json:"chat_history,omitempty"`
- ConversationID string `json:"conversation_id,omitempty"`
- PromptTruncation string `json:"prompt_truncation,omitempty"`
- Connectors []string `json:"connectors,omitempty"`
- SearchQueriesOnly bool `json:"search_queries_only,omitempty"`
- CitiationQuality string `json:"citiation_quality,omitempty"`
-
- // Stream bool `json:"stream,omitempty"`
-}
-
-type Connectors struct {
- ID string `json:"id"`
- UserAccessToken string `json:"user_access_token"`
- ContOnFail string `json:"continue_on_failure"`
- Options map[string]string `json:"options"`
-}
-
// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
return &ChatRequest{
@@ -61,13 +27,14 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
Connectors: cfg.DefaultParams.Connectors,
SearchQueriesOnly: cfg.DefaultParams.SearchQueriesOnly,
CitiationQuality: cfg.DefaultParams.CitiationQuality,
+ Stream: false,
}
}
// Chat sends a chat request to the specified cohere model.
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
// Create a new chat request
- chatRequest := c.createChatRequestSchema(request)
+ chatRequest := c.createRequestSchema(request)
chatResponse, err := c.doChatRequest(ctx, chatRequest)
if err != nil {
@@ -81,9 +48,9 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem
return chatResponse, nil
}
-func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest {
+func (c *Client) createRequestSchema(request *schemas.ChatRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
- chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
+ chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template
chatRequest.Message = request.Message.Content
// Build the Cohere specific ChatHistory
@@ -100,7 +67,7 @@ func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequ
}
}
- return chatRequest
+ return &chatRequest
}
func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) {
@@ -119,7 +86,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
req.Header.Set("Content-Type", "application/json")
// TODO: this could leak information from messages which may not be a desired thing to have
- c.telemetry.Logger.Debug(
+ c.tel.Logger.Debug(
"cohere chat request",
zap.String("chat_url", c.chatURL),
zap.Any("payload", payload),
@@ -135,10 +102,10 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
if resp.StatusCode != http.StatusOK {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
- c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err))
+ c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err))
}
- c.telemetry.Logger.Error(
+ c.tel.Logger.Error(
"cohere chat request failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response", string(bodyBytes)),
@@ -156,7 +123,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
// Read the response body into a byte slice
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
- c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err))
+ c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err))
return nil, err
}
@@ -165,7 +132,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
err = json.Unmarshal(bodyBytes, &responseJSON)
if err != nil {
- c.telemetry.Logger.Error("failed to parse cohere chat response", zap.Error(err))
+ c.tel.Logger.Error("failed to parse cohere chat response", zap.Error(err))
return nil, err
}
@@ -174,24 +141,24 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
err = json.Unmarshal(bodyBytes, &cohereCompletion)
if err != nil {
- c.telemetry.Logger.Error("failed to parse cohere chat response", zap.Error(err))
+ c.tel.Logger.Error("failed to parse cohere chat response", zap.Error(err))
return nil, err
}
// Map response to ChatResponse schema
response := schemas.ChatResponse{
- ID: cohereCompletion.ResponseID,
- Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this
- Provider: providerName,
- Model: c.config.Model,
- Cached: false,
- ModelResponse: schemas.ProviderResponse{
+ ID: cohereCompletion.ResponseID,
+ Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this
+ Provider: providerName,
+ ModelName: c.config.Model,
+ Cached: false,
+ ModelResponse: schemas.ModelResponse{
SystemID: map[string]string{
"generationId": cohereCompletion.GenerationID,
"responseId": cohereCompletion.ResponseID,
},
Message: schemas.ChatMessage{
- Role: "model", // TODO: Does this need to change?
+ Role: "model",
Content: cohereCompletion.Text,
Name: "",
},
@@ -209,11 +176,11 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
func (c *Client) handleErrorResponse(resp *http.Response) (*schemas.ChatResponse, error) {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
- c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err))
+ c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err))
return nil, err
}
- c.telemetry.Logger.Error(
+ c.tel.Logger.Error(
"cohere chat request failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response", string(bodyBytes)),
@@ -229,6 +196,10 @@ func (c *Client) handleErrorResponse(resp *http.Response) (*schemas.ChatResponse
return nil, clients.NewRateLimitError(&cooldownDelay)
}
+ if resp.StatusCode == http.StatusUnauthorized {
+ return nil, clients.ErrUnauthorized
+ }
+
return nil, clients.ErrProviderUnavailable
}
diff --git a/pkg/providers/cohere/chat_stream.go b/pkg/providers/cohere/chat_stream.go
new file mode 100644
index 00000000..f6f9e9a8
--- /dev/null
+++ b/pkg/providers/cohere/chat_stream.go
@@ -0,0 +1,242 @@
+package cohere
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+
+ "github.com/r3labs/sse/v2"
+ "glide/pkg/providers/clients"
+ "glide/pkg/telemetry"
+
+ "go.uber.org/zap"
+
+ "glide/pkg/api/schemas"
+)
+
+var StopReason = "stream-end"
+
+// ChatStream represents cohere chat stream for a specific request
+type ChatStream struct {
+ tel *telemetry.Telemetry
+ client *http.Client
+ req *http.Request
+ reqID string
+ reqMetadata *schemas.Metadata
+ resp *http.Response
+ reader *sse.EventStreamReader
+ errMapper *ErrorMapper
+}
+
+func NewChatStream(
+ tel *telemetry.Telemetry,
+ client *http.Client,
+ req *http.Request,
+ reqID string,
+ reqMetadata *schemas.Metadata,
+ errMapper *ErrorMapper,
+) *ChatStream {
+ return &ChatStream{
+ tel: tel,
+ client: client,
+ req: req,
+ reqID: reqID,
+ reqMetadata: reqMetadata,
+ errMapper: errMapper,
+ }
+}
+
+func (s *ChatStream) Open() error {
+ resp, err := s.client.Do(s.req) //nolint:bodyclose
+ if err != nil {
+ return err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return s.errMapper.Map(resp)
+ }
+
+ s.resp = resp
+ s.reader = sse.NewEventStreamReader(resp.Body, 8192) // TODO: should we expose maxBufferSize?
+
+ return nil
+}
+
+func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
+ var completionChunk ChatCompletionChunk
+
+ for {
+ rawEvent, err := s.reader.ReadEvent()
+ if err != nil {
+ s.tel.L().Warn(
+ "Chat stream is unexpectedly disconnected",
+ zap.String("provider", providerName),
+ zap.Error(err),
+ )
+
+ if err == io.EOF {
+ return nil, io.EOF
+ }
+
+ // if err is io.EOF, this still means that the stream is interrupted unexpectedly
+ // because the normal stream termination is done via finding out streamDoneMarker
+
+ return nil, clients.ErrProviderUnavailable
+ }
+
+ s.tel.L().Debug(
+ "Raw chat stream chunk",
+ zap.String("provider", providerName),
+ zap.ByteString("rawChunk", rawEvent),
+ )
+
+ event, err := clients.ParseSSEvent(rawEvent)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse chat stream message: %v", err)
+ }
+
+ if !event.HasContent() {
+ s.tel.L().Debug(
+ "Received an empty message in chat stream, skipping it",
+ zap.String("provider", providerName),
+ zap.Any("msg", event),
+ )
+
+ continue
+ }
+
+ err = json.Unmarshal(event.Data, &completionChunk)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err)
+ }
+
+ responseChunk := completionChunk
+
+ var finishReason *schemas.FinishReason
+
+ if responseChunk.IsFinished {
+ finishReason = &schemas.Complete
+
+ return &schemas.ChatStreamChunk{
+ ID: s.reqID,
+ Provider: providerName,
+ Cached: false,
+ ModelName: "NA",
+ Metadata: s.reqMetadata,
+ ModelResponse: schemas.ModelChunkResponse{
+ Metadata: &schemas.Metadata{
+ "generationId": responseChunk.Response.GenerationID,
+ "responseId": responseChunk.Response.ResponseID,
+ },
+ Message: schemas.ChatMessage{
+ Role: "model",
+ Content: responseChunk.Text,
+ },
+ FinishReason: finishReason,
+ },
+ }, nil
+ }
+
+ // TODO: use objectpool here
+ return &schemas.ChatStreamChunk{
+ ID: s.reqID,
+ Provider: providerName,
+ Cached: false,
+ ModelName: "NA",
+ Metadata: s.reqMetadata,
+ ModelResponse: schemas.ModelChunkResponse{
+ Message: schemas.ChatMessage{
+ Role: "model",
+ Content: responseChunk.Text,
+ },
+ FinishReason: finishReason,
+ },
+ }, nil
+ }
+}
+
+func (s *ChatStream) Close() error {
+ if s.resp != nil {
+ return s.resp.Body.Close()
+ }
+
+ return nil
+}
+
+func (c *Client) SupportChatStream() bool {
+ return true
+}
+
+func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) {
+ // Create a new chat request
+ httpRequest, err := c.makeStreamReq(ctx, req)
+ if err != nil {
+ return nil, err
+ }
+
+ return NewChatStream(
+ c.tel,
+ c.httpClient,
+ httpRequest,
+ req.ID,
+ req.Metadata,
+ c.errMapper,
+ ), nil
+}
+
+func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest {
+ // TODO: consider using objectpool to optimize memory allocation
+ chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template
+
+ chatRequest.Message = request.Message.Content
+
+ // Build the Cohere specific ChatHistory
+ if len(request.MessageHistory) > 0 {
+ chatRequest.ChatHistory = make([]ChatHistory, len(request.MessageHistory))
+ for i, message := range request.MessageHistory {
+ chatRequest.ChatHistory[i] = ChatHistory{
+ // Copy the necessary fields from message to ChatHistory
+ // For example, if ChatHistory has a field called "Text", you can do:
+ Role: message.Role,
+ Message: message.Content,
+ User: "",
+ }
+ }
+ }
+
+ return &chatRequest
+}
+
+func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) {
+ chatRequest := c.createRequestFromStream(req)
+
+ chatRequest.Stream = true
+
+ rawPayload, err := json.Marshal(chatRequest)
+ if err != nil {
+ return nil, fmt.Errorf("unable to marshal cohere chat stream request payload: %w", err)
+ }
+
+ request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload))
+ if err != nil {
+ return nil, fmt.Errorf("unable to create cohere stream chat request: %w", err)
+ }
+
+ request.Header.Set("Content-Type", "application/json")
+ request.Header.Set("Authorization", fmt.Sprintf("Bearer %v", string(c.config.APIKey)))
+ request.Header.Set("Cache-Control", "no-cache")
+ request.Header.Set("Accept", "text/event-stream")
+ request.Header.Set("Connection", "keep-alive")
+
+ // TODO: this could leak information from messages which may not be a desired thing to have
+ c.tel.L().Debug(
+ "Stream chat request",
+ zap.String("chatURL", c.chatURL),
+ zap.Any("payload", chatRequest),
+ )
+
+ return request, nil
+}
diff --git a/pkg/providers/cohere/chat_stream_test.go b/pkg/providers/cohere/chat_stream_test.go
new file mode 100644
index 00000000..3552b45b
--- /dev/null
+++ b/pkg/providers/cohere/chat_stream_test.go
@@ -0,0 +1,156 @@
+package cohere
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "glide/pkg/api/schemas"
+
+ "glide/pkg/providers/clients"
+ "glide/pkg/telemetry"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestCohere_ChatStreamSupported(t *testing.T) {
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
+
+ require.True(t, client.SupportChatStream())
+}
+
+func TestCohere_ChatStreamRequest(t *testing.T) {
+ tests := map[string]string{
+ "success stream": "./testdata/chat_stream.success.txt",
+ }
+
+ for name, streamFile := range tests {
+ t.Run(name, func(t *testing.T) {
+ cohereMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ rawPayload, _ := io.ReadAll(r.Body)
+
+ var data interface{}
+ // Parse the JSON body
+ err := json.Unmarshal(rawPayload, &data)
+ if err != nil {
+ t.Errorf("error decoding payload (%q): %v", string(rawPayload), err)
+ }
+
+ chatResponse, err := os.ReadFile(filepath.Clean(streamFile))
+ if err != nil {
+ t.Errorf("error reading cohere chat mock response: %v", err)
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ _, err = w.Write(chatResponse)
+ if err != nil {
+ t.Errorf("error on sending chat response: %v", err)
+ }
+ })
+
+ cohereServer := httptest.NewServer(cohereMock)
+ defer cohereServer.Close()
+
+ ctx := context.Background()
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ providerCfg.BaseURL = cohereServer.URL
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
+
+ req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?")
+
+ stream, err := client.ChatStream(ctx, req)
+ require.NoError(t, err)
+
+ err = stream.Open()
+ require.NoError(t, err)
+
+ for {
+ chunk, err := stream.Recv()
+
+ if err == io.EOF {
+ return
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, chunk)
+ }
+ })
+ }
+}
+
+func TestCohere_ChatStreamRequestInterrupted(t *testing.T) {
+ tests := map[string]string{
+ "success stream, but with empty event": "./testdata/chat_stream.empty.txt",
+ }
+
+ for name, streamFile := range tests {
+ t.Run(name, func(t *testing.T) {
+ cohereMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ rawPayload, _ := io.ReadAll(r.Body)
+
+ var data interface{}
+ // Parse the JSON body
+ err := json.Unmarshal(rawPayload, &data)
+ if err != nil {
+ t.Errorf("error decoding payload (%q): %v", string(rawPayload), err)
+ }
+
+ chatResponse, err := os.ReadFile(filepath.Clean(streamFile))
+ if err != nil {
+ t.Errorf("error reading cohere chat mock response: %v", err)
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ _, err = w.Write(chatResponse)
+ if err != nil {
+ t.Errorf("error on sending chat response: %v", err)
+ }
+ })
+
+ cohereServer := httptest.NewServer(cohereMock)
+ defer cohereServer.Close()
+
+ ctx := context.Background()
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ providerCfg.BaseURL = cohereServer.URL
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
+
+ req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?")
+ stream, err := client.ChatStream(ctx, req)
+ require.NoError(t, err)
+
+ err = stream.Open()
+ require.NoError(t, err)
+
+ for {
+ chunk, err := stream.Recv()
+ if err != nil {
+ require.ErrorIs(t, err, io.EOF)
+ return
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, chunk)
+ }
+ })
+ }
+}
diff --git a/pkg/providers/cohere/client.go b/pkg/providers/cohere/client.go
index a6cc9cf5..ec778c15 100644
--- a/pkg/providers/cohere/client.go
+++ b/pkg/providers/cohere/client.go
@@ -23,9 +23,10 @@ type Client struct {
baseURL string
chatURL string
chatRequestTemplate *ChatRequest
+ errMapper *ErrorMapper
config *Config
httpClient *http.Client
- telemetry *telemetry.Telemetry
+ tel *telemetry.Telemetry
}
// NewClient creates a new Cohere client for the Cohere API.
@@ -48,7 +49,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
MaxIdleConnsPerHost: 2,
},
},
- telemetry: tel,
+ tel: tel,
}
return c, nil
diff --git a/pkg/providers/cohere/config.go b/pkg/providers/cohere/config.go
index 1a38aefa..50dbcde4 100644
--- a/pkg/providers/cohere/config.go
+++ b/pkg/providers/cohere/config.go
@@ -8,7 +8,7 @@ import (
// TODO: Add validations
type Params struct {
Temperature float64 `json:"temperature,omitempty"`
- Stream bool `json:"stream,omitempty"` // unsupported right now
+ Stream bool `json:"stream,omitempty"`
PreambleOverride string `json:"preamble_override,omitempty"`
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
ConversationID string `json:"conversation_id,omitempty"`
diff --git a/pkg/providers/cohere/error.go b/pkg/providers/cohere/error.go
new file mode 100644
index 00000000..3e5ae89e
--- /dev/null
+++ b/pkg/providers/cohere/error.go
@@ -0,0 +1,64 @@
+package cohere
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "glide/pkg/providers/clients"
+ "glide/pkg/telemetry"
+ "go.uber.org/zap"
+)
+
+type ErrorMapper struct {
+ tel *telemetry.Telemetry
+}
+
+func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper {
+ return &ErrorMapper{
+ tel: tel,
+ }
+}
+
+func (m *ErrorMapper) Map(resp *http.Response) error {
+ bodyBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ m.tel.Logger.Error(
+ "Failed to unmarshal chat response error",
+ zap.String("provider", providerName),
+ zap.Error(err),
+ zap.ByteString("rawResponse", bodyBytes),
+ )
+
+ return clients.ErrProviderUnavailable
+ }
+
+ m.tel.Logger.Error(
+ "Chat request failed",
+ zap.String("provider", providerName),
+ zap.Int("statusCode", resp.StatusCode),
+ zap.String("response", string(bodyBytes)),
+ zap.Any("headers", resp.Header),
+ )
+
+ if resp.StatusCode == http.StatusTooManyRequests {
+ // Read the value of the "Retry-After" header to get the cooldown delay
+ retryAfter := resp.Header.Get("Retry-After")
+
+ // Parse the value to get the duration
+ cooldownDelay, err := time.ParseDuration(retryAfter)
+ if err != nil {
+ return fmt.Errorf("failed to parse cooldown delay from headers: %w", err)
+ }
+
+ return clients.NewRateLimitError(&cooldownDelay)
+ }
+
+ if resp.StatusCode == http.StatusUnauthorized {
+ return clients.ErrUnauthorized
+ }
+
+ // Server & client errors result in the same error to keep gateway resilient
+ return clients.ErrProviderUnavailable
+}
diff --git a/pkg/providers/cohere/schemas.go b/pkg/providers/cohere/schemas.go
index c807aa56..f6fc310e 100644
--- a/pkg/providers/cohere/schemas.go
+++ b/pkg/providers/cohere/schemas.go
@@ -15,10 +15,10 @@ type ChatCompletion struct {
}
type TokenCount struct {
- PromptTokens float64 `json:"prompt_tokens"`
- ResponseTokens float64 `json:"response_tokens"`
- TotalTokens float64 `json:"total_tokens"`
- BilledTokens float64 `json:"billed_tokens"`
+ PromptTokens int `json:"prompt_tokens"`
+ ResponseTokens int `json:"response_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ BilledTokens int `json:"billed_tokens"`
}
type Meta struct {
@@ -65,3 +65,54 @@ type ConnectorsResponse struct {
ContOnFail string `json:"continue_on_failure"`
Options map[string]string `json:"options"`
}
+
+// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming
+// Ref: https://docs.cohere.com/reference/about
+type ChatCompletionChunk struct {
+ IsFinished bool `json:"is_finished"`
+ EventType string `json:"event_type"`
+ Text string `json:"text"`
+ Response FinalResponse `json:"response,omitempty"`
+}
+
+type FinalResponse struct {
+ ResponseID string `json:"response_id"`
+ Text string `json:"text"`
+ GenerationID string `json:"generation_id"`
+ TokenCount TokenCount `json:"token_count"`
+ Meta Meta `json:"meta"`
+ FinishReason string `json:"finish_reason"`
+}
+
+type ChatMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+type ChatHistory struct {
+ Role string `json:"role"`
+ Message string `json:"message"`
+ User string `json:"user,omitempty"`
+}
+
+// ChatRequest is a request to complete a chat completion..
+type ChatRequest struct {
+ Model string `json:"model"`
+ Message string `json:"message"`
+ Temperature float64 `json:"temperature,omitempty"`
+ PreambleOverride string `json:"preamble_override,omitempty"`
+ ChatHistory []ChatHistory `json:"chat_history,omitempty"`
+ ConversationID string `json:"conversation_id,omitempty"`
+ PromptTruncation string `json:"prompt_truncation,omitempty"`
+ Connectors []string `json:"connectors,omitempty"`
+ SearchQueriesOnly bool `json:"search_queries_only,omitempty"`
+ CitiationQuality string `json:"citiation_quality,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+}
+
+type Connectors struct {
+ ID string `json:"id"`
+ UserAccessToken string `json:"user_access_token"`
+ ContOnFail string `json:"continue_on_failure"`
+ Options map[string]string `json:"options"`
+}
diff --git a/pkg/providers/cohere/testdata/chat_stream.empty.txt b/pkg/providers/cohere/testdata/chat_stream.empty.txt
new file mode 100644
index 00000000..38471d95
--- /dev/null
+++ b/pkg/providers/cohere/testdata/chat_stream.empty.txt
@@ -0,0 +1,277 @@
+{
+ "is_finished": false,
+ "event_type": "stream-start",
+ "generation_id": "d686011c-e1bb-41c1-9964-823d9b94d394"
+}
+{
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " capital"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " of"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " the"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " United"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " Kingdom"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " is"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " London"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": "."
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " London"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " is"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " a"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " vibrant"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " city"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " full"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " of"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " diverse"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " culture"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " history"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " and"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " iconic"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " landmarks"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": "."
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " It"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": "'s"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " a"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " global"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " city"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " that"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " has"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " influence"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " in"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " the"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " fields"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " of"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " art"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " fashion"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " finance"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " media"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " and"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " politics"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": "."
+}
+{
+ "is_finished": true,
+ "event_type": "stream-end",
+ "response": {
+ "response_id": "d4a5e49e-b892-41c5-950d-a97162d19393",
+ "text": "The capital of the United Kingdom is London. London is a vibrant city full of diverse culture, history, and iconic landmarks. It's a global city that has influence in the fields of art, fashion, finance, media, and politics.",
+ "generation_id": "d686011c-e1bb-41c1-9964-823d9b94d394",
+ "chat_history": [
+ {
+ "role": "USER",
+ "message": "What's the capital of the United Kingdom?"
+ },
+ {
+ "role": "CHATBOT",
+ "message": "The capital of the United Kingdom is London. London is a vibrant city full of diverse culture, history, and iconic landmarks. It's a global city that has influence in the fields of art, fashion, finance, media, and politics."
+ }
+ ],
+ "token_count": {
+ "prompt_tokens": 75,
+ "response_tokens": 48,
+ "total_tokens": 123,
+ "billed_tokens": 57
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 9,
+ "output_tokens": 48
+ }
+ }
+ },
+ "finish_reason": "COMPLETE"
+}
\ No newline at end of file
diff --git a/pkg/providers/cohere/testdata/chat_stream.nodone.txt b/pkg/providers/cohere/testdata/chat_stream.nodone.txt
new file mode 100644
index 00000000..785fe492
--- /dev/null
+++ b/pkg/providers/cohere/testdata/chat_stream.nodone.txt
@@ -0,0 +1,22 @@
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"The"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" capital"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" of"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" the"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" United"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" Kingdom"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" London"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
+
diff --git a/pkg/providers/cohere/testdata/chat_stream.success.txt b/pkg/providers/cohere/testdata/chat_stream.success.txt
new file mode 100644
index 00000000..d68c7eda
--- /dev/null
+++ b/pkg/providers/cohere/testdata/chat_stream.success.txt
@@ -0,0 +1,280 @@
+{
+ "is_finished": false,
+ "event_type": "stream-start",
+ "generation_id": "d686011c-e1bb-41c1-9964-823d9b94d394"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": "The"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " capital"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " of"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " the"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " United"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " Kingdom"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " is"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " London"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": "."
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " London"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " is"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " a"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " vibrant"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " city"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " full"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " of"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " diverse"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " culture"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " history"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " and"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " iconic"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " landmarks"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": "."
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " It"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": "'s"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " a"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " global"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " city"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " that"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " has"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " influence"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " in"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " the"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " fields"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " of"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " art"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " fashion"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " finance"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " media"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": ","
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " and"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": " politics"
+}
+{
+ "is_finished": false,
+ "event_type": "text-generation",
+ "text": "."
+}
+{
+ "is_finished": true,
+ "event_type": "stream-end",
+ "response": {
+ "response_id": "d4a5e49e-b892-41c5-950d-a97162d19393",
+ "text": "The capital of the United Kingdom is London. London is a vibrant city full of diverse culture, history, and iconic landmarks. It's a global city that has influence in the fields of art, fashion, finance, media, and politics.",
+ "generation_id": "d686011c-e1bb-41c1-9964-823d9b94d394",
+ "chat_history": [
+ {
+ "role": "USER",
+ "message": "What's the capital of the United Kingdom?"
+ },
+ {
+ "role": "CHATBOT",
+ "message": "The capital of the United Kingdom is London. London is a vibrant city full of diverse culture, history, and iconic landmarks. It's a global city that has influence in the fields of art, fashion, finance, media, and politics."
+ }
+ ],
+ "token_count": {
+ "prompt_tokens": 75,
+ "response_tokens": 48,
+ "total_tokens": 123,
+ "billed_tokens": 57
+ },
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 9,
+ "output_tokens": 48
+ }
+ }
+ },
+ "finish_reason": "COMPLETE"
+}
\ No newline at end of file
diff --git a/pkg/providers/config.go b/pkg/providers/config.go
index 55d885f6..58aebff7 100644
--- a/pkg/providers/config.go
+++ b/pkg/providers/config.go
@@ -49,18 +49,18 @@ func DefaultLangModelConfig() *LangModelConfig {
}
}
-func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (*LangModel, error) {
+func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) {
client, err := c.initClient(tel)
if err != nil {
return nil, fmt.Errorf("error initializing client: %v", err)
}
- return NewLangModel(c.ID, client, *c.ErrorBudget, *c.Latency, c.Weight), nil
+ return NewLangModel(c.ID, client, c.ErrorBudget, *c.Latency, c.Weight), nil
}
// initClient initializes the language model client based on the provided configuration.
// It takes a telemetry object as input and returns a LangModelProvider and an error.
-func (c *LangModelConfig) initClient(tel *telemetry.Telemetry) (LangModelProvider, error) {
+func (c *LangModelConfig) initClient(tel *telemetry.Telemetry) (LangProvider, error) {
switch {
case c.OpenAI != nil:
return openai.NewClient(c.OpenAI, c.Client, tel)
diff --git a/pkg/providers/lang.go b/pkg/providers/lang.go
new file mode 100644
index 00000000..f4ab258b
--- /dev/null
+++ b/pkg/providers/lang.go
@@ -0,0 +1,177 @@
+package providers
+
+import (
+ "context"
+ "io"
+ "time"
+
+ "glide/pkg/routers/health"
+
+ "glide/pkg/api/schemas"
+ "glide/pkg/providers/clients"
+ "glide/pkg/routers/latency"
+)
+
+// LangProvider defines an interface a provider should fulfill to be able to serve language chat requests
+type LangProvider interface {
+ ModelProvider
+
+ SupportChatStream() bool
+
+ Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error)
+ ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error)
+}
+
+type LangModel interface {
+ Model
+ Provider() string
+ Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error)
+ ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (<-chan *clients.ChatStreamResult, error)
+}
+
+// LanguageModel wraps provider client and expend it with health & latency tracking
+//
+// The model health is assumed to be independent of model actions (e.g. chat & chatStream)
+// The latency is assumed to be action-specific (e.g. streaming chat chunks are much low latency than the full chat action)
+type LanguageModel struct {
+ modelID string
+ weight int
+ client LangProvider
+ healthTracker *health.Tracker
+ chatLatency *latency.MovingAverage
+ chatStreamLatency *latency.MovingAverage
+ latencyUpdateInterval *time.Duration
+}
+
+func NewLangModel(modelID string, client LangProvider, budget *health.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel {
+ return &LanguageModel{
+ modelID: modelID,
+ client: client,
+ healthTracker: health.NewTracker(budget),
+ chatLatency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples),
+ chatStreamLatency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples),
+ latencyUpdateInterval: latencyConfig.UpdateInterval,
+ weight: weight,
+ }
+}
+
+func (m LanguageModel) ID() string {
+ return m.modelID
+}
+
+func (m LanguageModel) Healthy() bool {
+ return m.healthTracker.Healthy()
+}
+
+func (m LanguageModel) Weight() int {
+ return m.weight
+}
+
+func (m LanguageModel) LatencyUpdateInterval() *time.Duration {
+ return m.latencyUpdateInterval
+}
+
+func (m *LanguageModel) SupportChatStream() bool {
+ return m.client.SupportChatStream()
+}
+
+func (m LanguageModel) ChatLatency() *latency.MovingAverage {
+ return m.chatLatency
+}
+
+func (m LanguageModel) ChatStreamLatency() *latency.MovingAverage {
+ return m.chatStreamLatency
+}
+
+func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
+ startedAt := time.Now()
+ resp, err := m.client.Chat(ctx, request)
+
+ if err == nil {
+ // record latency per token to normalize measurements
+ m.chatLatency.Add(float64(time.Since(startedAt)) / float64(resp.ModelResponse.TokenUsage.ResponseTokens))
+
+ // successful response
+ resp.ModelID = m.modelID
+
+ return resp, err
+ }
+
+ m.healthTracker.TrackErr(err)
+
+ return resp, err
+}
+
+func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (<-chan *clients.ChatStreamResult, error) {
+ stream, err := m.client.ChatStream(ctx, req)
+ if err != nil {
+ m.healthTracker.TrackErr(err)
+
+ return nil, err
+ }
+
+ startedAt := time.Now()
+ err = stream.Open()
+ chunkLatency := time.Since(startedAt)
+
+ // the first chunk latency
+ m.chatStreamLatency.Add(float64(chunkLatency))
+
+ if err != nil {
+ m.healthTracker.TrackErr(err)
+
+ // if connection was not even open, we should not send our clients any messages about this failure
+
+ return nil, err
+ }
+
+ streamResultC := make(chan *clients.ChatStreamResult)
+
+ go func() {
+ defer close(streamResultC)
+ defer stream.Close()
+
+ for {
+ startedAt = time.Now()
+ chunk, err := stream.Recv()
+ chunkLatency = time.Since(startedAt)
+
+ if err != nil {
+ if err == io.EOF {
+ // end of the stream
+ return
+ }
+
+ streamResultC <- clients.NewChatStreamResult(nil, err)
+
+ m.healthTracker.TrackErr(err)
+
+ return
+ }
+
+ streamResultC <- clients.NewChatStreamResult(chunk, nil)
+
+ if chunkLatency > 1*time.Millisecond {
+ // All events are read in a bigger chunks of bytes, so one chunk may contain more than one event.
+ // Each byte chunk is then parsed, so there is no easy way to precisely guess latency per chunk,
+ // So we assume that if we spent more than 1ms waiting for a chunk it's likely
+ // we were trying to read from the connection (otherwise, it would take nanoseconds)
+ m.chatStreamLatency.Add(float64(chunkLatency))
+ }
+ }
+ }()
+
+ return streamResultC, nil
+}
+
+func (m *LanguageModel) Provider() string {
+ return m.client.Provider()
+}
+
+func ChatLatency(model Model) *latency.MovingAverage {
+ return model.(LanguageModel).ChatLatency()
+}
+
+func ChatStreamLatency(model Model) *latency.MovingAverage {
+ return model.(LanguageModel).ChatStreamLatency()
+}
diff --git a/pkg/providers/octoml/chat.go b/pkg/providers/octoml/chat.go
index 4860a0b9..5a0aa2bc 100644
--- a/pkg/providers/octoml/chat.go
+++ b/pkg/providers/octoml/chat.go
@@ -7,12 +7,9 @@ import (
"fmt"
"io"
"net/http"
- "time"
"glide/pkg/providers/openai"
- "glide/pkg/providers/clients"
-
"glide/pkg/api/schemas"
"go.uber.org/zap"
)
@@ -117,33 +114,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- c.telemetry.Logger.Error("failed to read octoml chat response", zap.Error(err))
- }
-
- c.telemetry.Logger.Error(
- "octoml chat request failed",
- zap.Int("status_code", resp.StatusCode),
- zap.String("response", string(bodyBytes)),
- zap.Any("headers", resp.Header),
- )
-
- if resp.StatusCode == http.StatusTooManyRequests {
- // Read the value of the "Retry-After" header to get the cooldown delay
- retryAfter := resp.Header.Get("Retry-After")
-
- // Parse the value to get the duration
- cooldownDelay, err := time.ParseDuration(retryAfter)
- if err != nil {
- return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err)
- }
-
- return nil, clients.NewRateLimitError(&cooldownDelay)
- }
-
- // Server & client errors result in the same error to keep gateway resilient
- return nil, clients.ErrProviderUnavailable
+ return nil, c.errMapper.Map(resp)
}
// Read the response body into a byte slice
@@ -164,19 +135,18 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
// Map response to UnifiedChatResponse schema
response := schemas.ChatResponse{
- ID: openAICompletion.ID,
- Created: openAICompletion.Created,
- Provider: providerName,
- Model: openAICompletion.Model,
- Cached: false,
- ModelResponse: schemas.ProviderResponse{
+ ID: openAICompletion.ID,
+ Created: openAICompletion.Created,
+ Provider: providerName,
+ ModelName: openAICompletion.ModelName,
+ Cached: false,
+ ModelResponse: schemas.ModelResponse{
SystemID: map[string]string{
"system_fingerprint": openAICompletion.SystemFingerprint,
},
Message: schemas.ChatMessage{
Role: openAICompletion.Choices[0].Message.Role,
Content: openAICompletion.Choices[0].Message.Content,
- Name: "",
},
TokenUsage: schemas.TokenUsage{
PromptTokens: openAICompletion.Usage.PromptTokens,
diff --git a/pkg/providers/octoml/chat_stream.go b/pkg/providers/octoml/chat_stream.go
new file mode 100644
index 00000000..2885195b
--- /dev/null
+++ b/pkg/providers/octoml/chat_stream.go
@@ -0,0 +1,16 @@
+package octoml
+
+import (
+ "context"
+
+ "glide/pkg/api/schemas"
+ "glide/pkg/providers/clients"
+)
+
+func (c *Client) SupportChatStream() bool {
+ return false
+}
+
+func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
+ return nil, clients.ErrChatStreamNotImplemented
+}
diff --git a/pkg/providers/octoml/client.go b/pkg/providers/octoml/client.go
index df8cff5b..66e71502 100644
--- a/pkg/providers/octoml/client.go
+++ b/pkg/providers/octoml/client.go
@@ -23,6 +23,7 @@ type Client struct {
baseURL string
chatURL string
chatRequestTemplate *ChatRequest
+ errMapper *ErrorMapper
config *Config
httpClient *http.Client
telemetry *telemetry.Telemetry
@@ -40,6 +41,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
chatURL: chatURL,
config: providerConfig,
chatRequestTemplate: NewChatRequestFromConfig(providerConfig),
+ errMapper: NewErrorMapper(tel),
httpClient: &http.Client{
Timeout: *clientConfig.Timeout,
// TODO: use values from the config
diff --git a/pkg/providers/octoml/client_test.go b/pkg/providers/octoml/client_test.go
index c8a438c1..b6f41b95 100644
--- a/pkg/providers/octoml/client_test.go
+++ b/pkg/providers/octoml/client_test.go
@@ -63,7 +63,7 @@ func TestOctoMLClient_ChatRequest(t *testing.T) {
response, err := client.Chat(ctx, &request)
require.NoError(t, err)
- require.Equal(t, providerCfg.Model, response.Model)
+ require.Equal(t, providerCfg.Model, response.ModelName)
require.Equal(t, "cmpl-8ea213aece0747aca6d0608b02b57196", response.ID)
}
@@ -112,21 +112,19 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) {
defer mockServer.Close()
// Create a new client with the mock server URL
- client := &Client{
- httpClient: http.DefaultClient,
- chatURL: mockServer.URL,
- config: &Config{APIKey: "dummy_key"},
- telemetry: telemetry.NewTelemetryMock(),
- }
+ ctx := context.Background()
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ providerCfg.BaseURL = mockServer.URL
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
// Create a chat request payload
- payload := &ChatRequest{
- Model: "dummy_model",
- Messages: []ChatMessage{{Role: "human", Content: "Hello"}},
- }
+ payload := schemas.NewChatFromStr("What's the dealio?")
- // Call the doChatRequest function
- _, err := client.doChatRequest(context.Background(), payload)
+ _, err = client.Chat(ctx, payload)
require.Error(t, err)
require.Contains(t, err.Error(), "provider is not available")
diff --git a/pkg/providers/octoml/errors.go b/pkg/providers/octoml/errors.go
new file mode 100644
index 00000000..23657076
--- /dev/null
+++ b/pkg/providers/octoml/errors.go
@@ -0,0 +1,56 @@
+package octoml
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "glide/pkg/providers/clients"
+ "glide/pkg/telemetry"
+ "go.uber.org/zap"
+)
+
+type ErrorMapper struct {
+ tel *telemetry.Telemetry
+}
+
+func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper {
+ return &ErrorMapper{
+ tel: tel,
+ }
+}
+
+func (m *ErrorMapper) Map(resp *http.Response) error {
+ bodyBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ m.tel.L().Error("failed to read octoml chat response", zap.Error(err))
+ }
+
+ m.tel.L().Error(
+ "octoml chat request failed",
+ zap.Int("status_code", resp.StatusCode),
+ zap.String("response", string(bodyBytes)),
+ zap.Any("headers", resp.Header),
+ )
+
+ if resp.StatusCode == http.StatusTooManyRequests {
+ // Read the value of the "Retry-After" header to get the cooldown delay
+ retryAfter := resp.Header.Get("Retry-After")
+
+ // Parse the value to get the duration
+ cooldownDelay, err := time.ParseDuration(retryAfter)
+ if err != nil {
+ return fmt.Errorf("failed to parse cooldown delay from headers: %w", err)
+ }
+
+ return clients.NewRateLimitError(&cooldownDelay)
+ }
+
+ if resp.StatusCode == http.StatusUnauthorized {
+ return clients.ErrUnauthorized
+ }
+
+ // Server & client errors result in the same error to keep gateway resilient
+ return clients.ErrProviderUnavailable
+}
diff --git a/pkg/providers/ollama/chat.go b/pkg/providers/ollama/chat.go
index f2247dd7..dd4a22fc 100644
--- a/pkg/providers/ollama/chat.go
+++ b/pkg/providers/ollama/chat.go
@@ -181,24 +181,23 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
// Map response to UnifiedChatResponse schema
response := schemas.ChatResponse{
- ID: uuid.NewString(),
- Created: int(time.Now().Unix()),
- Provider: providerName,
- Model: ollamaCompletion.Model,
- Cached: false,
- ModelResponse: schemas.ProviderResponse{
+ ID: uuid.NewString(),
+ Created: int(time.Now().Unix()),
+ Provider: providerName,
+ ModelName: ollamaCompletion.Model,
+ Cached: false,
+ ModelResponse: schemas.ModelResponse{
SystemID: map[string]string{
"system_fingerprint": "",
},
Message: schemas.ChatMessage{
Role: ollamaCompletion.Message.Role,
Content: ollamaCompletion.Message.Content,
- Name: "",
},
TokenUsage: schemas.TokenUsage{
- PromptTokens: float64(ollamaCompletion.EvalCount),
- ResponseTokens: float64(ollamaCompletion.EvalCount),
- TotalTokens: float64(ollamaCompletion.EvalCount),
+ PromptTokens: ollamaCompletion.EvalCount,
+ ResponseTokens: ollamaCompletion.EvalCount,
+ TotalTokens: ollamaCompletion.EvalCount,
},
},
}
diff --git a/pkg/providers/ollama/chat_stream.go b/pkg/providers/ollama/chat_stream.go
new file mode 100644
index 00000000..2bf0b87f
--- /dev/null
+++ b/pkg/providers/ollama/chat_stream.go
@@ -0,0 +1,16 @@
+package ollama
+
+import (
+ "context"
+
+ "glide/pkg/api/schemas"
+ "glide/pkg/providers/clients"
+)
+
+func (c *Client) SupportChatStream() bool {
+ return false
+}
+
+func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
+ return nil, clients.ErrChatStreamNotImplemented
+}
diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go
index bbcc4ff4..fcd75db9 100644
--- a/pkg/providers/openai/chat.go
+++ b/pkg/providers/openai/chat.go
@@ -7,39 +7,11 @@ import (
"fmt"
"io"
"net/http"
- "time"
-
- "glide/pkg/providers/clients"
"glide/pkg/api/schemas"
"go.uber.org/zap"
)
-type ChatMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
-}
-
-// ChatRequest is an OpenAI-specific request schema
-type ChatRequest struct {
- Model string `json:"model"`
- Messages []ChatMessage `json:"messages"`
- Temperature float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- N int `json:"n,omitempty"`
- StopWords []string `json:"stop,omitempty"`
- Stream bool `json:"stream,omitempty"`
- FrequencyPenalty int `json:"frequency_penalty,omitempty"`
- PresencePenalty int `json:"presence_penalty,omitempty"`
- LogitBias *map[int]float64 `json:"logit_bias,omitempty"`
- User *string `json:"user,omitempty"`
- Seed *int `json:"seed,omitempty"`
- Tools []string `json:"tools,omitempty"`
- ToolChoice interface{} `json:"tool_choice,omitempty"`
- ResponseFormat interface{} `json:"response_format,omitempty"`
-}
-
// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
return &ChatRequest{
@@ -49,7 +21,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
MaxTokens: cfg.DefaultParams.MaxTokens,
N: cfg.DefaultParams.N,
StopWords: cfg.DefaultParams.StopWords,
- Stream: false, // unsupported right now
+ Stream: false,
FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty,
PresencePenalty: cfg.DefaultParams.PresencePenalty,
LogitBias: cfg.DefaultParams.LogitBias,
@@ -61,23 +33,11 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
}
}
-func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage {
- messages := make([]ChatMessage, 0, len(request.MessageHistory)+1)
-
- // Add items from messageHistory first and the new chat message last
- for _, message := range request.MessageHistory {
- messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content})
- }
-
- messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})
-
- return messages
-}
-
// Chat sends a chat request to the specified OpenAI model.
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
// Create a new chat request
- chatRequest := c.createChatRequestSchema(request)
+ chatRequest := c.createRequestSchema(request)
+ chatRequest.Stream = false
chatResponse, err := c.doChatRequest(ctx, chatRequest)
if err != nil {
@@ -91,12 +51,21 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem
return chatResponse, nil
}
-func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest {
+// createRequestSchema creates a new ChatRequest object based on the given request.
+func (c *Client) createRequestSchema(request *schemas.ChatRequest) *ChatRequest {
// TODO: consider using objectpool to optimize memory allocation
- chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
- chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request)
+ chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template
+
+ chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1)
- return chatRequest
+ // Add items from messageHistory first and the new chat message last
+ for _, message := range request.MessageHistory {
+ chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content})
+ }
+
+ chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})
+
+ return &chatRequest
}
func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) {
@@ -111,13 +80,14 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
return nil, fmt.Errorf("unable to create openai chat request: %w", err)
}
- req.Header.Set("Authorization", "Bearer "+string(c.config.APIKey))
req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", string(c.config.APIKey)))
// TODO: this could leak information from messages which may not be a desired thing to have
- c.telemetry.Logger.Debug(
- "openai chat request",
- zap.String("chat_url", c.chatURL),
+ c.tel.Logger.Debug(
+ "Chat Request",
+ zap.String("provider", c.Provider()),
+ zap.String("chatURL", c.chatURL),
zap.Any("payload", payload),
)
@@ -129,71 +99,55 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
- bodyBytes, err := io.ReadAll(resp.Body)
- if err != nil {
- c.telemetry.Logger.Error("failed to read openai chat response", zap.Error(err))
- }
-
- c.telemetry.Logger.Error(
- "openai chat request failed",
- zap.Int("status_code", resp.StatusCode),
- zap.String("response", string(bodyBytes)),
- zap.Any("headers", resp.Header),
- )
-
- if resp.StatusCode == http.StatusTooManyRequests {
- // Read the value of the "Retry-After" header to get the cooldown delay
- retryAfter := resp.Header.Get("Retry-After")
-
- // Parse the value to get the duration
- cooldownDelay, err := time.ParseDuration(retryAfter)
- if err != nil {
- return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err)
- }
-
- return nil, clients.NewRateLimitError(&cooldownDelay)
- }
-
- // Server & client errors result in the same error to keep gateway resilient
- return nil, clients.ErrProviderUnavailable
+ return nil, c.errMapper.Map(resp)
}
// Read the response body into a byte slice
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
- c.telemetry.Logger.Error("failed to read openai chat response", zap.Error(err))
+ c.tel.Logger.Error(
+ "Failed to read chat response",
+ zap.String("provider", c.Provider()), zap.Error(err),
+ zap.ByteString("rawResponse", bodyBytes),
+ )
+
return nil, err
}
// Parse the response JSON
- var openAICompletion ChatCompletion
+ var chatCompletion ChatCompletion
- err = json.Unmarshal(bodyBytes, &openAICompletion)
+ err = json.Unmarshal(bodyBytes, &chatCompletion)
if err != nil {
- c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err))
+ c.tel.Logger.Error(
+ "Failed to unmarshal chat response",
+ zap.String("provider", c.Provider()),
+ zap.ByteString("rawResponse", bodyBytes),
+ zap.Error(err),
+ )
+
return nil, err
}
// Map response to ChatResponse schema
response := schemas.ChatResponse{
- ID: openAICompletion.ID,
- Created: openAICompletion.Created,
- Provider: providerName,
- Model: openAICompletion.Model,
- Cached: false,
- ModelResponse: schemas.ProviderResponse{
+ ID: chatCompletion.ID,
+ Created: chatCompletion.Created,
+ Provider: providerName,
+ ModelName: chatCompletion.ModelName,
+ Cached: false,
+ ModelResponse: schemas.ModelResponse{
SystemID: map[string]string{
- "system_fingerprint": openAICompletion.SystemFingerprint,
+ "system_fingerprint": chatCompletion.SystemFingerprint,
},
Message: schemas.ChatMessage{
- Role: openAICompletion.Choices[0].Message.Role,
- Content: openAICompletion.Choices[0].Message.Content,
- Name: "",
+ Role: chatCompletion.Choices[0].Message.Role,
+ Content: chatCompletion.Choices[0].Message.Content,
},
TokenUsage: schemas.TokenUsage{
- PromptTokens: openAICompletion.Usage.PromptTokens,
- ResponseTokens: openAICompletion.Usage.CompletionTokens,
- TotalTokens: openAICompletion.Usage.TotalTokens,
+ PromptTokens: chatCompletion.Usage.PromptTokens,
+ ResponseTokens: chatCompletion.Usage.CompletionTokens,
+ TotalTokens: chatCompletion.Usage.TotalTokens,
},
},
}
diff --git a/pkg/providers/openai/chat_stream.go b/pkg/providers/openai/chat_stream.go
new file mode 100644
index 00000000..fb8be776
--- /dev/null
+++ b/pkg/providers/openai/chat_stream.go
@@ -0,0 +1,224 @@
+package openai
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+
+ "github.com/r3labs/sse/v2"
+ "glide/pkg/providers/clients"
+ "glide/pkg/telemetry"
+
+ "go.uber.org/zap"
+
+ "glide/pkg/api/schemas"
+)
+
+var (
+ StopReason = "stop"
+ streamDoneMarker = []byte("[DONE]")
+)
+
+// ChatStream represents OpenAI chat stream for a specific request
+type ChatStream struct {
+ tel *telemetry.Telemetry
+ client *http.Client
+ req *http.Request
+ reqID string
+ reqMetadata *schemas.Metadata
+ resp *http.Response
+ reader *sse.EventStreamReader
+ errMapper *ErrorMapper
+}
+
+func NewChatStream(
+ tel *telemetry.Telemetry,
+ client *http.Client,
+ req *http.Request,
+ reqID string,
+ reqMetadata *schemas.Metadata,
+ errMapper *ErrorMapper,
+) *ChatStream {
+ return &ChatStream{
+ tel: tel,
+ client: client,
+ req: req,
+ reqID: reqID,
+ reqMetadata: reqMetadata,
+ errMapper: errMapper,
+ }
+}
+
+func (s *ChatStream) Open() error {
+ resp, err := s.client.Do(s.req) //nolint:bodyclose
+ if err != nil {
+ return err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return s.errMapper.Map(resp)
+ }
+
+ s.resp = resp
+ s.reader = sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize?
+
+ return nil
+}
+
+func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
+ var completionChunk ChatCompletionChunk
+
+ for {
+ rawEvent, err := s.reader.ReadEvent()
+ if err != nil {
+ s.tel.L().Warn(
+ "Chat stream is unexpectedly disconnected",
+ zap.String("provider", providerName),
+ zap.Error(err),
+ )
+
+ // if err is io.EOF, this still means that the stream is interrupted unexpectedly
+ // because the normal stream termination is done via finding out streamDoneMarker
+
+ return nil, clients.ErrProviderUnavailable
+ }
+
+ s.tel.L().Debug(
+ "Raw chat stream chunk",
+ zap.String("provider", providerName),
+ zap.ByteString("rawChunk", rawEvent),
+ )
+
+ event, err := clients.ParseSSEvent(rawEvent)
+
+ if bytes.Equal(event.Data, streamDoneMarker) {
+ return nil, io.EOF
+ }
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse chat stream message: %v", err)
+ }
+
+ if !event.HasContent() {
+ s.tel.L().Debug(
+ "Received an empty message in chat stream, skipping it",
+ zap.String("provider", providerName),
+ zap.Any("msg", event),
+ )
+
+ continue
+ }
+
+ err = json.Unmarshal(event.Data, &completionChunk)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err)
+ }
+
+ responseChunk := completionChunk.Choices[0]
+
+ var finishReason *schemas.FinishReason
+
+ if responseChunk.FinishReason == StopReason {
+ finishReason = &schemas.Complete
+ }
+
+ // TODO: use objectpool here
+ return &schemas.ChatStreamChunk{
+ ID: s.reqID,
+ Provider: providerName,
+ Cached: false,
+ ModelName: completionChunk.ModelName,
+ Metadata: s.reqMetadata,
+ ModelResponse: schemas.ModelChunkResponse{
+ Metadata: &schemas.Metadata{
+ "response_id": completionChunk.ID,
+ "system_fingerprint": completionChunk.SystemFingerprint,
+ },
+ Message: schemas.ChatMessage{
+ Role: responseChunk.Delta.Role,
+ Content: responseChunk.Delta.Content,
+ },
+ FinishReason: finishReason,
+ },
+ }, nil
+ }
+}
+
+func (s *ChatStream) Close() error {
+ if s.resp != nil {
+ return s.resp.Body.Close()
+ }
+
+ return nil
+}
+
+func (c *Client) SupportChatStream() bool {
+ return true
+}
+
+func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) {
+ // Create a new chat request
+ httpRequest, err := c.makeStreamReq(ctx, req)
+ if err != nil {
+ return nil, err
+ }
+
+ return NewChatStream(
+ c.tel,
+ c.httpClient,
+ httpRequest,
+ req.ID,
+ req.Metadata,
+ c.errMapper,
+ ), nil
+}
+
+func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest {
+ // TODO: consider using objectpool to optimize memory allocation
+ chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template
+
+ chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1)
+
+ // Add items from messageHistory first and the new chat message last
+ for _, message := range request.MessageHistory {
+ chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content})
+ }
+
+ chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})
+
+ return &chatRequest
+}
+
+func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) {
+ chatRequest := c.createRequestFromStream(req)
+
+ chatRequest.Stream = true
+
+ rawPayload, err := json.Marshal(chatRequest)
+ if err != nil {
+ return nil, fmt.Errorf("unable to marshal openAI chat stream request payload: %w", err)
+ }
+
+ request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload))
+ if err != nil {
+ return nil, fmt.Errorf("unable to create OpenAI stream chat request: %w", err)
+ }
+
+ request.Header.Set("Content-Type", "application/json")
+ request.Header.Set("Authorization", fmt.Sprintf("Bearer %v", string(c.config.APIKey)))
+ request.Header.Set("Cache-Control", "no-cache")
+ request.Header.Set("Accept", "text/event-stream")
+ request.Header.Set("Connection", "keep-alive")
+
+ // TODO: this could leak information from messages which may not be a desired thing to have
+ c.tel.L().Debug(
+ "Stream chat request",
+ zap.String("chatURL", c.chatURL),
+ zap.Any("payload", chatRequest),
+ )
+
+ return request, nil
+}
diff --git a/pkg/providers/openai/chat_stream_test.go b/pkg/providers/openai/chat_stream_test.go
new file mode 100644
index 00000000..ba1542da
--- /dev/null
+++ b/pkg/providers/openai/chat_stream_test.go
@@ -0,0 +1,155 @@
+package openai
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "glide/pkg/api/schemas"
+
+ "github.com/stretchr/testify/require"
+ "glide/pkg/providers/clients"
+ "glide/pkg/telemetry"
+)
+
+func TestOpenAIClient_ChatStreamSupported(t *testing.T) {
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
+
+ require.True(t, client.SupportChatStream())
+}
+
+func TestOpenAIClient_ChatStreamRequest(t *testing.T) {
+ tests := map[string]string{
+ "success stream": "./testdata/chat_stream.success.txt",
+ }
+
+ for name, streamFile := range tests {
+ t.Run(name, func(t *testing.T) {
+ openAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ rawPayload, _ := io.ReadAll(r.Body)
+
+ var data interface{}
+ // Parse the JSON body
+ err := json.Unmarshal(rawPayload, &data)
+ if err != nil {
+ t.Errorf("error decoding payload (%q): %v", string(rawPayload), err)
+ }
+
+ chatResponse, err := os.ReadFile(filepath.Clean(streamFile))
+ if err != nil {
+ t.Errorf("error reading openai chat mock response: %v", err)
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ _, err = w.Write(chatResponse)
+ if err != nil {
+ t.Errorf("error on sending chat response: %v", err)
+ }
+ })
+
+ openAIServer := httptest.NewServer(openAIMock)
+ defer openAIServer.Close()
+
+ ctx := context.Background()
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ providerCfg.BaseURL = openAIServer.URL
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
+
+ req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?")
+ stream, err := client.ChatStream(ctx, req)
+ require.NoError(t, err)
+
+ err = stream.Open()
+ require.NoError(t, err)
+
+ for {
+ chunk, err := stream.Recv()
+
+ if err == io.EOF {
+ return
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, chunk)
+ }
+ })
+ }
+}
+
+func TestOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) {
+ tests := map[string]string{
+ "success stream, but no last done message": "./testdata/chat_stream.nodone.txt",
+ "success stream, but with empty event": "./testdata/chat_stream.empty.txt",
+ }
+
+ for name, streamFile := range tests {
+ t.Run(name, func(t *testing.T) {
+ openAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ rawPayload, _ := io.ReadAll(r.Body)
+
+ var data interface{}
+ // Parse the JSON body
+ err := json.Unmarshal(rawPayload, &data)
+ if err != nil {
+ t.Errorf("error decoding payload (%q): %v", string(rawPayload), err)
+ }
+
+ chatResponse, err := os.ReadFile(filepath.Clean(streamFile))
+ if err != nil {
+ t.Errorf("error reading openai chat mock response: %v", err)
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ _, err = w.Write(chatResponse)
+ if err != nil {
+ t.Errorf("error on sending chat response: %v", err)
+ }
+ })
+
+ openAIServer := httptest.NewServer(openAIMock)
+ defer openAIServer.Close()
+
+ ctx := context.Background()
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ providerCfg.BaseURL = openAIServer.URL
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
+
+ req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?")
+ stream, err := client.ChatStream(ctx, req)
+ require.NoError(t, err)
+
+ err = stream.Open()
+ require.NoError(t, err)
+
+ for {
+ chunk, err := stream.Recv()
+ if err != nil {
+ require.ErrorIs(t, err, clients.ErrProviderUnavailable)
+ return
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, chunk)
+ }
+ })
+ }
+}
diff --git a/pkg/providers/openai/client_test.go b/pkg/providers/openai/chat_test.go
similarity index 67%
rename from pkg/providers/openai/client_test.go
rename to pkg/providers/openai/chat_test.go
index 6bd8298d..bde4486b 100644
--- a/pkg/providers/openai/client_test.go
+++ b/pkg/providers/openai/chat_test.go
@@ -10,9 +10,8 @@ import (
"path/filepath"
"testing"
- "glide/pkg/providers/clients"
-
"glide/pkg/api/schemas"
+ "glide/pkg/providers/clients"
"glide/pkg/telemetry"
@@ -66,3 +65,32 @@ func TestOpenAIClient_ChatRequest(t *testing.T) {
require.Equal(t, "chatcmpl-123", response.ID)
}
+
+func TestOpenAIClient_RateLimit(t *testing.T) {
+ openAIMock := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Retry-After", "5m")
+ w.WriteHeader(http.StatusTooManyRequests)
+ })
+
+ openAIServer := httptest.NewServer(openAIMock)
+ defer openAIServer.Close()
+
+ ctx := context.Background()
+ providerCfg := DefaultConfig()
+ clientCfg := clients.DefaultClientConfig()
+
+ providerCfg.BaseURL = openAIServer.URL
+
+ client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
+ require.NoError(t, err)
+
+ request := schemas.ChatRequest{Message: schemas.ChatMessage{
+ Role: "user",
+ Content: "What's the biggest animal?",
+ }}
+
+ _, err = client.Chat(ctx, &request)
+
+ require.Error(t, err)
+ require.IsType(t, &clients.RateLimitError{}, err)
+}
diff --git a/pkg/providers/openai/client.go b/pkg/providers/openai/client.go
index c56f227b..22d68d45 100644
--- a/pkg/providers/openai/client.go
+++ b/pkg/providers/openai/client.go
@@ -23,9 +23,10 @@ type Client struct {
baseURL string
chatURL string
chatRequestTemplate *ChatRequest
+ errMapper *ErrorMapper
config *Config
httpClient *http.Client
- telemetry *telemetry.Telemetry
+ tel *telemetry.Telemetry
}
// NewClient creates a new OpenAI client for the OpenAI API.
@@ -40,6 +41,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
chatURL: chatURL,
config: providerConfig,
chatRequestTemplate: NewChatRequestFromConfig(providerConfig),
+ errMapper: NewErrorMapper(tel),
httpClient: &http.Client{
Timeout: *clientConfig.Timeout,
// TODO: use values from the config
@@ -48,7 +50,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
MaxIdleConnsPerHost: 2,
},
},
- telemetry: tel,
+ tel: tel,
}
return c, nil
diff --git a/pkg/providers/openai/config.go b/pkg/providers/openai/config.go
index 86854f3e..49781d60 100644
--- a/pkg/providers/openai/config.go
+++ b/pkg/providers/openai/config.go
@@ -20,7 +20,6 @@ type Params struct {
Tools []string `yaml:"tools,omitempty" json:"tools"`
ToolChoice interface{} `yaml:"tool_choice,omitempty" json:"tool_choice"`
ResponseFormat interface{} `yaml:"response_format,omitempty" json:"response_format"` // TODO: should this be a part of the chat request API?
- // Stream bool `json:"stream,omitempty"` // TODO: we are not supporting this at the moment
}
func DefaultParams() Params {
diff --git a/pkg/providers/openai/errors.go b/pkg/providers/openai/errors.go
new file mode 100644
index 00000000..94ef5418
--- /dev/null
+++ b/pkg/providers/openai/errors.go
@@ -0,0 +1,64 @@
+package openai
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "glide/pkg/providers/clients"
+ "glide/pkg/telemetry"
+ "go.uber.org/zap"
+)
+
+type ErrorMapper struct {
+ tel *telemetry.Telemetry
+}
+
+func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper {
+ return &ErrorMapper{
+ tel: tel,
+ }
+}
+
+func (m *ErrorMapper) Map(resp *http.Response) error {
+ bodyBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ m.tel.Logger.Error(
+ "Failed to unmarshal chat response error",
+ zap.String("provider", providerName),
+ zap.Error(err),
+ zap.ByteString("rawResponse", bodyBytes),
+ )
+
+ return clients.ErrProviderUnavailable
+ }
+
+ m.tel.Logger.Error(
+ "Chat request failed",
+ zap.String("provider", providerName),
+ zap.Int("statusCode", resp.StatusCode),
+ zap.String("response", string(bodyBytes)),
+ zap.Any("headers", resp.Header),
+ )
+
+ if resp.StatusCode == http.StatusTooManyRequests {
+ // Read the value of the "Retry-After" header to get the cooldown delay
+ retryAfter := resp.Header.Get("Retry-After")
+
+ // Parse the value to get the duration
+ cooldownDelay, err := time.ParseDuration(retryAfter)
+ if err != nil {
+ return fmt.Errorf("failed to parse cooldown delay from headers: %w", err)
+ }
+
+ return clients.NewRateLimitError(&cooldownDelay)
+ }
+
+ if resp.StatusCode == http.StatusUnauthorized {
+ return clients.ErrUnauthorized
+ }
+
+ // Server & client errors result in the same error to keep gateway resilient
+ return clients.ErrProviderUnavailable
+}
diff --git a/pkg/providers/openai/schemas.go b/pkg/providers/openai/schemas.go
index cf41aebf..022d510e 100644
--- a/pkg/providers/openai/schemas.go
+++ b/pkg/providers/openai/schemas.go
@@ -2,11 +2,38 @@ package openai
// OpenAI Chat Response (also used by Azure OpenAI and OctoML)
+type ChatMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+// ChatRequest is an OpenAI-specific request schema
+type ChatRequest struct {
+ Model string `json:"model"`
+ Messages []ChatMessage `json:"messages"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ N int `json:"n,omitempty"`
+ StopWords []string `json:"stop,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ FrequencyPenalty int `json:"frequency_penalty,omitempty"`
+ PresencePenalty int `json:"presence_penalty,omitempty"`
+ LogitBias *map[int]float64 `json:"logit_bias,omitempty"`
+ User *string `json:"user,omitempty"`
+ Seed *int `json:"seed,omitempty"`
+ Tools []string `json:"tools,omitempty"`
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
+ ResponseFormat interface{} `json:"response_format,omitempty"`
+}
+
+// ChatCompletion
+// Ref: https://platform.openai.com/docs/api-reference/chat/object
type ChatCompletion struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
- Model string `json:"model"`
+ ModelName string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
@@ -20,7 +47,25 @@ type Choice struct {
}
type Usage struct {
- PromptTokens float64 `json:"prompt_tokens"`
- CompletionTokens float64 `json:"completion_tokens"`
- TotalTokens float64 `json:"total_tokens"`
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming
+// Ref: https://platform.openai.com/docs/api-reference/chat/streaming
+type ChatCompletionChunk struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int `json:"created"`
+ ModelName string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint"`
+ Choices []StreamChoice `json:"choices"`
+}
+
+type StreamChoice struct {
+ Index int `json:"index"`
+ Delta ChatMessage `json:"delta"`
+ Logprobs interface{} `json:"logprobs"`
+ FinishReason string `json:"finish_reason"`
}
diff --git a/pkg/providers/openai/testdata/chat_stream.empty.txt b/pkg/providers/openai/testdata/chat_stream.empty.txt
new file mode 100644
index 00000000..ff2a38eb
--- /dev/null
+++ b/pkg/providers/openai/testdata/chat_stream.empty.txt
@@ -0,0 +1,6 @@
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
+
+data:
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
+
diff --git a/pkg/providers/openai/testdata/chat_stream.nodone.txt b/pkg/providers/openai/testdata/chat_stream.nodone.txt
new file mode 100644
index 00000000..785fe492
--- /dev/null
+++ b/pkg/providers/openai/testdata/chat_stream.nodone.txt
@@ -0,0 +1,22 @@
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"The"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" capital"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" of"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" the"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" United"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" Kingdom"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" London"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
+
diff --git a/pkg/providers/openai/testdata/chat_stream.success.txt b/pkg/providers/openai/testdata/chat_stream.success.txt
new file mode 100644
index 00000000..1e673eaf
--- /dev/null
+++ b/pkg/providers/openai/testdata/chat_stream.success.txt
@@ -0,0 +1,24 @@
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"The"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" capital"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" of"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" the"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" United"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" Kingdom"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" London"},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}]}
+
+data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
+
+data: [DONE]
+
diff --git a/pkg/providers/provider.go b/pkg/providers/provider.go
index 399d6ee7..0fffa08b 100644
--- a/pkg/providers/provider.go
+++ b/pkg/providers/provider.go
@@ -1,106 +1,18 @@
package providers
import (
- "context"
- "errors"
"time"
-
- "glide/pkg/providers/clients"
- "glide/pkg/routers/health"
- "glide/pkg/routers/latency"
-
- "glide/pkg/api/schemas"
)
-// LangModelProvider defines an interface a provider should fulfill to be able to serve language chat requests
-type LangModelProvider interface {
+// ModelProvider exposes provider context
+type ModelProvider interface {
Provider() string
- Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error)
}
+// Model represent a configured external modality-agnostic model with its routing properties and status
type Model interface {
ID() string
Healthy() bool
- Latency() *latency.MovingAverage
LatencyUpdateInterval() *time.Duration
Weight() int
}
-
-type LanguageModel interface {
- Model
- LangModelProvider
-}
-
-// LangModel wraps provider client and expend it with health & latency tracking
-type LangModel struct {
- modelID string
- weight int
- client LangModelProvider
- rateLimit *health.RateLimitTracker
- errorBudget *health.TokenBucket // TODO: centralize provider API health tracking in the registry
- latency *latency.MovingAverage
- latencyUpdateInterval *time.Duration
-}
-
-func NewLangModel(modelID string, client LangModelProvider, budget health.ErrorBudget, latencyConfig latency.Config, weight int) *LangModel {
- return &LangModel{
- modelID: modelID,
- client: client,
- rateLimit: health.NewRateLimitTracker(),
- errorBudget: health.NewTokenBucket(budget.TimePerTokenMicro(), budget.Budget()),
- latency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples),
- latencyUpdateInterval: latencyConfig.UpdateInterval,
- weight: weight,
- }
-}
-
-func (m *LangModel) ID() string {
- return m.modelID
-}
-
-func (m *LangModel) Provider() string {
- return m.client.Provider()
-}
-
-func (m *LangModel) Latency() *latency.MovingAverage {
- return m.latency
-}
-
-func (m *LangModel) LatencyUpdateInterval() *time.Duration {
- return m.latencyUpdateInterval
-}
-
-func (m *LangModel) Healthy() bool {
- return !m.rateLimit.Limited() && m.errorBudget.HasTokens()
-}
-
-func (m *LangModel) Weight() int {
- return m.weight
-}
-
-func (m *LangModel) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
- startedAt := time.Now()
- resp, err := m.client.Chat(ctx, request)
-
- if err == nil {
- // record latency per token to normalize measurements
- m.latency.Add(float64(time.Since(startedAt)) / resp.ModelResponse.TokenUsage.ResponseTokens)
-
- // successful response
- resp.ModelID = m.modelID
-
- return resp, err
- }
-
- var rle *clients.RateLimitError
-
- if errors.As(err, &rle) {
- m.rateLimit.SetLimited(rle.UntilReset())
-
- return resp, err
- }
-
- _ = m.errorBudget.Take(1)
-
- return resp, err
-}
diff --git a/pkg/providers/testing.go b/pkg/providers/testing.go
deleted file mode 100644
index 890421a0..00000000
--- a/pkg/providers/testing.go
+++ /dev/null
@@ -1,100 +0,0 @@
-package providers
-
-import (
- "context"
- "time"
-
- "glide/pkg/routers/latency"
-
- "glide/pkg/api/schemas"
-)
-
-type ResponseMock struct {
- Msg string
- Err *error
-}
-
-func (m *ResponseMock) Resp() *schemas.ChatResponse {
- return &schemas.ChatResponse{
- ID: "rsp0001",
- ModelResponse: schemas.ProviderResponse{
- SystemID: map[string]string{
- "ID": "0001",
- },
- Message: schemas.ChatMessage{
- Content: m.Msg,
- },
- },
- }
-}
-
-type ProviderMock struct {
- idx int
- responses []ResponseMock
-}
-
-func NewProviderMock(responses []ResponseMock) *ProviderMock {
- return &ProviderMock{
- idx: 0,
- responses: responses,
- }
-}
-
-func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatRequest) (*schemas.ChatResponse, error) {
- response := c.responses[c.idx]
- c.idx++
-
- if response.Err != nil {
- return nil, *response.Err
- }
-
- return response.Resp(), nil
-}
-
-func (c *ProviderMock) Provider() string {
- return "provider_mock"
-}
-
-type LangModelMock struct {
- modelID string
- healthy bool
- latency *latency.MovingAverage
- weight int
-}
-
-func NewLangModelMock(ID string, healthy bool, avgLatency float64, weight int) *LangModelMock {
- movingAverage := latency.NewMovingAverage(0.06, 3)
-
- if avgLatency > 0.0 {
- movingAverage.Set(avgLatency)
- }
-
- return &LangModelMock{
- modelID: ID,
- healthy: healthy,
- latency: movingAverage,
- weight: weight,
- }
-}
-
-func (m *LangModelMock) ID() string {
- return m.modelID
-}
-
-func (m *LangModelMock) Healthy() bool {
- return m.healthy
-}
-
-func (m *LangModelMock) Latency() *latency.MovingAverage {
- return m.latency
-}
-
-func (m *LangModelMock) LatencyUpdateInterval() *time.Duration {
- updateInterval := 30 * time.Second
-
- return &updateInterval
-}
-
-func (m *LangModelMock) Weight() int {
- return m.weight
-}
diff --git a/pkg/providers/testing/lang.go b/pkg/providers/testing/lang.go
new file mode 100644
index 00000000..405b3c53
--- /dev/null
+++ b/pkg/providers/testing/lang.go
@@ -0,0 +1,154 @@
+package testing
+
+import (
+ "context"
+ "io"
+
+ "glide/pkg/api/schemas"
+ "glide/pkg/providers/clients"
+)
+
+// RespMock mocks a chat response or a streaming chat chunk
+type RespMock struct {
+ Msg string
+ Err *error
+}
+
+func (m *RespMock) Resp() *schemas.ChatResponse {
+ return &schemas.ChatResponse{
+ ID: "rsp0001",
+ ModelResponse: schemas.ModelResponse{
+ SystemID: map[string]string{
+ "ID": "0001",
+ },
+ Message: schemas.ChatMessage{
+ Content: m.Msg,
+ },
+ },
+ }
+}
+
+func (m *RespMock) RespChunk() *schemas.ChatStreamChunk {
+ return &schemas.ChatStreamChunk{
+ ID: "rsp0001",
+ ModelResponse: schemas.ModelChunkResponse{
+ Message: schemas.ChatMessage{
+ Content: m.Msg,
+ },
+ },
+ }
+}
+
+// RespStreamMock mocks a chat stream
+type RespStreamMock struct {
+ idx int
+ OpenErr error
+ Chunks *[]RespMock
+}
+
+func NewRespStreamMock(chunk *[]RespMock) RespStreamMock {
+ return RespStreamMock{
+ idx: 0,
+ OpenErr: nil,
+ Chunks: chunk,
+ }
+}
+
+func NewRespStreamWithOpenErr(openErr error) RespStreamMock {
+ return RespStreamMock{
+ idx: 0,
+ OpenErr: openErr,
+ Chunks: nil,
+ }
+}
+
+func (m *RespStreamMock) Open() error {
+ if m.OpenErr != nil {
+ return m.OpenErr
+ }
+
+ return nil
+}
+
+func (m *RespStreamMock) Recv() (*schemas.ChatStreamChunk, error) {
+ if m.Chunks != nil && m.idx >= len(*m.Chunks) {
+ return nil, io.EOF
+ }
+
+ chunks := *m.Chunks
+
+ chunk := chunks[m.idx]
+ m.idx++
+
+ if chunk.Err != nil {
+ return nil, *chunk.Err
+ }
+
+ return chunk.RespChunk(), nil
+}
+
+func (m *RespStreamMock) Close() error {
+ return nil
+}
+
+// ProviderMock mocks a model provider
+type ProviderMock struct {
+ idx int
+ chatResps *[]RespMock
+ chatStreams *[]RespStreamMock
+ supportStreaming bool
+}
+
+func NewProviderMock(responses []RespMock) *ProviderMock {
+ return &ProviderMock{
+ idx: 0,
+ chatResps: &responses,
+ supportStreaming: false,
+ }
+}
+
+func NewStreamProviderMock(chatStreams []RespStreamMock) *ProviderMock {
+ return &ProviderMock{
+ idx: 0,
+ chatStreams: &chatStreams,
+ supportStreaming: true,
+ }
+}
+
+func (c *ProviderMock) SupportChatStream() bool {
+ return c.supportStreaming
+}
+
+func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatRequest) (*schemas.ChatResponse, error) {
+ if c.chatResps == nil {
+ return nil, clients.ErrProviderUnavailable
+ }
+
+ responses := *c.chatResps
+
+ response := responses[c.idx]
+ c.idx++
+
+ if response.Err != nil {
+ return nil, *response.Err
+ }
+
+ return response.Resp(), nil
+}
+
+func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
+ if c.chatStreams == nil || c.idx >= len(*c.chatStreams) {
+ return nil, clients.ErrProviderUnavailable
+ }
+
+ streams := *c.chatStreams
+
+ stream := streams[c.idx]
+ c.idx++
+
+ return &stream, nil
+}
+
+func (c *ProviderMock) Provider() string {
+ return "provider_mock"
+}
diff --git a/pkg/providers/testing/models.go b/pkg/providers/testing/models.go
new file mode 100644
index 00000000..c5f8d6d6
--- /dev/null
+++ b/pkg/providers/testing/models.go
@@ -0,0 +1,57 @@
+package testing
+
+import (
+ "time"
+
+ "glide/pkg/providers"
+ "glide/pkg/routers/latency"
+)
+
+// LangModelMock
+type LangModelMock struct {
+ modelID string
+ healthy bool
+ chatLatency *latency.MovingAverage
+ weight int
+}
+
+func NewLangModelMock(ID string, healthy bool, avgLatency float64, weight int) LangModelMock {
+ chatLatency := latency.NewMovingAverage(0.06, 3)
+
+ if avgLatency > 0.0 {
+ chatLatency.Set(avgLatency)
+ }
+
+ return LangModelMock{
+ modelID: ID,
+ healthy: healthy,
+ chatLatency: chatLatency,
+ weight: weight,
+ }
+}
+
+func (m LangModelMock) ID() string {
+ return m.modelID
+}
+
+func (m LangModelMock) Healthy() bool {
+ return m.healthy
+}
+
+func (m *LangModelMock) ChatLatency() *latency.MovingAverage {
+ return m.chatLatency
+}
+
+func (m LangModelMock) LatencyUpdateInterval() *time.Duration {
+ updateInterval := 30 * time.Second
+
+ return &updateInterval
+}
+
+func (m LangModelMock) Weight() int {
+ return m.weight
+}
+
+func ChatMockLatency(model providers.Model) *latency.MovingAverage {
+ return model.(LangModelMock).chatLatency
+}
diff --git a/pkg/routers/config.go b/pkg/routers/config.go
index 025f0c6d..f17ed52e 100644
--- a/pkg/routers/config.go
+++ b/pkg/routers/config.go
@@ -29,11 +29,11 @@ func (c *Config) BuildLangRouters(tel *telemetry.Telemetry) ([]*LangRouter, erro
seenIDs[routerConfig.ID] = true
if !routerConfig.Enabled {
- tel.Logger.Info(fmt.Sprintf("Router \"%v\" is disabled, skipping", routerConfig.ID))
+ tel.L().Info(fmt.Sprintf("Router \"%v\" is disabled, skipping", routerConfig.ID))
continue
}
- tel.Logger.Debug("Init router", zap.String("routerID", routerConfig.ID))
+ tel.L().Debug("Init router", zap.String("routerID", routerConfig.ID))
router, err := NewLangRouter(&c.LanguageRouters[idx], tel)
if err != nil {
@@ -63,15 +63,16 @@ type LangRouterConfig struct {
}
// BuildModels creates LanguageModel slice out of the given config
-func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.LanguageModel, error) {
+func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*providers.LanguageModel, []*providers.LanguageModel, error) { //nolint: cyclop
var errs error
seenIDs := make(map[string]bool, len(c.Models))
- models := make([]providers.LanguageModel, 0, len(c.Models))
+ chatModels := make([]*providers.LanguageModel, 0, len(c.Models))
+ chatStreamModels := make([]*providers.LanguageModel, 0, len(c.Models))
for _, modelConfig := range c.Models {
if _, ok := seenIDs[modelConfig.ID]; ok {
- return nil, fmt.Errorf(
+ return nil, nil, fmt.Errorf(
"ID \"%v\" is specified for more than one model in router \"%v\", while it should be unique in scope of that pool",
modelConfig.ID,
c.ID,
@@ -81,7 +82,7 @@ func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.La
seenIDs[modelConfig.ID] = true
if !modelConfig.Enabled {
- tel.Logger.Info(
+ tel.L().Info(
"Model is disabled, skipping",
zap.String("router", c.ID),
zap.String("model", modelConfig.ID),
@@ -90,7 +91,7 @@ func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.La
continue
}
- tel.Logger.Debug(
+ tel.L().Debug(
"Init lang model",
zap.String("router", c.ID),
zap.String("model", modelConfig.ID),
@@ -102,19 +103,32 @@ func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.La
continue
}
- models = append(models, model)
+ chatModels = append(chatModels, model)
+
+ if !model.SupportChatStream() {
+ tel.L().Warn(
+ "Provider doesn't support or have not been yet integrated with streaming chat, it won't serve streaming chat requests",
+ zap.String("routerID", c.ID),
+ zap.String("modelID", model.ID()),
+ zap.String("provider", model.Provider()),
+ )
+
+ continue
+ }
+
+ chatStreamModels = append(chatStreamModels, model)
}
if errs != nil {
- return nil, errs
+ return nil, nil, errs
}
- if len(models) == 0 {
- return nil, fmt.Errorf("router \"%v\" must have at least one active model, zero defined", c.ID)
+ if len(chatModels) == 0 {
+ return nil, nil, fmt.Errorf("router \"%v\" must have at least one active model, zero defined", c.ID)
}
- if len(models) == 1 {
- tel.Logger.WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn(
+ if len(chatModels) == 1 {
+ tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn(
fmt.Sprintf("Router \"%v\" has only one active model defined. "+
"This is not recommended for production setups. "+
"Define at least a few models to leverage resiliency logic Glide provides",
@@ -123,7 +137,26 @@ func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.La
)
}
- return models, nil
+ if len(chatStreamModels) == 1 {
+ tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn(
+ fmt.Sprintf("Router \"%v\" has only one active model defined with streaming chat support. "+
+ "This is not recommended for production setups. "+
+ "Define at least a few models to leverage resiliency logic Glide provides",
+ c.ID,
+ ),
+ )
+ }
+
+ if len(chatStreamModels) == 0 {
+ tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn(
+ fmt.Sprintf("Router \"%v\" has only no model with streaming chat support. "+
+ "The streaming chat workflow won't work until you define any",
+ c.ID,
+ ),
+ )
+ }
+
+ return chatModels, chatStreamModels, nil
}
func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry {
@@ -137,24 +170,35 @@ func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry {
)
}
-func (c *LangRouterConfig) BuildRouting(models []providers.LanguageModel) (routing.LangModelRouting, error) {
- m := make([]providers.Model, 0, len(models))
- for _, model := range models {
- m = append(m, model)
+func (c *LangRouterConfig) BuildRouting(
+ chatModels []*providers.LanguageModel,
+ chatStreamModels []*providers.LanguageModel,
+) (routing.LangModelRouting, routing.LangModelRouting, error) {
+ chatModelPool := make([]providers.Model, 0, len(chatModels))
+ chatStreamModelPool := make([]providers.Model, 0, len(chatStreamModels))
+
+ for _, model := range chatModels {
+ chatModelPool = append(chatModelPool, model)
+ }
+
+ for _, model := range chatStreamModels {
+ chatStreamModelPool = append(chatStreamModelPool, model)
}
switch c.RoutingStrategy {
case routing.Priority:
- return routing.NewPriority(m), nil
+ return routing.NewPriority(chatModelPool), routing.NewPriority(chatStreamModelPool), nil
case routing.RoundRobin:
- return routing.NewRoundRobinRouting(m), nil
+ return routing.NewRoundRobinRouting(chatModelPool), routing.NewRoundRobinRouting(chatStreamModelPool), nil
case routing.WeightedRoundRobin:
- return routing.NewWeightedRoundRobin(m), nil
+ return routing.NewWeightedRoundRobin(chatModelPool), routing.NewWeightedRoundRobin(chatStreamModelPool), nil
case routing.LeastLatency:
- return routing.NewLeastLatencyRouting(m), nil
+ return routing.NewLeastLatencyRouting(providers.ChatLatency, chatModelPool),
+ routing.NewLeastLatencyRouting(providers.ChatStreamLatency, chatStreamModelPool),
+ nil
}
- return nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy)
+ return nil, nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy)
}
func DefaultLangRouterConfig() LangRouterConfig {
diff --git a/pkg/routers/config_test.go b/pkg/routers/config_test.go
index f2912164..28fb81fe 100644
--- a/pkg/routers/config_test.go
+++ b/pkg/routers/config_test.go
@@ -3,6 +3,8 @@ package routers
import (
"testing"
+ "glide/pkg/providers/cohere"
+
"github.com/stretchr/testify/require"
"glide/pkg/providers"
"glide/pkg/providers/clients"
@@ -64,10 +66,53 @@ func TestRouterConfig_BuildModels(t *testing.T) {
require.NoError(t, err)
require.Len(t, routers, 2)
- require.Len(t, routers[0].models, 1)
- require.IsType(t, &routing.PriorityRouting{}, routers[0].routing)
- require.Len(t, routers[1].models, 1)
- require.IsType(t, &routing.LeastLatencyRouting{}, routers[1].routing)
+ require.Len(t, routers[0].chatModels, 1)
+ require.IsType(t, &routing.PriorityRouting{}, routers[0].chatRouting)
+ require.Len(t, routers[1].chatModels, 1)
+ require.IsType(t, &routing.LeastLatencyRouting{}, routers[1].chatRouting)
+}
+
+func TestRouterConfig_BuildModelsPerType(t *testing.T) {
+ tel := telemetry.NewTelemetryMock()
+ openAIParams := openai.DefaultParams()
+ cohereParams := cohere.DefaultParams()
+
+ cfg := LangRouterConfig{
+ ID: "first_router",
+ Enabled: true,
+ RoutingStrategy: routing.Priority,
+ Retry: retry.DefaultExpRetryConfig(),
+ Models: []providers.LangModelConfig{
+ {
+ ID: "first_model",
+ Enabled: true,
+ Client: clients.DefaultClientConfig(),
+ ErrorBudget: health.DefaultErrorBudget(),
+ Latency: latency.DefaultConfig(),
+ OpenAI: &openai.Config{
+ APIKey: "ABC",
+ DefaultParams: &openAIParams,
+ },
+ },
+ {
+ ID: "second_model",
+ Enabled: true,
+ Client: clients.DefaultClientConfig(),
+ ErrorBudget: health.DefaultErrorBudget(),
+ Latency: latency.DefaultConfig(),
+ Cohere: &cohere.Config{
+ APIKey: "ABC",
+ DefaultParams: &cohereParams,
+ },
+ },
+ },
+ }
+
+ chatModels, streamChatModels, err := cfg.BuildModels(tel)
+
+ require.Len(t, chatModels, 2)
+ require.Len(t, streamChatModels, 2)
+ require.NoError(t, err)
}
func TestRouterConfig_InvalidSetups(t *testing.T) {
diff --git a/pkg/routers/health/tracker.go b/pkg/routers/health/tracker.go
new file mode 100644
index 00000000..f49e310b
--- /dev/null
+++ b/pkg/routers/health/tracker.go
@@ -0,0 +1,44 @@
+package health
+
+import (
+ "errors"
+
+ "glide/pkg/providers/clients"
+)
+
+// Tracker tracks errors and general health of model provider
+type Tracker struct {
+ unauthorized bool
+ errBudget *TokenBucket
+ rateLimit *RateLimitTracker
+}
+
+func NewTracker(budget *ErrorBudget) *Tracker {
+ return &Tracker{
+ unauthorized: false,
+ rateLimit: NewRateLimitTracker(),
+ errBudget: NewTokenBucket(budget.TimePerTokenMicro(), budget.Budget()),
+ }
+}
+
+func (t *Tracker) Healthy() bool {
+ return !t.unauthorized && !t.rateLimit.Limited() && t.errBudget.HasTokens()
+}
+
+func (t *Tracker) TrackErr(err error) {
+ var rateLimitErr *clients.RateLimitError
+
+ if errors.Is(err, clients.ErrUnauthorized) {
+ t.unauthorized = true
+
+ return
+ }
+
+ if errors.As(err, &rateLimitErr) {
+ t.rateLimit.SetLimited(rateLimitErr.UntilReset())
+
+ return
+ }
+
+ _ = t.errBudget.Take(1)
+}
diff --git a/pkg/routers/health/tracker_test.go b/pkg/routers/health/tracker_test.go
new file mode 100644
index 00000000..94733ce0
--- /dev/null
+++ b/pkg/routers/health/tracker_test.go
@@ -0,0 +1,37 @@
+package health
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "glide/pkg/providers/clients"
+)
+
+func TestHealthTracker_HealthyByDefault(t *testing.T) {
+ budget := NewErrorBudget(3, SEC)
+ tracker := NewTracker(budget)
+
+ require.True(t, tracker.Healthy())
+}
+
+func TestHealthTracker_UnhealthyWhenBugetExceeds(t *testing.T) {
+ budget := NewErrorBudget(3, SEC)
+ tracker := NewTracker(budget)
+
+ for range 3 {
+ tracker.TrackErr(clients.ErrProviderUnavailable)
+ }
+
+ require.False(t, tracker.Healthy())
+}
+
+func TestHealthTracker_RateLimited(t *testing.T) {
+ budget := NewErrorBudget(3, SEC)
+ tracker := NewTracker(budget)
+
+ limitedUntil := 10 * time.Minute
+ tracker.TrackErr(clients.NewRateLimitError(&limitedUntil))
+
+ require.False(t, tracker.Healthy())
+}
diff --git a/pkg/routers/manager.go b/pkg/routers/manager.go
index e8e9e5a7..8dcecd60 100644
--- a/pkg/routers/manager.go
+++ b/pkg/routers/manager.go
@@ -10,7 +10,7 @@ var ErrRouterNotFound = errors.New("no router found with given ID")
type RouterManager struct {
Config *Config
- telemetry *telemetry.Telemetry
+ tel *telemetry.Telemetry
langRouterMap *map[string]*LangRouter
langRouters []*LangRouter
}
@@ -30,7 +30,7 @@ func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) {
manager := RouterManager{
Config: cfg,
- telemetry: tel,
+ tel: tel,
langRouters: langRouters,
langRouterMap: &langRouterMap,
}
diff --git a/pkg/routers/retry/config_test.go b/pkg/routers/retry/config_test.go
new file mode 100644
index 00000000..7fe409a3
--- /dev/null
+++ b/pkg/routers/retry/config_test.go
@@ -0,0 +1,13 @@
+package retry
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestRetryConfig_DefaultConfig(t *testing.T) {
+ config := DefaultExpRetryConfig()
+
+ require.NotNil(t, config)
+}
diff --git a/pkg/routers/router.go b/pkg/routers/router.go
index 13d89fa3..aa8ba067 100644
--- a/pkg/routers/router.go
+++ b/pkg/routers/router.go
@@ -20,32 +20,36 @@ var (
)
type LangRouter struct {
- routerID string
- Config *LangRouterConfig
- routing routing.LangModelRouting
- retry *retry.ExpRetry
- models []providers.LanguageModel
- telemetry *telemetry.Telemetry
+ routerID string
+ Config *LangRouterConfig
+ chatModels []*providers.LanguageModel
+ chatStreamModels []*providers.LanguageModel
+ chatRouting routing.LangModelRouting
+ chatStreamRouting routing.LangModelRouting
+ retry *retry.ExpRetry
+ tel *telemetry.Telemetry
}
func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter, error) {
- models, err := cfg.BuildModels(tel)
+ chatModels, chatStreamModels, err := cfg.BuildModels(tel)
if err != nil {
return nil, err
}
- strategy, err := cfg.BuildRouting(models)
+ chatRouting, chatStreamRouting, err := cfg.BuildRouting(chatModels, chatStreamModels)
if err != nil {
return nil, err
}
router := &LangRouter{
- routerID: cfg.ID,
- Config: cfg,
- models: models,
- retry: cfg.BuildRetry(),
- routing: strategy,
- telemetry: tel,
+ routerID: cfg.ID,
+ Config: cfg,
+ chatModels: chatModels,
+ chatStreamModels: chatStreamModels,
+ retry: cfg.BuildRetry(),
+ chatRouting: chatRouting,
+ chatStreamRouting: chatStreamRouting,
+ tel: tel,
}
return router, err
@@ -55,15 +59,15 @@ func (r *LangRouter) ID() string {
return r.routerID
}
-func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
- if len(r.models) == 0 {
+func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) {
+ if len(r.chatModels) == 0 {
return nil, ErrNoModels
}
retryIterator := r.retry.Iterator()
for retryIterator.HasNext() {
- modelIterator := r.routing.Iterator()
+ modelIterator := r.chatRouting.Iterator()
for {
model, err := modelIterator.Next()
@@ -73,20 +77,20 @@ func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*s
break
}
- langModel := model.(providers.LanguageModel)
+ langModel := model.(providers.LangModel)
// Check if there is an override in the request
- if request.Override != (schemas.OverrideChatRequest{}) {
+ if req.Override != nil {
// Override the message if the language model ID matches the override model ID
- if langModel.ID() == request.Override.Model {
- request.Message = request.Override.Message
+ if langModel.ID() == req.Override.Model {
+ req.Message = req.Override.Message
}
}
- resp, err := langModel.Chat(ctx, request)
+ resp, err := langModel.Chat(ctx, req)
if err != nil {
- r.telemetry.Logger.Warn(
- "lang model failed processing chat request",
+ r.tel.L().Warn(
+ "Lang model failed processing chat request",
zap.String("routerID", r.ID()),
zap.String("modelID", langModel.ID()),
zap.String("provider", langModel.Provider()),
@@ -103,7 +107,7 @@ func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*s
// no providers were available to handle the request,
// so we have to wait a bit with a hope there is some available next time
- r.telemetry.Logger.Warn("no healthy model found, wait and retry", zap.String("routerID", r.ID()))
+ r.tel.L().Warn("No healthy model found to serve chat request, wait and retry", zap.String("routerID", r.ID()))
err := retryIterator.WaitNext(ctx)
if err != nil {
@@ -113,7 +117,116 @@ func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*s
}
// if we reach this part, then we are in trouble
- r.telemetry.Logger.Error("no model was available to handle request", zap.String("routerID", r.ID()))
+ r.tel.L().Error("No model was available to handle chat request", zap.String("routerID", r.ID()))
return nil, ErrNoModelAvailable
}
+
+func (r *LangRouter) ChatStream(
+ ctx context.Context,
+ req *schemas.ChatStreamRequest,
+ respC chan<- *schemas.ChatStreamResult,
+) {
+ if len(r.chatStreamModels) == 0 {
+ respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{
+ ID: req.ID,
+ ErrCode: "noModels",
+ Message: ErrNoModels.Error(),
+ Metadata: req.Metadata,
+ })
+
+ return
+ }
+
+ retryIterator := r.retry.Iterator()
+
+ for retryIterator.HasNext() {
+ modelIterator := r.chatStreamRouting.Iterator()
+
+ NextModel:
+ for {
+ model, err := modelIterator.Next()
+
+ if errors.Is(err, routing.ErrNoHealthyModels) {
+ // no healthy model in the pool. Let's retry after some time
+ break
+ }
+
+ langModel := model.(providers.LangModel)
+ modelRespC, err := langModel.ChatStream(ctx, req)
+ if err != nil {
+ r.tel.L().Error(
+ "Lang model failed to create streaming chat request",
+ zap.String("routerID", r.ID()),
+ zap.String("modelID", langModel.ID()),
+ zap.String("provider", langModel.Provider()),
+ zap.Error(err),
+ )
+
+ continue
+ }
+
+ for chunkResult := range modelRespC {
+ err = chunkResult.Error()
+ if err != nil {
+ r.tel.L().Warn(
+ "Lang model failed processing streaming chat request",
+ zap.String("routerID", r.ID()),
+ zap.String("modelID", langModel.ID()),
+ zap.String("provider", langModel.Provider()),
+ zap.Error(err),
+ )
+
+ // It's challenging to hide an error in case of streaming chat as consumer apps
+ // may have already used all chunks we streamed this far (e.g. showed them to their users like OpenAI UI does),
+ // so we cannot easily restart that process from scratch
+ respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{
+ ID: req.ID,
+ ErrCode: "modelUnavailable",
+ Message: err.Error(),
+ Metadata: req.Metadata,
+ })
+
+ continue NextModel
+ }
+
+ respC <- schemas.NewChatStreamResult(chunkResult.Chunk())
+ }
+
+ return
+ }
+
+ // no providers were available to handle the request,
+ // so we have to wait a bit with a hope there is some available next time
+ r.tel.L().Warn(
+ "No healthy model found to serve streaming chat request, wait and retry",
+ zap.String("routerID", r.ID()),
+ )
+
+ err := retryIterator.WaitNext(ctx)
+ if err != nil {
+ // something has cancelled the context
+ respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{
+ ID: req.ID,
+ ErrCode: "other",
+ Message: err.Error(),
+ Metadata: req.Metadata,
+ })
+
+ return
+ }
+ }
+
+ // if we reach this part, then we are in trouble
+ r.tel.L().Error(
+ "No model was available to handle streaming chat request. Try to configure more fallback models to avoid this",
+ zap.String("routerID", r.ID()),
+ )
+
+ respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{
+ ID: req.ID,
+ ErrCode: "allModelsUnavailable",
+ Message: ErrNoModelAvailable.Error(),
+ Metadata: req.Metadata,
+ })
+}
diff --git a/pkg/routers/router_test.go b/pkg/routers/router_test.go
index 77fb7226..4fd50d28 100644
--- a/pkg/routers/router_test.go
+++ b/pkg/routers/router_test.go
@@ -12,28 +12,29 @@ import (
"github.com/stretchr/testify/require"
"glide/pkg/api/schemas"
"glide/pkg/providers"
+ ptesting "glide/pkg/providers/testing"
"glide/pkg/routers/health"
"glide/pkg/routers/retry"
"glide/pkg/routers/routing"
"glide/pkg/telemetry"
)
-func TestLangRouter_Priority_PickFistHealthy(t *testing.T) {
+func TestLangRouter_Chat_PickFistHealthy(t *testing.T) {
budget := health.NewErrorBudget(3, health.SEC)
latConfig := latency.DefaultConfig()
- langModels := []providers.LanguageModel{
+ langModels := []*providers.LanguageModel{
providers.NewLangModel(
"first",
- providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}),
- *budget,
+ ptesting.NewProviderMock([]ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}),
+ budget,
*latConfig,
1,
),
providers.NewLangModel(
"second",
- providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}}),
- *budget,
+ ptesting.NewProviderMock([]ptesting.RespMock{{Msg: "1"}}),
+ budget,
*latConfig,
1,
),
@@ -45,12 +46,13 @@ func TestLangRouter_Priority_PickFistHealthy(t *testing.T) {
}
router := LangRouter{
- routerID: "test_router",
- Config: &LangRouterConfig{},
- retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
- routing: routing.NewPriority(models),
- models: langModels,
- telemetry: telemetry.NewTelemetryMock(),
+ routerID: "test_router",
+ Config: &LangRouterConfig{},
+ retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
+ chatRouting: routing.NewPriority(models),
+ chatModels: langModels,
+ chatStreamModels: langModels,
+ tel: telemetry.NewTelemetryMock(),
}
ctx := context.Background()
@@ -65,28 +67,28 @@ func TestLangRouter_Priority_PickFistHealthy(t *testing.T) {
}
}
-func TestLangRouter_Priority_PickThirdHealthy(t *testing.T) {
+func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) {
budget := health.NewErrorBudget(1, health.SEC)
latConfig := latency.DefaultConfig()
- langModels := []providers.LanguageModel{
+ langModels := []*providers.LanguageModel{
providers.NewLangModel(
"first",
- providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "3"}}),
- *budget,
+ ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Msg: "3"}}),
+ budget,
*latConfig,
1,
),
providers.NewLangModel(
"second",
- providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "4"}}),
- *budget,
+ ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Msg: "4"}}),
+ budget,
*latConfig,
1,
),
providers.NewLangModel(
"third",
- providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}),
- *budget,
+ ptesting.NewProviderMock([]ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}),
+ budget,
*latConfig,
1,
),
@@ -100,12 +102,14 @@ func TestLangRouter_Priority_PickThirdHealthy(t *testing.T) {
expectedModels := []string{"third", "third"}
router := LangRouter{
- routerID: "test_router",
- Config: &LangRouterConfig{},
- retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
- routing: routing.NewPriority(models),
- models: langModels,
- telemetry: telemetry.NewTelemetryMock(),
+ routerID: "test_router",
+ Config: &LangRouterConfig{},
+ retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
+ chatRouting: routing.NewPriority(models),
+ chatStreamRouting: routing.NewPriority(models),
+ chatModels: langModels,
+ chatStreamModels: langModels,
+ tel: telemetry.NewTelemetryMock(),
}
ctx := context.Background()
@@ -120,21 +124,21 @@ func TestLangRouter_Priority_PickThirdHealthy(t *testing.T) {
}
}
-func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) {
+func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) {
budget := health.NewErrorBudget(1, health.MILLI)
latConfig := latency.DefaultConfig()
- langModels := []providers.LanguageModel{
+ langModels := []*providers.LanguageModel{
providers.NewLangModel(
"first",
- providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "2"}}),
- *budget,
+ ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Msg: "2"}}),
+ budget,
*latConfig,
1,
),
providers.NewLangModel(
"second",
- providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "1"}}),
- *budget,
+ ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Msg: "1"}}),
+ budget,
*latConfig,
1,
),
@@ -146,12 +150,14 @@ func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) {
}
router := LangRouter{
- routerID: "test_router",
- Config: &LangRouterConfig{},
- retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
- routing: routing.NewPriority(models),
- models: langModels,
- telemetry: telemetry.NewTelemetryMock(),
+ routerID: "test_router",
+ Config: &LangRouterConfig{},
+ retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
+ chatRouting: routing.NewPriority(models),
+ chatStreamRouting: routing.NewPriority(models),
+ chatModels: langModels,
+ chatStreamModels: langModels,
+ tel: telemetry.NewTelemetryMock(),
}
resp, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke"))
@@ -161,21 +167,21 @@ func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) {
require.Equal(t, "test_router", resp.RouterID)
}
-func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) {
+func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) {
budget := health.NewErrorBudget(1, health.MIN)
latConfig := latency.DefaultConfig()
- langModels := []providers.LanguageModel{
+ langModels := []*providers.LanguageModel{
providers.NewLangModel(
"first",
- providers.NewProviderMock([]providers.ResponseMock{{Err: &clients.ErrProviderUnavailable}, {Msg: "3"}}),
- *budget,
+ ptesting.NewProviderMock([]ptesting.RespMock{{Err: &clients.ErrProviderUnavailable}, {Msg: "3"}}),
+ budget,
*latConfig,
1,
),
providers.NewLangModel(
"second",
- providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}),
- *budget,
+ ptesting.NewProviderMock([]ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}),
+ budget,
*latConfig,
1,
),
@@ -187,12 +193,14 @@ func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) {
}
router := LangRouter{
- routerID: "test_router",
- Config: &LangRouterConfig{},
- retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
- routing: routing.NewPriority(models),
- models: langModels,
- telemetry: telemetry.NewTelemetryMock(),
+ routerID: "test_router",
+ Config: &LangRouterConfig{},
+ retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil),
+ chatRouting: routing.NewPriority(models),
+ chatModels: langModels,
+ chatStreamModels: langModels,
+ chatStreamRouting: routing.NewPriority(models),
+ tel: telemetry.NewTelemetryMock(),
}
for i := 0; i < 2; i++ {
@@ -204,21 +212,21 @@ func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) {
}
}
-func TestLangRouter_Priority_AllModelsUnavailable(t *testing.T) {
+func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) {
budget := health.NewErrorBudget(1, health.SEC)
latConfig := latency.DefaultConfig()
- langModels := []providers.LanguageModel{
+ langModels := []*providers.LanguageModel{
providers.NewLangModel(
"first",
- providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}),
- *budget,
+ ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}),
+ budget,
*latConfig,
1,
),
providers.NewLangModel(
"second",
- providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}),
- *budget,
+ ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}),
+ budget,
*latConfig,
1,
),
@@ -230,15 +238,221 @@ func TestLangRouter_Priority_AllModelsUnavailable(t *testing.T) {
}
router := LangRouter{
- routerID: "test_router",
- Config: &LangRouterConfig{},
- retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil),
- routing: routing.NewPriority(models),
- models: langModels,
- telemetry: telemetry.NewTelemetryMock(),
+ routerID: "test_router",
+ Config: &LangRouterConfig{},
+ retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil),
+ chatRouting: routing.NewPriority(models),
+ chatModels: langModels,
+ chatStreamModels: langModels,
+ chatStreamRouting: routing.NewPriority(models),
+ tel: telemetry.NewTelemetryMock(),
}
_, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke"))
require.Error(t, err)
}
+
+func TestLangRouter_ChatStream(t *testing.T) {
+ budget := health.NewErrorBudget(3, health.SEC)
+ latConfig := latency.DefaultConfig()
+
+ langModels := []*providers.LanguageModel{
+ providers.NewLangModel(
+ "first",
+ ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
+ ptesting.NewRespStreamMock(&[]ptesting.RespMock{
+ {Msg: "Bill"},
+ {Msg: "Gates"},
+ {Msg: "entered"},
+ {Msg: "the"},
+ {Msg: "bar"},
+ }),
+ }),
+ budget,
+ *latConfig,
+ 1,
+ ),
+ providers.NewLangModel(
+ "second",
+ ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
+ ptesting.NewRespStreamMock(&[]ptesting.RespMock{
+ {Msg: "Knock"},
+ {Msg: "Knock"},
+ {Msg: "joke"},
+ }),
+ }),
+ budget,
+ *latConfig,
+ 1,
+ ),
+ }
+
+ models := make([]providers.Model, 0, len(langModels))
+ for _, model := range langModels {
+ models = append(models, model)
+ }
+
+ router := LangRouter{
+ routerID: "test_stream_router",
+ Config: &LangRouterConfig{},
+ retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
+ chatRouting: routing.NewPriority(models),
+ chatModels: langModels,
+ chatStreamRouting: routing.NewPriority(models),
+ chatStreamModels: langModels,
+ tel: telemetry.NewTelemetryMock(),
+ }
+
+ ctx := context.Background()
+ req := schemas.NewChatStreamFromStr("tell me a dad joke")
+ respC := make(chan *schemas.ChatStreamResult)
+
+ defer close(respC)
+
+ go router.ChatStream(ctx, req, respC)
+
+ chunks := make([]string, 0, 5)
+
+ for range 5 {
+ select { //nolint:gosimple
+ case chunk := <-respC:
+ require.Nil(t, chunk.Error())
+ require.NotNil(t, chunk.Chunk().ModelResponse.Message.Content)
+
+ chunks = append(chunks, chunk.Chunk().ModelResponse.Message.Content)
+ }
+ }
+
+ require.Equal(t, []string{"Bill", "Gates", "entered", "the", "bar"}, chunks)
+}
+
+func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) {
+ budget := health.NewErrorBudget(3, health.SEC)
+ latConfig := latency.DefaultConfig()
+
+ langModels := []*providers.LanguageModel{
+ providers.NewLangModel(
+ "first",
+ ptesting.NewStreamProviderMock(nil),
+ budget,
+ *latConfig,
+ 1,
+ ),
+ providers.NewLangModel(
+ "second",
+ ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
+ ptesting.NewRespStreamMock(
+ &[]ptesting.RespMock{
+ {Msg: "Knock"},
+ {Msg: "knock"},
+ {Msg: "joke"},
+ },
+ ),
+ }),
+ budget,
+ *latConfig,
+ 1,
+ ),
+ }
+
+ models := make([]providers.Model, 0, len(langModels))
+ for _, model := range langModels {
+ models = append(models, model)
+ }
+
+ router := LangRouter{
+ routerID: "test_stream_router",
+ Config: &LangRouterConfig{},
+ retry: retry.NewExpRetry(3, 2, 1*time.Second, nil),
+ chatRouting: routing.NewPriority(models),
+ chatModels: langModels,
+ chatStreamRouting: routing.NewPriority(models),
+ chatStreamModels: langModels,
+ tel: telemetry.NewTelemetryMock(),
+ }
+
+ ctx := context.Background()
+ req := schemas.NewChatStreamFromStr("tell me a dad joke")
+ respC := make(chan *schemas.ChatStreamResult)
+
+ defer close(respC)
+
+ go router.ChatStream(ctx, req, respC)
+
+ chunks := make([]string, 0, 3)
+
+ for range 3 {
+ select { //nolint:gosimple
+ case chunk := <-respC:
+ require.Nil(t, chunk.Error())
+ require.NotNil(t, chunk.Chunk().ModelResponse.Message.Content)
+
+ chunks = append(chunks, chunk.Chunk().ModelResponse.Message.Content)
+ }
+ }
+
+ require.Equal(t, []string{"Knock", "knock", "joke"}, chunks)
+}
+
+func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) {
+ budget := health.NewErrorBudget(1, health.SEC)
+ latConfig := latency.DefaultConfig()
+
+ langModels := []*providers.LanguageModel{
+ providers.NewLangModel(
+ "first",
+ ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
+ ptesting.NewRespStreamMock(&[]ptesting.RespMock{
+ {Err: &clients.ErrProviderUnavailable},
+ }),
+ }),
+ budget,
+ *latConfig,
+ 1,
+ ),
+ providers.NewLangModel(
+ "second",
+ ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
+ ptesting.NewRespStreamMock(&[]ptesting.RespMock{
+ {Err: &clients.ErrProviderUnavailable},
+ }),
+ }),
+ budget,
+ *latConfig,
+ 1,
+ ),
+ }
+
+ models := make([]providers.Model, 0, len(langModels))
+ for _, model := range langModels {
+ models = append(models, model)
+ }
+
+ router := LangRouter{
+ routerID: "test_router",
+ Config: &LangRouterConfig{},
+ retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil),
+ chatRouting: routing.NewPriority(models),
+ chatModels: langModels,
+ chatStreamModels: langModels,
+ chatStreamRouting: routing.NewPriority(models),
+ tel: telemetry.NewTelemetryMock(),
+ }
+
+ respC := make(chan *schemas.ChatStreamResult)
+ defer close(respC)
+
+ go router.ChatStream(context.Background(), schemas.NewChatStreamFromStr("tell me a dad joke"), respC)
+
+ errs := make([]string, 0, 3)
+
+ for range 3 {
+ result := <-respC
+ require.Nil(t, result.Chunk())
+
+ errs = append(errs, result.Error().ErrCode)
+ }
+
+ require.Equal(t, []string{"modelUnavailable", "modelUnavailable", "allModelsUnavailable"}, errs)
+}
diff --git a/pkg/routers/routing/least_latency.go b/pkg/routers/routing/least_latency.go
index 2b65dc4c..422fd863 100644
--- a/pkg/routers/routing/least_latency.go
+++ b/pkg/routers/routing/least_latency.go
@@ -5,6 +5,8 @@ import (
"sync/atomic"
"time"
+ "glide/pkg/routers/latency"
+
"glide/pkg/providers"
)
@@ -12,6 +14,9 @@ const (
LeastLatency Strategy = "least_latency"
)
+// LatencyGetter defines where to find latency for the specific model action
+type LatencyGetter = func(model providers.Model) *latency.MovingAverage
+
// ModelSchedule defines latency update schedule for models
type ModelSchedule struct {
mu sync.RWMutex
@@ -57,11 +62,12 @@ func (s *ModelSchedule) Update() {
// other model latency may improve over time overperform the best one),
// so we need to send some traffic to other models from time to time to update their latency stats
type LeastLatencyRouting struct {
- warmupIdx atomic.Uint32
- schedules []*ModelSchedule
+ latencyGetter LatencyGetter
+ warmupIdx atomic.Uint32
+ schedules []*ModelSchedule
}
-func NewLeastLatencyRouting(models []providers.Model) *LeastLatencyRouting {
+func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []providers.Model) *LeastLatencyRouting {
schedules := make([]*ModelSchedule, 0, len(models))
for _, model := range models {
@@ -69,7 +75,8 @@ func NewLeastLatencyRouting(models []providers.Model) *LeastLatencyRouting {
}
return &LeastLatencyRouting{
- schedules: schedules,
+ latencyGetter: latencyGetter,
+ schedules: schedules,
}
}
@@ -125,7 +132,7 @@ func (r *LeastLatencyRouting) Next() (providers.Model, error) { //nolint:cyclop
}
if !schedule.Expired() && !nextSchedule.Expired() &&
- nextSchedule.model.Latency().Value() > schedule.model.Latency().Value() {
+ r.latencyGetter(nextSchedule.model).Value() > r.latencyGetter(schedule.model).Value() {
nextSchedule = schedule
}
}
@@ -143,7 +150,7 @@ func (r *LeastLatencyRouting) getColdModelSchedules() []*ModelSchedule {
coldModels := make([]*ModelSchedule, 0, len(r.schedules))
for _, schedule := range r.schedules {
- if schedule.model.Healthy() && !schedule.model.Latency().WarmedUp() {
+ if schedule.model.Healthy() && !r.latencyGetter(schedule.model).WarmedUp() {
coldModels = append(coldModels, schedule)
}
}
diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/routers/routing/least_latency_test.go
index dbbc699c..4b517afc 100644
--- a/pkg/routers/routing/least_latency_test.go
+++ b/pkg/routers/routing/least_latency_test.go
@@ -5,6 +5,8 @@ import (
"testing"
"time"
+ ptesting "glide/pkg/providers/testing"
+
"github.com/stretchr/testify/require"
"glide/pkg/providers"
)
@@ -33,10 +35,10 @@ func TestLeastLatencyRouting_Warmup(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))
for _, model := range tc.models {
- models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, model.latency, 1))
+ models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, model.latency, 1))
}
- routing := NewLeastLatencyRouting(models)
+ routing := NewLeastLatencyRouting(ptesting.ChatMockLatency, models)
iterator := routing.Iterator()
// loop three times over the whole pool to check if we return back to the begging of the list
@@ -104,7 +106,7 @@ func TestLeastLatencyRouting_Routing(t *testing.T) {
for _, model := range tc.models {
schedules = append(schedules, &ModelSchedule{
- model: providers.NewLangModelMock(
+ model: ptesting.NewLangModelMock(
model.modelID,
model.healthy,
model.latency,
@@ -115,7 +117,8 @@ func TestLeastLatencyRouting_Routing(t *testing.T) {
}
routing := LeastLatencyRouting{
- schedules: schedules,
+ latencyGetter: ptesting.ChatMockLatency,
+ schedules: schedules,
}
iterator := routing.Iterator()
@@ -143,10 +146,10 @@ func TestLeastLatencyRouting_NoHealthyModels(t *testing.T) {
models := make([]providers.Model, 0, len(latencies))
for idx, latency := range latencies {
- models = append(models, providers.NewLangModelMock(strconv.Itoa(idx), false, latency, 1))
+ models = append(models, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1))
}
- routing := NewLeastLatencyRouting(models)
+ routing := NewLeastLatencyRouting(providers.ChatLatency, models)
iterator := routing.Iterator()
_, err := iterator.Next()
diff --git a/pkg/routers/routing/priority_test.go b/pkg/routers/routing/priority_test.go
index 4b0d8f94..2c4bcc1e 100644
--- a/pkg/routers/routing/priority_test.go
+++ b/pkg/routers/routing/priority_test.go
@@ -3,6 +3,8 @@ package routing
import (
"testing"
+ ptesting "glide/pkg/providers/testing"
+
"github.com/stretchr/testify/require"
"glide/pkg/providers"
)
@@ -29,7 +31,7 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))
for _, model := range tc.models {
- models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 100, 1))
+ models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1))
}
routing := NewPriority(models)
@@ -47,9 +49,9 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) {
func TestPriorityRouting_NoHealthyModels(t *testing.T) {
models := []providers.Model{
- providers.NewLangModelMock("first", false, 0, 1),
- providers.NewLangModelMock("second", false, 0, 1),
- providers.NewLangModelMock("third", false, 0, 1),
+ ptesting.NewLangModelMock("first", false, 0, 1),
+ ptesting.NewLangModelMock("second", false, 0, 1),
+ ptesting.NewLangModelMock("third", false, 0, 1),
}
routing := NewPriority(models)
diff --git a/pkg/routers/routing/round_robin_test.go b/pkg/routers/routing/round_robin_test.go
index becdf69f..f8853ebe 100644
--- a/pkg/routers/routing/round_robin_test.go
+++ b/pkg/routers/routing/round_robin_test.go
@@ -3,6 +3,8 @@ package routing
import (
"testing"
+ ptesting "glide/pkg/providers/testing"
+
"github.com/stretchr/testify/require"
"glide/pkg/providers"
)
@@ -30,7 +32,7 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))
for _, model := range tc.models {
- models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 100, 1))
+ models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1))
}
routing := NewRoundRobinRouting(models)
@@ -50,9 +52,9 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) {
func TestRoundRobinRouting_NoHealthyModels(t *testing.T) {
models := []providers.Model{
- providers.NewLangModelMock("first", false, 0, 1),
- providers.NewLangModelMock("second", false, 0, 1),
- providers.NewLangModelMock("third", false, 0, 1),
+ ptesting.NewLangModelMock("first", false, 0, 1),
+ ptesting.NewLangModelMock("second", false, 0, 1),
+ ptesting.NewLangModelMock("third", false, 0, 1),
}
routing := NewRoundRobinRouting(models)
diff --git a/pkg/routers/routing/weighted_round_robin_test.go b/pkg/routers/routing/weighted_round_robin_test.go
index 71a412b3..af258bb3 100644
--- a/pkg/routers/routing/weighted_round_robin_test.go
+++ b/pkg/routers/routing/weighted_round_robin_test.go
@@ -3,6 +3,8 @@ package routing
import (
"testing"
+ ptesting "glide/pkg/providers/testing"
+
"github.com/stretchr/testify/require"
"glide/pkg/providers"
)
@@ -116,7 +118,7 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) {
models := make([]providers.Model, 0, len(tc.models))
for _, model := range tc.models {
- models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 0, model.weight))
+ models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 0, model.weight))
}
routing := NewWeightedRoundRobin(models)
@@ -140,9 +142,9 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) {
func TestWRoundRobinRouting_NoHealthyModels(t *testing.T) {
models := []providers.Model{
- providers.NewLangModelMock("first", false, 0, 1),
- providers.NewLangModelMock("second", false, 0, 2),
- providers.NewLangModelMock("third", false, 0, 3),
+ ptesting.NewLangModelMock("first", false, 0, 1),
+ ptesting.NewLangModelMock("second", false, 0, 2),
+ ptesting.NewLangModelMock("third", false, 0, 3),
}
routing := NewWeightedRoundRobin(models)
diff --git a/pkg/telemetry/logging_test.go b/pkg/telemetry/logging_test.go
new file mode 100644
index 00000000..c7a61178
--- /dev/null
+++ b/pkg/telemetry/logging_test.go
@@ -0,0 +1,27 @@
+package telemetry
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestLogging_PlainOutputSetup(t *testing.T) {
+ config := LogConfig{
+ Encoding: "console",
+ }
+ zapConfig := config.ToZapConfig()
+
+ require.Equal(t, "console", config.Encoding)
+ require.NotNil(t, zapConfig)
+ require.Equal(t, "console", zapConfig.Encoding)
+}
+
+func TestLogging_JSONOutputSetup(t *testing.T) {
+ config := DefaultLogConfig()
+ zapConfig := config.ToZapConfig()
+
+ require.Equal(t, "json", config.Encoding)
+ require.NotNil(t, zapConfig)
+ require.Equal(t, "json", zapConfig.Encoding)
+}
diff --git a/pkg/telemetry/telemetry.go b/pkg/telemetry/telemetry.go
index 868dfaeb..a97f87ac 100644
--- a/pkg/telemetry/telemetry.go
+++ b/pkg/telemetry/telemetry.go
@@ -13,6 +13,10 @@ type Telemetry struct {
// TODO: add OTEL meter, tracer
}
+func (t Telemetry) L() *zap.Logger {
+ return t.Logger
+}
+
func DefaultConfig() *Config {
return &Config{
LogConfig: DefaultLogConfig(),
diff --git a/pkg/telemetry/telemetry_test.go b/pkg/telemetry/telemetry_test.go
new file mode 100644
index 00000000..2f69ef89
--- /dev/null
+++ b/pkg/telemetry/telemetry_test.go
@@ -0,0 +1,12 @@
+package telemetry
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestTelemetry_Creation(t *testing.T) {
+ _, err := NewTelemetry(DefaultConfig())
+ require.NoError(t, err)
+}