From e76150a2f7b2f7e36f91cdf1293343ed9581dc09 Mon Sep 17 00:00:00 2001 From: Demis Bellot Date: Thu, 24 Oct 2024 02:02:46 +0800 Subject: [PATCH] Fix IsOnlineAsync APIs --- .../AnthropicAiProvider.cs | 44 ++++++++++++++++--- AiServer.ServiceInterface/OpenAiProvider.cs | 34 +++++++++----- AiServer.Tests/OpenAiProviderTests.cs | 10 ++++- AiServer.Tests/TestUtils.cs | 13 +++++- 4 files changed, 83 insertions(+), 18 deletions(-) diff --git a/AiServer.ServiceInterface/AnthropicAiProvider.cs b/AiServer.ServiceInterface/AnthropicAiProvider.cs index 991cf58..63c27e0 100644 --- a/AiServer.ServiceInterface/AnthropicAiProvider.cs +++ b/AiServer.ServiceInterface/AnthropicAiProvider.cs @@ -8,17 +8,23 @@ namespace AiServer.ServiceInterface; public class AnthropicAiProvider(ILogger log) : OpenAiProviderBase(log) { - protected override async Task SendOpenAiChatRequestAsync(AiProvider provider, OpenAiChat request, - Action? requestFilter, Action responseFilter, CancellationToken token=default) + protected override Action? CreateRequestFilter(AiProvider provider) { - var url = (provider.ApiBaseUrl ?? provider.AiType?.ApiBaseUrl).CombineWith("/v1/messages"); - Action? useRequestFilter = req => { + Action? requestFilter = req => + { req.Headers.Add("x-api-key", provider.ApiKey); req.Headers.Add("anthropic-version", "2023-06-01"); }; + return requestFilter; + } + + protected override async Task SendOpenAiChatRequestAsync(AiProvider provider, OpenAiChat request, + Action? requestFilter=null, Action? responseFilter=null, CancellationToken token=default) + { + var url = (provider.ApiBaseUrl ?? provider.AiType?.ApiBaseUrl).CombineWith("/v1/messages"); var anthropicRequest = ToAnthropicMessageRequest(request); var responseJson = await url.PostJsonToUrlAsync(anthropicRequest, - requestFilter: useRequestFilter, + requestFilter: requestFilter, responseFilter: responseFilter, token: token); // responseJson.Print(); @@ -78,6 +84,34 @@ public OpenAiChatResponse ToOpenAiChatResponse(AnthropicMessageResponse response return ret; } + + public override async Task IsOnlineAsync(AiProvider provider, CancellationToken token = default) + { + try + { + var apiModel = provider.GetPreferredAiModel(); + var request = new OpenAiChat + { + Model = apiModel, + Messages = [ + new() { Role = "user", Content = "1+1=" }, + ], + MaxTokens = 2, + Stream = false, + }; + + var requestFilter = CreateRequestFilter(provider); + var response = await SendOpenAiChatRequestAsync(provider, request, + requestFilter: requestFilter, responseFilter: null, token: token); + return true; + } + catch (Exception e) + { + if (e is TaskCanceledException) + throw; + return false; + } + } } [DataContract] diff --git a/AiServer.ServiceInterface/OpenAiProvider.cs b/AiServer.ServiceInterface/OpenAiProvider.cs index 1b676a9..f470a36 100644 --- a/AiServer.ServiceInterface/OpenAiProvider.cs +++ b/AiServer.ServiceInterface/OpenAiProvider.cs @@ -24,12 +24,27 @@ public string GetApiEndpointUrlFor(AiProvider aiProvider, TaskType taskType) } public virtual async Task ChatAsync(AiProvider provider, OpenAiChat request, CancellationToken token = default) + { + var requestFilter = CreateRequestFilter(provider); + return await ChatAsync(provider, request, token, requestFilter); + } + + protected virtual Action? CreateRequestFilter(AiProvider provider) { Action? requestFilter = provider.ApiKey != null - ? req => req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", provider.ApiKey) + ? req => + { + if (provider.ApiKeyHeader != null) + { + req.Headers.Add(provider.ApiKeyHeader, provider.ApiKey); + } + else + { + req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", provider.ApiKey); + } + } : null; - - return await ChatAsync(provider, request, token, requestFilter); + return requestFilter; } protected virtual async Task ChatAsync(AiProvider provider, OpenAiChat request, CancellationToken token, Action? requestFilter) @@ -103,7 +118,7 @@ protected virtual async Task ChatAsync(AiProvider provider, Op } protected virtual async Task SendOpenAiChatRequestAsync(AiProvider provider, OpenAiChat request, - Action? requestFilter, Action responseFilter, CancellationToken token=default) + Action? requestFilter=null, Action? responseFilter=null, CancellationToken token=default) { var url = GetApiEndpointUrlFor(provider,TaskType.OpenAiChat); var responseJson = await url.PostJsonToUrlAsync(request, @@ -113,14 +128,11 @@ protected virtual async Task SendOpenAiChatRequestAsync(AiPr return response; } - public async Task IsOnlineAsync(AiProvider provider, CancellationToken token = default) + public virtual async Task IsOnlineAsync(AiProvider provider, CancellationToken token = default) { try { - Action? requestFilter = provider.ApiKey != null - ? req => req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", provider.ApiKey) - : null; - + var requestFilter = CreateRequestFilter(provider); var heartbeatUrl = provider.HeartbeatUrl; if (heartbeatUrl != null) { @@ -138,8 +150,8 @@ public async Task IsOnlineAsync(AiProvider provider, CancellationToken tok Stream = false, }; - var url = GetApiEndpointUrlFor(provider, TaskType.OpenAiChat); - await url.PostJsonToUrlAsync(request, requestFilter:requestFilter, token: token); + var response = await SendOpenAiChatRequestAsync(provider, request, + requestFilter: requestFilter, responseFilter: null, token: token); return true; } catch (Exception e) diff --git a/AiServer.Tests/OpenAiProviderTests.cs b/AiServer.Tests/OpenAiProviderTests.cs index 256f708..03d77a6 100644 --- a/AiServer.Tests/OpenAiProviderTests.cs +++ b/AiServer.Tests/OpenAiProviderTests.cs @@ -76,6 +76,9 @@ public async Task Can_Send_Ollama_Qwen2_Request() }); response.PrintDump(); + + var isOnline = await openAi.IsOnlineAsync(TestUtils.OpenRouterProvider); + Assert.IsTrue(isOnline); } [Test] @@ -142,6 +145,9 @@ public async Task Can_Send_Google_GeminiPro_PVQ_Request() }); response.PrintDump(); + + var isOnline = await openAi.IsOnlineAsync(TestUtils.GoogleAiProvider); + Assert.IsTrue(isOnline); } [Test] @@ -164,8 +170,10 @@ public async Task Can_Send_Anthropic_Haiku_Request() ], MaxTokens = 100, }); - response.PrintDump(); + + var isOnline = await openAi.IsOnlineAsync(TestUtils.AnthropicProvider); + Assert.IsTrue(isOnline); } [Test] diff --git a/AiServer.Tests/TestUtils.cs b/AiServer.Tests/TestUtils.cs index e2b7d73..60b719c 100644 --- a/AiServer.Tests/TestUtils.cs +++ b/AiServer.Tests/TestUtils.cs @@ -132,6 +132,8 @@ public static JsonApiClient CreateSystemClient() ["mixtral:8x7b"] = "mistralai/mistral-7b-instruct", ["mistral-nemo:12b"] = "mistralai/mistral-nemo", ["gemma:7b"] = "google/gemma-7b-it", + ["gemma2:9b"] = "google/gemma-2-9b-it", + ["gemma2:27"] = "google/gemma-2-27b-it", ["mixtral:8x7b"] = "mistralai/mixtral-8x7b-instruct", ["mixtral:8x22b"] = "mistralai/mixtral-8x22b-instruct", ["llama3:8b"] = "meta-llama/llama-3-8b-instruct", @@ -309,9 +311,15 @@ public static JsonApiClient CreateSystemClient() Enabled = true, Models = [ - new() { Model = "gemma2:27b", }, + // new() { Model = "gemma:7b", }, + new() { Model = "gemma2:9b", }, + new() { Model = "gemma2:27", }, new() { Model = "mixtral:8x22b", }, + new() { Model = "llama3:8b" }, new() { Model = "llama3:70b" }, + new() { Model = "llama3.1:8b" }, + new() { Model = "llama3.1:70b" }, + new() { Model = "llama3.1:405b" }, new() { Model = "wizardlm2:7b", }, new() { Model = "wizardlm2:8x22b", }, new() { Model = "mistral-small", }, @@ -436,6 +444,9 @@ public static JsonApiClient CreateSystemClient() ApiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY"), Enabled = true, AiType = AnthropicAiType, + Models = [ + new() { Model = "claude-3-haiku" }, + ], }; }