Skip to content

Commit b23c5ba

Browse files
authored
add support classification req (#5)
1 parent 86ccac7 commit b23c5ba

File tree

7 files changed

+213
-3
lines changed

7 files changed

+213
-3
lines changed

src/main/java/com/github/llmjava/cohere4j/CohereApi.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package com.github.llmjava.cohere4j;
22

3+
import com.github.llmjava.cohere4j.request.ClassifyRequest;
34
import com.github.llmjava.cohere4j.request.EmbedRequest;
45
import com.github.llmjava.cohere4j.request.GenerateRequest;
6+
import com.github.llmjava.cohere4j.response.ClassifyResponse;
57
import com.github.llmjava.cohere4j.response.EmbedResponse;
68
import com.github.llmjava.cohere4j.response.GenerateResponse;
79
import retrofit2.Call;
@@ -27,4 +29,9 @@ public interface CohereApi {
2729
@Headers({"accept: application/json", "content-type: application/json"})
2830
Call<EmbedResponse>
2931
embed(@Body EmbedRequest request);
32+
33+
@POST("/v1/classify")
34+
@Headers({"accept: application/json", "content-type: application/json"})
35+
Call<ClassifyResponse>
36+
classify(@Body ClassifyRequest request);
3037
}

src/main/java/com/github/llmjava/cohere4j/CohereClient.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import com.github.llmjava.cohere4j.callback.AsyncCallback;
44
import com.github.llmjava.cohere4j.callback.StreamingCallback;
5+
import com.github.llmjava.cohere4j.request.ClassifyRequest;
56
import com.github.llmjava.cohere4j.request.EmbedRequest;
67
import com.github.llmjava.cohere4j.request.GenerateRequest;
8+
import com.github.llmjava.cohere4j.response.ClassifyResponse;
79
import com.github.llmjava.cohere4j.response.EmbedResponse;
810
import com.github.llmjava.cohere4j.response.GenerateResponse;
911
import com.github.llmjava.cohere4j.response.streaming.StreamGenerateResponse;
@@ -70,6 +72,14 @@ public void embedAsync(EmbedRequest request, AsyncCallback<EmbedResponse> callba
7072
execute(api.embed(request), callback);
7173
}
7274

75+
public ClassifyResponse classify(ClassifyRequest request) {
76+
return execute(api.classify(request));
77+
}
78+
79+
public void classifyAsync(ClassifyRequest request, AsyncCallback<ClassifyResponse> callback) {
80+
execute(api.classify(request), callback);
81+
}
82+
7383
private <T> T execute(Call<T> action) {
7484
try {
7585
Response<T> response = action.execute();
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package com.github.llmjava.cohere4j.request;
2+
3+
import java.util.ArrayList;
4+
import java.util.List;
5+
6+
public class ClassifyRequest {
7+
/**
8+
* Represents a list of queries to be classified, each entry must not be empty. The maximum is 96 inputs.
9+
*/
10+
private List<String> inputs;
11+
/**
12+
* An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as {text: "...",label: "..."}.
13+
*
14+
* Note: Custom Models trained on classification examples don't require the examples parameter to be passed in explicitly.
15+
*/
16+
private List<Example> examples;
17+
/**
18+
* The identifier of the model. Currently available models are embed-multilingual-v2.0, embed-english-light-v2.0, and embed-english-v2.0 (default). Smaller "light" models are faster, while larger models will perform better. Custom models can also be supplied with their full ID.
19+
*/
20+
private String model;
21+
/**
22+
* The ID of a custom playground preset. You can create presets in the playground. If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters.
23+
*/
24+
private String preset;
25+
/**
26+
* One of NONE|START|END to specify how the API will handle inputs longer than the maximum token length.
27+
*
28+
* Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
29+
*
30+
* If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
31+
*
32+
* Default: END
33+
*/
34+
private String truncate;
35+
36+
ClassifyRequest(Builder builder) {
37+
this.inputs = builder.inputs;
38+
this.examples = builder.examples;
39+
this.model = builder.model;
40+
this.preset = builder.preset;
41+
this.truncate = builder.truncate;
42+
}
43+
44+
public static class Example {
45+
private String text;
46+
private String label;
47+
48+
public Example(String text, String label) {
49+
this.text = text;
50+
this.label = label;
51+
}
52+
}
53+
54+
public static class Builder {
55+
private List<String> inputs = new ArrayList<>();
56+
private List<Example> examples = new ArrayList<>();
57+
private String model;
58+
private String preset;
59+
private String truncate;
60+
61+
public Builder withInput(String text) {
62+
this.inputs.add(text);
63+
return this;
64+
}
65+
66+
public Builder withExample(String text, String label) {
67+
this.examples.add(new Example(text, label));
68+
return this;
69+
}
70+
71+
public Builder withModel(String model) {
72+
this.model = model;
73+
return this;
74+
}
75+
76+
public Builder withPreset(String preset) {
77+
this.preset = preset;
78+
return this;
79+
}
80+
81+
public Builder withTruncate(String truncate) {
82+
this.truncate = truncate;
83+
return this;
84+
}
85+
86+
public ClassifyRequest build() {
87+
return new ClassifyRequest(this);
88+
}
89+
}
90+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package com.github.llmjava.cohere4j.response;
2+
3+
import java.util.Map;
4+
5+
public class ClassifyResponse {
6+
private String id;
7+
8+
private Classification[] classifications;
9+
10+
public static class Classification {
11+
private String id;
12+
private String input;
13+
private String prediction;
14+
private Float confidence;
15+
private Map<String, ClassificationDetail> labels;
16+
17+
public String getInput() {
18+
return input;
19+
}
20+
21+
public String getPrediction() {
22+
return prediction;
23+
}
24+
25+
public Float getConfidence() {
26+
return confidence;
27+
}
28+
29+
public Float getConfidence(String label) {
30+
return labels.get(label).confidence;
31+
}
32+
}
33+
34+
public static class ClassificationDetail {
35+
private Float confidence;
36+
}
37+
38+
private Meta meta;
39+
40+
public Classification[] getClassifications() {
41+
return classifications;
42+
}
43+
public Classification getClassification(int index) {
44+
return classifications[index];
45+
}
46+
}

src/main/java/com/github/llmjava/cohere4j/response/Meta.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,8 @@ public class Meta {
66
public static class ApiVersion {
77
String version;
88
}
9+
10+
public String getApiVersion() {
11+
return api_version.version;
12+
}
913
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package com.github.llmjava.cohere4j;
2+
3+
import com.github.llmjava.cohere4j.callback.AsyncCallback;
4+
import com.github.llmjava.cohere4j.request.ClassifyRequest;
5+
import com.github.llmjava.cohere4j.request.EmbedRequest;
6+
import com.github.llmjava.cohere4j.response.ClassifyResponse;
7+
import com.github.llmjava.cohere4j.response.EmbedResponse;
8+
import org.apache.commons.configuration2.ex.ConfigurationException;
9+
10+
public class ClassificationExample {
11+
12+
public static void main(String[] args) throws ConfigurationException {
13+
CohereConfig config = CohereConfig.fromProperties("cohere.properties");
14+
CohereClient client = new CohereClient.Builder().withConfig(config).build();
15+
16+
ClassifyRequest request = new ClassifyRequest.Builder()
17+
.withExample("Dermatologists don't like her!", "Spam")
18+
.withExample("Hello, open to this?", "Spam")
19+
.withExample("I need help please wire me $1000 right now", "Spam")
20+
.withExample("Nice to know you ;)", "Spam")
21+
.withExample("Please help me?", "Spam")
22+
.withExample("Your parcel will be delivered today", "Not spam")
23+
.withExample("Review changes to our Terms and Conditions", "Not spam")
24+
.withExample("Weekly sync notes", "Not spam")
25+
.withExample("Re: Follow up from today’s meeting", "Not spam")
26+
.withExample("Pre-read for tomorrow", "Not spam")
27+
.withInput("Confirm your email address")
28+
.withInput("hey i need u to send some $")
29+
.withTruncate("END")
30+
.build();
31+
32+
System.out.println("--- Sync example");
33+
ClassifyResponse response = client.classify(request);
34+
System.out.println("Input: " + response.getClassification(0).getInput());
35+
System.out.println("Prediction: " + response.getClassification(0).getPrediction());
36+
System.out.println("Confidence: " + response.getClassification(0).getConfidence());
37+
38+
client.classifyAsync(request, new AsyncCallback<ClassifyResponse>() {
39+
@Override
40+
public void onSuccess(ClassifyResponse response) {
41+
System.out.println("--- Async example - onSuccess");
42+
System.out.println("Input: " + response.getClassification(0).getInput());
43+
System.out.println("Prediction: " + response.getClassification(0).getPrediction());
44+
System.out.println("Confidence: " + response.getClassification(0).getConfidence()); }
45+
46+
@Override
47+
public void onFailure(Throwable throwable) {
48+
System.out.println("--- Async example - onFailure");
49+
throwable.printStackTrace();
50+
}
51+
});
52+
}
53+
}

src/test/java/com/github/llmjava/cohere4j/EmbeddingsExample.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ public static void main(String[] args) throws ConfigurationException {
2222
System.out.println("Embeddings: " + response.getEmbeddings(0));
2323
client.embedAsync(request, new AsyncCallback<EmbedResponse>() {
2424
@Override
25-
public void onSuccess(EmbedResponse completion) {
25+
public void onSuccess(EmbedResponse response) {
2626
System.out.println("--- Async example - onSuccess");
27-
System.out.println("Texts: " + completion.getTexts()[0]);
28-
System.out.println("Embeddings: " + completion.getEmbeddings(0));
27+
System.out.println("Texts: " + response.getTexts()[0]);
28+
System.out.println("Embeddings: " + response.getEmbeddings(0));
2929
}
3030

3131
@Override

0 commit comments

Comments
 (0)