Skip to content

Commit 63f4bd9

Browse files
authored
Add support for summarize and rerank (#15)
1 parent 85cf7aa commit 63f4bd9

File tree

8 files changed

+383
-0
lines changed

8 files changed

+383
-0
lines changed

src/main/java/com/github/llmjava/cohere4j/CohereApi.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,24 @@ public interface CohereApi {
8686
@Headers({"accept: application/json", "content-type: application/json"})
8787
Call<DetectLanguageResponse>
8888
detectLanguage(@Body DetectLanguageRequest request);
89+
90+
/**
91+
* This endpoint generates a summary in English for a given text.
92+
* @param request Summarization request
93+
* @return Summarization response
94+
*/
95+
@POST("/v1/summarize")
96+
@Headers({"accept: application/json", "content-type: application/json"})
97+
Call<SummarizeResponse>
98+
summarize(@Body SummarizeRequest request);
99+
100+
/**
101+
* This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
102+
* @param request rerank request
103+
* @return rerank response
104+
*/
105+
@POST("/v1/rerank")
106+
@Headers({"accept: application/json", "content-type: application/json"})
107+
Call<RerankResponse>
108+
rerank(@Body RerankRequest request);
89109
}

src/main/java/com/github/llmjava/cohere4j/CohereClient.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,22 @@ public void detectLanguageAsync(DetectLanguageRequest request, AsyncCallback<Det
101101
execute(api.detectLanguage(request), callback);
102102
}
103103

104+
public SummarizeResponse summarize(SummarizeRequest request) {
105+
return execute(api.summarize(request));
106+
}
107+
108+
public void summarizeAsync(SummarizeRequest request, AsyncCallback<SummarizeResponse> callback) {
109+
execute(api.summarize(request), callback);
110+
}
111+
112+
public RerankResponse rerank(RerankRequest request) {
113+
return execute(api.rerank(request));
114+
}
115+
116+
public void rerankAsync(RerankRequest request, AsyncCallback<RerankResponse> callback) {
117+
execute(api.rerank(request), callback);
118+
}
119+
104120
private <T> T execute(Call<T> action) {
105121
try {
106122
Response<T> response = action.execute();
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package com.github.llmjava.cohere4j.request;
2+
3+
import java.util.ArrayList;
4+
import java.util.Collection;
5+
import java.util.List;
6+
7+
public class RerankRequest {
8+
/**
9+
* The identifier of the model to use, one of : rerank-english-v2.0, rerank-multilingual-v2.0
10+
*/
11+
private String model;
12+
/**
13+
* The search query.
14+
*/
15+
private String query;
16+
/**
17+
* A list of document objects or strings to rerank.
18+
* If a document is provided the text fields is required and all other fields will be preserved in the response.
19+
* The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000.
20+
*/
21+
private List<String> documents;
22+
/**
23+
* The number of most relevant documents or indices to return, defaults to the length of the documents
24+
*/
25+
private Integer top_n;
26+
/**
27+
* If false, returns results without the doc text - the api will return a list of {index, relevance score} where index is inferred from the list passed into the request.
28+
* If true, returns results with the doc text passed in - the api will return an ordered list of {index, text, relevance score} where index + text refers to the list passed into the request.
29+
*/
30+
private Boolean return_documents;
31+
/**
32+
* The maximum number of chunks to produce internally from a document
33+
*/
34+
private Integer max_chunks_per_doc;
35+
36+
RerankRequest(Builder builder) {
37+
this.model = builder.model;
38+
this.query = builder.query;
39+
this.documents = builder.documents;
40+
this.top_n = builder.top_n;
41+
this.return_documents = builder.return_documents;
42+
this.max_chunks_per_doc = builder.max_chunks_per_doc;
43+
}
44+
45+
public static class Builder {
46+
private String model;
47+
48+
private String query;
49+
private List<String> documents = new ArrayList<>();
50+
private Integer top_n;
51+
private Boolean return_documents;
52+
private Integer max_chunks_per_doc;
53+
54+
public Builder withModel(String model) {
55+
this.model = model;
56+
return this;
57+
}
58+
59+
public Builder withQuery(String query) {
60+
this.query = query;
61+
return this;
62+
}
63+
64+
public Builder withDocument(String document) {
65+
this.documents.add(document);
66+
return this;
67+
}
68+
69+
public Builder withTopN(Integer top_n) {
70+
this.top_n = top_n;
71+
return this;
72+
}
73+
74+
public Builder withReturnDocuments(Boolean return_documents) {
75+
this.return_documents = return_documents;
76+
return this;
77+
}
78+
79+
public Builder withMaxChunksPerDoc(Integer max_chunks_per_doc) {
80+
this.max_chunks_per_doc = max_chunks_per_doc;
81+
return this;
82+
}
83+
84+
public RerankRequest build() {
85+
return new RerankRequest(this);
86+
}
87+
}
88+
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package com.github.llmjava.cohere4j.request;
2+
3+
import java.util.ArrayList;
4+
import java.util.Collection;
5+
import java.util.List;
6+
7+
/**
8+
* This request generates a summary in English for a given text.
9+
*/
10+
public class SummarizeRequest {
11+
/**
12+
* The text to generate a summary for. Can be up to 100,000 characters long. Currently the only supported language is English.
13+
*/
14+
String text;
15+
16+
/**
17+
* One of short, medium, long, or auto defaults to auto. Indicates the approximate length of the summary. If auto is selected, the best option will be picked based on the input text.
18+
*
19+
* Default: medium
20+
*/
21+
String length;
22+
23+
/**
24+
* One of paragraph, bullets, or auto, defaults to auto. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If auto is selected, the best option will be picked based on the input text.
25+
*
26+
* Default: paragraph
27+
*/
28+
String format;
29+
30+
/**
31+
* The identifier of the model to generate the summary with. Currently available models are command (default), command-nightly (experimental), command-light, and command-light-nightly (experimental). Smaller, "light" models are faster, while larger models will perform better.
32+
*
33+
* Default: command
34+
*/
35+
String model;
36+
37+
/**
38+
* One of low, medium, high, or auto, defaults to auto. Controls how close to the original text the summary is. high extractiveness summaries will lean towards reusing sentences verbatim, while low extractiveness summaries will tend to paraphrase more. If auto is selected, the best option will be picked based on the input text.
39+
*
40+
* Default: low
41+
*/
42+
String extractiveness;
43+
44+
/**
45+
* Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1.
46+
*/
47+
Float temperature;
48+
49+
/**
50+
* A free-form instruction for modifying how the summaries get generated. Should complete the sentence "Generate a summary _". Eg. "focusing on the next steps" or "written by Yoda"
51+
*/
52+
String additional_command;
53+
SummarizeRequest(Builder builder) {
54+
this.text = builder.text;
55+
this.length = builder.length;
56+
this.format = builder.format;
57+
this.model = builder.model;
58+
this.extractiveness = builder.extractiveness;
59+
this.temperature = builder.temperature;
60+
this.additional_command = builder.additional_command;
61+
}
62+
63+
public static class Builder {
64+
65+
private String text;
66+
private String length;
67+
private String format;
68+
private String model;
69+
private String extractiveness;
70+
private Float temperature;
71+
private String additional_command;
72+
73+
public Builder withText(String text) {
74+
this.text = text;
75+
return this;
76+
}
77+
78+
public Builder withLength(String length) {
79+
this.length = length;
80+
return this;
81+
}
82+
83+
public Builder withFormat(String format) {
84+
this.format = format;
85+
return this;
86+
}
87+
88+
public Builder withModel(String model) {
89+
this.model = model;
90+
return this;
91+
}
92+
93+
public Builder withExtractiveness(String extractiveness) {
94+
this.extractiveness = extractiveness;
95+
return this;
96+
}
97+
98+
public Builder withTemperature(Float temperature) {
99+
this.temperature = temperature;
100+
return this;
101+
}
102+
103+
public Builder withAdditionalCommand(String additional_command) {
104+
this.additional_command = additional_command;
105+
return this;
106+
}
107+
108+
public SummarizeRequest build() {
109+
return new SummarizeRequest(this);
110+
}
111+
}
112+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package com.github.llmjava.cohere4j.response;
2+
3+
import java.util.HashMap;
4+
import java.util.List;
5+
import java.util.Map;
6+
7+
public class RerankResponse {
8+
private String id;
9+
private List<Result> results;
10+
private Map<Integer, Result> resultByIndex = new HashMap<>();
11+
private Meta meta;
12+
13+
private void init() {
14+
if(resultByIndex.isEmpty() && !(results==null || results.isEmpty())) {
15+
for(Result result: results) {
16+
resultByIndex.put(result.index, result);
17+
}
18+
}
19+
}
20+
21+
public Result getResultByIndex(int index) {
22+
init();
23+
return resultByIndex.get(index);
24+
}
25+
26+
public Result getResultByRank(int rank) {
27+
return results.get(rank);
28+
}
29+
30+
public Float getScoreByIndex(int index) {
31+
init();
32+
Result result = resultByIndex.get(index);
33+
return result.getRelevanceScore();
34+
}
35+
36+
public static class Result {
37+
private Integer index;
38+
private Float relevance_score;
39+
40+
public Integer getIndex() {
41+
return index;
42+
}
43+
44+
public Float getRelevanceScore() {
45+
return relevance_score;
46+
}
47+
}
48+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package com.github.llmjava.cohere4j.response;
2+
3+
public class SummarizeResponse {
4+
private String id;
5+
private String summary;
6+
private Meta meta;
7+
8+
public String getSummary() {
9+
return summary;
10+
}
11+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package com.github.llmjava.cohere4j;
2+
3+
import com.github.llmjava.cohere4j.callback.AsyncCallback;
4+
import com.github.llmjava.cohere4j.request.RerankRequest;
5+
import com.github.llmjava.cohere4j.response.RerankResponse;
6+
import org.apache.commons.configuration2.ex.ConfigurationException;
7+
8+
public class RerankExample {
9+
10+
public static void main(String[] args) throws ConfigurationException {
11+
CohereConfig config = CohereConfig.fromProperties("cohere.properties");
12+
CohereClient client = new CohereClient.Builder().withConfig(config).build();
13+
14+
RerankRequest request = new RerankRequest.Builder()
15+
.withReturnDocuments(false)
16+
.withMaxChunksPerDoc(10)
17+
.withModel("rerank-english-v2.0")
18+
.withQuery("What is the capital of the United States?")
19+
.withDocument("Carson City is the capital city of the American state of Nevada.")
20+
.withDocument("The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.")
21+
.withDocument("Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.")
22+
.withDocument("Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.")
23+
.build();
24+
25+
System.out.println("--- Sync example");
26+
RerankResponse response = client.rerank(request);
27+
for(int i=0; i<4; i++) {
28+
System.out.println("Score Document "+i+": " + response.getScoreByIndex(i));
29+
}
30+
31+
client.rerankAsync(request, new AsyncCallback<RerankResponse>() {
32+
@Override
33+
public void onSuccess(RerankResponse response) {
34+
System.out.println("--- Async example - onSuccess");
35+
for(int i=0; i<4; i++) {
36+
System.out.println("Score Document "+i+": " + response.getScoreByIndex(i));
37+
}
38+
}
39+
40+
@Override
41+
public void onFailure(Throwable throwable) {
42+
System.out.println("--- Async example - onFailure");
43+
throwable.printStackTrace();
44+
}
45+
});
46+
}
47+
}

0 commit comments

Comments
 (0)