Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/main/java/com/github/llmjava/cohere4j/CohereApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import com.github.llmjava.cohere4j.request.ClassifyRequest;
import com.github.llmjava.cohere4j.request.EmbedRequest;
import com.github.llmjava.cohere4j.request.GenerateRequest;
import com.github.llmjava.cohere4j.request.TokenizeRequest;
import com.github.llmjava.cohere4j.response.ClassifyResponse;
import com.github.llmjava.cohere4j.response.EmbedResponse;
import com.github.llmjava.cohere4j.response.GenerateResponse;
import com.github.llmjava.cohere4j.response.TokenizeResponse;
import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.Headers;
Expand All @@ -14,24 +16,50 @@

public interface CohereApi {

/**
* This endpoint generates realistic text conditioned on a given input.
*/
@POST("/v1/generate")
@Headers({"accept: application/json", "content-type: application/json"})
Call<GenerateResponse>
generate(@Body GenerateRequest request);

/**
* This endpoint generates realistic text conditioned on a given input.
*/
@Streaming
@POST("/v1/generate")
@Headers({"accept: application/stream+json", "content-type: application/json"})
Call<String>
generateStream(@Body GenerateRequest request);

/**
* This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
*
* Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
*
* If you want to learn more how to use the embedding model, have a look at the Semantic Search Guide
*/
@POST("/v1/embed")
@Headers({"accept: application/json", "content-type: application/json"})
Call<EmbedResponse>
embed(@Body EmbedRequest request);

/**
* This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided examples of text + label pairs as a reference.
*
* Note: Custom Models trained on classification examples don't require the examples parameter to be passed in explicitly.
*/
@POST("/v1/classify")
@Headers({"accept: application/json", "content-type: application/json"})
Call<ClassifyResponse>
classify(@Body ClassifyRequest request);

/**
* This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page.
*/
@POST("/v1/tokenize")
@Headers({"accept: application/json", "content-type: application/json"})
Call<TokenizeResponse>
tokenize(@Body TokenizeRequest request);
}
10 changes: 10 additions & 0 deletions src/main/java/com/github/llmjava/cohere4j/CohereClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import com.github.llmjava.cohere4j.request.ClassifyRequest;
import com.github.llmjava.cohere4j.request.EmbedRequest;
import com.github.llmjava.cohere4j.request.GenerateRequest;
import com.github.llmjava.cohere4j.request.TokenizeRequest;
import com.github.llmjava.cohere4j.response.ClassifyResponse;
import com.github.llmjava.cohere4j.response.EmbedResponse;
import com.github.llmjava.cohere4j.response.GenerateResponse;
import com.github.llmjava.cohere4j.response.TokenizeResponse;
import com.github.llmjava.cohere4j.response.streaming.StreamGenerateResponse;
import com.github.llmjava.cohere4j.response.streaming.ResponseConverter;
import com.google.gson.Gson;
Expand Down Expand Up @@ -80,6 +82,14 @@ public void classifyAsync(ClassifyRequest request, AsyncCallback<ClassifyRespons
execute(api.classify(request), callback);
}

public TokenizeResponse tokenize(TokenizeRequest request) {
return execute(api.tokenize(request));
}

public void tokenizeAsync(TokenizeRequest request, AsyncCallback<TokenizeResponse> callback) {
execute(api.tokenize(request), callback);
}

private <T> T execute(Call<T> action) {
try {
Response<T> response = action.execute();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.github.llmjava.cohere4j.request;

public class TokenizeRequest {
/**
* The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters.
*/
private String text;
/**
* An optional parameter to provide the model name. This will ensure that the tokenization uses the tokenizer used by that model.
*/
private String model;


TokenizeRequest(Builder builder) {
this.text = builder.text;
this.model = builder.model;
}

public static class Builder {

private String text;
private String model;

public Builder withText(String text) {
this.text = text;
return this;
}

public Builder withModel(String model) {
this.model = model;
return this;
}

public TokenizeRequest build() {
return new TokenizeRequest(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.github.llmjava.cohere4j.response;

public class TokenizeResponse {
private Integer[] tokens;
private String[] token_strings;
private Meta meta;

public Integer[] getTokens() {
return tokens;
}

public Integer getToken(int index) {
return tokens[index];
}

public String[] getTokenStrings() {
return token_strings;
}

public String getTokenString(int index) {
return token_strings[index];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import com.github.llmjava.cohere4j.callback.AsyncCallback;
import com.github.llmjava.cohere4j.request.ClassifyRequest;
import com.github.llmjava.cohere4j.request.EmbedRequest;
import com.github.llmjava.cohere4j.response.ClassifyResponse;
import com.github.llmjava.cohere4j.response.EmbedResponse;
import org.apache.commons.configuration2.ex.ConfigurationException;

public class ClassificationExample {
Expand Down Expand Up @@ -39,9 +37,9 @@ public static void main(String[] args) throws ConfigurationException {
@Override
public void onSuccess(ClassifyResponse response) {
System.out.println("--- Async example - onSuccess");
System.out.println("Input: " + response.getClassification(0).getInput());
System.out.println("Prediction: " + response.getClassification(0).getPrediction());
System.out.println("Confidence: " + response.getClassification(0).getConfidence()); }
System.out.println("Input: " + response.getClassification(1).getInput());
System.out.println("Prediction: " + response.getClassification(1).getPrediction());
System.out.println("Confidence: " + response.getClassification(1).getConfidence()); }

@Override
public void onFailure(Throwable throwable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public static void main(String[] args) throws ConfigurationException {
.build();

System.out.println("--- Sync example");
System.out.println(client.generate(request1).getTexts());
System.out.println(client.generate(request1).getTexts().get(0));
client.generateAsync(request1, new AsyncCallback<GenerateResponse>() {
@Override
public void onSuccess(GenerateResponse completion) {
Expand Down
39 changes: 39 additions & 0 deletions src/test/java/com/github/llmjava/cohere4j/TokenizationExample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.github.llmjava.cohere4j;

import com.github.llmjava.cohere4j.callback.AsyncCallback;
import com.github.llmjava.cohere4j.request.TokenizeRequest;
import com.github.llmjava.cohere4j.response.TokenizeResponse;
import org.apache.commons.configuration2.ex.ConfigurationException;

public class TokenizationExample {

public static void main(String[] args) throws ConfigurationException {
CohereConfig config = CohereConfig.fromProperties("cohere.properties");
CohereClient client = new CohereClient.Builder().withConfig(config).build();

TokenizeRequest request = new TokenizeRequest.Builder()
.withText("tokenize me! :D")
.withModel("command")
.build();

System.out.println("--- Sync example");
TokenizeResponse response = client.tokenize(request);
System.out.println("Token: " + response.getTokenString(0));
System.out.println("Token ID: " + response.getToken(0));

client.tokenizeAsync(request, new AsyncCallback<TokenizeResponse>() {
@Override
public void onSuccess(TokenizeResponse response) {
System.out.println("--- Async example - onSuccess");
System.out.println("Token: " + response.getTokenString(0));
System.out.println("Token ID: " + response.getToken(0));
}

@Override
public void onFailure(Throwable throwable) {
System.out.println("--- Async example - onFailure");
throwable.printStackTrace();
}
});
}
}