Skip to content

Commit a2c07ef

Browse files
authored
add all support req generation params (#3)
1 parent 2132697 commit a2c07ef

File tree

4 files changed

+144
-10
lines changed

4 files changed

+144
-10
lines changed

src/main/java/com/github/llmjava/cohere4j/CohereConfig.java

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import org.apache.commons.configuration2.ex.ConfigurationException;
66

77
import java.time.Duration;
8+
import java.util.*;
89

910
public class CohereConfig {
1011
private final Configuration config;
@@ -29,10 +30,75 @@ public Duration getTimeout() {
2930
Long timeout = config.getLong(CohereConfig.TIMEOUT, CohereConfig.DEFAULT_TIMEOUT_MILLIS);
3031
return Duration.ofMillis(timeout);
3132
}
33+
public Integer getTopK() {
34+
return config.getInteger("topK", 0);
35+
}
36+
public Double geTopP() {
37+
return config.getDouble("topP", 0.0);
38+
}
39+
public Double getTemperature() {
40+
return config.getDouble("temperature", 0.75);
41+
}
42+
43+
public String getModel() {
44+
return config.getString("cohere.model", "command");
45+
}
46+
public Integer getNumGenerations() {
47+
return config.getInteger("num_generations", 1);
48+
}
49+
public Boolean isStream() {
50+
return config.getBoolean("stream", false);
51+
}
52+
public Integer getMaxTokens() {
53+
return config.getInteger("max_tokens", 1024);
54+
}
55+
public String getTruncate() {
56+
return config.getString("truncate", "END");
57+
}
58+
public String getPreset() {
59+
return config.getString("preset");
60+
}
61+
public List<String> getEndSequences() {
62+
List<String> result = new ArrayList<>();
63+
String sequences = config.getString("end_sequences");
64+
if(sequences!=null) {
65+
for(String seq: sequences.split(",")) result.add(seq);
66+
}
67+
return result;
68+
}
69+
public List<String> getStopSequences() {
70+
List<String> result = new ArrayList<>();
71+
String sequences = config.getString("stop_sequences");
72+
if(sequences!=null) {
73+
for(String seq: sequences.split(",")) result.add(seq);
74+
}
75+
return result;
76+
}
77+
public Double getFrequencyPenalty() {
78+
return config.getDouble("frequency_penalty", 0.0);
79+
}
80+
public Double getPresencePenalty() {
81+
return config.getDouble("presence_penalty", 0.0);
82+
}
83+
public String getReturnLikelihoods() {
84+
return config.getString("return_likelihoods");
85+
}
86+
public Map<String, Double> getLogitBias() {
87+
Map<String, Double> result = new HashMap<>();
88+
String sequences = config.getString("logit_bias");
89+
if(sequences!=null) {
90+
for(String pair: sequences.split(",")) {
91+
String[] kv = pair.split(":");
92+
if(kv.length==2) {
93+
result.put(kv[0], Double.valueOf(kv[1]));
94+
}
95+
}
96+
}
97+
return result;
98+
}
3299

33100
public static CohereConfig fromProperties(String path) throws ConfigurationException {
34101
Configuration baseConfig = new Configurations().properties(path);
35102
return new CohereConfig(baseConfig);
36103
}
37-
38104
}

src/main/java/com/github/llmjava/cohere4j/request/GenerationRequest.java

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.github.llmjava.cohere4j.request;
22

3+
import com.github.llmjava.cohere4j.CohereConfig;
4+
35
import java.util.List;
46
import java.util.Map;
57

@@ -18,6 +20,11 @@ public class GenerationRequest {
1820
*/
1921
private String model;
2022

23+
/**
24+
* The maximum number of generations that will be returned. Defaults to 1, min value of 1, max value of 5.
25+
*/
26+
private Integer num_generations;
27+
2128
/**
2229
* When true, the response will be a JSON stream of events. Streaming is beneficial for user interfaces that render the contents of the response piece by piece, as it gets generated.
2330
*
@@ -116,12 +123,22 @@ public class GenerationRequest {
116123
private Map<String, Double> logit_bias;
117124

118125
GenerationRequest(Builder builder) {
119-
this.prompt = builder.prompt;
120-
this.model = builder.model;
121-
this.stream = builder.stream;
122-
this.max_tokens = builder.max_tokens;
123-
this.truncate = builder.truncate;
124-
this.return_likelihoods = builder.return_likelihoods;
126+
prompt = builder.prompt;
127+
model = builder.model;
128+
num_generations = builder.num_generations;
129+
stream = builder.stream;
130+
max_tokens = builder.max_tokens;
131+
truncate = builder.truncate;
132+
temperature = builder.temperature;
133+
preset = builder.preset;
134+
end_sequences = builder.end_sequences;
135+
stop_sequences = builder.stop_sequences;
136+
k = builder.k;
137+
p = builder.p;
138+
frequency_penalty = builder.frequency_penalty;
139+
presence_penalty = builder.presence_penalty;
140+
return_likelihoods = builder.return_likelihoods;
141+
logit_bias = builder.logit_bias;
125142
}
126143

127144
public Boolean isStreaming() {
@@ -131,10 +148,20 @@ public Boolean isStreaming() {
131148
public static class Builder {
132149
private String prompt;
133150
private String model;
151+
private Integer num_generations;
134152
private Boolean stream;
135153
private Integer max_tokens;
136154
private String truncate;
155+
private Double temperature;
156+
private String preset;
157+
private List<String> end_sequences;
158+
private List<String> stop_sequences;
159+
private Integer k;
160+
private Double p;
161+
private Double frequency_penalty;
162+
private Double presence_penalty;
137163
private String return_likelihoods;
164+
private Map<String, Double> logit_bias;
138165

139166
public Builder withPrompt(String prompt) {
140167
this.prompt = prompt;
@@ -166,6 +193,25 @@ public Builder withLikelihoods(String likelihoods) {
166193
return this;
167194
}
168195

196+
public Builder withConfig(CohereConfig config) {
197+
model = config.getModel();
198+
num_generations = config.getNumGenerations();
199+
stream = config.isStream();
200+
max_tokens = config.getMaxTokens();
201+
truncate = config.getTruncate();
202+
temperature = config.getTemperature();
203+
preset = config.getPreset();
204+
end_sequences = config.getEndSequences();
205+
stop_sequences = config.getStopSequences();
206+
k = config.getTopK();
207+
p = config.geTopP();
208+
frequency_penalty = config.getFrequencyPenalty();
209+
presence_penalty = config.getPresencePenalty();
210+
return_likelihoods = config.getReturnLikelihoods();
211+
logit_bias = config.getLogitBias();
212+
return this;
213+
}
214+
169215
public GenerationRequest build() {
170216
return new GenerationRequest(this);
171217
}

src/test/java/com/github/llmjava/cohere4j/GenerationExample.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public static void main(String[] args) throws ConfigurationException {
1616
String text = "tell a joke";
1717
GenerationRequest request1 = new GenerationRequest.Builder()
1818
.withPrompt(text)
19-
.withMaxTokens(1024)
19+
.withConfig(config)
2020
.build();
2121

2222
System.out.println("--- Sync example");
@@ -37,8 +37,8 @@ public void onFailure(Throwable throwable) {
3737

3838
GenerationRequest request2 = new GenerationRequest.Builder()
3939
.withPrompt(text)
40+
.withConfig(config)
4041
.withStream(true)
41-
.withMaxTokens(1024)
4242
.build();
4343
client.generateStream(request2, new StreamingCallback<StreamingGenerationResponse>() {
4444
@Override

src/test/resources/cohere.properties

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,26 @@
22
cohere.apiKey=${env:COHERE_API_KEY}
33

44
# timeout in milliseconds
5-
timeout=10000
5+
timeout=10000
6+
7+
topK=0
8+
topP=0.0
9+
temperature=0.75
10+
11+
cohere.model=command
12+
num_generations=1
13+
stream=false
14+
max_tokens=1024
15+
truncate=END
16+
#preset=
17+
18+
# comma separated
19+
#end_sequences=
20+
#stop_sequences;
21+
22+
frequency_penalty=0.0
23+
presence_penalty=0.0
24+
return_likelihoods=NONE
25+
26+
# comma separated key:value sequence
27+
#logit_bias=

0 commit comments

Comments
 (0)