Skip to content

Commit

Permalink
Fix IsOnlineAsync APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
mythz committed Oct 23, 2024
1 parent 5426d0a commit e76150a
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 18 deletions.
44 changes: 39 additions & 5 deletions AiServer.ServiceInterface/AnthropicAiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@ namespace AiServer.ServiceInterface;

public class AnthropicAiProvider(ILogger<AnthropicAiProvider> log) : OpenAiProviderBase(log)
{
protected override async Task<OpenAiChatResponse> SendOpenAiChatRequestAsync(AiProvider provider, OpenAiChat request,
Action<HttpRequestMessage>? requestFilter, Action<HttpResponseMessage> responseFilter, CancellationToken token=default)
protected override Action<HttpRequestMessage>? CreateRequestFilter(AiProvider provider)
{
var url = (provider.ApiBaseUrl ?? provider.AiType?.ApiBaseUrl).CombineWith("/v1/messages");
Action<HttpRequestMessage>? useRequestFilter = req => {
Action<HttpRequestMessage>? requestFilter = req =>
{
req.Headers.Add("x-api-key", provider.ApiKey);
req.Headers.Add("anthropic-version", "2023-06-01");
};
return requestFilter;
}

protected override async Task<OpenAiChatResponse> SendOpenAiChatRequestAsync(AiProvider provider, OpenAiChat request,
Action<HttpRequestMessage>? requestFilter=null, Action<HttpResponseMessage>? responseFilter=null, CancellationToken token=default)
{
var url = (provider.ApiBaseUrl ?? provider.AiType?.ApiBaseUrl).CombineWith("/v1/messages");
var anthropicRequest = ToAnthropicMessageRequest(request);
var responseJson = await url.PostJsonToUrlAsync(anthropicRequest,
requestFilter: useRequestFilter,
requestFilter: requestFilter,
responseFilter: responseFilter, token: token);

// responseJson.Print();
Expand Down Expand Up @@ -78,6 +84,34 @@ public OpenAiChatResponse ToOpenAiChatResponse(AnthropicMessageResponse response

return ret;
}

public override async Task<bool> IsOnlineAsync(AiProvider provider, CancellationToken token = default)
{
try
{
var apiModel = provider.GetPreferredAiModel();
var request = new OpenAiChat
{
Model = apiModel,
Messages = [
new() { Role = "user", Content = "1+1=" },
],
MaxTokens = 2,
Stream = false,
};

var requestFilter = CreateRequestFilter(provider);
var response = await SendOpenAiChatRequestAsync(provider, request,
requestFilter: requestFilter, responseFilter: null, token: token);
return true;
}
catch (Exception e)
{
if (e is TaskCanceledException)
throw;
return false;
}
}
}

[DataContract]
Expand Down
34 changes: 23 additions & 11 deletions AiServer.ServiceInterface/OpenAiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,27 @@ public string GetApiEndpointUrlFor(AiProvider aiProvider, TaskType taskType)
}

public virtual async Task<OpenAiChatResult> ChatAsync(AiProvider provider, OpenAiChat request, CancellationToken token = default)
{
var requestFilter = CreateRequestFilter(provider);
return await ChatAsync(provider, request, token, requestFilter);
}

protected virtual Action<HttpRequestMessage>? CreateRequestFilter(AiProvider provider)
{
Action<HttpRequestMessage>? requestFilter = provider.ApiKey != null
? req => req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", provider.ApiKey)
? req =>
{
if (provider.ApiKeyHeader != null)
{
req.Headers.Add(provider.ApiKeyHeader, provider.ApiKey);
}
else
{
req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", provider.ApiKey);
}
}
: null;

return await ChatAsync(provider, request, token, requestFilter);
return requestFilter;
}

protected virtual async Task<OpenAiChatResult> ChatAsync(AiProvider provider, OpenAiChat request, CancellationToken token, Action<HttpRequestMessage>? requestFilter)
Expand Down Expand Up @@ -103,7 +118,7 @@ protected virtual async Task<OpenAiChatResult> ChatAsync(AiProvider provider, Op
}

protected virtual async Task<OpenAiChatResponse> SendOpenAiChatRequestAsync(AiProvider provider, OpenAiChat request,
Action<HttpRequestMessage>? requestFilter, Action<HttpResponseMessage> responseFilter, CancellationToken token=default)
Action<HttpRequestMessage>? requestFilter=null, Action<HttpResponseMessage>? responseFilter=null, CancellationToken token=default)
{
var url = GetApiEndpointUrlFor(provider,TaskType.OpenAiChat);
var responseJson = await url.PostJsonToUrlAsync(request,
Expand All @@ -113,14 +128,11 @@ protected virtual async Task<OpenAiChatResponse> SendOpenAiChatRequestAsync(AiPr
return response;
}

public async Task<bool> IsOnlineAsync(AiProvider provider, CancellationToken token = default)
public virtual async Task<bool> IsOnlineAsync(AiProvider provider, CancellationToken token = default)
{
try
{
Action<HttpRequestMessage>? requestFilter = provider.ApiKey != null
? req => req.Headers.Authorization = new AuthenticationHeaderValue("Bearer", provider.ApiKey)
: null;

var requestFilter = CreateRequestFilter(provider);
var heartbeatUrl = provider.HeartbeatUrl;
if (heartbeatUrl != null)
{
Expand All @@ -138,8 +150,8 @@ public async Task<bool> IsOnlineAsync(AiProvider provider, CancellationToken tok
Stream = false,
};

var url = GetApiEndpointUrlFor(provider, TaskType.OpenAiChat);
await url.PostJsonToUrlAsync(request, requestFilter:requestFilter, token: token);
var response = await SendOpenAiChatRequestAsync(provider, request,
requestFilter: requestFilter, responseFilter: null, token: token);
return true;
}
catch (Exception e)
Expand Down
10 changes: 9 additions & 1 deletion AiServer.Tests/OpenAiProviderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ public async Task Can_Send_Ollama_Qwen2_Request()
});

response.PrintDump();

var isOnline = await openAi.IsOnlineAsync(TestUtils.OpenRouterProvider);
Assert.IsTrue(isOnline);
}

[Test]
Expand Down Expand Up @@ -142,6 +145,9 @@ public async Task Can_Send_Google_GeminiPro_PVQ_Request()
});

response.PrintDump();

var isOnline = await openAi.IsOnlineAsync(TestUtils.GoogleAiProvider);
Assert.IsTrue(isOnline);
}

[Test]
Expand All @@ -164,8 +170,10 @@ public async Task Can_Send_Anthropic_Haiku_Request()
],
MaxTokens = 100,
});

response.PrintDump();

var isOnline = await openAi.IsOnlineAsync(TestUtils.AnthropicProvider);
Assert.IsTrue(isOnline);
}

[Test]
Expand Down
13 changes: 12 additions & 1 deletion AiServer.Tests/TestUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ public static JsonApiClient CreateSystemClient()
["mixtral:8x7b"] = "mistralai/mistral-7b-instruct",
["mistral-nemo:12b"] = "mistralai/mistral-nemo",
["gemma:7b"] = "google/gemma-7b-it",
["gemma2:9b"] = "google/gemma-2-9b-it",
["gemma2:27"] = "google/gemma-2-27b-it",
["mixtral:8x7b"] = "mistralai/mixtral-8x7b-instruct",
["mixtral:8x22b"] = "mistralai/mixtral-8x22b-instruct",
["llama3:8b"] = "meta-llama/llama-3-8b-instruct",
Expand Down Expand Up @@ -309,9 +311,15 @@ public static JsonApiClient CreateSystemClient()
Enabled = true,
Models =
[
new() { Model = "gemma2:27b", },
// new() { Model = "gemma:7b", },
new() { Model = "gemma2:9b", },
new() { Model = "gemma2:27", },
new() { Model = "mixtral:8x22b", },
new() { Model = "llama3:8b" },
new() { Model = "llama3:70b" },
new() { Model = "llama3.1:8b" },
new() { Model = "llama3.1:70b" },
new() { Model = "llama3.1:405b" },
new() { Model = "wizardlm2:7b", },
new() { Model = "wizardlm2:8x22b", },
new() { Model = "mistral-small", },
Expand Down Expand Up @@ -436,6 +444,9 @@ public static JsonApiClient CreateSystemClient()
ApiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY"),
Enabled = true,
AiType = AnthropicAiType,
Models = [
new() { Model = "claude-3-haiku" },
],
};

}
Expand Down

0 comments on commit e76150a

Please sign in to comment.