Skip to content

Commit 1f5078f

Browse files
authored
add support tokenization req (#6)
1 parent b23c5ba commit 1f5078f

File tree

7 files changed

+142
-6
lines changed

7 files changed

+142
-6
lines changed

src/main/java/com/github/llmjava/cohere4j/CohereApi.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import com.github.llmjava.cohere4j.request.ClassifyRequest;
44
import com.github.llmjava.cohere4j.request.EmbedRequest;
55
import com.github.llmjava.cohere4j.request.GenerateRequest;
6+
import com.github.llmjava.cohere4j.request.TokenizeRequest;
67
import com.github.llmjava.cohere4j.response.ClassifyResponse;
78
import com.github.llmjava.cohere4j.response.EmbedResponse;
89
import com.github.llmjava.cohere4j.response.GenerateResponse;
10+
import com.github.llmjava.cohere4j.response.TokenizeResponse;
911
import retrofit2.Call;
1012
import retrofit2.http.Body;
1113
import retrofit2.http.Headers;
@@ -14,24 +16,50 @@
1416

1517
public interface CohereApi {
1618

19+
/**
20+
* This endpoint generates realistic text conditioned on a given input.
21+
*/
1722
@POST("/v1/generate")
1823
@Headers({"accept: application/json", "content-type: application/json"})
1924
Call<GenerateResponse>
2025
generate(@Body GenerateRequest request);
2126

27+
/**
28+
* This endpoint generates realistic text conditioned on a given input.
29+
*/
2230
@Streaming
2331
@POST("/v1/generate")
2432
@Headers({"accept: application/stream+json", "content-type: application/json"})
2533
Call<String>
2634
generateStream(@Body GenerateRequest request);
2735

36+
/**
37+
* This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
38+
*
39+
* Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.
40+
*
41+
* If you want to learn more how to use the embedding model, have a look at the Semantic Search Guide
42+
*/
2843
@POST("/v1/embed")
2944
@Headers({"accept: application/json", "content-type: application/json"})
3045
Call<EmbedResponse>
3146
embed(@Body EmbedRequest request);
3247

48+
/**
49+
* This endpoint makes a prediction about which label fits the specified text inputs best. To make a prediction, Classify uses the provided examples of text + label pairs as a reference.
50+
*
51+
* Note: Custom Models trained on classification examples don't require the examples parameter to be passed in explicitly.
52+
*/
3353
@POST("/v1/classify")
3454
@Headers({"accept: application/json", "content-type: application/json"})
3555
Call<ClassifyResponse>
3656
classify(@Body ClassifyRequest request);
57+
58+
/**
59+
* This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE). To learn more about tokenization and byte pair encoding, see the tokens page.
60+
*/
61+
@POST("/v1/tokenize")
62+
@Headers({"accept: application/json", "content-type: application/json"})
63+
Call<TokenizeResponse>
64+
tokenize(@Body TokenizeRequest request);
3765
}

src/main/java/com/github/llmjava/cohere4j/CohereClient.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import com.github.llmjava.cohere4j.request.ClassifyRequest;
66
import com.github.llmjava.cohere4j.request.EmbedRequest;
77
import com.github.llmjava.cohere4j.request.GenerateRequest;
8+
import com.github.llmjava.cohere4j.request.TokenizeRequest;
89
import com.github.llmjava.cohere4j.response.ClassifyResponse;
910
import com.github.llmjava.cohere4j.response.EmbedResponse;
1011
import com.github.llmjava.cohere4j.response.GenerateResponse;
12+
import com.github.llmjava.cohere4j.response.TokenizeResponse;
1113
import com.github.llmjava.cohere4j.response.streaming.StreamGenerateResponse;
1214
import com.github.llmjava.cohere4j.response.streaming.ResponseConverter;
1315
import com.google.gson.Gson;
@@ -80,6 +82,14 @@ public void classifyAsync(ClassifyRequest request, AsyncCallback<ClassifyRespons
8082
execute(api.classify(request), callback);
8183
}
8284

85+
public TokenizeResponse tokenize(TokenizeRequest request) {
86+
return execute(api.tokenize(request));
87+
}
88+
89+
public void tokenizeAsync(TokenizeRequest request, AsyncCallback<TokenizeResponse> callback) {
90+
execute(api.tokenize(request), callback);
91+
}
92+
8393
private <T> T execute(Call<T> action) {
8494
try {
8595
Response<T> response = action.execute();
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package com.github.llmjava.cohere4j.request;
2+
3+
public class TokenizeRequest {
4+
/**
5+
* The string to be tokenized, the minimum text length is 1 character, and the maximum text length is 65536 characters.
6+
*/
7+
private String text;
8+
/**
9+
* An optional parameter to provide the model name. This will ensure that the tokenization uses the tokenizer used by that model.
10+
*/
11+
private String model;
12+
13+
14+
TokenizeRequest(Builder builder) {
15+
this.text = builder.text;
16+
this.model = builder.model;
17+
}
18+
19+
public static class Builder {
20+
21+
private String text;
22+
private String model;
23+
24+
public Builder withText(String text) {
25+
this.text = text;
26+
return this;
27+
}
28+
29+
public Builder withModel(String model) {
30+
this.model = model;
31+
return this;
32+
}
33+
34+
public TokenizeRequest build() {
35+
return new TokenizeRequest(this);
36+
}
37+
}
38+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.github.llmjava.cohere4j.response;
2+
3+
public class TokenizeResponse {
4+
private Integer[] tokens;
5+
private String[] token_strings;
6+
private Meta meta;
7+
8+
public Integer[] getTokens() {
9+
return tokens;
10+
}
11+
12+
public Integer getToken(int index) {
13+
return tokens[index];
14+
}
15+
16+
public String[] getTokenStrings() {
17+
return token_strings;
18+
}
19+
20+
public String getTokenString(int index) {
21+
return token_strings[index];
22+
}
23+
}

src/test/java/com/github/llmjava/cohere4j/ClassificationExample.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
import com.github.llmjava.cohere4j.callback.AsyncCallback;
44
import com.github.llmjava.cohere4j.request.ClassifyRequest;
5-
import com.github.llmjava.cohere4j.request.EmbedRequest;
65
import com.github.llmjava.cohere4j.response.ClassifyResponse;
7-
import com.github.llmjava.cohere4j.response.EmbedResponse;
86
import org.apache.commons.configuration2.ex.ConfigurationException;
97

108
public class ClassificationExample {
@@ -39,9 +37,9 @@ public static void main(String[] args) throws ConfigurationException {
3937
@Override
4038
public void onSuccess(ClassifyResponse response) {
4139
System.out.println("--- Async example - onSuccess");
42-
System.out.println("Input: " + response.getClassification(0).getInput());
43-
System.out.println("Prediction: " + response.getClassification(0).getPrediction());
44-
System.out.println("Confidence: " + response.getClassification(0).getConfidence()); }
40+
System.out.println("Input: " + response.getClassification(1).getInput());
41+
System.out.println("Prediction: " + response.getClassification(1).getPrediction());
42+
System.out.println("Confidence: " + response.getClassification(1).getConfidence()); }
4543

4644
@Override
4745
public void onFailure(Throwable throwable) {

src/test/java/com/github/llmjava/cohere4j/GenerationExample.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public static void main(String[] args) throws ConfigurationException {
2020
.build();
2121

2222
System.out.println("--- Sync example");
23-
System.out.println(client.generate(request1).getTexts());
23+
System.out.println(client.generate(request1).getTexts().get(0));
2424
client.generateAsync(request1, new AsyncCallback<GenerateResponse>() {
2525
@Override
2626
public void onSuccess(GenerateResponse completion) {
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package com.github.llmjava.cohere4j;
2+
3+
import com.github.llmjava.cohere4j.callback.AsyncCallback;
4+
import com.github.llmjava.cohere4j.request.TokenizeRequest;
5+
import com.github.llmjava.cohere4j.response.TokenizeResponse;
6+
import org.apache.commons.configuration2.ex.ConfigurationException;
7+
8+
public class TokenizationExample {
9+
10+
public static void main(String[] args) throws ConfigurationException {
11+
CohereConfig config = CohereConfig.fromProperties("cohere.properties");
12+
CohereClient client = new CohereClient.Builder().withConfig(config).build();
13+
14+
TokenizeRequest request = new TokenizeRequest.Builder()
15+
.withText("tokenize me! :D")
16+
.withModel("command")
17+
.build();
18+
19+
System.out.println("--- Sync example");
20+
TokenizeResponse response = client.tokenize(request);
21+
System.out.println("Token: " + response.getTokenString(0));
22+
System.out.println("Token ID: " + response.getToken(0));
23+
24+
client.tokenizeAsync(request, new AsyncCallback<TokenizeResponse>() {
25+
@Override
26+
public void onSuccess(TokenizeResponse response) {
27+
System.out.println("--- Async example - onSuccess");
28+
System.out.println("Token: " + response.getTokenString(0));
29+
System.out.println("Token ID: " + response.getToken(0));
30+
}
31+
32+
@Override
33+
public void onFailure(Throwable throwable) {
34+
System.out.println("--- Async example - onFailure");
35+
throwable.printStackTrace();
36+
}
37+
});
38+
}
39+
}

0 commit comments

Comments
 (0)