Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/main/java/com/github/llmjava/cohere4j/CohereApi.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.github.llmjava.cohere4j;

import com.github.llmjava.cohere4j.request.ClassifyRequest;
import com.github.llmjava.cohere4j.request.EmbedRequest;
import com.github.llmjava.cohere4j.request.GenerateRequest;
import com.github.llmjava.cohere4j.response.ClassifyResponse;
import com.github.llmjava.cohere4j.response.EmbedResponse;
import com.github.llmjava.cohere4j.response.GenerateResponse;
import retrofit2.Call;
Expand All @@ -27,4 +29,9 @@ public interface CohereApi {
@Headers({"accept: application/json", "content-type: application/json"})
Call<EmbedResponse>
embed(@Body EmbedRequest request);

@POST("/v1/classify")
@Headers({"accept: application/json", "content-type: application/json"})
Call<ClassifyResponse>
classify(@Body ClassifyRequest request);
}
10 changes: 10 additions & 0 deletions src/main/java/com/github/llmjava/cohere4j/CohereClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import com.github.llmjava.cohere4j.callback.AsyncCallback;
import com.github.llmjava.cohere4j.callback.StreamingCallback;
import com.github.llmjava.cohere4j.request.ClassifyRequest;
import com.github.llmjava.cohere4j.request.EmbedRequest;
import com.github.llmjava.cohere4j.request.GenerateRequest;
import com.github.llmjava.cohere4j.response.ClassifyResponse;
import com.github.llmjava.cohere4j.response.EmbedResponse;
import com.github.llmjava.cohere4j.response.GenerateResponse;
import com.github.llmjava.cohere4j.response.streaming.StreamGenerateResponse;
Expand Down Expand Up @@ -70,6 +72,14 @@ public void embedAsync(EmbedRequest request, AsyncCallback<EmbedResponse> callba
execute(api.embed(request), callback);
}

public ClassifyResponse classify(ClassifyRequest request) {
return execute(api.classify(request));
}

public void classifyAsync(ClassifyRequest request, AsyncCallback<ClassifyResponse> callback) {
execute(api.classify(request), callback);
}

private <T> T execute(Call<T> action) {
try {
Response<T> response = action.execute();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package com.github.llmjava.cohere4j.request;

import java.util.ArrayList;
import java.util.List;

public class ClassifyRequest {
/**
* Represents a list of queries to be classified, each entry must not be empty. The maximum is 96 inputs.
*/
private List<String> inputs;
/**
* An array of examples to provide context to the model. Each example is a text string and its associated label/class. Each unique label requires at least 2 examples associated with it; the maximum number of examples is 2500, and each example has a maximum length of 512 tokens. The values should be structured as {text: "...",label: "..."}.
*
* Note: Custom Models trained on classification examples don't require the examples parameter to be passed in explicitly.
*/
private List<Example> examples;
/**
* The identifier of the model. Currently available models are embed-multilingual-v2.0, embed-english-light-v2.0, and embed-english-v2.0 (default). Smaller "light" models are faster, while larger models will perform better. Custom models can also be supplied with their full ID.
*/
private String model;
/**
* The ID of a custom playground preset. You can create presets in the playground. If you use a preset, all other parameters become optional, and any included parameters will override the preset's parameters.
*/
private String preset;
/**
* One of NONE|START|END to specify how the API will handle inputs longer than the maximum token length.
*
* Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
*
* If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
*
* Default: END
*/
private String truncate;

ClassifyRequest(Builder builder) {
this.inputs = builder.inputs;
this.examples = builder.examples;
this.model = builder.model;
this.preset = builder.preset;
this.truncate = builder.truncate;
}

public static class Example {
private String text;
private String label;

public Example(String text, String label) {
this.text = text;
this.label = label;
}
}

public static class Builder {
private List<String> inputs = new ArrayList<>();
private List<Example> examples = new ArrayList<>();
private String model;
private String preset;
private String truncate;

public Builder withInput(String text) {
this.inputs.add(text);
return this;
}

public Builder withExample(String text, String label) {
this.examples.add(new Example(text, label));
return this;
}

public Builder withModel(String model) {
this.model = model;
return this;
}

public Builder withPreset(String preset) {
this.preset = preset;
return this;
}

public Builder withTruncate(String truncate) {
this.truncate = truncate;
return this;
}

public ClassifyRequest build() {
return new ClassifyRequest(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.github.llmjava.cohere4j.response;

import java.util.Map;

public class ClassifyResponse {
private String id;

private Classification[] classifications;

public static class Classification {
private String id;
private String input;
private String prediction;
private Float confidence;
private Map<String, ClassificationDetail> labels;

public String getInput() {
return input;
}

public String getPrediction() {
return prediction;
}

public Float getConfidence() {
return confidence;
}

public Float getConfidence(String label) {
return labels.get(label).confidence;
}
}

public static class ClassificationDetail {
private Float confidence;
}

private Meta meta;

public Classification[] getClassifications() {
return classifications;
}
public Classification getClassification(int index) {
return classifications[index];
}
}
4 changes: 4 additions & 0 deletions src/main/java/com/github/llmjava/cohere4j/response/Meta.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@ public class Meta {
public static class ApiVersion {
String version;
}

public String getApiVersion() {
return api_version.version;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.github.llmjava.cohere4j;

import com.github.llmjava.cohere4j.callback.AsyncCallback;
import com.github.llmjava.cohere4j.request.ClassifyRequest;
import com.github.llmjava.cohere4j.request.EmbedRequest;
import com.github.llmjava.cohere4j.response.ClassifyResponse;
import com.github.llmjava.cohere4j.response.EmbedResponse;
import org.apache.commons.configuration2.ex.ConfigurationException;

public class ClassificationExample {

public static void main(String[] args) throws ConfigurationException {
CohereConfig config = CohereConfig.fromProperties("cohere.properties");
CohereClient client = new CohereClient.Builder().withConfig(config).build();

ClassifyRequest request = new ClassifyRequest.Builder()
.withExample("Dermatologists don't like her!", "Spam")
.withExample("Hello, open to this?", "Spam")
.withExample("I need help please wire me $1000 right now", "Spam")
.withExample("Nice to know you ;)", "Spam")
.withExample("Please help me?", "Spam")
.withExample("Your parcel will be delivered today", "Not spam")
.withExample("Review changes to our Terms and Conditions", "Not spam")
.withExample("Weekly sync notes", "Not spam")
.withExample("Re: Follow up from today’s meeting", "Not spam")
.withExample("Pre-read for tomorrow", "Not spam")
.withInput("Confirm your email address")
.withInput("hey i need u to send some $")
.withTruncate("END")
.build();

System.out.println("--- Sync example");
ClassifyResponse response = client.classify(request);
System.out.println("Input: " + response.getClassification(0).getInput());
System.out.println("Prediction: " + response.getClassification(0).getPrediction());
System.out.println("Confidence: " + response.getClassification(0).getConfidence());

client.classifyAsync(request, new AsyncCallback<ClassifyResponse>() {
@Override
public void onSuccess(ClassifyResponse response) {
System.out.println("--- Async example - onSuccess");
System.out.println("Input: " + response.getClassification(0).getInput());
System.out.println("Prediction: " + response.getClassification(0).getPrediction());
System.out.println("Confidence: " + response.getClassification(0).getConfidence()); }

@Override
public void onFailure(Throwable throwable) {
System.out.println("--- Async example - onFailure");
throwable.printStackTrace();
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ public static void main(String[] args) throws ConfigurationException {
System.out.println("Embeddings: " + response.getEmbeddings(0));
client.embedAsync(request, new AsyncCallback<EmbedResponse>() {
@Override
public void onSuccess(EmbedResponse completion) {
public void onSuccess(EmbedResponse response) {
System.out.println("--- Async example - onSuccess");
System.out.println("Texts: " + completion.getTexts()[0]);
System.out.println("Embeddings: " + completion.getEmbeddings(0));
System.out.println("Texts: " + response.getTexts()[0]);
System.out.println("Embeddings: " + response.getEmbeddings(0));
}

@Override
Expand Down