Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/main/java/com/github/llmjava/cohere4j/CohereApi.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.github.llmjava.cohere4j;

import com.github.llmjava.cohere4j.request.GenerationRequest;
import com.github.llmjava.cohere4j.response.GenerationResponse;
import com.github.llmjava.cohere4j.request.EmbedRequest;
import com.github.llmjava.cohere4j.request.GenerateRequest;
import com.github.llmjava.cohere4j.response.EmbedResponse;
import com.github.llmjava.cohere4j.response.GenerateResponse;
import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.Headers;
Expand All @@ -12,13 +14,17 @@ public interface CohereApi {

@POST("/v1/generate")
@Headers({"accept: application/json", "content-type: application/json"})
Call<GenerationResponse>
generate(@Body GenerationRequest request);
Call<GenerateResponse>
generate(@Body GenerateRequest request);

@Streaming
@POST("/v1/generate")
@Headers({"accept: application/stream+json", "content-type: application/json"})
Call<String>
generateStream(@Body GenerationRequest request);
generateStream(@Body GenerateRequest request);

@POST("/v1/embed")
@Headers({"accept: application/json", "content-type: application/json"})
Call<EmbedResponse>
embed(@Body EmbedRequest request);
}
81 changes: 49 additions & 32 deletions src/main/java/com/github/llmjava/cohere4j/CohereClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import com.github.llmjava.cohere4j.callback.AsyncCallback;
import com.github.llmjava.cohere4j.callback.StreamingCallback;
import com.github.llmjava.cohere4j.request.GenerationRequest;
import com.github.llmjava.cohere4j.response.GenerationResponse;
import com.github.llmjava.cohere4j.response.streaming.StreamingGenerationResponse;
import com.github.llmjava.cohere4j.request.EmbedRequest;
import com.github.llmjava.cohere4j.request.GenerateRequest;
import com.github.llmjava.cohere4j.response.EmbedResponse;
import com.github.llmjava.cohere4j.response.GenerateResponse;
import com.github.llmjava.cohere4j.response.streaming.StreamGenerateResponse;
import com.github.llmjava.cohere4j.response.streaming.ResponseConverter;
import com.google.gson.Gson;
import retrofit2.Call;
Expand All @@ -23,38 +25,15 @@ public class CohereClient {
this.gson = builder.gson;
}

public GenerationResponse generate(GenerationRequest request) {
try {
Response<GenerationResponse> response = api.generate(request).execute();
if (response.isSuccessful()) {
return response.body();
} else {
throw newException(response);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
public GenerateResponse generate(GenerateRequest request) {
return execute(api.generate(request));
}

public void generateAsync(GenerationRequest request, AsyncCallback<GenerationResponse> callback) {
api.generate(request).enqueue(new retrofit2.Callback<GenerationResponse>() {
@Override
public void onResponse(Call<GenerationResponse> call, Response<GenerationResponse> response) {
if (response.isSuccessful()) {
callback.onSuccess(response.body());
} else {
callback.onFailure(newException(response));
}
}

@Override
public void onFailure(Call<GenerationResponse> call, Throwable throwable) {
callback.onFailure(throwable);
}
});
public void generateAsync(GenerateRequest request, AsyncCallback<GenerateResponse> callback) {
execute(api.generate(request), callback);
}

public void generateStream(GenerationRequest request, StreamingCallback<StreamingGenerationResponse> callback) {
public void generateStream(GenerateRequest request, StreamingCallback<StreamGenerateResponse> callback) {
if(!request.isStreaming()) {
throw new IllegalArgumentException("Expected a streaming request");
}
Expand All @@ -63,7 +42,7 @@ public void generateStream(GenerationRequest request, StreamingCallback<Streamin
@Override
public void onResponse(Call<String> call, Response<String> response) {
if (response.isSuccessful()) {
for(StreamingGenerationResponse resp: converter.toStreamingGenerationResponse(response.body())) {
for(StreamGenerateResponse resp: converter.toStreamingGenerationResponse(response.body())) {
if(resp.isFinished()) {
callback.onComplete(resp);
} else {
Expand All @@ -83,6 +62,44 @@ public void onFailure(Call<String> call, Throwable throwable) {
});
}

public EmbedResponse embed(EmbedRequest request) {
return execute(api.embed(request));
}

public void embedAsync(EmbedRequest request, AsyncCallback<EmbedResponse> callback) {
execute(api.embed(request), callback);
}

private <T> T execute(Call<T> action) {
try {
Response<T> response = action.execute();
if (response.isSuccessful()) {
return response.body();
} else {
throw newException(response);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private <T> void execute(Call<T> action, AsyncCallback<T> callback) {
action.enqueue(new retrofit2.Callback<T>() {
@Override
public void onResponse(Call<T> call, Response<T> response) {
if (response.isSuccessful()) {
callback.onSuccess(response.body());
} else {
callback.onFailure(newException(response));
}
}

@Override
public void onFailure(Call<T> call, Throwable throwable) {
callback.onFailure(throwable);
}
});
}

/**
* Parse exceptions:
* status code: 429; body: {"message":"You are using a Trial key, which is limited to 5 API calls / minute. You can continue to use the Trial key for free or upgrade to a Production key with higher rate limits at 'https://dashboard.cohere.ai/api-keys'. Contact us on 'https://discord.gg/XW44jPfYJu' or email us at support@cohere.com with any questions"}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package com.github.llmjava.cohere4j.request;

import java.util.ArrayList;
import java.util.List;

public class EmbedRequest {
/**
* An array of strings for the model to embed. Maximum number of texts per call is 96. We recommend reducing the length of each text to be under 512 tokens for optimal quality.
*/
private String[] texts;
/**
* The identifier of the model. Smaller "light" models are faster, while larger models will perform better. Custom models can also be supplied with their full ID.
*
* Available models and corresponding embedding dimensions:
*
* embed-english-v2.0 (default) 4096
* embed-english-light-v2.0 1024
* embed-multilingual-v2.0 768
*/
private String model;
/**
* One of NONE|START|END to specify how the API will handle inputs longer than the maximum token length.
*
* Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
*
* If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
*
* Default: END
*/
private String truncate;

EmbedRequest(Builder builder) {
this.texts = builder.texts.toArray(new String[builder.texts.size()]);
this.model = builder.model;
this.truncate = builder.truncate;
}

public static class Builder {

private List<String> texts = new ArrayList<>();
private String model;
private String truncate;

public Builder withText(String text) {
this.texts.add(text);
return this;
}

public Builder withModel(String model) {
this.model = model;
return this;
}

public Builder withTruncate(String truncate) {
this.truncate = truncate;
return this;
}

public EmbedRequest build() {
return new EmbedRequest(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import java.util.List;
import java.util.Map;

public class GenerationRequest {
public class GenerateRequest {

/**
* The input text that serves as the starting point for generating the response.
Expand Down Expand Up @@ -122,7 +122,7 @@ public class GenerationRequest {
*/
private Map<String, Double> logit_bias;

GenerationRequest(Builder builder) {
GenerateRequest(Builder builder) {
prompt = builder.prompt;
model = builder.model;
num_generations = builder.num_generations;
Expand Down Expand Up @@ -212,8 +212,8 @@ public Builder withConfig(CohereConfig config) {
return this;
}

public GenerationRequest build() {
return new GenerationRequest(this);
public GenerateRequest build() {
return new GenerateRequest(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.github.llmjava.cohere4j.response;

public class EmbedResponse {
private String id;

private String[] texts;

private Float[][] embeddings;

private Meta meta;

public String[] getTexts() {
return texts;
}

public Float[][] getEmbeddings() {
return embeddings;
}

public Float[] getEmbeddings(int index) {
return embeddings[index];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import java.util.ArrayList;
import java.util.List;

public class GenerationResponse {
public class GenerateResponse {
private String id;
private String prompt;
private List<Generation> generations;
private Meta meta;

public String getPrompt() {
return prompt;
Expand Down
9 changes: 9 additions & 0 deletions src/main/java/com/github/llmjava/cohere4j/response/Meta.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.github.llmjava.cohere4j.response;

public class Meta {
ApiVersion api_version;

public static class ApiVersion {
String version;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ public ResponseConverter(Gson gson) {
* Type elmType = new TypeToken<T>(){}.getType();
* Type listType = new TypeToken<ArrayList<T>>(){}.getType();
*/
public List<StreamingGenerationResponse> toStreamingGenerationResponse(String responseBody) {
List<StreamingGenerationResponse> responses = new ArrayList<>();
Type elmType = new TypeToken<StreamingGenerationResponse>(){}.getType();
Type listType = new TypeToken<ArrayList<StreamingGenerationResponse>>(){}.getType();
public List<StreamGenerateResponse> toStreamingGenerationResponse(String responseBody) {
List<StreamGenerateResponse> responses = new ArrayList<>();
Type elmType = new TypeToken<StreamGenerateResponse>(){}.getType();
Type listType = new TypeToken<ArrayList<StreamGenerateResponse>>(){}.getType();
String[] lines = responseBody.split("\n");
for(String line: lines) {
if(line.charAt(0)=='[') {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import java.util.List;

public class StreamingGenerationResponse {
public class StreamGenerateResponse {

private String text;
private Integer index;
Expand Down
38 changes: 38 additions & 0 deletions src/test/java/com/github/llmjava/cohere4j/EmbeddingsExample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.github.llmjava.cohere4j;

import com.github.llmjava.cohere4j.callback.AsyncCallback;
import com.github.llmjava.cohere4j.request.EmbedRequest;
import com.github.llmjava.cohere4j.response.EmbedResponse;
import org.apache.commons.configuration2.ex.ConfigurationException;

public class EmbeddingsExample {

public static void main(String[] args) throws ConfigurationException {
CohereConfig config = CohereConfig.fromProperties("cohere.properties");
CohereClient client = new CohereClient.Builder().withConfig(config).build();

String text = "tell a joke";
EmbedRequest request = new EmbedRequest.Builder()
.withText(text)
.build();

System.out.println("--- Sync example");
EmbedResponse response = client.embed(request);
System.out.println("Texts: " + response.getTexts()[0]);
System.out.println("Embeddings: " + response.getEmbeddings(0));
client.embedAsync(request, new AsyncCallback<EmbedResponse>() {
@Override
public void onSuccess(EmbedResponse completion) {
System.out.println("--- Async example - onSuccess");
System.out.println("Texts: " + completion.getTexts()[0]);
System.out.println("Embeddings: " + completion.getEmbeddings(0));
}

@Override
public void onFailure(Throwable throwable) {
System.out.println("--- Async example - onFailure");
throwable.printStackTrace();
}
});
}
}
Loading