1
+ // Copyright (c) Microsoft Corporation.
2
+ // Licensed under the MIT License.
3
+
4
+ using System . Diagnostics ;
5
+ using System . Net . Http . Json ;
6
+ using Microsoft . Extensions . Logging ;
7
+
8
+ namespace Microsoft . DevProxy . Abstractions ;
9
+
10
+ public class OllamaLanguageModelClient ( LanguageModelConfiguration ? configuration , ILogger logger ) : ILanguageModelClient
11
+ {
12
+ private readonly LanguageModelConfiguration ? _configuration = configuration ;
13
+ private readonly ILogger _logger = logger ;
14
+ private bool ? _lmAvailable ;
15
+ private Dictionary < string , OllamaLanguageModelCompletionResponse > _cacheCompletion = new ( ) ;
16
+ private Dictionary < ILanguageModelChatCompletionMessage [ ] , OllamaLanguageModelChatCompletionResponse > _cacheChatCompletion = new ( ) ;
17
+
18
+ public async Task < bool > IsEnabled ( )
19
+ {
20
+ if ( _lmAvailable . HasValue )
21
+ {
22
+ return _lmAvailable . Value ;
23
+ }
24
+
25
+ _lmAvailable = await IsEnabledInternal ( ) ;
26
+ return _lmAvailable . Value ;
27
+ }
28
+
29
+ private async Task < bool > IsEnabledInternal ( )
30
+ {
31
+ if ( _configuration is null || ! _configuration . Enabled )
32
+ {
33
+ return false ;
34
+ }
35
+
36
+ if ( string . IsNullOrEmpty ( _configuration . Url ) )
37
+ {
38
+ _logger . LogError ( "URL is not set. Language model will be disabled" ) ;
39
+ return false ;
40
+ }
41
+
42
+ if ( string . IsNullOrEmpty ( _configuration . Model ) )
43
+ {
44
+ _logger . LogError ( "Model is not set. Language model will be disabled" ) ;
45
+ return false ;
46
+ }
47
+
48
+ _logger . LogDebug ( "Checking LM availability at {url}..." , _configuration . Url ) ;
49
+
50
+ try
51
+ {
52
+ // check if lm is on
53
+ using var client = new HttpClient ( ) ;
54
+ var response = await client . GetAsync ( _configuration . Url ) ;
55
+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
56
+
57
+ if ( ! response . IsSuccessStatusCode )
58
+ {
59
+ return false ;
60
+ }
61
+
62
+ var testCompletion = await GenerateCompletionInternal ( "Are you there? Reply with a yes or no." ) ;
63
+ if ( testCompletion ? . Error is not null )
64
+ {
65
+ _logger . LogError ( "Error: {error}" , testCompletion . Error ) ;
66
+ return false ;
67
+ }
68
+
69
+ return true ;
70
+ }
71
+ catch ( Exception ex )
72
+ {
73
+ _logger . LogError ( ex , "Couldn't reach language model at {url}" , _configuration . Url ) ;
74
+ return false ;
75
+ }
76
+ }
77
+
78
+ public async Task < ILanguageModelCompletionResponse ? > GenerateCompletion ( string prompt )
79
+ {
80
+ using var scope = _logger . BeginScope ( nameof ( OllamaLanguageModelClient ) ) ;
81
+
82
+ if ( _configuration is null )
83
+ {
84
+ return null ;
85
+ }
86
+
87
+ if ( ! _lmAvailable . HasValue )
88
+ {
89
+ _logger . LogError ( "Language model availability is not checked. Call {isEnabled} first." , nameof ( IsEnabled ) ) ;
90
+ return null ;
91
+ }
92
+
93
+ if ( ! _lmAvailable . Value )
94
+ {
95
+ return null ;
96
+ }
97
+
98
+ if ( _configuration . CacheResponses && _cacheCompletion . TryGetValue ( prompt , out var cachedResponse ) )
99
+ {
100
+ _logger . LogDebug ( "Returning cached response for prompt: {prompt}" , prompt ) ;
101
+ return cachedResponse ;
102
+ }
103
+
104
+ var response = await GenerateCompletionInternal ( prompt ) ;
105
+ if ( response == null )
106
+ {
107
+ return null ;
108
+ }
109
+ if ( response . Error is not null )
110
+ {
111
+ _logger . LogError ( response . Error ) ;
112
+ return null ;
113
+ }
114
+ else
115
+ {
116
+ if ( _configuration . CacheResponses && response . Response is not null )
117
+ {
118
+ _cacheCompletion [ prompt ] = response ;
119
+ }
120
+
121
+ return response ;
122
+ }
123
+ }
124
+
125
+ private async Task < OllamaLanguageModelCompletionResponse ? > GenerateCompletionInternal ( string prompt )
126
+ {
127
+ Debug . Assert ( _configuration != null , "Configuration is null" ) ;
128
+
129
+ try
130
+ {
131
+ using var client = new HttpClient ( ) ;
132
+ var url = $ "{ _configuration . Url } /api/generate";
133
+ _logger . LogDebug ( "Requesting completion. Prompt: {prompt}" , prompt ) ;
134
+
135
+ var response = await client . PostAsJsonAsync ( url ,
136
+ new
137
+ {
138
+ prompt ,
139
+ model = _configuration . Model ,
140
+ stream = false
141
+ }
142
+ ) ;
143
+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
144
+
145
+ var res = await response . Content . ReadFromJsonAsync < OllamaLanguageModelCompletionResponse > ( ) ;
146
+ if ( res is null )
147
+ {
148
+ return res ;
149
+ }
150
+
151
+ res . RequestUrl = url ;
152
+ return res ;
153
+ }
154
+ catch ( Exception ex )
155
+ {
156
+ _logger . LogError ( ex , "Failed to generate completion" ) ;
157
+ return null ;
158
+ }
159
+ }
160
+
161
+ public async Task < ILanguageModelCompletionResponse ? > GenerateChatCompletion ( ILanguageModelChatCompletionMessage [ ] messages )
162
+ {
163
+ using var scope = _logger . BeginScope ( nameof ( OllamaLanguageModelClient ) ) ;
164
+
165
+ if ( _configuration is null )
166
+ {
167
+ return null ;
168
+ }
169
+
170
+ if ( ! _lmAvailable . HasValue )
171
+ {
172
+ _logger . LogError ( "Language model availability is not checked. Call {isEnabled} first." , nameof ( IsEnabled ) ) ;
173
+ return null ;
174
+ }
175
+
176
+ if ( ! _lmAvailable . Value )
177
+ {
178
+ return null ;
179
+ }
180
+
181
+ if ( _configuration . CacheResponses && _cacheChatCompletion . TryGetValue ( messages , out var cachedResponse ) )
182
+ {
183
+ _logger . LogDebug ( "Returning cached response for message: {lastMessage}" , messages . Last ( ) . Content ) ;
184
+ return cachedResponse ;
185
+ }
186
+
187
+ var response = await GenerateChatCompletionInternal ( messages ) ;
188
+ if ( response == null )
189
+ {
190
+ return null ;
191
+ }
192
+ if ( response . Error is not null )
193
+ {
194
+ _logger . LogError ( response . Error ) ;
195
+ return null ;
196
+ }
197
+ else
198
+ {
199
+ if ( _configuration . CacheResponses && response . Response is not null )
200
+ {
201
+ _cacheChatCompletion [ messages ] = response ;
202
+ }
203
+
204
+ return response ;
205
+ }
206
+ }
207
+
208
+ private async Task < OllamaLanguageModelChatCompletionResponse ? > GenerateChatCompletionInternal ( ILanguageModelChatCompletionMessage [ ] messages )
209
+ {
210
+ Debug . Assert ( _configuration != null , "Configuration is null" ) ;
211
+
212
+ try
213
+ {
214
+ using var client = new HttpClient ( ) ;
215
+ var url = $ "{ _configuration . Url } /api/chat";
216
+ _logger . LogDebug ( "Requesting chat completion. Message: {lastMessage}" , messages . Last ( ) . Content ) ;
217
+
218
+ var response = await client . PostAsJsonAsync ( url ,
219
+ new
220
+ {
221
+ messages ,
222
+ model = _configuration . Model ,
223
+ stream = false
224
+ }
225
+ ) ;
226
+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
227
+
228
+ var res = await response . Content . ReadFromJsonAsync < OllamaLanguageModelChatCompletionResponse > ( ) ;
229
+ if ( res is null )
230
+ {
231
+ return res ;
232
+ }
233
+
234
+ res . RequestUrl = url ;
235
+ return res ;
236
+ }
237
+ catch ( Exception ex )
238
+ {
239
+ _logger . LogError ( ex , "Failed to generate chat completion" ) ;
240
+ return null ;
241
+ }
242
+ }
243
+ }
244
+
245
+ internal static class CacheChatCompletionExtensions
246
+ {
247
+ public static OllamaLanguageModelChatCompletionMessage [ ] ? GetKey (
248
+ this Dictionary < OllamaLanguageModelChatCompletionMessage [ ] , OllamaLanguageModelChatCompletionResponse > cache ,
249
+ ILanguageModelChatCompletionMessage [ ] messages )
250
+ {
251
+ return cache . Keys . FirstOrDefault ( k => k . SequenceEqual ( messages ) ) ;
252
+ }
253
+
254
+ public static bool TryGetValue (
255
+ this Dictionary < OllamaLanguageModelChatCompletionMessage [ ] , OllamaLanguageModelChatCompletionResponse > cache ,
256
+ ILanguageModelChatCompletionMessage [ ] messages , out OllamaLanguageModelChatCompletionResponse ? value )
257
+ {
258
+ var key = cache . GetKey ( messages ) ;
259
+ if ( key is null )
260
+ {
261
+ value = null ;
262
+ return false ;
263
+ }
264
+
265
+ value = cache [ key ] ;
266
+ return true ;
267
+ }
268
+ }
0 commit comments