Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/main/java/com/github/llmjava/cohere4j/CohereApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,24 @@ public interface CohereApi {
@Headers({"accept: application/json", "content-type: application/json"})
Call<DetectLanguageResponse>
detectLanguage(@Body DetectLanguageRequest request);

/**
* This endpoint generates a summary in English for a given text.
* @param request Summarization request
* @return Summarization response
*/
@POST("/v1/summarize")
@Headers({"accept: application/json", "content-type: application/json"})
Call<SummarizeResponse>
summarize(@Body SummarizeRequest request);

/**
* This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
* @param request rerank request
* @return rerank response
*/
@POST("/v1/rerank")
@Headers({"accept: application/json", "content-type: application/json"})
Call<RerankResponse>
rerank(@Body RerankRequest request);
}
16 changes: 16 additions & 0 deletions src/main/java/com/github/llmjava/cohere4j/CohereClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,22 @@ public void detectLanguageAsync(DetectLanguageRequest request, AsyncCallback<Det
execute(api.detectLanguage(request), callback);
}

public SummarizeResponse summarize(SummarizeRequest request) {
return execute(api.summarize(request));
}

public void summarizeAsync(SummarizeRequest request, AsyncCallback<SummarizeResponse> callback) {
execute(api.summarize(request), callback);
}

public RerankResponse rerank(RerankRequest request) {
return execute(api.rerank(request));
}

public void rerankAsync(RerankRequest request, AsyncCallback<RerankResponse> callback) {
execute(api.rerank(request), callback);
}

private <T> T execute(Call<T> action) {
try {
Response<T> response = action.execute();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package com.github.llmjava.cohere4j.request;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

public class RerankRequest {
/**
* The identifier of the model to use, one of : rerank-english-v2.0, rerank-multilingual-v2.0
*/
private String model;
/**
* The search query.
*/
private String query;
/**
* A list of document objects or strings to rerank.
* If a document is provided the text fields is required and all other fields will be preserved in the response.
* The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000.
*/
private List<String> documents;
/**
* The number of most relevant documents or indices to return, defaults to the length of the documents
*/
private Integer top_n;
/**
* If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request.
* If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request.
*/
private Boolean return_documents;
/**
* The maximum number of chunks to produce internally from a document
*/
private Integer max_chunks_per_doc;

RerankRequest(Builder builder) {
this.model = builder.model;
this.query = builder.query;
this.documents = builder.documents;
this.top_n = builder.top_n;
this.return_documents = builder.return_documents;
this.max_chunks_per_doc = builder.max_chunks_per_doc;
}

public static class Builder {
private String model;

private String query;
private List<String> documents = new ArrayList<>();
private Integer top_n;
private Boolean return_documents;
private Integer max_chunks_per_doc;

public Builder withModel(String model) {
this.model = model;
return this;
}

public Builder withQuery(String query) {
this.query = query;
return this;
}

public Builder withDocument(String document) {
this.documents.add(document);
return this;
}

public Builder withTopN(Integer top_n) {
this.top_n = top_n;
return this;
}

public Builder withReturnDocuments(Boolean return_documents) {
this.return_documents = return_documents;
return this;
}

public Builder withMaxChunksPerDoc(Integer max_chunks_per_doc) {
this.max_chunks_per_doc = max_chunks_per_doc;
return this;
}

public RerankRequest build() {
return new RerankRequest(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package com.github.llmjava.cohere4j.request;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
* This request generates a summary in English for a given text.
*/
public class SummarizeRequest {
/**
* The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English.
*/
String text;

/**
* One of short, medium, long, or auto defaults to auto. Indicates the approximate length of the summary. If auto is selected, the best option will be picked based on the input text.
*
* Default: medium
*/
String length;

/**
* One of paragraph, bullets, or auto, defaults to auto. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If auto is selected, the best option will be picked based on the input text.
*
* Default: paragraph
*/
String format;

/**
* The identifier of the model to generate the summary with. Currently available models are command (default), command-nightly (experimental), command-light, and command-light-nightly (experimental). Smaller, "light" models are faster, while larger models will perform better.
*
* Default: command
*/
String model;

/**
* One of low, medium, high, or auto, defaults to auto. Controls how close to the original text the summary is. high extractiveness summaries will lean towards reusing sentences verbatim, while low extractiveness summaries will tend to paraphrase more. If auto is selected, the best option will be picked based on the input text.
*
* Default: low
*/
String extractiveness;

/**
* Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1.
*/
Float temperature;

/**
* A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda"
*/
String additional_command;
SummarizeRequest(Builder builder) {
this.text = builder.text;
this.length = builder.length;
this.format = builder.format;
this.model = builder.model;
this.extractiveness = builder.extractiveness;
this.temperature = builder.temperature;
this.additional_command = builder.additional_command;
}

public static class Builder {

private String text;
private String length;
private String format;
private String model;
private String extractiveness;
private Float temperature;
private String additional_command;

public Builder withText(String text) {
this.text = text;
return this;
}

public Builder withLength(String length) {
this.length = length;
return this;
}

public Builder withFormat(String format) {
this.format = format;
return this;
}

public Builder withModel(String model) {
this.model = model;
return this;
}

public Builder withExtractiveness(String extractiveness) {
this.extractiveness = extractiveness;
return this;
}

public Builder withTemperature(Float temperature) {
this.temperature = temperature;
return this;
}

public Builder withAdditionalCommand(String additional_command) {
this.additional_command = additional_command;
return this;
}

public SummarizeRequest build() {
return new SummarizeRequest(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.github.llmjava.cohere4j.response;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class RerankResponse {
private String id;
private List<Result> results;
private Map<Integer, Result> resultByIndex = new HashMap<>();
private Meta meta;

private void init() {
if(resultByIndex.isEmpty() && !(results==null || results.isEmpty())) {
for(Result result: results) {
resultByIndex.put(result.index, result);
}
}
}

public Result getResultByIndex(int index) {
init();
return resultByIndex.get(index);
}

public Result getResultByRank(int rank) {
return results.get(rank);
}

public Float getScoreByIndex(int index) {
init();
Result result = resultByIndex.get(index);
return result.getRelevanceScore();
}

public static class Result {
private Integer index;
private Float relevance_score;

public Integer getIndex() {
return index;
}

public Float getRelevanceScore() {
return relevance_score;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.github.llmjava.cohere4j.response;

public class SummarizeResponse {
private String id;
private String summary;
private Meta meta;

public String getSummary() {
return summary;
}
}
47 changes: 47 additions & 0 deletions src/test/java/com/github/llmjava/cohere4j/RerankExample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.github.llmjava.cohere4j;

import com.github.llmjava.cohere4j.callback.AsyncCallback;
import com.github.llmjava.cohere4j.request.RerankRequest;
import com.github.llmjava.cohere4j.response.RerankResponse;
import org.apache.commons.configuration2.ex.ConfigurationException;

public class RerankExample {

public static void main(String[] args) throws ConfigurationException {
CohereConfig config = CohereConfig.fromProperties("cohere.properties");
CohereClient client = new CohereClient.Builder().withConfig(config).build();

RerankRequest request = new RerankRequest.Builder()
.withReturnDocuments(false)
.withMaxChunksPerDoc(10)
.withModel("rerank-english-v2.0")
.withQuery("What is the capital of the United States?")
.withDocument("Carson City is the capital city of the American state of Nevada.")
.withDocument("The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.")
.withDocument("Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.")
.withDocument("Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.")
.build();

System.out.println("--- Sync example");
RerankResponse response = client.rerank(request);
for(int i=0; i<4; i++) {
System.out.println("Score Document "+i+": " + response.getScoreByIndex(i));
}

client.rerankAsync(request, new AsyncCallback<RerankResponse>() {
@Override
public void onSuccess(RerankResponse response) {
System.out.println("--- Async example - onSuccess");
for(int i=0; i<4; i++) {
System.out.println("Score Document "+i+": " + response.getScoreByIndex(i));
}
}

@Override
public void onFailure(Throwable throwable) {
System.out.println("--- Async example - onFailure");
throwable.printStackTrace();
}
});
}
}
Loading