Skip to content

Commit fa465d4

Browse files
Adds the OpenAIMockResponsePlugin (#768)
1 parent 9e73c68 commit fa465d4

15 files changed

+757
-170
lines changed

dev-proxy-abstractions/ILanguageModelClient.cs

Lines changed: 0 additions & 4 deletions
This file was deleted.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT License.
3+
4+
namespace Microsoft.DevProxy.Abstractions;
5+
6+
public interface ILanguageModelChatCompletionMessage
7+
{
8+
string Content { get; set; }
9+
string Role { get; set; }
10+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT License.
3+
4+
namespace Microsoft.DevProxy.Abstractions;
5+
6+
public interface ILanguageModelClient
7+
{
8+
Task<ILanguageModelCompletionResponse?> GenerateChatCompletion(ILanguageModelChatCompletionMessage[] messages);
9+
Task<ILanguageModelCompletionResponse?> GenerateCompletion(string prompt);
10+
Task<bool> IsEnabled();
11+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT License.
3+
4+
namespace Microsoft.DevProxy.Abstractions;
5+
6+
public interface ILanguageModelCompletionResponse
7+
{
8+
string? Error { get; set; }
9+
string? Response { get; set; }
10+
}

dev-proxy/LanguageModel/LanguageModelConfiguration.cs renamed to dev-proxy-abstractions/LanguageModel/LanguageModelConfiguration.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT License.
33

4-
namespace Microsoft.DevProxy.LanguageModel;
4+
namespace Microsoft.DevProxy.Abstractions;
55

66
public class LanguageModelConfiguration
77
{
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT License.
3+
4+
using System.Diagnostics;
5+
using System.Net.Http.Json;
6+
using Microsoft.Extensions.Logging;
7+
8+
namespace Microsoft.DevProxy.Abstractions;
9+
10+
public class OllamaLanguageModelClient(LanguageModelConfiguration? configuration, ILogger logger) : ILanguageModelClient
11+
{
12+
private readonly LanguageModelConfiguration? _configuration = configuration;
13+
private readonly ILogger _logger = logger;
14+
private bool? _lmAvailable;
15+
private Dictionary<string, OllamaLanguageModelCompletionResponse> _cacheCompletion = new();
16+
private Dictionary<ILanguageModelChatCompletionMessage[], OllamaLanguageModelChatCompletionResponse> _cacheChatCompletion = new();
17+
18+
public async Task<bool> IsEnabled()
19+
{
20+
if (_lmAvailable.HasValue)
21+
{
22+
return _lmAvailable.Value;
23+
}
24+
25+
_lmAvailable = await IsEnabledInternal();
26+
return _lmAvailable.Value;
27+
}
28+
29+
private async Task<bool> IsEnabledInternal()
30+
{
31+
if (_configuration is null || !_configuration.Enabled)
32+
{
33+
return false;
34+
}
35+
36+
if (string.IsNullOrEmpty(_configuration.Url))
37+
{
38+
_logger.LogError("URL is not set. Language model will be disabled");
39+
return false;
40+
}
41+
42+
if (string.IsNullOrEmpty(_configuration.Model))
43+
{
44+
_logger.LogError("Model is not set. Language model will be disabled");
45+
return false;
46+
}
47+
48+
_logger.LogDebug("Checking LM availability at {url}...", _configuration.Url);
49+
50+
try
51+
{
52+
// check if lm is on
53+
using var client = new HttpClient();
54+
var response = await client.GetAsync(_configuration.Url);
55+
_logger.LogDebug("Response: {response}", response.StatusCode);
56+
57+
if (!response.IsSuccessStatusCode)
58+
{
59+
return false;
60+
}
61+
62+
var testCompletion = await GenerateCompletionInternal("Are you there? Reply with a yes or no.");
63+
if (testCompletion?.Error is not null)
64+
{
65+
_logger.LogError("Error: {error}", testCompletion.Error);
66+
return false;
67+
}
68+
69+
return true;
70+
}
71+
catch (Exception ex)
72+
{
73+
_logger.LogError(ex, "Couldn't reach language model at {url}", _configuration.Url);
74+
return false;
75+
}
76+
}
77+
78+
public async Task<ILanguageModelCompletionResponse?> GenerateCompletion(string prompt)
79+
{
80+
using var scope = _logger.BeginScope(nameof(OllamaLanguageModelClient));
81+
82+
if (_configuration is null)
83+
{
84+
return null;
85+
}
86+
87+
if (!_lmAvailable.HasValue)
88+
{
89+
_logger.LogError("Language model availability is not checked. Call {isEnabled} first.", nameof(IsEnabled));
90+
return null;
91+
}
92+
93+
if (!_lmAvailable.Value)
94+
{
95+
return null;
96+
}
97+
98+
if (_configuration.CacheResponses && _cacheCompletion.TryGetValue(prompt, out var cachedResponse))
99+
{
100+
_logger.LogDebug("Returning cached response for prompt: {prompt}", prompt);
101+
return cachedResponse;
102+
}
103+
104+
var response = await GenerateCompletionInternal(prompt);
105+
if (response == null)
106+
{
107+
return null;
108+
}
109+
if (response.Error is not null)
110+
{
111+
_logger.LogError(response.Error);
112+
return null;
113+
}
114+
else
115+
{
116+
if (_configuration.CacheResponses && response.Response is not null)
117+
{
118+
_cacheCompletion[prompt] = response;
119+
}
120+
121+
return response;
122+
}
123+
}
124+
125+
private async Task<OllamaLanguageModelCompletionResponse?> GenerateCompletionInternal(string prompt)
126+
{
127+
Debug.Assert(_configuration != null, "Configuration is null");
128+
129+
try
130+
{
131+
using var client = new HttpClient();
132+
var url = $"{_configuration.Url}/api/generate";
133+
_logger.LogDebug("Requesting completion. Prompt: {prompt}", prompt);
134+
135+
var response = await client.PostAsJsonAsync(url,
136+
new
137+
{
138+
prompt,
139+
model = _configuration.Model,
140+
stream = false
141+
}
142+
);
143+
_logger.LogDebug("Response: {response}", response.StatusCode);
144+
145+
var res = await response.Content.ReadFromJsonAsync<OllamaLanguageModelCompletionResponse>();
146+
if (res is null)
147+
{
148+
return res;
149+
}
150+
151+
res.RequestUrl = url;
152+
return res;
153+
}
154+
catch (Exception ex)
155+
{
156+
_logger.LogError(ex, "Failed to generate completion");
157+
return null;
158+
}
159+
}
160+
161+
public async Task<ILanguageModelCompletionResponse?> GenerateChatCompletion(ILanguageModelChatCompletionMessage[] messages)
162+
{
163+
using var scope = _logger.BeginScope(nameof(OllamaLanguageModelClient));
164+
165+
if (_configuration is null)
166+
{
167+
return null;
168+
}
169+
170+
if (!_lmAvailable.HasValue)
171+
{
172+
_logger.LogError("Language model availability is not checked. Call {isEnabled} first.", nameof(IsEnabled));
173+
return null;
174+
}
175+
176+
if (!_lmAvailable.Value)
177+
{
178+
return null;
179+
}
180+
181+
if (_configuration.CacheResponses && _cacheChatCompletion.TryGetValue(messages, out var cachedResponse))
182+
{
183+
_logger.LogDebug("Returning cached response for message: {lastMessage}", messages.Last().Content);
184+
return cachedResponse;
185+
}
186+
187+
var response = await GenerateChatCompletionInternal(messages);
188+
if (response == null)
189+
{
190+
return null;
191+
}
192+
if (response.Error is not null)
193+
{
194+
_logger.LogError(response.Error);
195+
return null;
196+
}
197+
else
198+
{
199+
if (_configuration.CacheResponses && response.Response is not null)
200+
{
201+
_cacheChatCompletion[messages] = response;
202+
}
203+
204+
return response;
205+
}
206+
}
207+
208+
private async Task<OllamaLanguageModelChatCompletionResponse?> GenerateChatCompletionInternal(ILanguageModelChatCompletionMessage[] messages)
209+
{
210+
Debug.Assert(_configuration != null, "Configuration is null");
211+
212+
try
213+
{
214+
using var client = new HttpClient();
215+
var url = $"{_configuration.Url}/api/chat";
216+
_logger.LogDebug("Requesting chat completion. Message: {lastMessage}", messages.Last().Content);
217+
218+
var response = await client.PostAsJsonAsync(url,
219+
new
220+
{
221+
messages,
222+
model = _configuration.Model,
223+
stream = false
224+
}
225+
);
226+
_logger.LogDebug("Response: {response}", response.StatusCode);
227+
228+
var res = await response.Content.ReadFromJsonAsync<OllamaLanguageModelChatCompletionResponse>();
229+
if (res is null)
230+
{
231+
return res;
232+
}
233+
234+
res.RequestUrl = url;
235+
return res;
236+
}
237+
catch (Exception ex)
238+
{
239+
_logger.LogError(ex, "Failed to generate chat completion");
240+
return null;
241+
}
242+
}
243+
}
244+
245+
internal static class CacheChatCompletionExtensions
246+
{
247+
public static OllamaLanguageModelChatCompletionMessage[]? GetKey(
248+
this Dictionary<OllamaLanguageModelChatCompletionMessage[], OllamaLanguageModelChatCompletionResponse> cache,
249+
ILanguageModelChatCompletionMessage[] messages)
250+
{
251+
return cache.Keys.FirstOrDefault(k => k.SequenceEqual(messages));
252+
}
253+
254+
public static bool TryGetValue(
255+
this Dictionary<OllamaLanguageModelChatCompletionMessage[], OllamaLanguageModelChatCompletionResponse> cache,
256+
ILanguageModelChatCompletionMessage[] messages, out OllamaLanguageModelChatCompletionResponse? value)
257+
{
258+
var key = cache.GetKey(messages);
259+
if (key is null)
260+
{
261+
value = null;
262+
return false;
263+
}
264+
265+
value = cache[key];
266+
return true;
267+
}
268+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT License.
3+
4+
using System.Text.Json.Serialization;
5+
6+
namespace Microsoft.DevProxy.Abstractions;
7+
8+
public abstract class OllamaResponse : ILanguageModelCompletionResponse
9+
{
10+
[JsonPropertyName("created_at")]
11+
public DateTime CreatedAt { get; set; } = DateTime.MinValue;
12+
public bool Done { get; set; } = false;
13+
public string? Error { get; set; }
14+
[JsonPropertyName("eval_count")]
15+
public long EvalCount { get; set; }
16+
[JsonPropertyName("eval_duration")]
17+
public long EvalDuration { get; set; }
18+
[JsonPropertyName("load_duration")]
19+
public long LoadDuration { get; set; }
20+
public string Model { get; set; } = string.Empty;
21+
[JsonPropertyName("prompt_eval_count")]
22+
public long PromptEvalCount { get; set; }
23+
[JsonPropertyName("prompt_eval_duration")]
24+
public long PromptEvalDuration { get; set; }
25+
public virtual string? Response { get; set; }
26+
[JsonPropertyName("total_duration")]
27+
public long TotalDuration { get; set; }
28+
// custom property added to log in the mock output
29+
public string RequestUrl { get; set; } = string.Empty;
30+
}
31+
32+
public class OllamaLanguageModelCompletionResponse : OllamaResponse
33+
{
34+
public int[] Context { get; set; } = [];
35+
}
36+
37+
public class OllamaLanguageModelChatCompletionResponse : OllamaResponse
38+
{
39+
public OllamaLanguageModelChatCompletionMessage Message { get; set; } = new();
40+
public override string? Response
41+
{
42+
get => Message.Content;
43+
set
44+
{
45+
if (value is null)
46+
{
47+
return;
48+
}
49+
50+
Message = new() { Content = value };
51+
}
52+
}
53+
}
54+
55+
public class OllamaLanguageModelChatCompletionMessage : ILanguageModelChatCompletionMessage
56+
{
57+
public string Content { get; set; } = string.Empty;
58+
public string Role { get; set; } = string.Empty;
59+
60+
public override bool Equals(object? obj)
61+
{
62+
if (obj is null || GetType() != obj.GetType())
63+
{
64+
return false;
65+
}
66+
67+
OllamaLanguageModelChatCompletionMessage m = (OllamaLanguageModelChatCompletionMessage)obj;
68+
return Content == m.Content && Role == m.Role;
69+
}
70+
71+
public override int GetHashCode()
72+
{
73+
return HashCode.Combine(Content, Role);
74+
}
75+
}

0 commit comments

Comments
 (0)