diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 62e9c47f..e504ec17 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -23,7 +23,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: "1.21" + go-version: "1.22" check-latest: true - name: Install @@ -56,7 +56,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: "1.21" + go-version: "1.22" check-latest: true - name: Build run: go build -v ./... @@ -70,12 +70,9 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: "1.21" + go-version: "1.22" check-latest: true - - name: Install staticcheck - run: go install honnef.co/go/tools/cmd/staticcheck@latest - - name: Install nilaway run: go install go.uber.org/nilaway/cmd/nilaway@latest @@ -85,8 +82,6 @@ jobs: version: latest args: --timeout 5m - - name: Staticcheck - run: staticcheck ./... # TODO: Ignore the issue in https://github.com/modelgateway/Glide/issues/32 # - name: Nilaway # run: nilaway ./... @@ -100,7 +95,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: "1.21" + go-version: "1.22" check-latest: true - name: Test @@ -126,7 +121,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: "1.21" + go-version: "1.22" check-latest: true - name: Generate OpenAPI Schema diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index f97309b8..69f01d09 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -23,7 +23,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: 1.21 + go-version: 1.22 - name: Checkout uses: actions/checkout@v4 @@ -69,9 +69,6 @@ jobs: - name: login into Github Container Registry run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u $ --password-stdin - - name: login into Github Container Registry - run: echo "${{ secrets.DOCKER_HUB_TOKEN }}" | docker login -u einstack --password-stdin - - name: login into Docker run: echo "${{ secrets.DOCKER_HUB_TOKEN }}" | docker login -u einstack --password-stdin diff --git a/.github/workflows/vuln.yaml b/.github/workflows/vuln.yaml index 9e363c8e..aa26598c 100644 --- a/.github/workflows/vuln.yaml +++ b/.github/workflows/vuln.yaml @@ -27,20 +27,22 @@ jobs: - name: Install Go uses: actions/setup-go@v4 with: - go-version: '1.21.5' + go-version: '1.22.1' check-latest: true - name: Checkout uses: actions/checkout@v3 - - name: Install govulncheck - run: go install golang.org/x/vuln/cmd/govulncheck@latest +# TODO: enable in https://github.com/EinStack/glide/issues/169 +# - name: Install govulncheck +# run: go install golang.org/x/vuln/cmd/govulncheck@latest - name: Install gosec run: go install github.com/securego/gosec/v2/cmd/gosec@latest - - name: Govulncheck - run: govulncheck -test ./... +# TODO: enable in https://github.com/EinStack/glide/issues/169 +# - name: Govulncheck +# run: govulncheck -test ./... - name: Govulncheck run: gosec ./... diff --git a/.go-version b/.go-version index d2ab029d..71f7f51d 100644 --- a/.go-version +++ b/.go-version @@ -1 +1 @@ -1.21 +1.22 diff --git a/CHANGELOG.md b/CHANGELOG.md index a84dbc27..f6c2c138 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,75 +1,138 @@ # Changelog -The changelog consists of three categories: -- **Features** - a new functionality that brings value to users -- **Improvements** - bugfixes, performance and other types of improvements to existing functionality +The changelog consists of eight categories: +- **Added** - new functionality that brings value to users +- **Changed** - changes in existing functionality, performance and other types of improvements +- **Fixed** - bugfixes +- **Deprecated** - soon-to-be removed user-facing features +- **Removed** - earlier deprecated, now removed user-facing features +- **Security** - fixing CVEs in the gateway or dependencies - **Miscellaneous** - all other updates like build, release, CLI, etc. -## 0.0.2-rc.2, 0.0.2 (Feb 22nd, 2024) +See [keepachangelog.com](https://keepachangelog.com/en/1.1.0/) for more information. -### Features +## [Unreleased] -- ✨ #142: [Lang Chat Router] Ollama Support (@mkrueger12) -- ✨ #131: [Lang Chat Router] AWS Bedrock Support (@mkrueger12) +### Added + +TBU + +### Changed + +TBU + +### Fixed + +TBU + +### Deprecated + +TBU + +### Removed + +TBU + +### Security + +TBU ### Miscellaneous -- πŸ‘· #155 Fixing the dockerhub authorization step in the release workflow (@roma-glushko) -- ♻️ #151: Moved specific provider schemas closer to provider's packages (@roma-glushko) +TBU + +## [0.0.3-rc.1] (Apr 7th, 2024) -## 0.0.2-rc.1 (Feb 12th, 2024) +Bringing support for streaming chat in Glide. -### Features +### Added -- ✨#117 Allow to load dotenv files (@roma-glushko) +- ✨Streaming Chat Workflow #149 #163 #161 (@roma-glushko) +- ✨Streaming Support for Azure OpenAI #173 (@mkrueger12) +- ✨Cohere Streaming Chat Support #171 (@mkrueger12) +- ✨Start counting token usage in Anthropic Chat #183 (@roma-glushko) +- ✨Handle unauthorized error in health tracker #170 (@roma-glushko) -### Improvements +### Fixed -- βœ¨πŸ‘·#91 Support for Windows (@roma-glushko) -- πŸ‘· #139 Build Glide for OpenBSD and ppc65le, s390x, riscv64 architectures (@roma-glushko) +- πŸ› Fix Anthropic API key header #183 (@roma-glushko) + +### Security + +- πŸ”“ Update crypto lib, golang, fiber #148 (@roma-glushko) ### Miscellaneous -- πŸ‘· #92 Release binaries to Snapcraft (@roma-glushko) -- πŸ‘· #123 publish images to DockerHub (@roma-glushko) -- πŸ”§ #136 Migrated all API to Fiber (@roma-glushko) -- πŸ‘· #139 Create a image tag with pure version (without distro suffix) (@roma-glushko) +- πŸ› Update README.md to fix helm chart location #167 (@arjunnair22) +- πŸ”§ Updated .go-version (@roma-glushko) +- βœ… Covered the telemetry by tests #146 (@roma-glushko) +- πŸ“ Separate and list all supported capabilities per provider #190 (@roma-glushko) -## 0.0.1 (Jan 31st, 2024) +## [0.0.2-rc.2], [0.0.2] (Feb 22nd, 2024) -### Features +### Added -- ✨ #81: Allow to chat message based for specific models (@mkrueger12) +- ✨ [Lang Chat Router] Ollama Support #142 (@mkrueger12) +- ✨ [Lang Chat Router] AWS Bedrock Support #131 (@mkrueger12) -### Improvements +### Miscellaneous -- πŸ”§ #78: Normalize response latency by response token count (@roma-glushko) -- πŸ“ #112 added the CLI banner info (@roma-glushko) +- πŸ‘· Fixing the dockerhub authorization step in the release workflow #155 (@roma-glushko) +- ♻️ Moved specific provider schemas closer to provider's packages #151 (@roma-glushko) + +## [0.0.2-rc.1] (Feb 12th, 2024) + +### Added + +- ✨ Allow to load dotenv files #117 (@roma-glushko) + +### Changed + +- βœ¨πŸ‘· Support for Windows #91 (@roma-glushko) +- πŸ‘· Build Glide for OpenBSD and ppc65le, s390x, riscv64 architectures #139 (@roma-glushko) + +### Miscellaneous + +- πŸ‘· Release binaries to Snapcraft #92 (@roma-glushko) +- πŸ‘· Publish images to DockerHub #123 (@roma-glushko) +- πŸ”§ Migrated all API to Fiber #136 (@roma-glushko) +- πŸ‘· Create a image tag with pure version (without distro suffix) #139 (@roma-glushko) + +## [0.0.1] (Jan 31st, 2024) + +### Added + +- ✨Allow to chat message based for specific models #81 (@mkrueger12) + +### Changed + +- πŸ”§ Normalize response latency by response token count #78 (@roma-glushko) +- πŸ“ Added the CLI banner info #112 (@roma-glushko) ### Miscellaneous - πŸ“ #114 Make links actual across the project (@roma-glushko) -## 0.0.1-rc.2 (Jan 22nd, 2024) +## [0.0.1-rc.2] (Jan 22nd, 2024) -### Improvements +### Added - βš™οΈ [config] Added validation for config file content #40 (@roma-glushko) - βš™οΈ [config] Allowed to pass HTTP server configs from config file #41 (@roma-glushko) - πŸ‘· [build] Allowed building Homebrew taps for release candidates #99 (@roma-glushko) -## 0.0.1-rc.1 (Jan 21st, 2024) +## [0.0.1-rc.1] (Jan 21st, 2024) -### Features +### Added - ✨ [providers] Support for OpenAI Chat API #3 (@mkrueger12) - ✨ [API] Unified Chat API #54 (@mkrueger12) - ✨ [providers] Support for Cohere Chat API #5 (@mkrueger12) - ✨ [providers] Support for Azure OpenAI Chat API #4 (@mkrueger12) - ✨ [providers] Support for OctoML Chat API #58 (@mkrueger12) - ✨ [routing] The Routing Mechanism, Adaptive Health Tracking, and Fallbacks #42 #43 #51 (@roma-glushko) -- ✨ [routing] Support for round robin routing strategy #44 (@roma-glushko) +- ✨ [routing] Support for round-robin routing strategy #44 (@roma-glushko) - ✨ [routing] Support for the least latency routing strategy #46 (@roma-glushko) -- ✨ [routing] Support for weighted round robin routing strategy #45 (@roma-glushko) +- ✨ [routing] Support for weighted round-robin routing strategy #45 (@roma-glushko) - ✨ [providers] Support for Anthropic Chat API #60 (@mkrueger12) - ✨ [docs] OpenAPI specifications #22 (@roma-glushko) @@ -80,5 +143,14 @@ The changelog consists of three categories: - πŸ”§ [chores] Inited Glide's CLI #12 (@roma-glushko) - πŸ‘· [chores] Setup CI workflows #8 (@roma-glushko) - βš™οΈ [config] Inited configs #11 (@roma-glushko) -- πŸ”§ [chores] Automatic coverage reports #39 (@roma-glushko) +- πŸ”§ [chores] Automatic coverage reports #39 (@roma-glushko) - πŸ‘· [build] Setup release workflows #9 (@roma-glushko) + +[unreleased]: https://github.com/olivierlacan/keep-a-changelog/compare/0.0.3-rc.1...HEAD +[0.0.3-rc.1]: https://github.com/EinStack/glide/compare/0.0.2..0.0.3-rc.1 +[0.0.2]: https://github.com/EinStack/glide/compare/0.0.2-rc.1..0.0.2 +[0.0.2-rc.2]: https://github.com/EinStack/glide/compare/0.0.2-rc.1..0.0.2-rc.2 +[0.0.2-rc.1]: https://github.com/EinStack/glide/compare/0.0.1..0.0.2-rc.1 +[0.0.1]: https://github.com/EinStack/glide/compare/0.0.1-rc.2..0.0.1 +[0.0.1-rc.2]: https://github.com/EinStack/glide/compare/0.0.1-rc.1..0.0.1-rc.2 +[0.0.1-rc.1]: https://github.com/EinStack/glide/releases/tag/0.0.1-rc.1 diff --git a/Makefile b/Makefile index ecf858bb..4cebf96b 100644 --- a/Makefile +++ b/Makefile @@ -33,7 +33,7 @@ lint: install-checkers ## Lint the source code vuln: install-checkers ## Check for vulnerabilities @echo "πŸ” Checking for vulnerabilities" - @$(CHECKER_BIN)/govulncheck -test ./... + @#$(CHECKER_BIN)/govulncheck -test ./... enable in https://github.com/EinStack/glide/issues/169 @$(CHECKER_BIN)/gosec -quiet -exclude=G104 ./... run: ## Run Glide diff --git a/README.md b/README.md index c3597081..797f7409 100644 --- a/README.md +++ b/README.md @@ -42,16 +42,16 @@ Check out our [documentation](https://glide.einstack.ai)! ### Large Language Models -| | Provider | Support Status | -|-----------------------------------------------------|---------------|-----------------| -| | Anthropic | πŸ‘ Supported | -| | Azure OpenAI | πŸ‘ Supported | -| | AWS Bedrock (Titan) | πŸ‘ Supported | -| | Cohere | πŸ‘ Supported | -| | Google Gemini | πŸ—οΈ Coming Soon | -| | OctoML | πŸ‘ Supported | -| | Ollama | πŸ‘ Supported | -| | OpenAI | πŸ‘ Supported | +| Provider | Supported Capabilities | +|-----------------------------------------------------------------------|-------------------------------------------| +| OpenAI | βœ… Chat
βœ… Streaming Chat | +| Anthropic | βœ… Chat
πŸ—οΈ Streaming Chat (coming soon) | +| Azure OpenAI | βœ… Chat
πŸ—οΈ Streaming Chat (coming soon) | +| AWS Bedrock (Titan) | βœ… Chat | +| Cohere | βœ… Chat
πŸ—οΈ Streaming Chat (coming soon) | +| Google Gemini | πŸ—οΈ Chat (coming soon) | +| OctoML | βœ… Chat | +| Ollama | βœ… Chat | ## Get Started @@ -183,7 +183,7 @@ docker pull ghcr.io/einstack/glide:latest-redhat Add the EinStack repository: ```bash -helm repo add einstack https://einstack.github.io/helm-charts +helm repo add einstack https://einstack.github.io/charts helm repo update ``` diff --git "a/docs/api/Language API/\360\237\222\254 Chat Stream.bru" "b/docs/api/Language API/\360\237\222\254 Chat Stream.bru" new file mode 100644 index 00000000..e544c061 --- /dev/null +++ "b/docs/api/Language API/\360\237\222\254 Chat Stream.bru" @@ -0,0 +1,11 @@ +meta { + name: πŸ’¬ Chat Stream + type: http + seq: 2 +} + +get { + url: {{base_url}}/language/default/chatStream + body: none + auth: none +} diff --git a/docs/api/[Lang] Chat.bru "b/docs/api/Language API/\360\237\222\254 Chat.bru" similarity index 71% rename from docs/api/[Lang] Chat.bru rename to "docs/api/Language API/\360\237\222\254 Chat.bru" index d3a31a71..6ea21147 100644 --- a/docs/api/[Lang] Chat.bru +++ "b/docs/api/Language API/\360\237\222\254 Chat.bru" @@ -1,11 +1,11 @@ meta { - name: [Lang] Chat + name: πŸ’¬ Chat type: http - seq: 2 + seq: 1 } post { - url: {{base_url}}/v1/language/myrouter/chat/ + url: {{base_url}}/language/default/chat/ body: json auth: none } diff --git a/docs/api/[Lang] Router List.bru "b/docs/api/Language API/\360\237\224\247 Router List.bru" similarity index 76% rename from docs/api/[Lang] Router List.bru rename to "docs/api/Language API/\360\237\224\247 Router List.bru" index 81ccec75..0545245f 100644 --- a/docs/api/[Lang] Router List.bru +++ "b/docs/api/Language API/\360\237\224\247 Router List.bru" @@ -1,11 +1,11 @@ meta { - name: [Lang] Router List + name: πŸ”§ Router List type: http seq: 3 } get { - url: {{base_url}}/v1/language/ + url: {{base_url}}/language/ body: json auth: none } diff --git a/docs/api/environments/Development.bru b/docs/api/environments/Development.bru index 732c80cf..a8abb8bf 100644 --- a/docs/api/environments/Development.bru +++ b/docs/api/environments/Development.bru @@ -1,3 +1,3 @@ vars { - base_url: http://127.0.0.1:9099 + base_url: http://127.0.0.1:9099/v1 } diff --git a/docs/api/Health.bru "b/docs/api/\360\237\224\247 Health.bru" similarity index 57% rename from docs/api/Health.bru rename to "docs/api/\360\237\224\247 Health.bru" index 0486a046..df4d4d10 100644 --- a/docs/api/Health.bru +++ "b/docs/api/\360\237\224\247 Health.bru" @@ -1,11 +1,11 @@ meta { - name: Health + name: πŸ”§ Health type: http seq: 1 } get { - url: {{base_url}}/health + url: {{base_url}}/health/ body: none auth: none } diff --git a/docs/docs.go b/docs/docs.go index 96ddabc5..d575a50a 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -72,7 +72,7 @@ const docTemplate = `{ }, "/v1/language/{router}/chat": { "post": { - "description": "Talk to different LLMs Chat API via unified endpoint", + "description": "Talk to different LLM Chat APIs via unified endpoint", "consumes": [ "application/json" ], @@ -123,17 +123,85 @@ const docTemplate = `{ } } } + }, + "/v1/language/{router}/chatStream": { + "get": { + "description": "Talk to different LLM Stream Chat APIs via a unified websocket endpoint", + "consumes": [ + "application/json" + ], + "tags": [ + "Language" + ], + "summary": "Language Chat", + "operationId": "glide-language-chat-stream", + "parameters": [ + { + "type": "string", + "description": "Router ID", + "name": "router", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Websocket Connection Type", + "name": "Connection", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "Upgrade header", + "name": "Upgrade", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "Websocket Security Token", + "name": "Sec-WebSocket-Key", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "Websocket Security Token", + "name": "Sec-WebSocket-Version", + "in": "header", + "required": true + } + ], + "responses": { + "101": { + "description": "Switching Protocols" + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/http.ErrorSchema" + } + }, + "426": { + "description": "Upgrade Required" + } + } + } } }, "definitions": { "anthropic.Config": { "type": "object", "required": [ + "apiVersion", "baseUrl", "chatEndpoint", "model" ], "properties": { + "apiVersion": { + "type": "string" + }, "baseUrl": { "type": "string" }, @@ -377,7 +445,6 @@ const docTemplate = `{ "type": "boolean" }, "stream": { - "description": "unsupported right now", "type": "boolean" }, "temperature": { @@ -742,6 +809,10 @@ const docTemplate = `{ }, "schemas.ChatMessage": { "type": "object", + "required": [ + "content", + "role" + ], "properties": { "content": { "description": "The content of the message.", @@ -759,6 +830,9 @@ const docTemplate = `{ }, "schemas.ChatRequest": { "type": "object", + "required": [ + "message" + ], "properties": { "message": { "$ref": "#/definitions/schemas.ChatMessage" @@ -790,7 +864,7 @@ const docTemplate = `{ "type": "string" }, "modelResponse": { - "$ref": "#/definitions/schemas.ProviderResponse" + "$ref": "#/definitions/schemas.ModelResponse" }, "model_id": { "type": "string" @@ -803,18 +877,7 @@ const docTemplate = `{ } } }, - "schemas.OverrideChatRequest": { - "type": "object", - "properties": { - "message": { - "$ref": "#/definitions/schemas.ChatMessage" - }, - "model_id": { - "type": "string" - } - } - }, - "schemas.ProviderResponse": { + "schemas.ModelResponse": { "type": "object", "properties": { "message": { @@ -831,17 +894,32 @@ const docTemplate = `{ } } }, + "schemas.OverrideChatRequest": { + "type": "object", + "required": [ + "message", + "model_id" + ], + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "model_id": { + "type": "string" + } + } + }, "schemas.TokenUsage": { "type": "object", "properties": { "promptTokens": { - "type": "number" + "type": "integer" }, "responseTokens": { - "type": "number" + "type": "integer" }, "totalTokens": { - "type": "number" + "type": "integer" } } } diff --git a/docs/swagger.json b/docs/swagger.json index aee257b6..ce439bc8 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -69,7 +69,7 @@ }, "/v1/language/{router}/chat": { "post": { - "description": "Talk to different LLMs Chat API via unified endpoint", + "description": "Talk to different LLM Chat APIs via unified endpoint", "consumes": [ "application/json" ], @@ -120,17 +120,85 @@ } } } + }, + "/v1/language/{router}/chatStream": { + "get": { + "description": "Talk to different LLM Stream Chat APIs via a unified websocket endpoint", + "consumes": [ + "application/json" + ], + "tags": [ + "Language" + ], + "summary": "Language Chat", + "operationId": "glide-language-chat-stream", + "parameters": [ + { + "type": "string", + "description": "Router ID", + "name": "router", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Websocket Connection Type", + "name": "Connection", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "Upgrade header", + "name": "Upgrade", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "Websocket Security Token", + "name": "Sec-WebSocket-Key", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "Websocket Security Token", + "name": "Sec-WebSocket-Version", + "in": "header", + "required": true + } + ], + "responses": { + "101": { + "description": "Switching Protocols" + }, + "404": { + "description": "Not Found", + "schema": { + "$ref": "#/definitions/http.ErrorSchema" + } + }, + "426": { + "description": "Upgrade Required" + } + } + } } }, "definitions": { "anthropic.Config": { "type": "object", "required": [ + "apiVersion", "baseUrl", "chatEndpoint", "model" ], "properties": { + "apiVersion": { + "type": "string" + }, "baseUrl": { "type": "string" }, @@ -374,7 +442,6 @@ "type": "boolean" }, "stream": { - "description": "unsupported right now", "type": "boolean" }, "temperature": { @@ -739,6 +806,10 @@ }, "schemas.ChatMessage": { "type": "object", + "required": [ + "content", + "role" + ], "properties": { "content": { "description": "The content of the message.", @@ -756,6 +827,9 @@ }, "schemas.ChatRequest": { "type": "object", + "required": [ + "message" + ], "properties": { "message": { "$ref": "#/definitions/schemas.ChatMessage" @@ -787,7 +861,7 @@ "type": "string" }, "modelResponse": { - "$ref": "#/definitions/schemas.ProviderResponse" + "$ref": "#/definitions/schemas.ModelResponse" }, "model_id": { "type": "string" @@ -800,18 +874,7 @@ } } }, - "schemas.OverrideChatRequest": { - "type": "object", - "properties": { - "message": { - "$ref": "#/definitions/schemas.ChatMessage" - }, - "model_id": { - "type": "string" - } - } - }, - "schemas.ProviderResponse": { + "schemas.ModelResponse": { "type": "object", "properties": { "message": { @@ -828,17 +891,32 @@ } } }, + "schemas.OverrideChatRequest": { + "type": "object", + "required": [ + "message", + "model_id" + ], + "properties": { + "message": { + "$ref": "#/definitions/schemas.ChatMessage" + }, + "model_id": { + "type": "string" + } + } + }, "schemas.TokenUsage": { "type": "object", "properties": { "promptTokens": { - "type": "number" + "type": "integer" }, "responseTokens": { - "type": "number" + "type": "integer" }, "totalTokens": { - "type": "number" + "type": "integer" } } } diff --git a/docs/swagger.yaml b/docs/swagger.yaml index d5fb088f..00e800eb 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -2,6 +2,8 @@ basePath: / definitions: anthropic.Config: properties: + apiVersion: + type: string baseUrl: type: string chatEndpoint: @@ -11,6 +13,7 @@ definitions: model: type: string required: + - apiVersion - baseUrl - chatEndpoint - model @@ -171,7 +174,6 @@ definitions: search_queries_only: type: boolean stream: - description: unsupported right now type: boolean temperature: type: number @@ -428,6 +430,9 @@ definitions: description: The role of the author of this message. One of system, user, or assistant. type: string + required: + - content + - role type: object schemas.ChatRequest: properties: @@ -439,6 +444,8 @@ definitions: type: array override: $ref: '#/definitions/schemas.OverrideChatRequest' + required: + - message type: object schemas.ChatResponse: properties: @@ -453,20 +460,13 @@ definitions: model_id: type: string modelResponse: - $ref: '#/definitions/schemas.ProviderResponse' + $ref: '#/definitions/schemas.ModelResponse' provider: type: string router: type: string type: object - schemas.OverrideChatRequest: - properties: - message: - $ref: '#/definitions/schemas.ChatMessage' - model_id: - type: string - type: object - schemas.ProviderResponse: + schemas.ModelResponse: properties: message: $ref: '#/definitions/schemas.ChatMessage' @@ -477,14 +477,24 @@ definitions: tokenCount: $ref: '#/definitions/schemas.TokenUsage' type: object + schemas.OverrideChatRequest: + properties: + message: + $ref: '#/definitions/schemas.ChatMessage' + model_id: + type: string + required: + - message + - model_id + type: object schemas.TokenUsage: properties: promptTokens: - type: number + type: integer responseTokens: - type: number + type: integer totalTokens: - type: number + type: integer type: object externalDocs: description: Documentation @@ -538,7 +548,7 @@ paths: post: consumes: - application/json - description: Talk to different LLMs Chat API via unified endpoint + description: Talk to different LLM Chat APIs via unified endpoint operationId: glide-language-chat parameters: - description: Router ID @@ -570,6 +580,51 @@ paths: summary: Language Chat tags: - Language + /v1/language/{router}/chatStream: + get: + consumes: + - application/json + description: Talk to different LLM Stream Chat APIs via a unified websocket + endpoint + operationId: glide-language-chat-stream + parameters: + - description: Router ID + in: path + name: router + required: true + type: string + - description: Websocket Connection Type + in: header + name: Connection + required: true + type: string + - description: Upgrade header + in: header + name: Upgrade + required: true + type: string + - description: Websocket Security Token + in: header + name: Sec-WebSocket-Key + required: true + type: string + - description: Websocket Security Token + in: header + name: Sec-WebSocket-Version + required: true + type: string + responses: + "101": + description: Switching Protocols + "404": + description: Not Found + schema: + $ref: '#/definitions/http.ErrorSchema' + "426": + description: Upgrade Required + summary: Language Chat + tags: + - Language schemes: - http swagger: "2.0" diff --git a/go.mod b/go.mod index 73a99abe..c90e943c 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module glide -go 1.21.5 +go 1.22.1 require ( github.com/aws/aws-sdk-go-v2 v1.24.1 @@ -10,9 +10,11 @@ require ( github.com/go-playground/validator/v10 v10.17.0 github.com/gofiber/contrib/fiberzap/v2 v2.1.2 github.com/gofiber/contrib/swagger v1.1.1 - github.com/gofiber/fiber/v2 v2.52.0 + github.com/gofiber/contrib/websocket v1.3.0 + github.com/gofiber/fiber/v2 v2.52.2 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 + github.com/r3labs/sse/v2 v2.10.0 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 github.com/swaggo/swag v1.16.2 @@ -24,7 +26,7 @@ require ( require ( github.com/KyleBanks/depth v1.2.1 // indirect - github.com/andybalholm/brotli v1.0.5 // indirect + github.com/andybalholm/brotli v1.1.0 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.4 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 // indirect @@ -38,6 +40,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 // indirect github.com/aws/smithy-go v1.19.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fasthttp/websocket v1.5.7 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/go-openapi/analysis v0.21.4 // indirect github.com/go-openapi/errors v0.20.3 // indirect @@ -53,7 +56,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect - github.com/klauspost/compress v1.17.0 // indirect + github.com/klauspost/compress v1.17.6 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect @@ -62,17 +65,19 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rivo/uniseg v0.2.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasthttp v1.51.0 // indirect + github.com/valyala/fasthttp v1.52.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect go.mongodb.org/mongo-driver v1.11.3 // indirect - golang.org/x/crypto v0.17.0 // indirect - golang.org/x/net v0.19.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/net v0.21.0 // indirect + golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.16.1 // indirect + gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index a64851f6..c59f3570 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,8 @@ github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= -github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= -github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= @@ -43,6 +43,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fasthttp/websocket v1.5.7 h1:0a6o2OfeATvtGgoMKleURhLT6JqWPg7fYfWnH4KHau4= +github.com/fasthttp/websocket v1.5.7/go.mod h1:bC4fxSono9czeXHQUVKxsC0sNjbm7lPJR04GDFqClfU= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/go-openapi/analysis v0.21.2/go.mod h1:HZwRk4RRisyG8vx2Oe6aqeSQcoxRp47Xkp3+K6q+LdY= @@ -119,8 +121,10 @@ github.com/gofiber/contrib/fiberzap/v2 v2.1.2 h1:7Z1BqS1sYK9e9jTwqPcWx9qQt46PI8o github.com/gofiber/contrib/fiberzap/v2 v2.1.2/go.mod h1:ulCCQOdDYABGsOQfbndASmCsCN86hsC96iKoOTNYfy8= github.com/gofiber/contrib/swagger v1.1.1 h1:on+D2fbXkvm0H0lur1rx69mpxLdX1wIH/FrTRZ99b9Y= github.com/gofiber/contrib/swagger v1.1.1/go.mod h1:pa9awsFSz/3BbSnyTe/drNZaiFfnhC4hk3m9BVet7Co= -github.com/gofiber/fiber/v2 v2.52.0 h1:S+qXi7y+/Pgvqq4DrSmREGiFwtB7Bu6+QFLuIHYw/UE= -github.com/gofiber/fiber/v2 v2.52.0/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ= +github.com/gofiber/contrib/websocket v1.3.0 h1:XADFAGorer1VJ1bqC4UkCjqS37kwRTV0415+050NrMk= +github.com/gofiber/contrib/websocket v1.3.0/go.mod h1:xguaOzn2ZZ759LavtosEP+rcxIgBEE/rdumPINhR+Xo= +github.com/gofiber/fiber/v2 v2.52.2 h1:b0rYH6b06Df+4NyrbdptQL8ifuxw/Tf2DgfkZkDaxEo= +github.com/gofiber/fiber/v2 v2.52.2/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= @@ -139,8 +143,8 @@ github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFF github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4= github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= -github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= +github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -180,14 +184,19 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/r3labs/sse/v2 v2.10.0 h1:hFEkLLFY4LDifoHdiCN/LlGBAdVJYsANaLqNYa1l/v0= +github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEmkNJ7I= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= +github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= @@ -217,8 +226,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= -github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= +github.com/valyala/fasthttp v1.52.0 h1:wqBQpxH71XW0e2g+Og4dzQM8pk34aFYlA1Ga8db7gU0= +github.com/valyala/fasthttp v1.52.0/go.mod h1:hf5C4QnVMkNXMspnsUlfM3WitlgYflyhHYoKol/szxQ= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= @@ -243,16 +252,17 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -271,8 +281,8 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -288,6 +298,8 @@ golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgw golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= +gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/images/alpine.Dockerfile b/images/alpine.Dockerfile index 1454c0a7..26b20357 100644 --- a/images/alpine.Dockerfile +++ b/images/alpine.Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -FROM golang:1.21-alpine as build +FROM golang:1.22-alpine as build ARG VERSION ARG COMMIT diff --git a/images/distroless.Dockerfile b/images/distroless.Dockerfile index 34776aca..23167563 100644 --- a/images/distroless.Dockerfile +++ b/images/distroless.Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -FROM golang:1.21-alpine as build +FROM golang:1.22-alpine as build ARG VERSION ARG COMMIT diff --git a/images/redhat.Dockerfile b/images/redhat.Dockerfile index f55c9534..c3a16248 100644 --- a/images/redhat.Dockerfile +++ b/images/redhat.Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -FROM golang:1.21-alpine as build +FROM golang:1.22-alpine as build ARG VERSION ARG COMMIT diff --git a/images/ubuntu.Dockerfile b/images/ubuntu.Dockerfile index 7db2cb62..fe238b79 100644 --- a/images/ubuntu.Dockerfile +++ b/images/ubuntu.Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -FROM golang:1.21-alpine as build +FROM golang:1.22-alpine as build ARG VERSION ARG COMMIT diff --git a/pkg/api/http/config.go b/pkg/api/http/config.go index 8e261098..e9098266 100644 --- a/pkg/api/http/config.go +++ b/pkg/api/http/config.go @@ -41,7 +41,7 @@ func (cfg *ServerConfig) ToServer() *fiber.App { // More configs are listed on https://docs.gofiber.io/api/fiber // TODO: Consider alternative JSON marshallers that provides better performance over the standard marshaller serverConfig := fiber.Config{ - AppName: "glide", + AppName: "Glide", DisableDefaultDate: true, ServerHeader: fmt.Sprintf("glide/%v", version.Version), StreamRequestBody: true, diff --git a/pkg/api/http/config_test.go b/pkg/api/http/config_test.go new file mode 100644 index 00000000..8d8667a6 --- /dev/null +++ b/pkg/api/http/config_test.go @@ -0,0 +1,14 @@ +package http + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHTTPConfig_DefaultConfig(t *testing.T) { + config := DefaultServerConfig() + + require.NotNil(t, config.Address()) + require.NotNil(t, config.ToServer()) +} diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index c97f542b..5e325ad0 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -1,10 +1,15 @@ package http import ( + "context" "errors" + "sync" - "github.com/gofiber/fiber/v2" + "glide/pkg/telemetry" + "go.uber.org/zap" + "github.com/gofiber/contrib/websocket" + "github.com/gofiber/fiber/v2" "glide/pkg/api/schemas" "glide/pkg/routers" ) @@ -18,7 +23,7 @@ type Handler = func(c *fiber.Ctx) error // // @id glide-language-chat // @Summary Language Chat -// @Description Talk to different LLMs Chat API via unified endpoint +// @Description Talk to different LLM Chat APIs via unified endpoint // @tags Language // @Param router path string true "Router ID" // @Param payload body schemas.ChatRequest true "Request Data" @@ -30,6 +35,12 @@ type Handler = func(c *fiber.Ctx) error // @Router /v1/language/{router}/chat [POST] func LangChatHandler(routerManager *routers.RouterManager) Handler { return func(c *fiber.Ctx) error { + if !c.Is("json") { + return c.Status(fiber.StatusBadRequest).JSON(ErrorSchema{ + Message: "Glide accepts only JSON payloads", + }) + } + // Unmarshal request body var req *schemas.ChatRequest @@ -65,6 +76,106 @@ func LangChatHandler(routerManager *routers.RouterManager) Handler { } } +func LangStreamRouterValidator(routerManager *routers.RouterManager) Handler { + return func(c *fiber.Ctx) error { + if websocket.IsWebSocketUpgrade(c) { + routerID := c.Params("router") + + _, err := routerManager.GetLangRouter(routerID) + if err != nil { + return c.Status(fiber.StatusNotFound).JSON(ErrorSchema{ + Message: err.Error(), + }) + } + + return c.Next() + } + + return fiber.ErrUpgradeRequired + } +} + +// LangStreamChatHandler +// +// @id glide-language-chat-stream +// @Summary Language Chat +// @Description Talk to different LLM Stream Chat APIs via a unified websocket endpoint +// @tags Language +// @Param router path string true "Router ID" +// @Param Connection header string true "Websocket Connection Type" +// @Param Upgrade header string true "Upgrade header" +// @Param Sec-WebSocket-Key header string true "Websocket Security Token" +// @Param Sec-WebSocket-Version header string true "Websocket Security Token" +// @Accept json +// @Success 101 +// @Failure 426 +// @Failure 404 {object} http.ErrorSchema +// @Router /v1/language/{router}/chatStream [GET] +func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.RouterManager) Handler { + // TODO: expose websocket connection configs https://github.com/gofiber/contrib/tree/main/websocket + return websocket.New(func(c *websocket.Conn) { + routerID := c.Params("router") + // websocket.Conn bindings https://pkg.go.dev/github.com/fasthttp/websocket?tab=doc#pkg-index + + var ( + err error + wg sync.WaitGroup + ) + + chunkResultC := make(chan *schemas.ChatStreamResult) + + router, _ := routerManager.GetLangRouter(routerID) + + defer close(chunkResultC) + defer c.Conn.Close() + + wg.Add(1) + + go func() { + defer wg.Done() + + for chunkResult := range chunkResultC { + if chunkResult.Error() != nil { + if err = c.WriteJSON(chunkResult.Error()); err != nil { + break + } + + continue + } + + if err = c.WriteJSON(chunkResult.Chunk()); err != nil { + break + } + } + }() + + for { + var chatRequest schemas.ChatStreamRequest + + if err = c.ReadJSON(&chatRequest); err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + tel.L().Warn("Streaming Chat connection is closed", zap.Error(err), zap.String("routerID", routerID)) + } + + tel.L().Debug("Streaming chat connection is closed by client", zap.Error(err), zap.String("routerID", routerID)) + + break + } + + // TODO: handle termination gracefully + wg.Add(1) + + go func(chatRequest schemas.ChatStreamRequest) { + defer wg.Done() + + router.ChatStream(context.Background(), &chatRequest, chunkResultC) + }(chatRequest) + } + + wg.Wait() + }) +} + // LangRoutersHandler // // @id glide-language-routers diff --git a/pkg/api/http/server.go b/pkg/api/http/server.go index d44beac1..e49c8fec 100644 --- a/pkg/api/http/server.go +++ b/pkg/api/http/server.go @@ -41,7 +41,7 @@ func (srv *Server) Run() error { Title: "Glide API Docs", BasePath: "/v1/", Path: "swagger", - FilePath: "./docs/swagger.json", + FilePath: "./docs/swagger.yaml", })) srv.server.Use(fiberzap.New(fiberzap.Config{ @@ -53,6 +53,9 @@ func (srv *Server) Run() error { v1.Get("/language/", LangRoutersHandler(srv.routerManager)) v1.Post("/language/:router/chat/", LangChatHandler(srv.routerManager)) + v1.Use("/language/:router/chatStream", LangStreamRouterValidator(srv.routerManager)) + v1.Get("/language/:router/chatStream", LangStreamChatHandler(srv.telemetry, srv.routerManager)) + v1.Get("/health/", HealthHandler) srv.server.Use(NotFoundHandler) diff --git a/pkg/api/schemas/chat.go b/pkg/api/schemas/chat.go new file mode 100644 index 00000000..4be88692 --- /dev/null +++ b/pkg/api/schemas/chat.go @@ -0,0 +1,60 @@ +package schemas + +// ChatRequest defines Glide's Chat Request Schema unified across all language models +type ChatRequest struct { + Message ChatMessage `json:"message" validate:"required"` + MessageHistory []ChatMessage `json:"messageHistory"` + Override *OverrideChatRequest `json:"override,omitempty"` +} + +type OverrideChatRequest struct { + Model string `json:"model_id" validate:"required"` + Message ChatMessage `json:"message" validate:"required"` +} + +func NewChatFromStr(message string) *ChatRequest { + return &ChatRequest{ + Message: ChatMessage{ + "user", + message, + "glide", + }, + } +} + +// ChatResponse defines Glide's Chat Response Schema unified across all language models +type ChatResponse struct { + ID string `json:"id,omitempty"` + Created int `json:"created,omitempty"` + Provider string `json:"provider,omitempty"` + RouterID string `json:"router,omitempty"` + ModelID string `json:"model_id,omitempty"` + ModelName string `json:"model,omitempty"` + Cached bool `json:"cached,omitempty"` + ModelResponse ModelResponse `json:"modelResponse,omitempty"` +} + +// ModelResponse is the unified response from the provider. + +type ModelResponse struct { + SystemID map[string]string `json:"responseId,omitempty"` + Message ChatMessage `json:"message"` + TokenUsage TokenUsage `json:"tokenCount"` +} + +type TokenUsage struct { + PromptTokens int `json:"promptTokens"` + ResponseTokens int `json:"responseTokens"` + TotalTokens int `json:"totalTokens"` +} + +// ChatMessage is a message in a chat request. +type ChatMessage struct { + // The role of the author of this message. One of system, user, or assistant. + Role string `json:"role" validate:"required"` + // The content of the message. + Content string `json:"content" validate:"required"` + // The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, + // with a maximum length of 64 characters. + Name string `json:"name,omitempty"` +} diff --git a/pkg/api/schemas/chat_stream.go b/pkg/api/schemas/chat_stream.go new file mode 100644 index 00000000..d77310c7 --- /dev/null +++ b/pkg/api/schemas/chat_stream.go @@ -0,0 +1,80 @@ +package schemas + +type ( + Metadata = map[string]any + FinishReason = string +) + +var Complete FinishReason = "complete" + +// ChatStreamRequest defines a message that requests a new streaming chat +type ChatStreamRequest struct { + ID string `json:"id" validate:"required"` + Message ChatMessage `json:"message" validate:"required"` + MessageHistory []ChatMessage `json:"messageHistory" validate:"required"` + Override *OverrideChatRequest `json:"overrideMessage,omitempty"` + Metadata *Metadata `json:"metadata,omitempty"` +} + +func NewChatStreamFromStr(message string) *ChatStreamRequest { + return &ChatStreamRequest{ + Message: ChatMessage{ + "user", + message, + "glide", + }, + } +} + +type ModelChunkResponse struct { + Metadata *Metadata `json:"metadata,omitempty"` + Message ChatMessage `json:"message"` + FinishReason *FinishReason `json:"finishReason,omitempty"` +} + +// ChatStreamChunk defines a message for a chunk of streaming chat response +type ChatStreamChunk struct { + ID string `json:"id"` + CreatedAt int `json:"createdAt"` + Provider string `json:"providerId"` + RouterID string `json:"routerId"` + ModelID string `json:"modelId"` + Cached bool `json:"cached"` + ModelName string `json:"modelName"` + Metadata *Metadata `json:"metadata,omitempty"` + ModelResponse ModelChunkResponse `json:"modelResponse"` +} + +type ChatStreamError struct { + ID string `json:"id"` + ErrCode string `json:"errCode"` + Message string `json:"message"` + Metadata *Metadata `json:"metadata,omitempty"` +} + +type ChatStreamResult struct { + chunk *ChatStreamChunk + err *ChatStreamError +} + +func (r *ChatStreamResult) Chunk() *ChatStreamChunk { + return r.chunk +} + +func (r *ChatStreamResult) Error() *ChatStreamError { + return r.err +} + +func NewChatStreamResult(chunk *ChatStreamChunk) *ChatStreamResult { + return &ChatStreamResult{ + chunk: chunk, + err: nil, + } +} + +func NewChatStreamErrorResult(err *ChatStreamError) *ChatStreamResult { + return &ChatStreamResult{ + chunk: nil, + err: err, + } +} diff --git a/pkg/api/schemas/language.go b/pkg/api/schemas/language.go deleted file mode 100644 index 7e2a2cdc..00000000 --- a/pkg/api/schemas/language.go +++ /dev/null @@ -1,60 +0,0 @@ -package schemas - -// ChatRequest defines Glide's Chat Request Schema unified across all language models -type ChatRequest struct { - Message ChatMessage `json:"message"` - MessageHistory []ChatMessage `json:"messageHistory"` - Override OverrideChatRequest `json:"override,omitempty"` -} - -type OverrideChatRequest struct { - Model string `json:"model_id"` - Message ChatMessage `json:"message"` -} - -func NewChatFromStr(message string) *ChatRequest { - return &ChatRequest{ - Message: ChatMessage{ - "human", - message, - "roma", - }, - } -} - -// ChatResponse defines Glide's Chat Response Schema unified across all language models -type ChatResponse struct { - ID string `json:"id,omitempty"` - Created int `json:"created,omitempty"` - Provider string `json:"provider,omitempty"` - RouterID string `json:"router,omitempty"` - ModelID string `json:"model_id,omitempty"` - Model string `json:"model,omitempty"` - Cached bool `json:"cached,omitempty"` - ModelResponse ProviderResponse `json:"modelResponse,omitempty"` -} - -// ProviderResponse is the unified response from the provider. - -type ProviderResponse struct { - SystemID map[string]string `json:"responseId,omitempty"` - Message ChatMessage `json:"message"` - TokenUsage TokenUsage `json:"tokenCount"` -} - -type TokenUsage struct { - PromptTokens float64 `json:"promptTokens"` - ResponseTokens float64 `json:"responseTokens"` - TotalTokens float64 `json:"totalTokens"` -} - -// ChatMessage is a message in a chat request. -type ChatMessage struct { - // The role of the author of this message. One of system, user, or assistant. - Role string `json:"role"` - // The content of the message. - Content string `json:"content"` - // The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, - // with a maximum length of 64 characters. - Name string `json:"name,omitempty"` -} diff --git a/pkg/cmd/cli.go b/pkg/cmd/cli.go index 154b83ca..60068884 100644 --- a/pkg/cmd/cli.go +++ b/pkg/cmd/cli.go @@ -49,7 +49,7 @@ func NewCLI() *cobra.Command { if err != nil { log.Println("⚠️failed to load dotenv file: ", err) // don't have an inited logger at this moment } else { - log.Printf("πŸ”§dot env file loaded (%v)", dotEnvFile) + log.Printf("πŸ”§dot env file is loaded (%v)", dotEnvFile) } _, err = configProvider.Load(cfgFile) diff --git a/pkg/config/expander_test.go b/pkg/config/expander_test.go index f5d2930f..5686ad4b 100644 --- a/pkg/config/expander_test.go +++ b/pkg/config/expander_test.go @@ -62,3 +62,24 @@ func TestExpander_EnvVarExpanded(t *testing.T) { assert.Equal(t, topP, cfg.Params[0].Value) assert.Equal(t, fmt.Sprintf("$%v", budget), cfg.Params[1].Value) } + +func TestExpander_FileContentExpanded(t *testing.T) { + content, err := os.ReadFile(filepath.Clean(filepath.Join(".", "testdata", "expander.file.yaml"))) + require.NoError(t, err) + + expander := Expander{} + updatedContent := string(expander.Expand(content)) + + require.NotContains(t, updatedContent, "${file:") + require.Contains(t, updatedContent, "sk-fakeapi-token") +} + +func TestExpander_FileDoesntExist(t *testing.T) { + content, err := os.ReadFile(filepath.Clean(filepath.Join(".", "testdata", "expander.file.notfound.yaml"))) + require.NoError(t, err) + + expander := Expander{} + updatedContent := string(expander.Expand(content)) + + require.NotContains(t, updatedContent, "${file:") +} diff --git a/pkg/config/testdata/expander.file.notfound.yaml b/pkg/config/testdata/expander.file.notfound.yaml new file mode 100644 index 00000000..4a61af87 --- /dev/null +++ b/pkg/config/testdata/expander.file.notfound.yaml @@ -0,0 +1,6 @@ +name: "OpenAI" +api_key: "${file:./testdata/doesntexist}" + +params: + - name: budget + value: "$$${file:./testdata/doesntexist}" diff --git a/pkg/config/testdata/expander.file.yaml b/pkg/config/testdata/expander.file.yaml new file mode 100644 index 00000000..30dc7dc8 --- /dev/null +++ b/pkg/config/testdata/expander.file.yaml @@ -0,0 +1,6 @@ +name: "OpenAI" +api_key: "${file:./testdata/openai_key}" + +params: + - name: budget + value: "$$${file:./testdata/openai_key}" diff --git a/pkg/config/testdata/openai_key b/pkg/config/testdata/openai_key new file mode 100644 index 00000000..607f1d9a --- /dev/null +++ b/pkg/config/testdata/openai_key @@ -0,0 +1 @@ +sk-fakeapi-token diff --git a/pkg/gateway.go b/pkg/gateway.go index a27eddbd..6d2a6092 100644 --- a/pkg/gateway.go +++ b/pkg/gateway.go @@ -25,8 +25,8 @@ import ( type Gateway struct { // configProvider holds all configurations configProvider *config.Provider - // telemetry holds logger, meter, and tracer - telemetry *telemetry.Telemetry + // tel holds logger, meter, and tracer + tel *telemetry.Telemetry // serverManager controls API over different protocols serverManager *api.ServerManager // signalChannel is used to receive termination signals from the OS. @@ -43,8 +43,8 @@ func NewGateway(configProvider *config.Provider) (*Gateway, error) { return nil, err } - tel.Logger.Info("🐦Glide is starting up", zap.String("version", version.FullVersion)) - tel.Logger.Debug("βœ… config loaded successfully:\n" + configProvider.GetStr()) + tel.L().Info("🐦Glide is starting up", zap.String("version", version.FullVersion)) + tel.L().Debug("βœ… config loaded successfully:\n" + configProvider.GetStr()) routerManager, err := routers.NewManager(&cfg.Routers, tel) if err != nil { @@ -58,7 +58,7 @@ func NewGateway(configProvider *config.Provider) (*Gateway, error) { return &Gateway{ configProvider: configProvider, - telemetry: tel, + tel: tel, serverManager: serverManager, signalC: make(chan os.Signal, 3), // equal to number of signal types we expect to receive shutdownC: make(chan struct{}), @@ -78,13 +78,13 @@ LOOP: select { // TODO: Watch for config updates case sig := <-gw.signalC: - gw.telemetry.Logger.Info("received signal from os", zap.String("signal", sig.String())) + gw.tel.L().Info("received signal from os", zap.String("signal", sig.String())) break LOOP case <-gw.shutdownC: - gw.telemetry.Logger.Info("received shutdown request") + gw.tel.L().Info("received shutdown request") break LOOP case <-ctx.Done(): - gw.telemetry.Logger.Info("context done, terminating process") + gw.tel.L().Info("context done, terminating process") // Call shutdown with background context as the passed in context has been canceled return gw.shutdown(context.Background()) //nolint:contextcheck } diff --git a/pkg/providers/anthropic/chat.go b/pkg/providers/anthropic/chat.go index 5a8d8ee3..dc8f857d 100644 --- a/pkg/providers/anthropic/chat.go +++ b/pkg/providers/anthropic/chat.go @@ -9,8 +9,6 @@ import ( "net/http" "time" - "glide/pkg/providers/clients" - "glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -63,6 +61,8 @@ func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessa } // Chat sends a chat request to the specified anthropic model. +// +// Ref: https://docs.anthropic.com/claude/reference/messages_post func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request chatRequest := c.createChatRequestSchema(request) @@ -72,10 +72,6 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem return nil, err } - if len(chatResponse.ModelResponse.Message.Content) == 0 { - return nil, ErrEmptyResponse - } - return chatResponse, nil } @@ -99,12 +95,13 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, fmt.Errorf("unable to create anthropic chat request: %w", err) } - req.Header.Set("Authorization", "Bearer "+string(c.config.APIKey)) + req.Header.Set("x-api-key", string(c.config.APIKey)) // must be in lower case + req.Header.Set("anthropic-version", c.apiVersion) req.Header.Set("Content-Type", "application/json") // TODO: this could leak information from messages which may not be a desired thing to have - c.telemetry.Logger.Debug( - "anthropic chat request", + c.tel.L().Debug( + "Anthropic chat request", zap.String("chat_url", c.chatURL), zap.Any("payload", payload), ) @@ -117,71 +114,49 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - c.telemetry.Logger.Error("failed to read anthropic chat response", zap.Error(err)) - } - - c.telemetry.Logger.Error( - "anthropic chat request failed", - zap.Int("status_code", resp.StatusCode), - zap.String("response", string(bodyBytes)), - zap.Any("headers", resp.Header), - ) - - if resp.StatusCode == http.StatusTooManyRequests { - // Read the value of the "Retry-After" header to get the cooldown delay - retryAfter := resp.Header.Get("Retry-After") - - // Parse the value to get the duration - cooldownDelay, err := time.ParseDuration(retryAfter) - if err != nil { - return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) - } - - return nil, clients.NewRateLimitError(&cooldownDelay) - } - - // Server & client errors result in the same error to keep gateway resilient - return nil, clients.ErrProviderUnavailable + return nil, c.errMapper.Map(resp) } // Read the response body into a byte slice bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - c.telemetry.Logger.Error("failed to read anthropic chat response", zap.Error(err)) + c.tel.L().Error("Failed to read anthropic chat response", zap.Error(err)) return nil, err } // Parse the response JSON - var anthropicCompletion ChatCompletion + var anthropicResponse ChatCompletion - err = json.Unmarshal(bodyBytes, &anthropicCompletion) + err = json.Unmarshal(bodyBytes, &anthropicResponse) if err != nil { - c.telemetry.Logger.Error("failed to parse anthropic chat response", zap.Error(err)) + c.tel.L().Error("Failed to parse anthropic chat response", zap.Error(err)) return nil, err } + if len(anthropicResponse.Content) == 0 { + return nil, ErrEmptyResponse + } + + completion := anthropicResponse.Content[0] + usage := anthropicResponse.Usage + // Map response to ChatResponse schema response := schemas.ChatResponse{ - ID: anthropicCompletion.ID, - Created: int(time.Now().UTC().Unix()), // not provided by anthropic - Provider: providerName, - Model: anthropicCompletion.Model, - Cached: false, - ModelResponse: schemas.ProviderResponse{ - SystemID: map[string]string{ - "system_fingerprint": anthropicCompletion.ID, - }, + ID: anthropicResponse.ID, + Created: int(time.Now().UTC().Unix()), // not provided by anthropic + Provider: providerName, + ModelName: anthropicResponse.Model, + Cached: false, + ModelResponse: schemas.ModelResponse{ + SystemID: map[string]string{}, Message: schemas.ChatMessage{ - Role: anthropicCompletion.Content[0].Type, - Content: anthropicCompletion.Content[0].Text, - Name: "", + Role: completion.Type, + Content: completion.Text, }, TokenUsage: schemas.TokenUsage{ - PromptTokens: 0, // Anthropic doesn't send prompt tokens - ResponseTokens: 0, - TotalTokens: 0, + PromptTokens: usage.InputTokens, + ResponseTokens: usage.OutputTokens, + TotalTokens: usage.InputTokens + usage.OutputTokens, }, }, } diff --git a/pkg/providers/anthropic/chat_stream.go b/pkg/providers/anthropic/chat_stream.go new file mode 100644 index 00000000..d5a31bc0 --- /dev/null +++ b/pkg/providers/anthropic/chat_stream.go @@ -0,0 +1,16 @@ +package anthropic + +import ( + "context" + + "glide/pkg/api/schemas" + "glide/pkg/providers/clients" +) + +func (c *Client) SupportChatStream() bool { + return false +} + +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented +} diff --git a/pkg/providers/anthropic/client.go b/pkg/providers/anthropic/client.go index c7131455..1af21262 100644 --- a/pkg/providers/anthropic/client.go +++ b/pkg/providers/anthropic/client.go @@ -22,10 +22,12 @@ var ( type Client struct { baseURL string chatURL string + apiVersion string chatRequestTemplate *ChatRequest + errMapper *ErrorMapper config *Config httpClient *http.Client - telemetry *telemetry.Telemetry + tel *telemetry.Telemetry } // NewClient creates a new OpenAI client for the OpenAI API. @@ -38,8 +40,10 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * c := &Client{ baseURL: providerConfig.BaseURL, chatURL: chatURL, + apiVersion: providerConfig.APIVersion, config: providerConfig, chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + errMapper: NewErrorMapper(tel), httpClient: &http.Client{ Timeout: *clientConfig.Timeout, // TODO: use values from the config @@ -48,7 +52,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * MaxIdleConnsPerHost: 2, }, }, - telemetry: tel, + tel: tel, } return c, nil diff --git a/pkg/providers/anthropic/config.go b/pkg/providers/anthropic/config.go index beb98734..0de765ac 100644 --- a/pkg/providers/anthropic/config.go +++ b/pkg/providers/anthropic/config.go @@ -14,7 +14,6 @@ type Params struct { MaxTokens int `yaml:"max_tokens,omitempty" json:"max_tokens"` StopSequences []string `yaml:"stop,omitempty" json:"stop"` Metadata *string `yaml:"metadata,omitempty" json:"metadata"` - // Stream bool `json:"stream,omitempty"` // TODO: we are not supporting this at the moment } func DefaultParams() Params { @@ -38,6 +37,7 @@ func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error { type Config struct { BaseURL string `yaml:"baseUrl" json:"baseUrl" validate:"required"` + APIVersion string `yaml:"apiVersion" json:"apiVersion" validate:"required"` ChatEndpoint string `yaml:"chatEndpoint" json:"chatEndpoint" validate:"required"` Model string `yaml:"model" json:"model" validate:"required"` APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` @@ -50,6 +50,7 @@ func DefaultConfig() *Config { return &Config{ BaseURL: "https://api.anthropic.com/v1", + APIVersion: "2023-06-01", ChatEndpoint: "/messages", Model: "claude-instant-1.2", DefaultParams: &defaultParams, diff --git a/pkg/providers/anthropic/errors.go b/pkg/providers/anthropic/errors.go new file mode 100644 index 00000000..44016560 --- /dev/null +++ b/pkg/providers/anthropic/errors.go @@ -0,0 +1,56 @@ +package anthropic + +import ( + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + "go.uber.org/zap" +) + +type ErrorMapper struct { + tel *telemetry.Telemetry +} + +func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper { + return &ErrorMapper{ + tel: tel, + } +} + +func (m *ErrorMapper) Map(resp *http.Response) error { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + m.tel.Logger.Error("failed to read anthropic chat response", zap.Error(err)) + } + + m.tel.Logger.Error( + "anthropic chat request failed", + zap.Int("status_code", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return clients.NewRateLimitError(&cooldownDelay) + } + + if resp.StatusCode == http.StatusUnauthorized { + return clients.ErrUnauthorized + } + + // Server & client errors result in the same error to keep gateway resilient + return clients.ErrProviderUnavailable +} diff --git a/pkg/providers/anthropic/schamas.go b/pkg/providers/anthropic/schamas.go index 69b00248..2f915a0c 100644 --- a/pkg/providers/anthropic/schamas.go +++ b/pkg/providers/anthropic/schamas.go @@ -1,6 +1,16 @@ package anthropic -// Anthropic Chat Response +type Content struct { + Type string `json:"type"` + Text string `json:"text"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// ChatCompletion is an Anthropic Chat Response type ChatCompletion struct { ID string `json:"id"` Type string `json:"type"` @@ -9,9 +19,5 @@ type ChatCompletion struct { Content []Content `json:"content"` StopReason string `json:"stop_reason"` StopSequence string `json:"stop_sequence"` -} - -type Content struct { - Type string `json:"type"` - Text string `json:"text"` + Usage Usage `json:"usage"` } diff --git a/pkg/providers/anthropic/testdata/chat.success.json b/pkg/providers/anthropic/testdata/chat.success.json index eaf0f6c9..f4921bd4 100644 --- a/pkg/providers/anthropic/testdata/chat.success.json +++ b/pkg/providers/anthropic/testdata/chat.success.json @@ -1,7 +1,7 @@ { "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", "type": "message", - "model": "claude-2.1", + "model": "claude-instant-1.2", "role": "assistant", "content": [ { @@ -10,5 +10,9 @@ } ], "stop_reason": "end_turn", - "stop_sequence": null -} \ No newline at end of file + "stop_sequence": null, + "usage":{ + "input_tokens": 24, + "output_tokens": 13 + } +} diff --git a/pkg/providers/azureopenai/chat.go b/pkg/providers/azureopenai/chat.go index f961587c..b3563eb2 100644 --- a/pkg/providers/azureopenai/chat.go +++ b/pkg/providers/azureopenai/chat.go @@ -7,40 +7,13 @@ import ( "fmt" "io" "net/http" - "time" "glide/pkg/api/schemas" "glide/pkg/providers/openai" - "glide/pkg/providers/clients" - "go.uber.org/zap" ) -type ChatMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -// ChatRequest is an Azure openai-specific request schema -type ChatRequest struct { - Messages []ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` - LogitBias *map[int]float64 `json:"logit_bias,omitempty"` - User *string `json:"user,omitempty"` - Seed *int `json:"seed,omitempty"` - Tools []string `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` -} - // NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives func NewChatRequestFromConfig(cfg *Config) *ChatRequest { return &ChatRequest{ @@ -49,7 +22,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { MaxTokens: cfg.DefaultParams.MaxTokens, N: cfg.DefaultParams.N, StopWords: cfg.DefaultParams.StopWords, - Stream: false, // unsupported right now + Stream: false, FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty, PresencePenalty: cfg.DefaultParams.PresencePenalty, LogitBias: cfg.DefaultParams.LogitBias, @@ -61,23 +34,10 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { - messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) - - // Add items from messageHistory first and the new chat message last - for _, message := range request.MessageHistory { - messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content}) - } - - messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) - - return messages -} - // Chat sends a chat request to the specified azure openai model. func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request - chatRequest := c.createChatRequestSchema(request) + chatRequest := c.createRequestSchema(request) chatResponse, err := c.doChatRequest(ctx, chatRequest) if err != nil { @@ -91,12 +51,21 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { +// createRequestSchema creates a new ChatRequest object based on the given request. +func (c *Client) createRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation - chatRequest := c.chatRequestTemplate // hoping to get a copy of the template - chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) + chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template + + chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1) + + // Add items from messageHistory first and the new chat message last + for _, message := range request.MessageHistory { + chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content}) + } + + chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) - return chatRequest + return &chatRequest } func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { @@ -115,7 +84,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche req.Header.Set("Content-Type", "application/json") // TODO: this could leak information from messages which may not be a desired thing to have - c.telemetry.Logger.Debug( + c.tel.Logger.Debug( "azure openai chat request", zap.String("chat_url", c.chatURL), zap.Any("payload", payload), @@ -129,39 +98,13 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - c.telemetry.Logger.Error("failed to read azure openai chat response", zap.Error(err)) - } - - c.telemetry.Logger.Error( - "azure openai chat request failed", - zap.Int("status_code", resp.StatusCode), - zap.String("response", string(bodyBytes)), - zap.Any("headers", resp.Header), - ) - - if resp.StatusCode == http.StatusTooManyRequests { - // Read the value of the "Retry-After" header to get the cooldown delay - retryAfter := resp.Header.Get("Retry-After") - - // Parse the value to get the duration - cooldownDelay, err := time.ParseDuration(retryAfter) - if err != nil { - return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) - } - - return nil, clients.NewRateLimitError(&cooldownDelay) - } - - // Server & client errors result in the same error to keep gateway resilient - return nil, clients.ErrProviderUnavailable + return nil, c.errMapper.Map(resp) } // Read the response body into a byte slice bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - c.telemetry.Logger.Error("failed to read azure openai chat response", zap.Error(err)) + c.tel.Logger.Error("failed to read azure openai chat response", zap.Error(err)) return nil, err } @@ -170,7 +113,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche err = json.Unmarshal(bodyBytes, &openAICompletion) if err != nil { - c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err)) + c.tel.Logger.Error("failed to parse openai chat response", zap.Error(err)) return nil, err } @@ -178,19 +121,18 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche // Map response to UnifiedChatResponse schema response := schemas.ChatResponse{ - ID: openAICompletion.ID, - Created: openAICompletion.Created, - Provider: providerName, - Model: openAICompletion.Model, - Cached: false, - ModelResponse: schemas.ProviderResponse{ + ID: openAICompletion.ID, + Created: openAICompletion.Created, + Provider: providerName, + ModelName: openAICompletion.ModelName, + Cached: false, + ModelResponse: schemas.ModelResponse{ SystemID: map[string]string{ "system_fingerprint": openAICompletion.SystemFingerprint, }, Message: schemas.ChatMessage{ Role: openAICompletion.Choices[0].Message.Role, Content: openAICompletion.Choices[0].Message.Content, - Name: "", }, TokenUsage: schemas.TokenUsage{ PromptTokens: openAICompletion.Usage.PromptTokens, diff --git a/pkg/providers/azureopenai/chat_stream.go b/pkg/providers/azureopenai/chat_stream.go new file mode 100644 index 00000000..a3a64bff --- /dev/null +++ b/pkg/providers/azureopenai/chat_stream.go @@ -0,0 +1,231 @@ +package azureopenai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/r3labs/sse/v2" + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + + "go.uber.org/zap" + + "glide/pkg/api/schemas" +) + +var ( + StopReason = "stop" + streamDoneMarker = []byte("[DONE]") +) + +// ChatStream represents chat stream for a specific request +type ChatStream struct { + tel *telemetry.Telemetry + client *http.Client + req *http.Request + reqID string + reqMetadata *schemas.Metadata + resp *http.Response + reader *sse.EventStreamReader + errMapper *ErrorMapper +} + +func NewChatStream( + tel *telemetry.Telemetry, + client *http.Client, + req *http.Request, + reqID string, + reqMetadata *schemas.Metadata, + errMapper *ErrorMapper, +) *ChatStream { + return &ChatStream{ + tel: tel, + client: client, + req: req, + reqID: reqID, + reqMetadata: reqMetadata, + errMapper: errMapper, + } +} + +// Open initializes and opens a ChatStream. +func (s *ChatStream) Open() error { + resp, err := s.client.Do(s.req) //nolint:bodyclose + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return s.errMapper.Map(resp) + } + + s.resp = resp + s.reader = sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize? + + return nil +} + +// Recv receives a chat stream chunk from the ChatStream and returns a ChatStreamChunk object. +func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { + var completionChunk ChatCompletionChunk + + for { + rawEvent, err := s.reader.ReadEvent() + if err != nil { + s.tel.L().Warn( + "Chat stream is unexpectedly disconnected", + zap.String("provider", providerName), + zap.Error(err), + ) + + // if err is io.EOF, this still means that the stream is interrupted unexpectedly + // because the normal stream termination is done via finding out streamDoneMarker + + return nil, clients.ErrProviderUnavailable + } + + s.tel.L().Debug( + "Raw chat stream chunk", + zap.String("provider", providerName), + zap.ByteString("rawChunk", rawEvent), + ) + + event, err := clients.ParseSSEvent(rawEvent) + + if bytes.Equal(event.Data, streamDoneMarker) { + s.tel.L().Info( + "EOF: [DONE] marker found in chat stream", + zap.String("provider", providerName), + ) + + return nil, io.EOF + } + + if err != nil { + return nil, fmt.Errorf("failed to parse chat stream message: %v", err) + } + + if !event.HasContent() { + s.tel.L().Debug( + "Received an empty message in chat stream, skipping it", + zap.String("provider", providerName), + zap.Any("msg", event), + ) + + continue + } + + err = json.Unmarshal(event.Data, &completionChunk) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal AzureOpenAI chat stream chunk: %v", err) + } + + responseChunk := completionChunk.Choices[0] + + var finishReason *schemas.FinishReason + + if responseChunk.FinishReason == StopReason { + finishReason = &schemas.Complete + } + + // TODO: use objectpool here + return &schemas.ChatStreamChunk{ + ID: s.reqID, + Provider: providerName, + Cached: false, + ModelName: completionChunk.ModelName, + Metadata: s.reqMetadata, + ModelResponse: schemas.ModelChunkResponse{ + Metadata: &schemas.Metadata{ + "response_id": completionChunk.ID, + "system_fingerprint": completionChunk.SystemFingerprint, + }, + Message: schemas.ChatMessage{ + Role: responseChunk.Delta.Role, + Content: responseChunk.Delta.Content, + }, + FinishReason: finishReason, + }, + }, nil + } +} + +func (s *ChatStream) Close() error { + if s.resp != nil { + return s.resp.Body.Close() + } + + return nil +} + +func (c *Client) SupportChatStream() bool { + return true +} + +func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) { + // Create a new chat request + httpRequest, err := c.makeStreamReq(ctx, req) + if err != nil { + return nil, err + } + + return NewChatStream( + c.tel, + c.httpClient, + httpRequest, + req.ID, + req.Metadata, + c.errMapper, + ), nil +} + +func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest { + // TODO: consider using objectpool to optimize memory allocation + chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template + + chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1) + + // Add items from messageHistory first and the new chat message last + for _, message := range request.MessageHistory { + chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content}) + } + + chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) + + return &chatRequest +} + +func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) { + chatRequest := c.createRequestFromStream(req) + + chatRequest.Stream = true + + rawPayload, err := json.Marshal(chatRequest) + if err != nil { + return nil, fmt.Errorf("unable to marshal AzureOpenAI chat stream request payload: %w", err) + } + + request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) + if err != nil { + return nil, fmt.Errorf("unable to create AzureOpenAI stream chat request: %w", err) + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("api-key", string(c.config.APIKey)) + request.Header.Set("Cache-Control", "no-cache") + request.Header.Set("Accept", "text/event-stream") + request.Header.Set("Connection", "keep-alive") + + // TODO: this could leak information from messages which may not be a desired thing to have + c.tel.L().Debug( + "Stream chat request", + zap.String("chatURL", c.chatURL), + zap.Any("payload", chatRequest), + ) + + return request, nil +} diff --git a/pkg/providers/azureopenai/chat_stream_test.go b/pkg/providers/azureopenai/chat_stream_test.go new file mode 100644 index 00000000..080ffb24 --- /dev/null +++ b/pkg/providers/azureopenai/chat_stream_test.go @@ -0,0 +1,157 @@ +package azureopenai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/api/schemas" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + + "github.com/stretchr/testify/require" +) + +func TestAzureOpenAIClient_ChatStreamSupported(t *testing.T) { + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + require.True(t, client.SupportChatStream()) +} + +func TestAzureOpenAIClient_ChatStreamRequest(t *testing.T) { + tests := map[string]string{ + "success stream": "./testdata/chat_stream.success.txt", + } + + for name, streamFile := range tests { + t.Run(name, func(t *testing.T) { + AzureOpenAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean(streamFile)) + if err != nil { + t.Errorf("error reading azureopenai chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "text/event-stream") + + _, err = w.Write(chatResponse) + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + AzureopenAIServer := httptest.NewServer(AzureOpenAIMock) + defer AzureopenAIServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = AzureopenAIServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?") + stream, err := client.ChatStream(ctx, req) + require.NoError(t, err) + + err = stream.Open() + require.NoError(t, err) + + for { + chunk, err := stream.Recv() + + if err == io.EOF { + return + } + + require.NoError(t, err) + require.NotNil(t, chunk) + } + }) + } +} + +func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { + tests := map[string]string{ + "success stream, but no last done message": "./testdata/chat_stream.nodone.txt", + "success stream, but with empty event": "./testdata/chat_stream.empty.txt", + } + + for name, streamFile := range tests { + t.Run(name, func(t *testing.T) { + openAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean(streamFile)) + if err != nil { + t.Errorf("error reading openai chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "text/event-stream") + + _, err = w.Write(chatResponse) + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + openAIServer := httptest.NewServer(openAIMock) + defer openAIServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = openAIServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?") + + stream, err := client.ChatStream(ctx, req) + require.NoError(t, err) + + err = stream.Open() + require.NoError(t, err) + + for { + chunk, err := stream.Recv() + if err != nil { + require.ErrorIs(t, err, clients.ErrProviderUnavailable) + return + } + + require.NoError(t, err) + require.NotNil(t, chunk) + } + }) + } +} diff --git a/pkg/providers/azureopenai/client.go b/pkg/providers/azureopenai/client.go index cd05ba90..9a15aeb8 100644 --- a/pkg/providers/azureopenai/client.go +++ b/pkg/providers/azureopenai/client.go @@ -23,9 +23,10 @@ type Client struct { baseURL string // The name of your Azure OpenAI Resource (e.g https://glide-test.openai.azure.com/) chatURL string chatRequestTemplate *ChatRequest + errMapper *ErrorMapper config *Config httpClient *http.Client - telemetry *telemetry.Telemetry + tel *telemetry.Telemetry } // NewClient creates a new Azure OpenAI client for the OpenAI API. @@ -42,6 +43,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * chatURL: chatURL, config: providerConfig, chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + errMapper: NewErrorMapper(tel), httpClient: &http.Client{ // TODO: use values from the config Timeout: *clientConfig.Timeout, @@ -50,7 +52,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * MaxIdleConnsPerHost: 2, }, }, - telemetry: tel, + tel: tel, } return c, nil diff --git a/pkg/providers/azureopenai/client_test.go b/pkg/providers/azureopenai/client_test.go index 8f5de037..be98c316 100644 --- a/pkg/providers/azureopenai/client_test.go +++ b/pkg/providers/azureopenai/client_test.go @@ -106,21 +106,19 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { defer mockServer.Close() - // Create a new client with the mock server URL - client := &Client{ - httpClient: http.DefaultClient, - chatURL: mockServer.URL, - config: &Config{APIKey: "dummy_key"}, - telemetry: telemetry.NewTelemetryMock(), - } + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = mockServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) // Create a chat request payload - payload := &ChatRequest{ - Messages: []ChatMessage{{Role: "human", Content: "Hello"}}, - } + payload := schemas.NewChatFromStr("What's the dealio?") - // Call the doChatRequest function - _, err := client.doChatRequest(context.Background(), payload) + _, err = client.Chat(ctx, payload) require.Error(t, err) require.Contains(t, err.Error(), "provider is not available") diff --git a/pkg/providers/azureopenai/errors.go b/pkg/providers/azureopenai/errors.go new file mode 100644 index 00000000..0e55c0b0 --- /dev/null +++ b/pkg/providers/azureopenai/errors.go @@ -0,0 +1,56 @@ +package azureopenai + +import ( + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + "go.uber.org/zap" +) + +type ErrorMapper struct { + tel *telemetry.Telemetry +} + +func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper { + return &ErrorMapper{ + tel: tel, + } +} + +func (m *ErrorMapper) Map(resp *http.Response) error { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + m.tel.L().Error("failed to read azure openai chat response", zap.Error(err)) + } + + m.tel.L().Error( + "azure openai chat request failed", + zap.Int("status_code", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return clients.NewRateLimitError(&cooldownDelay) + } + + if resp.StatusCode == http.StatusUnauthorized { + return clients.ErrUnauthorized + } + + // Server & client errors result in the same error to keep gateway resilient + return clients.ErrProviderUnavailable +} diff --git a/pkg/providers/azureopenai/schemas.go b/pkg/providers/azureopenai/schemas.go new file mode 100644 index 00000000..993bb8d7 --- /dev/null +++ b/pkg/providers/azureopenai/schemas.go @@ -0,0 +1,67 @@ +package azureopenai + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatRequest is an Azure openai-specific request schema +type ChatRequest struct { + Messages []ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Seed *int `json:"seed,omitempty"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` +} + +// ChatCompletion +// Ref: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions +type ChatCompletion struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + ModelName string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage"` +} + +type Choice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens float64 `json:"prompt_tokens"` + CompletionTokens float64 `json:"completion_tokens"` + TotalTokens float64 `json:"total_tokens"` +} + +// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming +// Ref: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions +type ChatCompletionChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + ModelName string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []StreamChoice `json:"choices"` +} + +type StreamChoice struct { + Index int `json:"index"` + Delta ChatMessage `json:"delta"` + FinishReason string `json:"finish_reason"` +} diff --git a/pkg/providers/azureopenai/testdata/chat_stream.empty.txt b/pkg/providers/azureopenai/testdata/chat_stream.empty.txt new file mode 100644 index 00000000..a04fb787 --- /dev/null +++ b/pkg/providers/azureopenai/testdata/chat_stream.empty.txt @@ -0,0 +1,6 @@ +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]} + +data: + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + diff --git a/pkg/providers/azureopenai/testdata/chat_stream.nodone.txt b/pkg/providers/azureopenai/testdata/chat_stream.nodone.txt new file mode 100644 index 00000000..63152c29 --- /dev/null +++ b/pkg/providers/azureopenai/testdata/chat_stream.nodone.txt @@ -0,0 +1,22 @@ +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"The"},"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" capital"},"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" of"},"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" the"},"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" United"},"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" Kingdom"},"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" is"},"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" London"},"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"."},"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + diff --git a/pkg/providers/azureopenai/testdata/chat_stream.success.txt b/pkg/providers/azureopenai/testdata/chat_stream.success.txt new file mode 100644 index 00000000..1e673eaf --- /dev/null +++ b/pkg/providers/azureopenai/testdata/chat_stream.success.txt @@ -0,0 +1,24 @@ +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"The"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" capital"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" of"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" the"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" United"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" Kingdom"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" London"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + +data: [DONE] + diff --git a/pkg/providers/bedrock/chat.go b/pkg/providers/bedrock/chat.go index 14feb9bc..bb17cb11 100644 --- a/pkg/providers/bedrock/chat.go +++ b/pkg/providers/bedrock/chat.go @@ -99,16 +99,17 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche err = json.Unmarshal(result.Body, &bedrockCompletion) if err != nil { c.telemetry.Logger.Error("failed to parse bedrock chat response", zap.Error(err)) + return nil, err } response := schemas.ChatResponse{ - ID: uuid.NewString(), - Created: int(time.Now().Unix()), - Provider: "aws-bedrock", - Model: c.config.Model, - Cached: false, - ModelResponse: schemas.ProviderResponse{ + ID: uuid.NewString(), + Created: int(time.Now().Unix()), + Provider: "aws-bedrock", + ModelName: c.config.Model, + Cached: false, + ModelResponse: schemas.ModelResponse{ SystemID: map[string]string{ "system_fingerprint": "none", }, @@ -118,9 +119,9 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche Name: "", }, TokenUsage: schemas.TokenUsage{ - PromptTokens: float64(bedrockCompletion.Results[0].TokenCount), + PromptTokens: bedrockCompletion.Results[0].TokenCount, ResponseTokens: -1, - TotalTokens: float64(bedrockCompletion.Results[0].TokenCount), + TotalTokens: bedrockCompletion.Results[0].TokenCount, }, }, } diff --git a/pkg/providers/bedrock/chat_stream.go b/pkg/providers/bedrock/chat_stream.go new file mode 100644 index 00000000..35918a0c --- /dev/null +++ b/pkg/providers/bedrock/chat_stream.go @@ -0,0 +1,16 @@ +package bedrock + +import ( + "context" + + "glide/pkg/api/schemas" + "glide/pkg/providers/clients" +) + +func (c *Client) SupportChatStream() bool { + return false +} + +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented +} diff --git a/pkg/providers/clients/config_test.go b/pkg/providers/clients/config_test.go new file mode 100644 index 00000000..0e725201 --- /dev/null +++ b/pkg/providers/clients/config_test.go @@ -0,0 +1,13 @@ +package clients + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClientConfig_DefaultConfig(t *testing.T) { + config := DefaultClientConfig() + + require.NotEmpty(t, config.Timeout) +} diff --git a/pkg/providers/clients/errors.go b/pkg/providers/clients/errors.go index 8c704a3f..deaf00bc 100644 --- a/pkg/providers/clients/errors.go +++ b/pkg/providers/clients/errors.go @@ -6,7 +6,11 @@ import ( "time" ) -var ErrProviderUnavailable = errors.New("provider is not available") +var ( + ErrProviderUnavailable = errors.New("provider is not available") + ErrUnauthorized = errors.New("API key is wrong or not set") + ErrChatStreamNotImplemented = errors.New("streaming chat API is not implemented for provider") +) type RateLimitError struct { untilReset time.Duration diff --git a/pkg/providers/clients/errors_test.go b/pkg/providers/clients/errors_test.go new file mode 100644 index 00000000..5a570a13 --- /dev/null +++ b/pkg/providers/clients/errors_test.go @@ -0,0 +1,16 @@ +package clients + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRateLimitError(t *testing.T) { + duration := 5 * time.Minute + err := NewRateLimitError(&duration) + + require.Equal(t, duration, err.UntilReset()) + require.Contains(t, err.Error(), "rate limit reached") +} diff --git a/pkg/providers/clients/sse.go b/pkg/providers/clients/sse.go new file mode 100644 index 00000000..5619876f --- /dev/null +++ b/pkg/providers/clients/sse.go @@ -0,0 +1,82 @@ +package clients + +import ( + "bytes" + "errors" +) + +// Taken from https://github.com/r3labs/sse/blob/master/client.go#L322 + +var ( + headerID = []byte("id:") + headerData = []byte("data:") + headerEvent = []byte("event:") + headerRetry = []byte("retry:") +) + +// Event holds all the event source fields +type Event struct { + ID []byte + Data []byte + Event []byte + Retry []byte + Comment []byte +} + +func (e *Event) HasContent() bool { + return len(e.ID) > 0 || len(e.Data) > 0 || len(e.Event) > 0 || len(e.Retry) > 0 +} + +func ParseSSEvent(msg []byte) (event *Event, err error) { + var e Event + + if len(msg) < 1 { + return nil, errors.New("event message was empty") + } + + // Normalize the crlf to lf to make it easier to split the lines. + // Split the line by "\n" or "\r", per the spec. + for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) { + switch { + case bytes.HasPrefix(line, headerID): + e.ID = append([]byte(nil), trimHeader(len(headerID), line)...) + case bytes.HasPrefix(line, headerData): + // The spec allows for multiple data fields per event, concatenated them with "\n". + e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...) + // The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body. + case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))): + e.Data = append(e.Data, byte('\n')) + case bytes.HasPrefix(line, headerEvent): + e.Event = append([]byte(nil), trimHeader(len(headerEvent), line)...) + case bytes.HasPrefix(line, headerRetry): + e.Retry = append([]byte(nil), trimHeader(len(headerRetry), line)...) + default: + // Ignore any garbage that doesn't match what we're looking for. + } //nolint:wsl + } + + // Trim the last "\n" per the spec. + e.Data = bytes.TrimSuffix(e.Data, []byte("\n")) + + return &e, err +} + +func trimHeader(size int, data []byte) []byte { + if data == nil || len(data) < size { + return data + } + + data = data[size:] + + // Remove optional leading whitespace + if len(data) > 0 && data[0] == 32 { + data = data[1:] + } + + // Remove trailing new line + if len(data) > 0 && data[len(data)-1] == 10 { + data = data[:len(data)-1] + } + + return data +} diff --git a/pkg/providers/clients/sse_test.go b/pkg/providers/clients/sse_test.go new file mode 100644 index 00000000..c749737e --- /dev/null +++ b/pkg/providers/clients/sse_test.go @@ -0,0 +1,27 @@ +package clients + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseSSEvent_ValidEvents(t *testing.T) { + tests := []struct { + name string + rawMsg string + data string + }{ + {"data only", "data: {\"id\":\"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg\"}\n", "{\"id\":\"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg\"}"}, + {"empty data", "data:", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event, err := ParseSSEvent([]byte(tt.rawMsg)) + + require.NoError(t, err) + require.Equal(t, []byte(tt.data), event.Data) + }) + } +} diff --git a/pkg/providers/clients/stream.go b/pkg/providers/clients/stream.go new file mode 100644 index 00000000..ff150e8f --- /dev/null +++ b/pkg/providers/clients/stream.go @@ -0,0 +1,31 @@ +package clients + +import ( + "glide/pkg/api/schemas" +) + +type ChatStream interface { + Open() error + Recv() (*schemas.ChatStreamChunk, error) + Close() error +} + +type ChatStreamResult struct { + chunk *schemas.ChatStreamChunk + err error +} + +func (r *ChatStreamResult) Chunk() *schemas.ChatStreamChunk { + return r.chunk +} + +func (r *ChatStreamResult) Error() error { + return r.err +} + +func NewChatStreamResult(chunk *schemas.ChatStreamChunk, err error) *ChatStreamResult { + return &ChatStreamResult{ + chunk: chunk, + err: err, + } +} diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index 165b67bd..86573041 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -15,40 +15,6 @@ import ( "go.uber.org/zap" ) -type ChatMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type ChatHistory struct { - Role string `json:"role"` - Message string `json:"message"` - User string `json:"user,omitempty"` -} - -// ChatRequest is a request to complete a chat completion.. -type ChatRequest struct { - Model string `json:"model"` - Message string `json:"message"` - Temperature float64 `json:"temperature,omitempty"` - PreambleOverride string `json:"preamble_override,omitempty"` - ChatHistory []ChatHistory `json:"chat_history,omitempty"` - ConversationID string `json:"conversation_id,omitempty"` - PromptTruncation string `json:"prompt_truncation,omitempty"` - Connectors []string `json:"connectors,omitempty"` - SearchQueriesOnly bool `json:"search_queries_only,omitempty"` - CitiationQuality string `json:"citiation_quality,omitempty"` - - // Stream bool `json:"stream,omitempty"` -} - -type Connectors struct { - ID string `json:"id"` - UserAccessToken string `json:"user_access_token"` - ContOnFail string `json:"continue_on_failure"` - Options map[string]string `json:"options"` -} - // NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives func NewChatRequestFromConfig(cfg *Config) *ChatRequest { return &ChatRequest{ @@ -61,13 +27,14 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { Connectors: cfg.DefaultParams.Connectors, SearchQueriesOnly: cfg.DefaultParams.SearchQueriesOnly, CitiationQuality: cfg.DefaultParams.CitiationQuality, + Stream: false, } } // Chat sends a chat request to the specified cohere model. func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request - chatRequest := c.createChatRequestSchema(request) + chatRequest := c.createRequestSchema(request) chatResponse, err := c.doChatRequest(ctx, chatRequest) if err != nil { @@ -81,9 +48,9 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { +func (c *Client) createRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation - chatRequest := c.chatRequestTemplate // hoping to get a copy of the template + chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template chatRequest.Message = request.Message.Content // Build the Cohere specific ChatHistory @@ -100,7 +67,7 @@ func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequ } } - return chatRequest + return &chatRequest } func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { @@ -119,7 +86,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche req.Header.Set("Content-Type", "application/json") // TODO: this could leak information from messages which may not be a desired thing to have - c.telemetry.Logger.Debug( + c.tel.Logger.Debug( "cohere chat request", zap.String("chat_url", c.chatURL), zap.Any("payload", payload), @@ -135,10 +102,10 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche if resp.StatusCode != http.StatusOK { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err)) + c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err)) } - c.telemetry.Logger.Error( + c.tel.Logger.Error( "cohere chat request failed", zap.Int("status_code", resp.StatusCode), zap.String("response", string(bodyBytes)), @@ -156,7 +123,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche // Read the response body into a byte slice bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err)) + c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err)) return nil, err } @@ -165,7 +132,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche err = json.Unmarshal(bodyBytes, &responseJSON) if err != nil { - c.telemetry.Logger.Error("failed to parse cohere chat response", zap.Error(err)) + c.tel.Logger.Error("failed to parse cohere chat response", zap.Error(err)) return nil, err } @@ -174,24 +141,24 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche err = json.Unmarshal(bodyBytes, &cohereCompletion) if err != nil { - c.telemetry.Logger.Error("failed to parse cohere chat response", zap.Error(err)) + c.tel.Logger.Error("failed to parse cohere chat response", zap.Error(err)) return nil, err } // Map response to ChatResponse schema response := schemas.ChatResponse{ - ID: cohereCompletion.ResponseID, - Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this - Provider: providerName, - Model: c.config.Model, - Cached: false, - ModelResponse: schemas.ProviderResponse{ + ID: cohereCompletion.ResponseID, + Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this + Provider: providerName, + ModelName: c.config.Model, + Cached: false, + ModelResponse: schemas.ModelResponse{ SystemID: map[string]string{ "generationId": cohereCompletion.GenerationID, "responseId": cohereCompletion.ResponseID, }, Message: schemas.ChatMessage{ - Role: "model", // TODO: Does this need to change? + Role: "model", Content: cohereCompletion.Text, Name: "", }, @@ -209,11 +176,11 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche func (c *Client) handleErrorResponse(resp *http.Response) (*schemas.ChatResponse, error) { bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - c.telemetry.Logger.Error("failed to read cohere chat response", zap.Error(err)) + c.tel.Logger.Error("failed to read cohere chat response", zap.Error(err)) return nil, err } - c.telemetry.Logger.Error( + c.tel.Logger.Error( "cohere chat request failed", zap.Int("status_code", resp.StatusCode), zap.String("response", string(bodyBytes)), @@ -229,6 +196,10 @@ func (c *Client) handleErrorResponse(resp *http.Response) (*schemas.ChatResponse return nil, clients.NewRateLimitError(&cooldownDelay) } + if resp.StatusCode == http.StatusUnauthorized { + return nil, clients.ErrUnauthorized + } + return nil, clients.ErrProviderUnavailable } diff --git a/pkg/providers/cohere/chat_stream.go b/pkg/providers/cohere/chat_stream.go new file mode 100644 index 00000000..f6f9e9a8 --- /dev/null +++ b/pkg/providers/cohere/chat_stream.go @@ -0,0 +1,242 @@ +package cohere + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/r3labs/sse/v2" + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + + "go.uber.org/zap" + + "glide/pkg/api/schemas" +) + +var StopReason = "stream-end" + +// ChatStream represents cohere chat stream for a specific request +type ChatStream struct { + tel *telemetry.Telemetry + client *http.Client + req *http.Request + reqID string + reqMetadata *schemas.Metadata + resp *http.Response + reader *sse.EventStreamReader + errMapper *ErrorMapper +} + +func NewChatStream( + tel *telemetry.Telemetry, + client *http.Client, + req *http.Request, + reqID string, + reqMetadata *schemas.Metadata, + errMapper *ErrorMapper, +) *ChatStream { + return &ChatStream{ + tel: tel, + client: client, + req: req, + reqID: reqID, + reqMetadata: reqMetadata, + errMapper: errMapper, + } +} + +func (s *ChatStream) Open() error { + resp, err := s.client.Do(s.req) //nolint:bodyclose + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return s.errMapper.Map(resp) + } + + s.resp = resp + s.reader = sse.NewEventStreamReader(resp.Body, 8192) // TODO: should we expose maxBufferSize? + + return nil +} + +func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { + var completionChunk ChatCompletionChunk + + for { + rawEvent, err := s.reader.ReadEvent() + if err != nil { + s.tel.L().Warn( + "Chat stream is unexpectedly disconnected", + zap.String("provider", providerName), + zap.Error(err), + ) + + if err == io.EOF { + return nil, io.EOF + } + + // if err is io.EOF, this still means that the stream is interrupted unexpectedly + // because the normal stream termination is done via finding out streamDoneMarker + + return nil, clients.ErrProviderUnavailable + } + + s.tel.L().Debug( + "Raw chat stream chunk", + zap.String("provider", providerName), + zap.ByteString("rawChunk", rawEvent), + ) + + event, err := clients.ParseSSEvent(rawEvent) + if err != nil { + return nil, fmt.Errorf("failed to parse chat stream message: %v", err) + } + + if !event.HasContent() { + s.tel.L().Debug( + "Received an empty message in chat stream, skipping it", + zap.String("provider", providerName), + zap.Any("msg", event), + ) + + continue + } + + err = json.Unmarshal(event.Data, &completionChunk) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err) + } + + responseChunk := completionChunk + + var finishReason *schemas.FinishReason + + if responseChunk.IsFinished { + finishReason = &schemas.Complete + + return &schemas.ChatStreamChunk{ + ID: s.reqID, + Provider: providerName, + Cached: false, + ModelName: "NA", + Metadata: s.reqMetadata, + ModelResponse: schemas.ModelChunkResponse{ + Metadata: &schemas.Metadata{ + "generationId": responseChunk.Response.GenerationID, + "responseId": responseChunk.Response.ResponseID, + }, + Message: schemas.ChatMessage{ + Role: "model", + Content: responseChunk.Text, + }, + FinishReason: finishReason, + }, + }, nil + } + + // TODO: use objectpool here + return &schemas.ChatStreamChunk{ + ID: s.reqID, + Provider: providerName, + Cached: false, + ModelName: "NA", + Metadata: s.reqMetadata, + ModelResponse: schemas.ModelChunkResponse{ + Message: schemas.ChatMessage{ + Role: "model", + Content: responseChunk.Text, + }, + FinishReason: finishReason, + }, + }, nil + } +} + +func (s *ChatStream) Close() error { + if s.resp != nil { + return s.resp.Body.Close() + } + + return nil +} + +func (c *Client) SupportChatStream() bool { + return true +} + +func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) { + // Create a new chat request + httpRequest, err := c.makeStreamReq(ctx, req) + if err != nil { + return nil, err + } + + return NewChatStream( + c.tel, + c.httpClient, + httpRequest, + req.ID, + req.Metadata, + c.errMapper, + ), nil +} + +func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest { + // TODO: consider using objectpool to optimize memory allocation + chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template + + chatRequest.Message = request.Message.Content + + // Build the Cohere specific ChatHistory + if len(request.MessageHistory) > 0 { + chatRequest.ChatHistory = make([]ChatHistory, len(request.MessageHistory)) + for i, message := range request.MessageHistory { + chatRequest.ChatHistory[i] = ChatHistory{ + // Copy the necessary fields from message to ChatHistory + // For example, if ChatHistory has a field called "Text", you can do: + Role: message.Role, + Message: message.Content, + User: "", + } + } + } + + return &chatRequest +} + +func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) { + chatRequest := c.createRequestFromStream(req) + + chatRequest.Stream = true + + rawPayload, err := json.Marshal(chatRequest) + if err != nil { + return nil, fmt.Errorf("unable to marshal cohere chat stream request payload: %w", err) + } + + request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) + if err != nil { + return nil, fmt.Errorf("unable to create cohere stream chat request: %w", err) + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Authorization", fmt.Sprintf("Bearer %v", string(c.config.APIKey))) + request.Header.Set("Cache-Control", "no-cache") + request.Header.Set("Accept", "text/event-stream") + request.Header.Set("Connection", "keep-alive") + + // TODO: this could leak information from messages which may not be a desired thing to have + c.tel.L().Debug( + "Stream chat request", + zap.String("chatURL", c.chatURL), + zap.Any("payload", chatRequest), + ) + + return request, nil +} diff --git a/pkg/providers/cohere/chat_stream_test.go b/pkg/providers/cohere/chat_stream_test.go new file mode 100644 index 00000000..3552b45b --- /dev/null +++ b/pkg/providers/cohere/chat_stream_test.go @@ -0,0 +1,156 @@ +package cohere + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/api/schemas" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + + "github.com/stretchr/testify/require" +) + +func TestCohere_ChatStreamSupported(t *testing.T) { + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + require.True(t, client.SupportChatStream()) +} + +func TestCohere_ChatStreamRequest(t *testing.T) { + tests := map[string]string{ + "success stream": "./testdata/chat_stream.success.txt", + } + + for name, streamFile := range tests { + t.Run(name, func(t *testing.T) { + cohereMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean(streamFile)) + if err != nil { + t.Errorf("error reading cohere chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "text/event-stream") + + _, err = w.Write(chatResponse) + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + cohereServer := httptest.NewServer(cohereMock) + defer cohereServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = cohereServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?") + + stream, err := client.ChatStream(ctx, req) + require.NoError(t, err) + + err = stream.Open() + require.NoError(t, err) + + for { + chunk, err := stream.Recv() + + if err == io.EOF { + return + } + + require.NoError(t, err) + require.NotNil(t, chunk) + } + }) + } +} + +func TestCohere_ChatStreamRequestInterrupted(t *testing.T) { + tests := map[string]string{ + "success stream, but with empty event": "./testdata/chat_stream.empty.txt", + } + + for name, streamFile := range tests { + t.Run(name, func(t *testing.T) { + cohereMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean(streamFile)) + if err != nil { + t.Errorf("error reading cohere chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "text/event-stream") + + _, err = w.Write(chatResponse) + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + cohereServer := httptest.NewServer(cohereMock) + defer cohereServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = cohereServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?") + stream, err := client.ChatStream(ctx, req) + require.NoError(t, err) + + err = stream.Open() + require.NoError(t, err) + + for { + chunk, err := stream.Recv() + if err != nil { + require.ErrorIs(t, err, io.EOF) + return + } + + require.NoError(t, err) + require.NotNil(t, chunk) + } + }) + } +} diff --git a/pkg/providers/cohere/client.go b/pkg/providers/cohere/client.go index a6cc9cf5..ec778c15 100644 --- a/pkg/providers/cohere/client.go +++ b/pkg/providers/cohere/client.go @@ -23,9 +23,10 @@ type Client struct { baseURL string chatURL string chatRequestTemplate *ChatRequest + errMapper *ErrorMapper config *Config httpClient *http.Client - telemetry *telemetry.Telemetry + tel *telemetry.Telemetry } // NewClient creates a new Cohere client for the Cohere API. @@ -48,7 +49,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * MaxIdleConnsPerHost: 2, }, }, - telemetry: tel, + tel: tel, } return c, nil diff --git a/pkg/providers/cohere/config.go b/pkg/providers/cohere/config.go index 1a38aefa..50dbcde4 100644 --- a/pkg/providers/cohere/config.go +++ b/pkg/providers/cohere/config.go @@ -8,7 +8,7 @@ import ( // TODO: Add validations type Params struct { Temperature float64 `json:"temperature,omitempty"` - Stream bool `json:"stream,omitempty"` // unsupported right now + Stream bool `json:"stream,omitempty"` PreambleOverride string `json:"preamble_override,omitempty"` ChatHistory []ChatHistory `json:"chat_history,omitempty"` ConversationID string `json:"conversation_id,omitempty"` diff --git a/pkg/providers/cohere/error.go b/pkg/providers/cohere/error.go new file mode 100644 index 00000000..3e5ae89e --- /dev/null +++ b/pkg/providers/cohere/error.go @@ -0,0 +1,64 @@ +package cohere + +import ( + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + "go.uber.org/zap" +) + +type ErrorMapper struct { + tel *telemetry.Telemetry +} + +func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper { + return &ErrorMapper{ + tel: tel, + } +} + +func (m *ErrorMapper) Map(resp *http.Response) error { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + m.tel.Logger.Error( + "Failed to unmarshal chat response error", + zap.String("provider", providerName), + zap.Error(err), + zap.ByteString("rawResponse", bodyBytes), + ) + + return clients.ErrProviderUnavailable + } + + m.tel.Logger.Error( + "Chat request failed", + zap.String("provider", providerName), + zap.Int("statusCode", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return clients.NewRateLimitError(&cooldownDelay) + } + + if resp.StatusCode == http.StatusUnauthorized { + return clients.ErrUnauthorized + } + + // Server & client errors result in the same error to keep gateway resilient + return clients.ErrProviderUnavailable +} diff --git a/pkg/providers/cohere/schemas.go b/pkg/providers/cohere/schemas.go index c807aa56..f6fc310e 100644 --- a/pkg/providers/cohere/schemas.go +++ b/pkg/providers/cohere/schemas.go @@ -15,10 +15,10 @@ type ChatCompletion struct { } type TokenCount struct { - PromptTokens float64 `json:"prompt_tokens"` - ResponseTokens float64 `json:"response_tokens"` - TotalTokens float64 `json:"total_tokens"` - BilledTokens float64 `json:"billed_tokens"` + PromptTokens int `json:"prompt_tokens"` + ResponseTokens int `json:"response_tokens"` + TotalTokens int `json:"total_tokens"` + BilledTokens int `json:"billed_tokens"` } type Meta struct { @@ -65,3 +65,54 @@ type ConnectorsResponse struct { ContOnFail string `json:"continue_on_failure"` Options map[string]string `json:"options"` } + +// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming +// Ref: https://docs.cohere.com/reference/about +type ChatCompletionChunk struct { + IsFinished bool `json:"is_finished"` + EventType string `json:"event_type"` + Text string `json:"text"` + Response FinalResponse `json:"response,omitempty"` +} + +type FinalResponse struct { + ResponseID string `json:"response_id"` + Text string `json:"text"` + GenerationID string `json:"generation_id"` + TokenCount TokenCount `json:"token_count"` + Meta Meta `json:"meta"` + FinishReason string `json:"finish_reason"` +} + +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatHistory struct { + Role string `json:"role"` + Message string `json:"message"` + User string `json:"user,omitempty"` +} + +// ChatRequest is a request to complete a chat completion.. +type ChatRequest struct { + Model string `json:"model"` + Message string `json:"message"` + Temperature float64 `json:"temperature,omitempty"` + PreambleOverride string `json:"preamble_override,omitempty"` + ChatHistory []ChatHistory `json:"chat_history,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + PromptTruncation string `json:"prompt_truncation,omitempty"` + Connectors []string `json:"connectors,omitempty"` + SearchQueriesOnly bool `json:"search_queries_only,omitempty"` + CitiationQuality string `json:"citiation_quality,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type Connectors struct { + ID string `json:"id"` + UserAccessToken string `json:"user_access_token"` + ContOnFail string `json:"continue_on_failure"` + Options map[string]string `json:"options"` +} diff --git a/pkg/providers/cohere/testdata/chat_stream.empty.txt b/pkg/providers/cohere/testdata/chat_stream.empty.txt new file mode 100644 index 00000000..38471d95 --- /dev/null +++ b/pkg/providers/cohere/testdata/chat_stream.empty.txt @@ -0,0 +1,277 @@ +{ + "is_finished": false, + "event_type": "stream-start", + "generation_id": "d686011c-e1bb-41c1-9964-823d9b94d394" +} +{ +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " capital" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " of" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " the" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " United" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " Kingdom" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " is" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " London" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "." +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " London" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " is" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " a" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " vibrant" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " city" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " full" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " of" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " diverse" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " culture" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " history" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " and" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " iconic" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " landmarks" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "." +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " It" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "'s" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " a" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " global" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " city" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " that" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " has" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " influence" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " in" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " the" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " fields" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " of" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " art" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " fashion" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " finance" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " media" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " and" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " politics" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "." +} +{ + "is_finished": true, + "event_type": "stream-end", + "response": { + "response_id": "d4a5e49e-b892-41c5-950d-a97162d19393", + "text": "The capital of the United Kingdom is London. London is a vibrant city full of diverse culture, history, and iconic landmarks. It's a global city that has influence in the fields of art, fashion, finance, media, and politics.", + "generation_id": "d686011c-e1bb-41c1-9964-823d9b94d394", + "chat_history": [ + { + "role": "USER", + "message": "What's the capital of the United Kingdom?" + }, + { + "role": "CHATBOT", + "message": "The capital of the United Kingdom is London. London is a vibrant city full of diverse culture, history, and iconic landmarks. It's a global city that has influence in the fields of art, fashion, finance, media, and politics." + } + ], + "token_count": { + "prompt_tokens": 75, + "response_tokens": 48, + "total_tokens": 123, + "billed_tokens": 57 + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 9, + "output_tokens": 48 + } + } + }, + "finish_reason": "COMPLETE" +} \ No newline at end of file diff --git a/pkg/providers/cohere/testdata/chat_stream.nodone.txt b/pkg/providers/cohere/testdata/chat_stream.nodone.txt new file mode 100644 index 00000000..785fe492 --- /dev/null +++ b/pkg/providers/cohere/testdata/chat_stream.nodone.txt @@ -0,0 +1,22 @@ +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"The"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" capital"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" of"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" the"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" United"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" Kingdom"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" London"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + diff --git a/pkg/providers/cohere/testdata/chat_stream.success.txt b/pkg/providers/cohere/testdata/chat_stream.success.txt new file mode 100644 index 00000000..d68c7eda --- /dev/null +++ b/pkg/providers/cohere/testdata/chat_stream.success.txt @@ -0,0 +1,280 @@ +{ + "is_finished": false, + "event_type": "stream-start", + "generation_id": "d686011c-e1bb-41c1-9964-823d9b94d394" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "The" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " capital" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " of" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " the" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " United" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " Kingdom" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " is" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " London" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "." +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " London" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " is" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " a" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " vibrant" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " city" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " full" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " of" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " diverse" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " culture" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " history" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " and" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " iconic" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " landmarks" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "." +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " It" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "'s" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " a" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " global" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " city" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " that" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " has" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " influence" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " in" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " the" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " fields" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " of" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " art" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " fashion" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " finance" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " media" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "," +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " and" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": " politics" +} +{ + "is_finished": false, + "event_type": "text-generation", + "text": "." +} +{ + "is_finished": true, + "event_type": "stream-end", + "response": { + "response_id": "d4a5e49e-b892-41c5-950d-a97162d19393", + "text": "The capital of the United Kingdom is London. London is a vibrant city full of diverse culture, history, and iconic landmarks. It's a global city that has influence in the fields of art, fashion, finance, media, and politics.", + "generation_id": "d686011c-e1bb-41c1-9964-823d9b94d394", + "chat_history": [ + { + "role": "USER", + "message": "What's the capital of the United Kingdom?" + }, + { + "role": "CHATBOT", + "message": "The capital of the United Kingdom is London. London is a vibrant city full of diverse culture, history, and iconic landmarks. It's a global city that has influence in the fields of art, fashion, finance, media, and politics." + } + ], + "token_count": { + "prompt_tokens": 75, + "response_tokens": 48, + "total_tokens": 123, + "billed_tokens": 57 + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 9, + "output_tokens": 48 + } + } + }, + "finish_reason": "COMPLETE" +} \ No newline at end of file diff --git a/pkg/providers/config.go b/pkg/providers/config.go index 55d885f6..58aebff7 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -49,18 +49,18 @@ func DefaultLangModelConfig() *LangModelConfig { } } -func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (*LangModel, error) { +func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) { client, err := c.initClient(tel) if err != nil { return nil, fmt.Errorf("error initializing client: %v", err) } - return NewLangModel(c.ID, client, *c.ErrorBudget, *c.Latency, c.Weight), nil + return NewLangModel(c.ID, client, c.ErrorBudget, *c.Latency, c.Weight), nil } // initClient initializes the language model client based on the provided configuration. // It takes a telemetry object as input and returns a LangModelProvider and an error. -func (c *LangModelConfig) initClient(tel *telemetry.Telemetry) (LangModelProvider, error) { +func (c *LangModelConfig) initClient(tel *telemetry.Telemetry) (LangProvider, error) { switch { case c.OpenAI != nil: return openai.NewClient(c.OpenAI, c.Client, tel) diff --git a/pkg/providers/lang.go b/pkg/providers/lang.go new file mode 100644 index 00000000..f4ab258b --- /dev/null +++ b/pkg/providers/lang.go @@ -0,0 +1,177 @@ +package providers + +import ( + "context" + "io" + "time" + + "glide/pkg/routers/health" + + "glide/pkg/api/schemas" + "glide/pkg/providers/clients" + "glide/pkg/routers/latency" +) + +// LangProvider defines an interface a provider should fulfill to be able to serve language chat requests +type LangProvider interface { + ModelProvider + + SupportChatStream() bool + + Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) + ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) +} + +type LangModel interface { + Model + Provider() string + Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) + ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (<-chan *clients.ChatStreamResult, error) +} + +// LanguageModel wraps provider client and expend it with health & latency tracking +// +// The model health is assumed to be independent of model actions (e.g. chat & chatStream) +// The latency is assumed to be action-specific (e.g. streaming chat chunks are much low latency than the full chat action) +type LanguageModel struct { + modelID string + weight int + client LangProvider + healthTracker *health.Tracker + chatLatency *latency.MovingAverage + chatStreamLatency *latency.MovingAverage + latencyUpdateInterval *time.Duration +} + +func NewLangModel(modelID string, client LangProvider, budget *health.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { + return &LanguageModel{ + modelID: modelID, + client: client, + healthTracker: health.NewTracker(budget), + chatLatency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples), + chatStreamLatency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples), + latencyUpdateInterval: latencyConfig.UpdateInterval, + weight: weight, + } +} + +func (m LanguageModel) ID() string { + return m.modelID +} + +func (m LanguageModel) Healthy() bool { + return m.healthTracker.Healthy() +} + +func (m LanguageModel) Weight() int { + return m.weight +} + +func (m LanguageModel) LatencyUpdateInterval() *time.Duration { + return m.latencyUpdateInterval +} + +func (m *LanguageModel) SupportChatStream() bool { + return m.client.SupportChatStream() +} + +func (m LanguageModel) ChatLatency() *latency.MovingAverage { + return m.chatLatency +} + +func (m LanguageModel) ChatStreamLatency() *latency.MovingAverage { + return m.chatStreamLatency +} + +func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { + startedAt := time.Now() + resp, err := m.client.Chat(ctx, request) + + if err == nil { + // record latency per token to normalize measurements + m.chatLatency.Add(float64(time.Since(startedAt)) / float64(resp.ModelResponse.TokenUsage.ResponseTokens)) + + // successful response + resp.ModelID = m.modelID + + return resp, err + } + + m.healthTracker.TrackErr(err) + + return resp, err +} + +func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (<-chan *clients.ChatStreamResult, error) { + stream, err := m.client.ChatStream(ctx, req) + if err != nil { + m.healthTracker.TrackErr(err) + + return nil, err + } + + startedAt := time.Now() + err = stream.Open() + chunkLatency := time.Since(startedAt) + + // the first chunk latency + m.chatStreamLatency.Add(float64(chunkLatency)) + + if err != nil { + m.healthTracker.TrackErr(err) + + // if connection was not even open, we should not send our clients any messages about this failure + + return nil, err + } + + streamResultC := make(chan *clients.ChatStreamResult) + + go func() { + defer close(streamResultC) + defer stream.Close() + + for { + startedAt = time.Now() + chunk, err := stream.Recv() + chunkLatency = time.Since(startedAt) + + if err != nil { + if err == io.EOF { + // end of the stream + return + } + + streamResultC <- clients.NewChatStreamResult(nil, err) + + m.healthTracker.TrackErr(err) + + return + } + + streamResultC <- clients.NewChatStreamResult(chunk, nil) + + if chunkLatency > 1*time.Millisecond { + // All events are read in a bigger chunks of bytes, so one chunk may contain more than one event. + // Each byte chunk is then parsed, so there is no easy way to precisely guess latency per chunk, + // So we assume that if we spent more than 1ms waiting for a chunk it's likely + // we were trying to read from the connection (otherwise, it would take nanoseconds) + m.chatStreamLatency.Add(float64(chunkLatency)) + } + } + }() + + return streamResultC, nil +} + +func (m *LanguageModel) Provider() string { + return m.client.Provider() +} + +func ChatLatency(model Model) *latency.MovingAverage { + return model.(LanguageModel).ChatLatency() +} + +func ChatStreamLatency(model Model) *latency.MovingAverage { + return model.(LanguageModel).ChatStreamLatency() +} diff --git a/pkg/providers/octoml/chat.go b/pkg/providers/octoml/chat.go index 4860a0b9..5a0aa2bc 100644 --- a/pkg/providers/octoml/chat.go +++ b/pkg/providers/octoml/chat.go @@ -7,12 +7,9 @@ import ( "fmt" "io" "net/http" - "time" "glide/pkg/providers/openai" - "glide/pkg/providers/clients" - "glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -117,33 +114,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - c.telemetry.Logger.Error("failed to read octoml chat response", zap.Error(err)) - } - - c.telemetry.Logger.Error( - "octoml chat request failed", - zap.Int("status_code", resp.StatusCode), - zap.String("response", string(bodyBytes)), - zap.Any("headers", resp.Header), - ) - - if resp.StatusCode == http.StatusTooManyRequests { - // Read the value of the "Retry-After" header to get the cooldown delay - retryAfter := resp.Header.Get("Retry-After") - - // Parse the value to get the duration - cooldownDelay, err := time.ParseDuration(retryAfter) - if err != nil { - return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) - } - - return nil, clients.NewRateLimitError(&cooldownDelay) - } - - // Server & client errors result in the same error to keep gateway resilient - return nil, clients.ErrProviderUnavailable + return nil, c.errMapper.Map(resp) } // Read the response body into a byte slice @@ -164,19 +135,18 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche // Map response to UnifiedChatResponse schema response := schemas.ChatResponse{ - ID: openAICompletion.ID, - Created: openAICompletion.Created, - Provider: providerName, - Model: openAICompletion.Model, - Cached: false, - ModelResponse: schemas.ProviderResponse{ + ID: openAICompletion.ID, + Created: openAICompletion.Created, + Provider: providerName, + ModelName: openAICompletion.ModelName, + Cached: false, + ModelResponse: schemas.ModelResponse{ SystemID: map[string]string{ "system_fingerprint": openAICompletion.SystemFingerprint, }, Message: schemas.ChatMessage{ Role: openAICompletion.Choices[0].Message.Role, Content: openAICompletion.Choices[0].Message.Content, - Name: "", }, TokenUsage: schemas.TokenUsage{ PromptTokens: openAICompletion.Usage.PromptTokens, diff --git a/pkg/providers/octoml/chat_stream.go b/pkg/providers/octoml/chat_stream.go new file mode 100644 index 00000000..2885195b --- /dev/null +++ b/pkg/providers/octoml/chat_stream.go @@ -0,0 +1,16 @@ +package octoml + +import ( + "context" + + "glide/pkg/api/schemas" + "glide/pkg/providers/clients" +) + +func (c *Client) SupportChatStream() bool { + return false +} + +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented +} diff --git a/pkg/providers/octoml/client.go b/pkg/providers/octoml/client.go index df8cff5b..66e71502 100644 --- a/pkg/providers/octoml/client.go +++ b/pkg/providers/octoml/client.go @@ -23,6 +23,7 @@ type Client struct { baseURL string chatURL string chatRequestTemplate *ChatRequest + errMapper *ErrorMapper config *Config httpClient *http.Client telemetry *telemetry.Telemetry @@ -40,6 +41,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * chatURL: chatURL, config: providerConfig, chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + errMapper: NewErrorMapper(tel), httpClient: &http.Client{ Timeout: *clientConfig.Timeout, // TODO: use values from the config diff --git a/pkg/providers/octoml/client_test.go b/pkg/providers/octoml/client_test.go index c8a438c1..b6f41b95 100644 --- a/pkg/providers/octoml/client_test.go +++ b/pkg/providers/octoml/client_test.go @@ -63,7 +63,7 @@ func TestOctoMLClient_ChatRequest(t *testing.T) { response, err := client.Chat(ctx, &request) require.NoError(t, err) - require.Equal(t, providerCfg.Model, response.Model) + require.Equal(t, providerCfg.Model, response.ModelName) require.Equal(t, "cmpl-8ea213aece0747aca6d0608b02b57196", response.ID) } @@ -112,21 +112,19 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { defer mockServer.Close() // Create a new client with the mock server URL - client := &Client{ - httpClient: http.DefaultClient, - chatURL: mockServer.URL, - config: &Config{APIKey: "dummy_key"}, - telemetry: telemetry.NewTelemetryMock(), - } + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = mockServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) // Create a chat request payload - payload := &ChatRequest{ - Model: "dummy_model", - Messages: []ChatMessage{{Role: "human", Content: "Hello"}}, - } + payload := schemas.NewChatFromStr("What's the dealio?") - // Call the doChatRequest function - _, err := client.doChatRequest(context.Background(), payload) + _, err = client.Chat(ctx, payload) require.Error(t, err) require.Contains(t, err.Error(), "provider is not available") diff --git a/pkg/providers/octoml/errors.go b/pkg/providers/octoml/errors.go new file mode 100644 index 00000000..23657076 --- /dev/null +++ b/pkg/providers/octoml/errors.go @@ -0,0 +1,56 @@ +package octoml + +import ( + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + "go.uber.org/zap" +) + +type ErrorMapper struct { + tel *telemetry.Telemetry +} + +func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper { + return &ErrorMapper{ + tel: tel, + } +} + +func (m *ErrorMapper) Map(resp *http.Response) error { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + m.tel.L().Error("failed to read octoml chat response", zap.Error(err)) + } + + m.tel.L().Error( + "octoml chat request failed", + zap.Int("status_code", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return clients.NewRateLimitError(&cooldownDelay) + } + + if resp.StatusCode == http.StatusUnauthorized { + return clients.ErrUnauthorized + } + + // Server & client errors result in the same error to keep gateway resilient + return clients.ErrProviderUnavailable +} diff --git a/pkg/providers/ollama/chat.go b/pkg/providers/ollama/chat.go index f2247dd7..dd4a22fc 100644 --- a/pkg/providers/ollama/chat.go +++ b/pkg/providers/ollama/chat.go @@ -181,24 +181,23 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche // Map response to UnifiedChatResponse schema response := schemas.ChatResponse{ - ID: uuid.NewString(), - Created: int(time.Now().Unix()), - Provider: providerName, - Model: ollamaCompletion.Model, - Cached: false, - ModelResponse: schemas.ProviderResponse{ + ID: uuid.NewString(), + Created: int(time.Now().Unix()), + Provider: providerName, + ModelName: ollamaCompletion.Model, + Cached: false, + ModelResponse: schemas.ModelResponse{ SystemID: map[string]string{ "system_fingerprint": "", }, Message: schemas.ChatMessage{ Role: ollamaCompletion.Message.Role, Content: ollamaCompletion.Message.Content, - Name: "", }, TokenUsage: schemas.TokenUsage{ - PromptTokens: float64(ollamaCompletion.EvalCount), - ResponseTokens: float64(ollamaCompletion.EvalCount), - TotalTokens: float64(ollamaCompletion.EvalCount), + PromptTokens: ollamaCompletion.EvalCount, + ResponseTokens: ollamaCompletion.EvalCount, + TotalTokens: ollamaCompletion.EvalCount, }, }, } diff --git a/pkg/providers/ollama/chat_stream.go b/pkg/providers/ollama/chat_stream.go new file mode 100644 index 00000000..2bf0b87f --- /dev/null +++ b/pkg/providers/ollama/chat_stream.go @@ -0,0 +1,16 @@ +package ollama + +import ( + "context" + + "glide/pkg/api/schemas" + "glide/pkg/providers/clients" +) + +func (c *Client) SupportChatStream() bool { + return false +} + +func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented +} diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index bbcc4ff4..fcd75db9 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -7,39 +7,11 @@ import ( "fmt" "io" "net/http" - "time" - - "glide/pkg/providers/clients" "glide/pkg/api/schemas" "go.uber.org/zap" ) -type ChatMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -// ChatRequest is an OpenAI-specific request schema -type ChatRequest struct { - Model string `json:"model"` - Messages []ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` - LogitBias *map[int]float64 `json:"logit_bias,omitempty"` - User *string `json:"user,omitempty"` - Seed *int `json:"seed,omitempty"` - Tools []string `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` -} - // NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives func NewChatRequestFromConfig(cfg *Config) *ChatRequest { return &ChatRequest{ @@ -49,7 +21,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { MaxTokens: cfg.DefaultParams.MaxTokens, N: cfg.DefaultParams.N, StopWords: cfg.DefaultParams.StopWords, - Stream: false, // unsupported right now + Stream: false, FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty, PresencePenalty: cfg.DefaultParams.PresencePenalty, LogitBias: cfg.DefaultParams.LogitBias, @@ -61,23 +33,11 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } } -func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage { - messages := make([]ChatMessage, 0, len(request.MessageHistory)+1) - - // Add items from messageHistory first and the new chat message last - for _, message := range request.MessageHistory { - messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content}) - } - - messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) - - return messages -} - // Chat sends a chat request to the specified OpenAI model. func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { // Create a new chat request - chatRequest := c.createChatRequestSchema(request) + chatRequest := c.createRequestSchema(request) + chatRequest.Stream = false chatResponse, err := c.doChatRequest(ctx, chatRequest) if err != nil { @@ -91,12 +51,21 @@ func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schem return chatResponse, nil } -func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest { +// createRequestSchema creates a new ChatRequest object based on the given request. +func (c *Client) createRequestSchema(request *schemas.ChatRequest) *ChatRequest { // TODO: consider using objectpool to optimize memory allocation - chatRequest := c.chatRequestTemplate // hoping to get a copy of the template - chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request) + chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template + + chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1) - return chatRequest + // Add items from messageHistory first and the new chat message last + for _, message := range request.MessageHistory { + chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content}) + } + + chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) + + return &chatRequest } func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { @@ -111,13 +80,14 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, fmt.Errorf("unable to create openai chat request: %w", err) } - req.Header.Set("Authorization", "Bearer "+string(c.config.APIKey)) req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", string(c.config.APIKey))) // TODO: this could leak information from messages which may not be a desired thing to have - c.telemetry.Logger.Debug( - "openai chat request", - zap.String("chat_url", c.chatURL), + c.tel.Logger.Debug( + "Chat Request", + zap.String("provider", c.Provider()), + zap.String("chatURL", c.chatURL), zap.Any("payload", payload), ) @@ -129,71 +99,55 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - c.telemetry.Logger.Error("failed to read openai chat response", zap.Error(err)) - } - - c.telemetry.Logger.Error( - "openai chat request failed", - zap.Int("status_code", resp.StatusCode), - zap.String("response", string(bodyBytes)), - zap.Any("headers", resp.Header), - ) - - if resp.StatusCode == http.StatusTooManyRequests { - // Read the value of the "Retry-After" header to get the cooldown delay - retryAfter := resp.Header.Get("Retry-After") - - // Parse the value to get the duration - cooldownDelay, err := time.ParseDuration(retryAfter) - if err != nil { - return nil, fmt.Errorf("failed to parse cooldown delay from headers: %w", err) - } - - return nil, clients.NewRateLimitError(&cooldownDelay) - } - - // Server & client errors result in the same error to keep gateway resilient - return nil, clients.ErrProviderUnavailable + return nil, c.errMapper.Map(resp) } // Read the response body into a byte slice bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - c.telemetry.Logger.Error("failed to read openai chat response", zap.Error(err)) + c.tel.Logger.Error( + "Failed to read chat response", + zap.String("provider", c.Provider()), zap.Error(err), + zap.ByteString("rawResponse", bodyBytes), + ) + return nil, err } // Parse the response JSON - var openAICompletion ChatCompletion + var chatCompletion ChatCompletion - err = json.Unmarshal(bodyBytes, &openAICompletion) + err = json.Unmarshal(bodyBytes, &chatCompletion) if err != nil { - c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err)) + c.tel.Logger.Error( + "Failed to unmarshal chat response", + zap.String("provider", c.Provider()), + zap.ByteString("rawResponse", bodyBytes), + zap.Error(err), + ) + return nil, err } // Map response to ChatResponse schema response := schemas.ChatResponse{ - ID: openAICompletion.ID, - Created: openAICompletion.Created, - Provider: providerName, - Model: openAICompletion.Model, - Cached: false, - ModelResponse: schemas.ProviderResponse{ + ID: chatCompletion.ID, + Created: chatCompletion.Created, + Provider: providerName, + ModelName: chatCompletion.ModelName, + Cached: false, + ModelResponse: schemas.ModelResponse{ SystemID: map[string]string{ - "system_fingerprint": openAICompletion.SystemFingerprint, + "system_fingerprint": chatCompletion.SystemFingerprint, }, Message: schemas.ChatMessage{ - Role: openAICompletion.Choices[0].Message.Role, - Content: openAICompletion.Choices[0].Message.Content, - Name: "", + Role: chatCompletion.Choices[0].Message.Role, + Content: chatCompletion.Choices[0].Message.Content, }, TokenUsage: schemas.TokenUsage{ - PromptTokens: openAICompletion.Usage.PromptTokens, - ResponseTokens: openAICompletion.Usage.CompletionTokens, - TotalTokens: openAICompletion.Usage.TotalTokens, + PromptTokens: chatCompletion.Usage.PromptTokens, + ResponseTokens: chatCompletion.Usage.CompletionTokens, + TotalTokens: chatCompletion.Usage.TotalTokens, }, }, } diff --git a/pkg/providers/openai/chat_stream.go b/pkg/providers/openai/chat_stream.go new file mode 100644 index 00000000..fb8be776 --- /dev/null +++ b/pkg/providers/openai/chat_stream.go @@ -0,0 +1,224 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/r3labs/sse/v2" + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + + "go.uber.org/zap" + + "glide/pkg/api/schemas" +) + +var ( + StopReason = "stop" + streamDoneMarker = []byte("[DONE]") +) + +// ChatStream represents OpenAI chat stream for a specific request +type ChatStream struct { + tel *telemetry.Telemetry + client *http.Client + req *http.Request + reqID string + reqMetadata *schemas.Metadata + resp *http.Response + reader *sse.EventStreamReader + errMapper *ErrorMapper +} + +func NewChatStream( + tel *telemetry.Telemetry, + client *http.Client, + req *http.Request, + reqID string, + reqMetadata *schemas.Metadata, + errMapper *ErrorMapper, +) *ChatStream { + return &ChatStream{ + tel: tel, + client: client, + req: req, + reqID: reqID, + reqMetadata: reqMetadata, + errMapper: errMapper, + } +} + +func (s *ChatStream) Open() error { + resp, err := s.client.Do(s.req) //nolint:bodyclose + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return s.errMapper.Map(resp) + } + + s.resp = resp + s.reader = sse.NewEventStreamReader(resp.Body, 4096) // TODO: should we expose maxBufferSize? + + return nil +} + +func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { + var completionChunk ChatCompletionChunk + + for { + rawEvent, err := s.reader.ReadEvent() + if err != nil { + s.tel.L().Warn( + "Chat stream is unexpectedly disconnected", + zap.String("provider", providerName), + zap.Error(err), + ) + + // if err is io.EOF, this still means that the stream is interrupted unexpectedly + // because the normal stream termination is done via finding out streamDoneMarker + + return nil, clients.ErrProviderUnavailable + } + + s.tel.L().Debug( + "Raw chat stream chunk", + zap.String("provider", providerName), + zap.ByteString("rawChunk", rawEvent), + ) + + event, err := clients.ParseSSEvent(rawEvent) + + if bytes.Equal(event.Data, streamDoneMarker) { + return nil, io.EOF + } + + if err != nil { + return nil, fmt.Errorf("failed to parse chat stream message: %v", err) + } + + if !event.HasContent() { + s.tel.L().Debug( + "Received an empty message in chat stream, skipping it", + zap.String("provider", providerName), + zap.Any("msg", event), + ) + + continue + } + + err = json.Unmarshal(event.Data, &completionChunk) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal chat stream chunk: %v", err) + } + + responseChunk := completionChunk.Choices[0] + + var finishReason *schemas.FinishReason + + if responseChunk.FinishReason == StopReason { + finishReason = &schemas.Complete + } + + // TODO: use objectpool here + return &schemas.ChatStreamChunk{ + ID: s.reqID, + Provider: providerName, + Cached: false, + ModelName: completionChunk.ModelName, + Metadata: s.reqMetadata, + ModelResponse: schemas.ModelChunkResponse{ + Metadata: &schemas.Metadata{ + "response_id": completionChunk.ID, + "system_fingerprint": completionChunk.SystemFingerprint, + }, + Message: schemas.ChatMessage{ + Role: responseChunk.Delta.Role, + Content: responseChunk.Delta.Content, + }, + FinishReason: finishReason, + }, + }, nil + } +} + +func (s *ChatStream) Close() error { + if s.resp != nil { + return s.resp.Body.Close() + } + + return nil +} + +func (c *Client) SupportChatStream() bool { + return true +} + +func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) { + // Create a new chat request + httpRequest, err := c.makeStreamReq(ctx, req) + if err != nil { + return nil, err + } + + return NewChatStream( + c.tel, + c.httpClient, + httpRequest, + req.ID, + req.Metadata, + c.errMapper, + ), nil +} + +func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest { + // TODO: consider using objectpool to optimize memory allocation + chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template + + chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1) + + // Add items from messageHistory first and the new chat message last + for _, message := range request.MessageHistory { + chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content}) + } + + chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content}) + + return &chatRequest +} + +func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) { + chatRequest := c.createRequestFromStream(req) + + chatRequest.Stream = true + + rawPayload, err := json.Marshal(chatRequest) + if err != nil { + return nil, fmt.Errorf("unable to marshal openAI chat stream request payload: %w", err) + } + + request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.chatURL, bytes.NewBuffer(rawPayload)) + if err != nil { + return nil, fmt.Errorf("unable to create OpenAI stream chat request: %w", err) + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Authorization", fmt.Sprintf("Bearer %v", string(c.config.APIKey))) + request.Header.Set("Cache-Control", "no-cache") + request.Header.Set("Accept", "text/event-stream") + request.Header.Set("Connection", "keep-alive") + + // TODO: this could leak information from messages which may not be a desired thing to have + c.tel.L().Debug( + "Stream chat request", + zap.String("chatURL", c.chatURL), + zap.Any("payload", chatRequest), + ) + + return request, nil +} diff --git a/pkg/providers/openai/chat_stream_test.go b/pkg/providers/openai/chat_stream_test.go new file mode 100644 index 00000000..ba1542da --- /dev/null +++ b/pkg/providers/openai/chat_stream_test.go @@ -0,0 +1,155 @@ +package openai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "glide/pkg/api/schemas" + + "github.com/stretchr/testify/require" + "glide/pkg/providers/clients" + "glide/pkg/telemetry" +) + +func TestOpenAIClient_ChatStreamSupported(t *testing.T) { + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + require.True(t, client.SupportChatStream()) +} + +func TestOpenAIClient_ChatStreamRequest(t *testing.T) { + tests := map[string]string{ + "success stream": "./testdata/chat_stream.success.txt", + } + + for name, streamFile := range tests { + t.Run(name, func(t *testing.T) { + openAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean(streamFile)) + if err != nil { + t.Errorf("error reading openai chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "text/event-stream") + + _, err = w.Write(chatResponse) + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + openAIServer := httptest.NewServer(openAIMock) + defer openAIServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = openAIServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?") + stream, err := client.ChatStream(ctx, req) + require.NoError(t, err) + + err = stream.Open() + require.NoError(t, err) + + for { + chunk, err := stream.Recv() + + if err == io.EOF { + return + } + + require.NoError(t, err) + require.NotNil(t, chunk) + } + }) + } +} + +func TestOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { + tests := map[string]string{ + "success stream, but no last done message": "./testdata/chat_stream.nodone.txt", + "success stream, but with empty event": "./testdata/chat_stream.empty.txt", + } + + for name, streamFile := range tests { + t.Run(name, func(t *testing.T) { + openAIMock := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawPayload, _ := io.ReadAll(r.Body) + + var data interface{} + // Parse the JSON body + err := json.Unmarshal(rawPayload, &data) + if err != nil { + t.Errorf("error decoding payload (%q): %v", string(rawPayload), err) + } + + chatResponse, err := os.ReadFile(filepath.Clean(streamFile)) + if err != nil { + t.Errorf("error reading openai chat mock response: %v", err) + } + + w.Header().Set("Content-Type", "text/event-stream") + + _, err = w.Write(chatResponse) + if err != nil { + t.Errorf("error on sending chat response: %v", err) + } + }) + + openAIServer := httptest.NewServer(openAIMock) + defer openAIServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = openAIServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + req := schemas.NewChatStreamFromStr("What's the capital of the United Kingdom?") + stream, err := client.ChatStream(ctx, req) + require.NoError(t, err) + + err = stream.Open() + require.NoError(t, err) + + for { + chunk, err := stream.Recv() + if err != nil { + require.ErrorIs(t, err, clients.ErrProviderUnavailable) + return + } + + require.NoError(t, err) + require.NotNil(t, chunk) + } + }) + } +} diff --git a/pkg/providers/openai/client_test.go b/pkg/providers/openai/chat_test.go similarity index 67% rename from pkg/providers/openai/client_test.go rename to pkg/providers/openai/chat_test.go index 6bd8298d..bde4486b 100644 --- a/pkg/providers/openai/client_test.go +++ b/pkg/providers/openai/chat_test.go @@ -10,9 +10,8 @@ import ( "path/filepath" "testing" - "glide/pkg/providers/clients" - "glide/pkg/api/schemas" + "glide/pkg/providers/clients" "glide/pkg/telemetry" @@ -66,3 +65,32 @@ func TestOpenAIClient_ChatRequest(t *testing.T) { require.Equal(t, "chatcmpl-123", response.ID) } + +func TestOpenAIClient_RateLimit(t *testing.T) { + openAIMock := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Retry-After", "5m") + w.WriteHeader(http.StatusTooManyRequests) + }) + + openAIServer := httptest.NewServer(openAIMock) + defer openAIServer.Close() + + ctx := context.Background() + providerCfg := DefaultConfig() + clientCfg := clients.DefaultClientConfig() + + providerCfg.BaseURL = openAIServer.URL + + client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) + require.NoError(t, err) + + request := schemas.ChatRequest{Message: schemas.ChatMessage{ + Role: "user", + Content: "What's the biggest animal?", + }} + + _, err = client.Chat(ctx, &request) + + require.Error(t, err) + require.IsType(t, &clients.RateLimitError{}, err) +} diff --git a/pkg/providers/openai/client.go b/pkg/providers/openai/client.go index c56f227b..22d68d45 100644 --- a/pkg/providers/openai/client.go +++ b/pkg/providers/openai/client.go @@ -23,9 +23,10 @@ type Client struct { baseURL string chatURL string chatRequestTemplate *ChatRequest + errMapper *ErrorMapper config *Config httpClient *http.Client - telemetry *telemetry.Telemetry + tel *telemetry.Telemetry } // NewClient creates a new OpenAI client for the OpenAI API. @@ -40,6 +41,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * chatURL: chatURL, config: providerConfig, chatRequestTemplate: NewChatRequestFromConfig(providerConfig), + errMapper: NewErrorMapper(tel), httpClient: &http.Client{ Timeout: *clientConfig.Timeout, // TODO: use values from the config @@ -48,7 +50,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * MaxIdleConnsPerHost: 2, }, }, - telemetry: tel, + tel: tel, } return c, nil diff --git a/pkg/providers/openai/config.go b/pkg/providers/openai/config.go index 86854f3e..49781d60 100644 --- a/pkg/providers/openai/config.go +++ b/pkg/providers/openai/config.go @@ -20,7 +20,6 @@ type Params struct { Tools []string `yaml:"tools,omitempty" json:"tools"` ToolChoice interface{} `yaml:"tool_choice,omitempty" json:"tool_choice"` ResponseFormat interface{} `yaml:"response_format,omitempty" json:"response_format"` // TODO: should this be a part of the chat request API? - // Stream bool `json:"stream,omitempty"` // TODO: we are not supporting this at the moment } func DefaultParams() Params { diff --git a/pkg/providers/openai/errors.go b/pkg/providers/openai/errors.go new file mode 100644 index 00000000..94ef5418 --- /dev/null +++ b/pkg/providers/openai/errors.go @@ -0,0 +1,64 @@ +package openai + +import ( + "fmt" + "io" + "net/http" + "time" + + "glide/pkg/providers/clients" + "glide/pkg/telemetry" + "go.uber.org/zap" +) + +type ErrorMapper struct { + tel *telemetry.Telemetry +} + +func NewErrorMapper(tel *telemetry.Telemetry) *ErrorMapper { + return &ErrorMapper{ + tel: tel, + } +} + +func (m *ErrorMapper) Map(resp *http.Response) error { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + m.tel.Logger.Error( + "Failed to unmarshal chat response error", + zap.String("provider", providerName), + zap.Error(err), + zap.ByteString("rawResponse", bodyBytes), + ) + + return clients.ErrProviderUnavailable + } + + m.tel.Logger.Error( + "Chat request failed", + zap.String("provider", providerName), + zap.Int("statusCode", resp.StatusCode), + zap.String("response", string(bodyBytes)), + zap.Any("headers", resp.Header), + ) + + if resp.StatusCode == http.StatusTooManyRequests { + // Read the value of the "Retry-After" header to get the cooldown delay + retryAfter := resp.Header.Get("Retry-After") + + // Parse the value to get the duration + cooldownDelay, err := time.ParseDuration(retryAfter) + if err != nil { + return fmt.Errorf("failed to parse cooldown delay from headers: %w", err) + } + + return clients.NewRateLimitError(&cooldownDelay) + } + + if resp.StatusCode == http.StatusUnauthorized { + return clients.ErrUnauthorized + } + + // Server & client errors result in the same error to keep gateway resilient + return clients.ErrProviderUnavailable +} diff --git a/pkg/providers/openai/schemas.go b/pkg/providers/openai/schemas.go index cf41aebf..022d510e 100644 --- a/pkg/providers/openai/schemas.go +++ b/pkg/providers/openai/schemas.go @@ -2,11 +2,38 @@ package openai // OpenAI Chat Response (also used by Azure OpenAI and OctoML) +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatRequest is an OpenAI-specific request schema +type ChatRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Seed *int `json:"seed,omitempty"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` +} + +// ChatCompletion +// Ref: https://platform.openai.com/docs/api-reference/chat/object type ChatCompletion struct { ID string `json:"id"` Object string `json:"object"` Created int `json:"created"` - Model string `json:"model"` + ModelName string `json:"model"` SystemFingerprint string `json:"system_fingerprint"` Choices []Choice `json:"choices"` Usage Usage `json:"usage"` @@ -20,7 +47,25 @@ type Choice struct { } type Usage struct { - PromptTokens float64 `json:"prompt_tokens"` - CompletionTokens float64 `json:"completion_tokens"` - TotalTokens float64 `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming +// Ref: https://platform.openai.com/docs/api-reference/chat/streaming +type ChatCompletionChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + ModelName string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []StreamChoice `json:"choices"` +} + +type StreamChoice struct { + Index int `json:"index"` + Delta ChatMessage `json:"delta"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` } diff --git a/pkg/providers/openai/testdata/chat_stream.empty.txt b/pkg/providers/openai/testdata/chat_stream.empty.txt new file mode 100644 index 00000000..ff2a38eb --- /dev/null +++ b/pkg/providers/openai/testdata/chat_stream.empty.txt @@ -0,0 +1,6 @@ +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + diff --git a/pkg/providers/openai/testdata/chat_stream.nodone.txt b/pkg/providers/openai/testdata/chat_stream.nodone.txt new file mode 100644 index 00000000..785fe492 --- /dev/null +++ b/pkg/providers/openai/testdata/chat_stream.nodone.txt @@ -0,0 +1,22 @@ +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"The"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" capital"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" of"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" the"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" United"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" Kingdom"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" London"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + diff --git a/pkg/providers/openai/testdata/chat_stream.success.txt b/pkg/providers/openai/testdata/chat_stream.success.txt new file mode 100644 index 00000000..1e673eaf --- /dev/null +++ b/pkg/providers/openai/testdata/chat_stream.success.txt @@ -0,0 +1,24 @@ +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"The"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" capital"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" of"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" the"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" United"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" Kingdom"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":" London"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-8wFR3h2Spa9XeRbipfaJczj42pZQg","object":"chat.completion.chunk","created":1708893049,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_86156a94a0","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + +data: [DONE] + diff --git a/pkg/providers/provider.go b/pkg/providers/provider.go index 399d6ee7..0fffa08b 100644 --- a/pkg/providers/provider.go +++ b/pkg/providers/provider.go @@ -1,106 +1,18 @@ package providers import ( - "context" - "errors" "time" - - "glide/pkg/providers/clients" - "glide/pkg/routers/health" - "glide/pkg/routers/latency" - - "glide/pkg/api/schemas" ) -// LangModelProvider defines an interface a provider should fulfill to be able to serve language chat requests -type LangModelProvider interface { +// ModelProvider exposes provider context +type ModelProvider interface { Provider() string - Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) } +// Model represent a configured external modality-agnostic model with its routing properties and status type Model interface { ID() string Healthy() bool - Latency() *latency.MovingAverage LatencyUpdateInterval() *time.Duration Weight() int } - -type LanguageModel interface { - Model - LangModelProvider -} - -// LangModel wraps provider client and expend it with health & latency tracking -type LangModel struct { - modelID string - weight int - client LangModelProvider - rateLimit *health.RateLimitTracker - errorBudget *health.TokenBucket // TODO: centralize provider API health tracking in the registry - latency *latency.MovingAverage - latencyUpdateInterval *time.Duration -} - -func NewLangModel(modelID string, client LangModelProvider, budget health.ErrorBudget, latencyConfig latency.Config, weight int) *LangModel { - return &LangModel{ - modelID: modelID, - client: client, - rateLimit: health.NewRateLimitTracker(), - errorBudget: health.NewTokenBucket(budget.TimePerTokenMicro(), budget.Budget()), - latency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples), - latencyUpdateInterval: latencyConfig.UpdateInterval, - weight: weight, - } -} - -func (m *LangModel) ID() string { - return m.modelID -} - -func (m *LangModel) Provider() string { - return m.client.Provider() -} - -func (m *LangModel) Latency() *latency.MovingAverage { - return m.latency -} - -func (m *LangModel) LatencyUpdateInterval() *time.Duration { - return m.latencyUpdateInterval -} - -func (m *LangModel) Healthy() bool { - return !m.rateLimit.Limited() && m.errorBudget.HasTokens() -} - -func (m *LangModel) Weight() int { - return m.weight -} - -func (m *LangModel) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { - startedAt := time.Now() - resp, err := m.client.Chat(ctx, request) - - if err == nil { - // record latency per token to normalize measurements - m.latency.Add(float64(time.Since(startedAt)) / resp.ModelResponse.TokenUsage.ResponseTokens) - - // successful response - resp.ModelID = m.modelID - - return resp, err - } - - var rle *clients.RateLimitError - - if errors.As(err, &rle) { - m.rateLimit.SetLimited(rle.UntilReset()) - - return resp, err - } - - _ = m.errorBudget.Take(1) - - return resp, err -} diff --git a/pkg/providers/testing.go b/pkg/providers/testing.go deleted file mode 100644 index 890421a0..00000000 --- a/pkg/providers/testing.go +++ /dev/null @@ -1,100 +0,0 @@ -package providers - -import ( - "context" - "time" - - "glide/pkg/routers/latency" - - "glide/pkg/api/schemas" -) - -type ResponseMock struct { - Msg string - Err *error -} - -func (m *ResponseMock) Resp() *schemas.ChatResponse { - return &schemas.ChatResponse{ - ID: "rsp0001", - ModelResponse: schemas.ProviderResponse{ - SystemID: map[string]string{ - "ID": "0001", - }, - Message: schemas.ChatMessage{ - Content: m.Msg, - }, - }, - } -} - -type ProviderMock struct { - idx int - responses []ResponseMock -} - -func NewProviderMock(responses []ResponseMock) *ProviderMock { - return &ProviderMock{ - idx: 0, - responses: responses, - } -} - -func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatRequest) (*schemas.ChatResponse, error) { - response := c.responses[c.idx] - c.idx++ - - if response.Err != nil { - return nil, *response.Err - } - - return response.Resp(), nil -} - -func (c *ProviderMock) Provider() string { - return "provider_mock" -} - -type LangModelMock struct { - modelID string - healthy bool - latency *latency.MovingAverage - weight int -} - -func NewLangModelMock(ID string, healthy bool, avgLatency float64, weight int) *LangModelMock { - movingAverage := latency.NewMovingAverage(0.06, 3) - - if avgLatency > 0.0 { - movingAverage.Set(avgLatency) - } - - return &LangModelMock{ - modelID: ID, - healthy: healthy, - latency: movingAverage, - weight: weight, - } -} - -func (m *LangModelMock) ID() string { - return m.modelID -} - -func (m *LangModelMock) Healthy() bool { - return m.healthy -} - -func (m *LangModelMock) Latency() *latency.MovingAverage { - return m.latency -} - -func (m *LangModelMock) LatencyUpdateInterval() *time.Duration { - updateInterval := 30 * time.Second - - return &updateInterval -} - -func (m *LangModelMock) Weight() int { - return m.weight -} diff --git a/pkg/providers/testing/lang.go b/pkg/providers/testing/lang.go new file mode 100644 index 00000000..405b3c53 --- /dev/null +++ b/pkg/providers/testing/lang.go @@ -0,0 +1,154 @@ +package testing + +import ( + "context" + "io" + + "glide/pkg/api/schemas" + "glide/pkg/providers/clients" +) + +// RespMock mocks a chat response or a streaming chat chunk +type RespMock struct { + Msg string + Err *error +} + +func (m *RespMock) Resp() *schemas.ChatResponse { + return &schemas.ChatResponse{ + ID: "rsp0001", + ModelResponse: schemas.ModelResponse{ + SystemID: map[string]string{ + "ID": "0001", + }, + Message: schemas.ChatMessage{ + Content: m.Msg, + }, + }, + } +} + +func (m *RespMock) RespChunk() *schemas.ChatStreamChunk { + return &schemas.ChatStreamChunk{ + ID: "rsp0001", + ModelResponse: schemas.ModelChunkResponse{ + Message: schemas.ChatMessage{ + Content: m.Msg, + }, + }, + } +} + +// RespStreamMock mocks a chat stream +type RespStreamMock struct { + idx int + OpenErr error + Chunks *[]RespMock +} + +func NewRespStreamMock(chunk *[]RespMock) RespStreamMock { + return RespStreamMock{ + idx: 0, + OpenErr: nil, + Chunks: chunk, + } +} + +func NewRespStreamWithOpenErr(openErr error) RespStreamMock { + return RespStreamMock{ + idx: 0, + OpenErr: openErr, + Chunks: nil, + } +} + +func (m *RespStreamMock) Open() error { + if m.OpenErr != nil { + return m.OpenErr + } + + return nil +} + +func (m *RespStreamMock) Recv() (*schemas.ChatStreamChunk, error) { + if m.Chunks != nil && m.idx >= len(*m.Chunks) { + return nil, io.EOF + } + + chunks := *m.Chunks + + chunk := chunks[m.idx] + m.idx++ + + if chunk.Err != nil { + return nil, *chunk.Err + } + + return chunk.RespChunk(), nil +} + +func (m *RespStreamMock) Close() error { + return nil +} + +// ProviderMock mocks a model provider +type ProviderMock struct { + idx int + chatResps *[]RespMock + chatStreams *[]RespStreamMock + supportStreaming bool +} + +func NewProviderMock(responses []RespMock) *ProviderMock { + return &ProviderMock{ + idx: 0, + chatResps: &responses, + supportStreaming: false, + } +} + +func NewStreamProviderMock(chatStreams []RespStreamMock) *ProviderMock { + return &ProviderMock{ + idx: 0, + chatStreams: &chatStreams, + supportStreaming: true, + } +} + +func (c *ProviderMock) SupportChatStream() bool { + return c.supportStreaming +} + +func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatRequest) (*schemas.ChatResponse, error) { + if c.chatResps == nil { + return nil, clients.ErrProviderUnavailable + } + + responses := *c.chatResps + + response := responses[c.idx] + c.idx++ + + if response.Err != nil { + return nil, *response.Err + } + + return response.Resp(), nil +} + +func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) { + if c.chatStreams == nil || c.idx >= len(*c.chatStreams) { + return nil, clients.ErrProviderUnavailable + } + + streams := *c.chatStreams + + stream := streams[c.idx] + c.idx++ + + return &stream, nil +} + +func (c *ProviderMock) Provider() string { + return "provider_mock" +} diff --git a/pkg/providers/testing/models.go b/pkg/providers/testing/models.go new file mode 100644 index 00000000..c5f8d6d6 --- /dev/null +++ b/pkg/providers/testing/models.go @@ -0,0 +1,57 @@ +package testing + +import ( + "time" + + "glide/pkg/providers" + "glide/pkg/routers/latency" +) + +// LangModelMock +type LangModelMock struct { + modelID string + healthy bool + chatLatency *latency.MovingAverage + weight int +} + +func NewLangModelMock(ID string, healthy bool, avgLatency float64, weight int) LangModelMock { + chatLatency := latency.NewMovingAverage(0.06, 3) + + if avgLatency > 0.0 { + chatLatency.Set(avgLatency) + } + + return LangModelMock{ + modelID: ID, + healthy: healthy, + chatLatency: chatLatency, + weight: weight, + } +} + +func (m LangModelMock) ID() string { + return m.modelID +} + +func (m LangModelMock) Healthy() bool { + return m.healthy +} + +func (m *LangModelMock) ChatLatency() *latency.MovingAverage { + return m.chatLatency +} + +func (m LangModelMock) LatencyUpdateInterval() *time.Duration { + updateInterval := 30 * time.Second + + return &updateInterval +} + +func (m LangModelMock) Weight() int { + return m.weight +} + +func ChatMockLatency(model providers.Model) *latency.MovingAverage { + return model.(LangModelMock).chatLatency +} diff --git a/pkg/routers/config.go b/pkg/routers/config.go index 025f0c6d..f17ed52e 100644 --- a/pkg/routers/config.go +++ b/pkg/routers/config.go @@ -29,11 +29,11 @@ func (c *Config) BuildLangRouters(tel *telemetry.Telemetry) ([]*LangRouter, erro seenIDs[routerConfig.ID] = true if !routerConfig.Enabled { - tel.Logger.Info(fmt.Sprintf("Router \"%v\" is disabled, skipping", routerConfig.ID)) + tel.L().Info(fmt.Sprintf("Router \"%v\" is disabled, skipping", routerConfig.ID)) continue } - tel.Logger.Debug("Init router", zap.String("routerID", routerConfig.ID)) + tel.L().Debug("Init router", zap.String("routerID", routerConfig.ID)) router, err := NewLangRouter(&c.LanguageRouters[idx], tel) if err != nil { @@ -63,15 +63,16 @@ type LangRouterConfig struct { } // BuildModels creates LanguageModel slice out of the given config -func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.LanguageModel, error) { +func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*providers.LanguageModel, []*providers.LanguageModel, error) { //nolint: cyclop var errs error seenIDs := make(map[string]bool, len(c.Models)) - models := make([]providers.LanguageModel, 0, len(c.Models)) + chatModels := make([]*providers.LanguageModel, 0, len(c.Models)) + chatStreamModels := make([]*providers.LanguageModel, 0, len(c.Models)) for _, modelConfig := range c.Models { if _, ok := seenIDs[modelConfig.ID]; ok { - return nil, fmt.Errorf( + return nil, nil, fmt.Errorf( "ID \"%v\" is specified for more than one model in router \"%v\", while it should be unique in scope of that pool", modelConfig.ID, c.ID, @@ -81,7 +82,7 @@ func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.La seenIDs[modelConfig.ID] = true if !modelConfig.Enabled { - tel.Logger.Info( + tel.L().Info( "Model is disabled, skipping", zap.String("router", c.ID), zap.String("model", modelConfig.ID), @@ -90,7 +91,7 @@ func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.La continue } - tel.Logger.Debug( + tel.L().Debug( "Init lang model", zap.String("router", c.ID), zap.String("model", modelConfig.ID), @@ -102,19 +103,32 @@ func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.La continue } - models = append(models, model) + chatModels = append(chatModels, model) + + if !model.SupportChatStream() { + tel.L().Warn( + "Provider doesn't support or have not been yet integrated with streaming chat, it won't serve streaming chat requests", + zap.String("routerID", c.ID), + zap.String("modelID", model.ID()), + zap.String("provider", model.Provider()), + ) + + continue + } + + chatStreamModels = append(chatStreamModels, model) } if errs != nil { - return nil, errs + return nil, nil, errs } - if len(models) == 0 { - return nil, fmt.Errorf("router \"%v\" must have at least one active model, zero defined", c.ID) + if len(chatModels) == 0 { + return nil, nil, fmt.Errorf("router \"%v\" must have at least one active model, zero defined", c.ID) } - if len(models) == 1 { - tel.Logger.WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( + if len(chatModels) == 1 { + tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( fmt.Sprintf("Router \"%v\" has only one active model defined. "+ "This is not recommended for production setups. "+ "Define at least a few models to leverage resiliency logic Glide provides", @@ -123,7 +137,26 @@ func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]providers.La ) } - return models, nil + if len(chatStreamModels) == 1 { + tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( + fmt.Sprintf("Router \"%v\" has only one active model defined with streaming chat support. "+ + "This is not recommended for production setups. "+ + "Define at least a few models to leverage resiliency logic Glide provides", + c.ID, + ), + ) + } + + if len(chatStreamModels) == 0 { + tel.L().WithOptions(zap.AddStacktrace(zap.ErrorLevel)).Warn( + fmt.Sprintf("Router \"%v\" has only no model with streaming chat support. "+ + "The streaming chat workflow won't work until you define any", + c.ID, + ), + ) + } + + return chatModels, chatStreamModels, nil } func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry { @@ -137,24 +170,35 @@ func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry { ) } -func (c *LangRouterConfig) BuildRouting(models []providers.LanguageModel) (routing.LangModelRouting, error) { - m := make([]providers.Model, 0, len(models)) - for _, model := range models { - m = append(m, model) +func (c *LangRouterConfig) BuildRouting( + chatModels []*providers.LanguageModel, + chatStreamModels []*providers.LanguageModel, +) (routing.LangModelRouting, routing.LangModelRouting, error) { + chatModelPool := make([]providers.Model, 0, len(chatModels)) + chatStreamModelPool := make([]providers.Model, 0, len(chatStreamModels)) + + for _, model := range chatModels { + chatModelPool = append(chatModelPool, model) + } + + for _, model := range chatStreamModels { + chatStreamModelPool = append(chatStreamModelPool, model) } switch c.RoutingStrategy { case routing.Priority: - return routing.NewPriority(m), nil + return routing.NewPriority(chatModelPool), routing.NewPriority(chatStreamModelPool), nil case routing.RoundRobin: - return routing.NewRoundRobinRouting(m), nil + return routing.NewRoundRobinRouting(chatModelPool), routing.NewRoundRobinRouting(chatStreamModelPool), nil case routing.WeightedRoundRobin: - return routing.NewWeightedRoundRobin(m), nil + return routing.NewWeightedRoundRobin(chatModelPool), routing.NewWeightedRoundRobin(chatStreamModelPool), nil case routing.LeastLatency: - return routing.NewLeastLatencyRouting(m), nil + return routing.NewLeastLatencyRouting(providers.ChatLatency, chatModelPool), + routing.NewLeastLatencyRouting(providers.ChatStreamLatency, chatStreamModelPool), + nil } - return nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy) + return nil, nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy) } func DefaultLangRouterConfig() LangRouterConfig { diff --git a/pkg/routers/config_test.go b/pkg/routers/config_test.go index f2912164..28fb81fe 100644 --- a/pkg/routers/config_test.go +++ b/pkg/routers/config_test.go @@ -3,6 +3,8 @@ package routers import ( "testing" + "glide/pkg/providers/cohere" + "github.com/stretchr/testify/require" "glide/pkg/providers" "glide/pkg/providers/clients" @@ -64,10 +66,53 @@ func TestRouterConfig_BuildModels(t *testing.T) { require.NoError(t, err) require.Len(t, routers, 2) - require.Len(t, routers[0].models, 1) - require.IsType(t, &routing.PriorityRouting{}, routers[0].routing) - require.Len(t, routers[1].models, 1) - require.IsType(t, &routing.LeastLatencyRouting{}, routers[1].routing) + require.Len(t, routers[0].chatModels, 1) + require.IsType(t, &routing.PriorityRouting{}, routers[0].chatRouting) + require.Len(t, routers[1].chatModels, 1) + require.IsType(t, &routing.LeastLatencyRouting{}, routers[1].chatRouting) +} + +func TestRouterConfig_BuildModelsPerType(t *testing.T) { + tel := telemetry.NewTelemetryMock() + openAIParams := openai.DefaultParams() + cohereParams := cohere.DefaultParams() + + cfg := LangRouterConfig{ + ID: "first_router", + Enabled: true, + RoutingStrategy: routing.Priority, + Retry: retry.DefaultExpRetryConfig(), + Models: []providers.LangModelConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + OpenAI: &openai.Config{ + APIKey: "ABC", + DefaultParams: &openAIParams, + }, + }, + { + ID: "second_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Cohere: &cohere.Config{ + APIKey: "ABC", + DefaultParams: &cohereParams, + }, + }, + }, + } + + chatModels, streamChatModels, err := cfg.BuildModels(tel) + + require.Len(t, chatModels, 2) + require.Len(t, streamChatModels, 2) + require.NoError(t, err) } func TestRouterConfig_InvalidSetups(t *testing.T) { diff --git a/pkg/routers/health/tracker.go b/pkg/routers/health/tracker.go new file mode 100644 index 00000000..f49e310b --- /dev/null +++ b/pkg/routers/health/tracker.go @@ -0,0 +1,44 @@ +package health + +import ( + "errors" + + "glide/pkg/providers/clients" +) + +// Tracker tracks errors and general health of model provider +type Tracker struct { + unauthorized bool + errBudget *TokenBucket + rateLimit *RateLimitTracker +} + +func NewTracker(budget *ErrorBudget) *Tracker { + return &Tracker{ + unauthorized: false, + rateLimit: NewRateLimitTracker(), + errBudget: NewTokenBucket(budget.TimePerTokenMicro(), budget.Budget()), + } +} + +func (t *Tracker) Healthy() bool { + return !t.unauthorized && !t.rateLimit.Limited() && t.errBudget.HasTokens() +} + +func (t *Tracker) TrackErr(err error) { + var rateLimitErr *clients.RateLimitError + + if errors.Is(err, clients.ErrUnauthorized) { + t.unauthorized = true + + return + } + + if errors.As(err, &rateLimitErr) { + t.rateLimit.SetLimited(rateLimitErr.UntilReset()) + + return + } + + _ = t.errBudget.Take(1) +} diff --git a/pkg/routers/health/tracker_test.go b/pkg/routers/health/tracker_test.go new file mode 100644 index 00000000..94733ce0 --- /dev/null +++ b/pkg/routers/health/tracker_test.go @@ -0,0 +1,37 @@ +package health + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "glide/pkg/providers/clients" +) + +func TestHealthTracker_HealthyByDefault(t *testing.T) { + budget := NewErrorBudget(3, SEC) + tracker := NewTracker(budget) + + require.True(t, tracker.Healthy()) +} + +func TestHealthTracker_UnhealthyWhenBugetExceeds(t *testing.T) { + budget := NewErrorBudget(3, SEC) + tracker := NewTracker(budget) + + for range 3 { + tracker.TrackErr(clients.ErrProviderUnavailable) + } + + require.False(t, tracker.Healthy()) +} + +func TestHealthTracker_RateLimited(t *testing.T) { + budget := NewErrorBudget(3, SEC) + tracker := NewTracker(budget) + + limitedUntil := 10 * time.Minute + tracker.TrackErr(clients.NewRateLimitError(&limitedUntil)) + + require.False(t, tracker.Healthy()) +} diff --git a/pkg/routers/manager.go b/pkg/routers/manager.go index e8e9e5a7..8dcecd60 100644 --- a/pkg/routers/manager.go +++ b/pkg/routers/manager.go @@ -10,7 +10,7 @@ var ErrRouterNotFound = errors.New("no router found with given ID") type RouterManager struct { Config *Config - telemetry *telemetry.Telemetry + tel *telemetry.Telemetry langRouterMap *map[string]*LangRouter langRouters []*LangRouter } @@ -30,7 +30,7 @@ func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { manager := RouterManager{ Config: cfg, - telemetry: tel, + tel: tel, langRouters: langRouters, langRouterMap: &langRouterMap, } diff --git a/pkg/routers/retry/config_test.go b/pkg/routers/retry/config_test.go new file mode 100644 index 00000000..7fe409a3 --- /dev/null +++ b/pkg/routers/retry/config_test.go @@ -0,0 +1,13 @@ +package retry + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRetryConfig_DefaultConfig(t *testing.T) { + config := DefaultExpRetryConfig() + + require.NotNil(t, config) +} diff --git a/pkg/routers/router.go b/pkg/routers/router.go index 13d89fa3..aa8ba067 100644 --- a/pkg/routers/router.go +++ b/pkg/routers/router.go @@ -20,32 +20,36 @@ var ( ) type LangRouter struct { - routerID string - Config *LangRouterConfig - routing routing.LangModelRouting - retry *retry.ExpRetry - models []providers.LanguageModel - telemetry *telemetry.Telemetry + routerID string + Config *LangRouterConfig + chatModels []*providers.LanguageModel + chatStreamModels []*providers.LanguageModel + chatRouting routing.LangModelRouting + chatStreamRouting routing.LangModelRouting + retry *retry.ExpRetry + tel *telemetry.Telemetry } func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter, error) { - models, err := cfg.BuildModels(tel) + chatModels, chatStreamModels, err := cfg.BuildModels(tel) if err != nil { return nil, err } - strategy, err := cfg.BuildRouting(models) + chatRouting, chatStreamRouting, err := cfg.BuildRouting(chatModels, chatStreamModels) if err != nil { return nil, err } router := &LangRouter{ - routerID: cfg.ID, - Config: cfg, - models: models, - retry: cfg.BuildRetry(), - routing: strategy, - telemetry: tel, + routerID: cfg.ID, + Config: cfg, + chatModels: chatModels, + chatStreamModels: chatStreamModels, + retry: cfg.BuildRetry(), + chatRouting: chatRouting, + chatStreamRouting: chatStreamRouting, + tel: tel, } return router, err @@ -55,15 +59,15 @@ func (r *LangRouter) ID() string { return r.routerID } -func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) { - if len(r.models) == 0 { +func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) { + if len(r.chatModels) == 0 { return nil, ErrNoModels } retryIterator := r.retry.Iterator() for retryIterator.HasNext() { - modelIterator := r.routing.Iterator() + modelIterator := r.chatRouting.Iterator() for { model, err := modelIterator.Next() @@ -73,20 +77,20 @@ func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*s break } - langModel := model.(providers.LanguageModel) + langModel := model.(providers.LangModel) // Check if there is an override in the request - if request.Override != (schemas.OverrideChatRequest{}) { + if req.Override != nil { // Override the message if the language model ID matches the override model ID - if langModel.ID() == request.Override.Model { - request.Message = request.Override.Message + if langModel.ID() == req.Override.Model { + req.Message = req.Override.Message } } - resp, err := langModel.Chat(ctx, request) + resp, err := langModel.Chat(ctx, req) if err != nil { - r.telemetry.Logger.Warn( - "lang model failed processing chat request", + r.tel.L().Warn( + "Lang model failed processing chat request", zap.String("routerID", r.ID()), zap.String("modelID", langModel.ID()), zap.String("provider", langModel.Provider()), @@ -103,7 +107,7 @@ func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*s // no providers were available to handle the request, // so we have to wait a bit with a hope there is some available next time - r.telemetry.Logger.Warn("no healthy model found, wait and retry", zap.String("routerID", r.ID())) + r.tel.L().Warn("No healthy model found to serve chat request, wait and retry", zap.String("routerID", r.ID())) err := retryIterator.WaitNext(ctx) if err != nil { @@ -113,7 +117,116 @@ func (r *LangRouter) Chat(ctx context.Context, request *schemas.ChatRequest) (*s } // if we reach this part, then we are in trouble - r.telemetry.Logger.Error("no model was available to handle request", zap.String("routerID", r.ID())) + r.tel.L().Error("No model was available to handle chat request", zap.String("routerID", r.ID())) return nil, ErrNoModelAvailable } + +func (r *LangRouter) ChatStream( + ctx context.Context, + req *schemas.ChatStreamRequest, + respC chan<- *schemas.ChatStreamResult, +) { + if len(r.chatStreamModels) == 0 { + respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{ + ID: req.ID, + ErrCode: "noModels", + Message: ErrNoModels.Error(), + Metadata: req.Metadata, + }) + + return + } + + retryIterator := r.retry.Iterator() + + for retryIterator.HasNext() { + modelIterator := r.chatStreamRouting.Iterator() + + NextModel: + for { + model, err := modelIterator.Next() + + if errors.Is(err, routing.ErrNoHealthyModels) { + // no healthy model in the pool. Let's retry after some time + break + } + + langModel := model.(providers.LangModel) + modelRespC, err := langModel.ChatStream(ctx, req) + if err != nil { + r.tel.L().Error( + "Lang model failed to create streaming chat request", + zap.String("routerID", r.ID()), + zap.String("modelID", langModel.ID()), + zap.String("provider", langModel.Provider()), + zap.Error(err), + ) + + continue + } + + for chunkResult := range modelRespC { + err = chunkResult.Error() + if err != nil { + r.tel.L().Warn( + "Lang model failed processing streaming chat request", + zap.String("routerID", r.ID()), + zap.String("modelID", langModel.ID()), + zap.String("provider", langModel.Provider()), + zap.Error(err), + ) + + // It's challenging to hide an error in case of streaming chat as consumer apps + // may have already used all chunks we streamed this far (e.g. showed them to their users like OpenAI UI does), + // so we cannot easily restart that process from scratch + respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{ + ID: req.ID, + ErrCode: "modelUnavailable", + Message: err.Error(), + Metadata: req.Metadata, + }) + + continue NextModel + } + + respC <- schemas.NewChatStreamResult(chunkResult.Chunk()) + } + + return + } + + // no providers were available to handle the request, + // so we have to wait a bit with a hope there is some available next time + r.tel.L().Warn( + "No healthy model found to serve streaming chat request, wait and retry", + zap.String("routerID", r.ID()), + ) + + err := retryIterator.WaitNext(ctx) + if err != nil { + // something has cancelled the context + respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{ + ID: req.ID, + ErrCode: "other", + Message: err.Error(), + Metadata: req.Metadata, + }) + + return + } + } + + // if we reach this part, then we are in trouble + r.tel.L().Error( + "No model was available to handle streaming chat request. Try to configure more fallback models to avoid this", + zap.String("routerID", r.ID()), + ) + + respC <- schemas.NewChatStreamErrorResult(&schemas.ChatStreamError{ + ID: req.ID, + ErrCode: "allModelsUnavailable", + Message: ErrNoModelAvailable.Error(), + Metadata: req.Metadata, + }) +} diff --git a/pkg/routers/router_test.go b/pkg/routers/router_test.go index 77fb7226..4fd50d28 100644 --- a/pkg/routers/router_test.go +++ b/pkg/routers/router_test.go @@ -12,28 +12,29 @@ import ( "github.com/stretchr/testify/require" "glide/pkg/api/schemas" "glide/pkg/providers" + ptesting "glide/pkg/providers/testing" "glide/pkg/routers/health" "glide/pkg/routers/retry" "glide/pkg/routers/routing" "glide/pkg/telemetry" ) -func TestLangRouter_Priority_PickFistHealthy(t *testing.T) { +func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []providers.LanguageModel{ + langModels := []*providers.LanguageModel{ providers.NewLangModel( "first", - providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}), - *budget, + ptesting.NewProviderMock([]ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + budget, *latConfig, 1, ), providers.NewLangModel( "second", - providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}}), - *budget, + ptesting.NewProviderMock([]ptesting.RespMock{{Msg: "1"}}), + budget, *latConfig, 1, ), @@ -45,12 +46,13 @@ func TestLangRouter_Priority_PickFistHealthy(t *testing.T) { } router := LangRouter{ - routerID: "test_router", - Config: &LangRouterConfig{}, - retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - routing: routing.NewPriority(models), - models: langModels, - telemetry: telemetry.NewTelemetryMock(), + routerID: "test_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), + chatRouting: routing.NewPriority(models), + chatModels: langModels, + chatStreamModels: langModels, + tel: telemetry.NewTelemetryMock(), } ctx := context.Background() @@ -65,28 +67,28 @@ func TestLangRouter_Priority_PickFistHealthy(t *testing.T) { } } -func TestLangRouter_Priority_PickThirdHealthy(t *testing.T) { +func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []providers.LanguageModel{ + langModels := []*providers.LanguageModel{ providers.NewLangModel( "first", - providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "3"}}), - *budget, + ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Msg: "3"}}), + budget, *latConfig, 1, ), providers.NewLangModel( "second", - providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "4"}}), - *budget, + ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Msg: "4"}}), + budget, *latConfig, 1, ), providers.NewLangModel( "third", - providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}), - *budget, + ptesting.NewProviderMock([]ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + budget, *latConfig, 1, ), @@ -100,12 +102,14 @@ func TestLangRouter_Priority_PickThirdHealthy(t *testing.T) { expectedModels := []string{"third", "third"} router := LangRouter{ - routerID: "test_router", - Config: &LangRouterConfig{}, - retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - routing: routing.NewPriority(models), - models: langModels, - telemetry: telemetry.NewTelemetryMock(), + routerID: "test_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), + chatRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(models), + chatModels: langModels, + chatStreamModels: langModels, + tel: telemetry.NewTelemetryMock(), } ctx := context.Background() @@ -120,21 +124,21 @@ func TestLangRouter_Priority_PickThirdHealthy(t *testing.T) { } } -func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) { +func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { budget := health.NewErrorBudget(1, health.MILLI) latConfig := latency.DefaultConfig() - langModels := []providers.LanguageModel{ + langModels := []*providers.LanguageModel{ providers.NewLangModel( "first", - providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "2"}}), - *budget, + ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Msg: "2"}}), + budget, *latConfig, 1, ), providers.NewLangModel( "second", - providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Msg: "1"}}), - *budget, + ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Msg: "1"}}), + budget, *latConfig, 1, ), @@ -146,12 +150,14 @@ func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) { } router := LangRouter{ - routerID: "test_router", - Config: &LangRouterConfig{}, - retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), - routing: routing.NewPriority(models), - models: langModels, - telemetry: telemetry.NewTelemetryMock(), + routerID: "test_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), + chatRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(models), + chatModels: langModels, + chatStreamModels: langModels, + tel: telemetry.NewTelemetryMock(), } resp, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) @@ -161,21 +167,21 @@ func TestLangRouter_Priority_SuccessOnRetry(t *testing.T) { require.Equal(t, "test_router", resp.RouterID) } -func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) { +func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { budget := health.NewErrorBudget(1, health.MIN) latConfig := latency.DefaultConfig() - langModels := []providers.LanguageModel{ + langModels := []*providers.LanguageModel{ providers.NewLangModel( "first", - providers.NewProviderMock([]providers.ResponseMock{{Err: &clients.ErrProviderUnavailable}, {Msg: "3"}}), - *budget, + ptesting.NewProviderMock([]ptesting.RespMock{{Err: &clients.ErrProviderUnavailable}, {Msg: "3"}}), + budget, *latConfig, 1, ), providers.NewLangModel( "second", - providers.NewProviderMock([]providers.ResponseMock{{Msg: "1"}, {Msg: "2"}}), - *budget, + ptesting.NewProviderMock([]ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + budget, *latConfig, 1, ), @@ -187,12 +193,14 @@ func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) { } router := LangRouter{ - routerID: "test_router", - Config: &LangRouterConfig{}, - retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), - routing: routing.NewPriority(models), - models: langModels, - telemetry: telemetry.NewTelemetryMock(), + routerID: "test_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), + chatRouting: routing.NewPriority(models), + chatModels: langModels, + chatStreamModels: langModels, + chatStreamRouting: routing.NewPriority(models), + tel: telemetry.NewTelemetryMock(), } for i := 0; i < 2; i++ { @@ -204,21 +212,21 @@ func TestLangRouter_Priority_UnhealthyModelInThePool(t *testing.T) { } } -func TestLangRouter_Priority_AllModelsUnavailable(t *testing.T) { +func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []providers.LanguageModel{ + langModels := []*providers.LanguageModel{ providers.NewLangModel( "first", - providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}), - *budget, + ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}), + budget, *latConfig, 1, ), providers.NewLangModel( "second", - providers.NewProviderMock([]providers.ResponseMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}), - *budget, + ptesting.NewProviderMock([]ptesting.RespMock{{Err: &ErrNoModelAvailable}, {Err: &ErrNoModelAvailable}}), + budget, *latConfig, 1, ), @@ -230,15 +238,221 @@ func TestLangRouter_Priority_AllModelsUnavailable(t *testing.T) { } router := LangRouter{ - routerID: "test_router", - Config: &LangRouterConfig{}, - retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), - routing: routing.NewPriority(models), - models: langModels, - telemetry: telemetry.NewTelemetryMock(), + routerID: "test_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), + chatRouting: routing.NewPriority(models), + chatModels: langModels, + chatStreamModels: langModels, + chatStreamRouting: routing.NewPriority(models), + tel: telemetry.NewTelemetryMock(), } _, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) require.Error(t, err) } + +func TestLangRouter_ChatStream(t *testing.T) { + budget := health.NewErrorBudget(3, health.SEC) + latConfig := latency.DefaultConfig() + + langModels := []*providers.LanguageModel{ + providers.NewLangModel( + "first", + ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{ + ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + {Msg: "Bill"}, + {Msg: "Gates"}, + {Msg: "entered"}, + {Msg: "the"}, + {Msg: "bar"}, + }), + }), + budget, + *latConfig, + 1, + ), + providers.NewLangModel( + "second", + ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{ + ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + {Msg: "Knock"}, + {Msg: "Knock"}, + {Msg: "joke"}, + }), + }), + budget, + *latConfig, + 1, + ), + } + + models := make([]providers.Model, 0, len(langModels)) + for _, model := range langModels { + models = append(models, model) + } + + router := LangRouter{ + routerID: "test_stream_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), + chatRouting: routing.NewPriority(models), + chatModels: langModels, + chatStreamRouting: routing.NewPriority(models), + chatStreamModels: langModels, + tel: telemetry.NewTelemetryMock(), + } + + ctx := context.Background() + req := schemas.NewChatStreamFromStr("tell me a dad joke") + respC := make(chan *schemas.ChatStreamResult) + + defer close(respC) + + go router.ChatStream(ctx, req, respC) + + chunks := make([]string, 0, 5) + + for range 5 { + select { //nolint:gosimple + case chunk := <-respC: + require.Nil(t, chunk.Error()) + require.NotNil(t, chunk.Chunk().ModelResponse.Message.Content) + + chunks = append(chunks, chunk.Chunk().ModelResponse.Message.Content) + } + } + + require.Equal(t, []string{"Bill", "Gates", "entered", "the", "bar"}, chunks) +} + +func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { + budget := health.NewErrorBudget(3, health.SEC) + latConfig := latency.DefaultConfig() + + langModels := []*providers.LanguageModel{ + providers.NewLangModel( + "first", + ptesting.NewStreamProviderMock(nil), + budget, + *latConfig, + 1, + ), + providers.NewLangModel( + "second", + ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{ + ptesting.NewRespStreamMock( + &[]ptesting.RespMock{ + {Msg: "Knock"}, + {Msg: "knock"}, + {Msg: "joke"}, + }, + ), + }), + budget, + *latConfig, + 1, + ), + } + + models := make([]providers.Model, 0, len(langModels)) + for _, model := range langModels { + models = append(models, model) + } + + router := LangRouter{ + routerID: "test_stream_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), + chatRouting: routing.NewPriority(models), + chatModels: langModels, + chatStreamRouting: routing.NewPriority(models), + chatStreamModels: langModels, + tel: telemetry.NewTelemetryMock(), + } + + ctx := context.Background() + req := schemas.NewChatStreamFromStr("tell me a dad joke") + respC := make(chan *schemas.ChatStreamResult) + + defer close(respC) + + go router.ChatStream(ctx, req, respC) + + chunks := make([]string, 0, 3) + + for range 3 { + select { //nolint:gosimple + case chunk := <-respC: + require.Nil(t, chunk.Error()) + require.NotNil(t, chunk.Chunk().ModelResponse.Message.Content) + + chunks = append(chunks, chunk.Chunk().ModelResponse.Message.Content) + } + } + + require.Equal(t, []string{"Knock", "knock", "joke"}, chunks) +} + +func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { + budget := health.NewErrorBudget(1, health.SEC) + latConfig := latency.DefaultConfig() + + langModels := []*providers.LanguageModel{ + providers.NewLangModel( + "first", + ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{ + ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + {Err: &clients.ErrProviderUnavailable}, + }), + }), + budget, + *latConfig, + 1, + ), + providers.NewLangModel( + "second", + ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{ + ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + {Err: &clients.ErrProviderUnavailable}, + }), + }), + budget, + *latConfig, + 1, + ), + } + + models := make([]providers.Model, 0, len(langModels)) + for _, model := range langModels { + models = append(models, model) + } + + router := LangRouter{ + routerID: "test_router", + Config: &LangRouterConfig{}, + retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), + chatRouting: routing.NewPriority(models), + chatModels: langModels, + chatStreamModels: langModels, + chatStreamRouting: routing.NewPriority(models), + tel: telemetry.NewTelemetryMock(), + } + + respC := make(chan *schemas.ChatStreamResult) + defer close(respC) + + go router.ChatStream(context.Background(), schemas.NewChatStreamFromStr("tell me a dad joke"), respC) + + errs := make([]string, 0, 3) + + for range 3 { + result := <-respC + require.Nil(t, result.Chunk()) + + errs = append(errs, result.Error().ErrCode) + } + + require.Equal(t, []string{"modelUnavailable", "modelUnavailable", "allModelsUnavailable"}, errs) +} diff --git a/pkg/routers/routing/least_latency.go b/pkg/routers/routing/least_latency.go index 2b65dc4c..422fd863 100644 --- a/pkg/routers/routing/least_latency.go +++ b/pkg/routers/routing/least_latency.go @@ -5,6 +5,8 @@ import ( "sync/atomic" "time" + "glide/pkg/routers/latency" + "glide/pkg/providers" ) @@ -12,6 +14,9 @@ const ( LeastLatency Strategy = "least_latency" ) +// LatencyGetter defines where to find latency for the specific model action +type LatencyGetter = func(model providers.Model) *latency.MovingAverage + // ModelSchedule defines latency update schedule for models type ModelSchedule struct { mu sync.RWMutex @@ -57,11 +62,12 @@ func (s *ModelSchedule) Update() { // other model latency may improve over time overperform the best one), // so we need to send some traffic to other models from time to time to update their latency stats type LeastLatencyRouting struct { - warmupIdx atomic.Uint32 - schedules []*ModelSchedule + latencyGetter LatencyGetter + warmupIdx atomic.Uint32 + schedules []*ModelSchedule } -func NewLeastLatencyRouting(models []providers.Model) *LeastLatencyRouting { +func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []providers.Model) *LeastLatencyRouting { schedules := make([]*ModelSchedule, 0, len(models)) for _, model := range models { @@ -69,7 +75,8 @@ func NewLeastLatencyRouting(models []providers.Model) *LeastLatencyRouting { } return &LeastLatencyRouting{ - schedules: schedules, + latencyGetter: latencyGetter, + schedules: schedules, } } @@ -125,7 +132,7 @@ func (r *LeastLatencyRouting) Next() (providers.Model, error) { //nolint:cyclop } if !schedule.Expired() && !nextSchedule.Expired() && - nextSchedule.model.Latency().Value() > schedule.model.Latency().Value() { + r.latencyGetter(nextSchedule.model).Value() > r.latencyGetter(schedule.model).Value() { nextSchedule = schedule } } @@ -143,7 +150,7 @@ func (r *LeastLatencyRouting) getColdModelSchedules() []*ModelSchedule { coldModels := make([]*ModelSchedule, 0, len(r.schedules)) for _, schedule := range r.schedules { - if schedule.model.Healthy() && !schedule.model.Latency().WarmedUp() { + if schedule.model.Healthy() && !r.latencyGetter(schedule.model).WarmedUp() { coldModels = append(coldModels, schedule) } } diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/routers/routing/least_latency_test.go index dbbc699c..4b517afc 100644 --- a/pkg/routers/routing/least_latency_test.go +++ b/pkg/routers/routing/least_latency_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + ptesting "glide/pkg/providers/testing" + "github.com/stretchr/testify/require" "glide/pkg/providers" ) @@ -33,10 +35,10 @@ func TestLeastLatencyRouting_Warmup(t *testing.T) { models := make([]providers.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) + models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) } - routing := NewLeastLatencyRouting(models) + routing := NewLeastLatencyRouting(ptesting.ChatMockLatency, models) iterator := routing.Iterator() // loop three times over the whole pool to check if we return back to the begging of the list @@ -104,7 +106,7 @@ func TestLeastLatencyRouting_Routing(t *testing.T) { for _, model := range tc.models { schedules = append(schedules, &ModelSchedule{ - model: providers.NewLangModelMock( + model: ptesting.NewLangModelMock( model.modelID, model.healthy, model.latency, @@ -115,7 +117,8 @@ func TestLeastLatencyRouting_Routing(t *testing.T) { } routing := LeastLatencyRouting{ - schedules: schedules, + latencyGetter: ptesting.ChatMockLatency, + schedules: schedules, } iterator := routing.Iterator() @@ -143,10 +146,10 @@ func TestLeastLatencyRouting_NoHealthyModels(t *testing.T) { models := make([]providers.Model, 0, len(latencies)) for idx, latency := range latencies { - models = append(models, providers.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) + models = append(models, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) } - routing := NewLeastLatencyRouting(models) + routing := NewLeastLatencyRouting(providers.ChatLatency, models) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/priority_test.go b/pkg/routers/routing/priority_test.go index 4b0d8f94..2c4bcc1e 100644 --- a/pkg/routers/routing/priority_test.go +++ b/pkg/routers/routing/priority_test.go @@ -3,6 +3,8 @@ package routing import ( "testing" + ptesting "glide/pkg/providers/testing" + "github.com/stretchr/testify/require" "glide/pkg/providers" ) @@ -29,7 +31,7 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { models := make([]providers.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } routing := NewPriority(models) @@ -47,9 +49,9 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { func TestPriorityRouting_NoHealthyModels(t *testing.T) { models := []providers.Model{ - providers.NewLangModelMock("first", false, 0, 1), - providers.NewLangModelMock("second", false, 0, 1), - providers.NewLangModelMock("third", false, 0, 1), + ptesting.NewLangModelMock("first", false, 0, 1), + ptesting.NewLangModelMock("second", false, 0, 1), + ptesting.NewLangModelMock("third", false, 0, 1), } routing := NewPriority(models) diff --git a/pkg/routers/routing/round_robin_test.go b/pkg/routers/routing/round_robin_test.go index becdf69f..f8853ebe 100644 --- a/pkg/routers/routing/round_robin_test.go +++ b/pkg/routers/routing/round_robin_test.go @@ -3,6 +3,8 @@ package routing import ( "testing" + ptesting "glide/pkg/providers/testing" + "github.com/stretchr/testify/require" "glide/pkg/providers" ) @@ -30,7 +32,7 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { models := make([]providers.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } routing := NewRoundRobinRouting(models) @@ -50,9 +52,9 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { func TestRoundRobinRouting_NoHealthyModels(t *testing.T) { models := []providers.Model{ - providers.NewLangModelMock("first", false, 0, 1), - providers.NewLangModelMock("second", false, 0, 1), - providers.NewLangModelMock("third", false, 0, 1), + ptesting.NewLangModelMock("first", false, 0, 1), + ptesting.NewLangModelMock("second", false, 0, 1), + ptesting.NewLangModelMock("third", false, 0, 1), } routing := NewRoundRobinRouting(models) diff --git a/pkg/routers/routing/weighted_round_robin_test.go b/pkg/routers/routing/weighted_round_robin_test.go index 71a412b3..af258bb3 100644 --- a/pkg/routers/routing/weighted_round_robin_test.go +++ b/pkg/routers/routing/weighted_round_robin_test.go @@ -3,6 +3,8 @@ package routing import ( "testing" + ptesting "glide/pkg/providers/testing" + "github.com/stretchr/testify/require" "glide/pkg/providers" ) @@ -116,7 +118,7 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { models := make([]providers.Model, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, providers.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) + models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) } routing := NewWeightedRoundRobin(models) @@ -140,9 +142,9 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { func TestWRoundRobinRouting_NoHealthyModels(t *testing.T) { models := []providers.Model{ - providers.NewLangModelMock("first", false, 0, 1), - providers.NewLangModelMock("second", false, 0, 2), - providers.NewLangModelMock("third", false, 0, 3), + ptesting.NewLangModelMock("first", false, 0, 1), + ptesting.NewLangModelMock("second", false, 0, 2), + ptesting.NewLangModelMock("third", false, 0, 3), } routing := NewWeightedRoundRobin(models) diff --git a/pkg/telemetry/logging_test.go b/pkg/telemetry/logging_test.go new file mode 100644 index 00000000..c7a61178 --- /dev/null +++ b/pkg/telemetry/logging_test.go @@ -0,0 +1,27 @@ +package telemetry + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLogging_PlainOutputSetup(t *testing.T) { + config := LogConfig{ + Encoding: "console", + } + zapConfig := config.ToZapConfig() + + require.Equal(t, "console", config.Encoding) + require.NotNil(t, zapConfig) + require.Equal(t, "console", zapConfig.Encoding) +} + +func TestLogging_JSONOutputSetup(t *testing.T) { + config := DefaultLogConfig() + zapConfig := config.ToZapConfig() + + require.Equal(t, "json", config.Encoding) + require.NotNil(t, zapConfig) + require.Equal(t, "json", zapConfig.Encoding) +} diff --git a/pkg/telemetry/telemetry.go b/pkg/telemetry/telemetry.go index 868dfaeb..a97f87ac 100644 --- a/pkg/telemetry/telemetry.go +++ b/pkg/telemetry/telemetry.go @@ -13,6 +13,10 @@ type Telemetry struct { // TODO: add OTEL meter, tracer } +func (t Telemetry) L() *zap.Logger { + return t.Logger +} + func DefaultConfig() *Config { return &Config{ LogConfig: DefaultLogConfig(), diff --git a/pkg/telemetry/telemetry_test.go b/pkg/telemetry/telemetry_test.go new file mode 100644 index 00000000..2f69ef89 --- /dev/null +++ b/pkg/telemetry/telemetry_test.go @@ -0,0 +1,12 @@ +package telemetry + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTelemetry_Creation(t *testing.T) { + _, err := NewTelemetry(DefaultConfig()) + require.NoError(t, err) +}