-
Notifications
You must be signed in to change notification settings - Fork 345
/
LLamaSharpPromptExecutionSettingsConverter.cs
97 lines (88 loc) · 3.82 KB
/
LLamaSharpPromptExecutionSettingsConverter.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
using System.Text.Json;
using System.Text.Json.Serialization;
namespace LLamaSharp.SemanticKernel;
/// <summary>
/// JSON converter for <see cref="LLamaSharpPromptExecutionSettings"/>
/// </summary>
public class LLamaSharpPromptExecutionSettingsConverter
: JsonConverter<LLamaSharpPromptExecutionSettings>
{
/// <inheritdoc/>
public override LLamaSharpPromptExecutionSettings Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var requestSettings = new LLamaSharpPromptExecutionSettings();
while (reader.Read() && reader.TokenType != JsonTokenType.EndObject)
{
if (reader.TokenType == JsonTokenType.PropertyName)
{
var propertyName = reader.GetString()?.ToUpperInvariant();
reader.Read();
switch (propertyName)
{
case "MODELID":
case "MODEL_ID":
requestSettings.ModelId = reader.GetString();
break;
case "TEMPERATURE":
requestSettings.Temperature = reader.GetDouble();
break;
case "TOPP":
case "TOP_P":
requestSettings.TopP = reader.GetDouble();
break;
case "FREQUENCYPENALTY":
case "FREQUENCY_PENALTY":
requestSettings.FrequencyPenalty = reader.GetDouble();
break;
case "PRESENCEPENALTY":
case "PRESENCE_PENALTY":
requestSettings.PresencePenalty = reader.GetDouble();
break;
case "MAXTOKENS":
case "MAX_TOKENS":
requestSettings.MaxTokens = reader.GetInt32();
break;
case "STOPSEQUENCES":
case "STOP_SEQUENCES":
requestSettings.StopSequences = JsonSerializer.Deserialize<IList<string>>(ref reader, options) ?? Array.Empty<string>();
break;
case "RESULTSPERPROMPT":
case "RESULTS_PER_PROMPT":
requestSettings.ResultsPerPrompt = reader.GetInt32();
break;
case "TOKENSELECTIONBIASES":
case "TOKEN_SELECTION_BIASES":
requestSettings.TokenSelectionBiases = JsonSerializer.Deserialize<IDictionary<int, int>>(ref reader, options) ?? new Dictionary<int, int>();
break;
default:
reader.Skip();
break;
}
}
}
return requestSettings;
}
/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, LLamaSharpPromptExecutionSettings value, JsonSerializerOptions options)
{
writer.WriteStartObject();
writer.WriteNumber("temperature", value.Temperature);
writer.WriteNumber("top_p", value.TopP);
writer.WriteNumber("frequency_penalty", value.FrequencyPenalty);
writer.WriteNumber("presence_penalty", value.PresencePenalty);
if (value.MaxTokens is null)
{
writer.WriteNull("max_tokens");
}
else
{
writer.WriteNumber("max_tokens", (decimal)value.MaxTokens);
}
writer.WritePropertyName("stop_sequences");
JsonSerializer.Serialize(writer, value.StopSequences, options);
writer.WriteNumber("results_per_prompt", value.ResultsPerPrompt);
writer.WritePropertyName("token_selection_biases");
JsonSerializer.Serialize(writer, value.TokenSelectionBiases, options);
writer.WriteEndObject();
}
}