Skip to content

Commit ea3700b

Browse files
authored
added generation api (#1)
1 parent 1fb43ee commit ea3700b

File tree

14 files changed

+471
-0
lines changed

14 files changed

+471
-0
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.github.llmjava.cohere4j;
2+
3+
4+
import okhttp3.Interceptor;
5+
import okhttp3.Request;
6+
import okhttp3.Response;
7+
8+
import java.io.IOException;
9+
10+
public class AuthorizationInterceptor implements Interceptor {
11+
12+
private final String apiKey;
13+
14+
AuthorizationInterceptor(String apiKey) {
15+
this.apiKey = apiKey;
16+
}
17+
18+
@Override
19+
public Response intercept(Chain chain) throws IOException {
20+
21+
Request request = chain.request()
22+
.newBuilder()
23+
.addHeader("Authorization", "Bearer " + apiKey)
24+
.build();
25+
26+
return chain.proceed(request);
27+
}
28+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package com.github.llmjava.cohere4j;
2+
3+
import com.github.llmjava.cohere4j.request.CompletionRequest;
4+
import com.github.llmjava.cohere4j.response.CompletionResponse;
5+
import com.github.llmjava.cohere4j.response.streaming.StreamingCompletionResponse;
6+
import retrofit2.Call;
7+
import retrofit2.http.Body;
8+
import retrofit2.http.Headers;
9+
import retrofit2.http.POST;
10+
import retrofit2.http.Streaming;
11+
12+
public interface CohereApi {
13+
14+
@POST("/v1/generate")
15+
@Headers({"accept: application/json", "content-type: application/json"})
16+
Call<CompletionResponse>
17+
generate(@Body CompletionRequest request);
18+
19+
@Streaming
20+
@POST("/v1/generate")
21+
@Headers({"accept: application/stream+json", "content-type: application/json"})
22+
Call<CompletionResponse>
23+
generateStream(@Body CompletionRequest request);
24+
25+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package com.github.llmjava.cohere4j;
2+
3+
import com.google.gson.FieldNamingPolicy;
4+
import com.google.gson.Gson;
5+
import com.google.gson.GsonBuilder;
6+
import okhttp3.OkHttpClient;
7+
import retrofit2.Retrofit;
8+
import retrofit2.converter.gson.GsonConverterFactory;
9+
10+
import java.time.Duration;
11+
12+
public class CohereApiFactory {
13+
14+
public CohereApi build(CohereConfig config) {
15+
String apiKey = config.getApiKey();
16+
Duration timeout = config.getTimeout();
17+
CohereApi api = buildApi(apiKey, timeout);
18+
return api;
19+
}
20+
21+
CohereApi buildApi(String apiKey, Duration timeout) {
22+
OkHttpClient okHttpClient = new OkHttpClient.Builder()
23+
.addInterceptor(new AuthorizationInterceptor(apiKey))
24+
.callTimeout(timeout)
25+
.connectTimeout(timeout)
26+
.readTimeout(timeout)
27+
.writeTimeout(timeout)
28+
.build();
29+
30+
Gson gson = new GsonBuilder()
31+
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
32+
.create();
33+
34+
Retrofit retrofit = new Retrofit.Builder()
35+
.baseUrl(CohereConfig.BASE_URL)
36+
.client(okHttpClient)
37+
.addConverterFactory(GsonConverterFactory.create(gson))
38+
.build();
39+
40+
return retrofit.create(CohereApi.class);
41+
}
42+
43+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package com.github.llmjava.cohere4j;
2+
3+
import com.github.llmjava.cohere4j.callback.AsyncCallback;
4+
import com.github.llmjava.cohere4j.callback.StreamingCallback;
5+
import com.github.llmjava.cohere4j.request.CompletionRequest;
6+
import com.github.llmjava.cohere4j.response.CompletionResponse;
7+
import com.github.llmjava.cohere4j.response.streaming.StreamingCompletionResponse;
8+
import retrofit2.Call;
9+
import retrofit2.Response;
10+
11+
import java.io.IOException;
12+
13+
public class CohereClient {
14+
private final CohereApi api;
15+
private final CohereConfig config;
16+
17+
CohereClient(Builder builder) {
18+
this.api = builder.api;
19+
this.config = builder.config;
20+
}
21+
22+
public String generate(CompletionRequest request) {
23+
try {
24+
Response<CompletionResponse> response = api.generate(request).execute();
25+
if (response.isSuccessful()) {
26+
return response.body().getTexts().get(0);
27+
} else {
28+
throw newException(response);
29+
}
30+
} catch (IOException e) {
31+
throw new RuntimeException(e);
32+
}
33+
}
34+
35+
public void generateAsync(CompletionRequest request, AsyncCallback<String> callback) {
36+
api.generate(request).enqueue(new retrofit2.Callback<CompletionResponse>() {
37+
@Override
38+
public void onResponse(Call<CompletionResponse> call, Response<CompletionResponse> response) {
39+
if (response.isSuccessful()) {
40+
callback.onSuccess(response.body().getTexts().get(0));
41+
} else {
42+
callback.onFailure(newException(response));
43+
}
44+
}
45+
46+
@Override
47+
public void onFailure(Call<CompletionResponse> call, Throwable throwable) {
48+
callback.onFailure(throwable);
49+
}
50+
});
51+
}
52+
53+
public void generateStream(CompletionRequest request, StreamingCallback<String> callback) {
54+
api.generateStream(request).enqueue(new retrofit2.Callback<CompletionResponse>() {
55+
@Override
56+
public void onResponse(Call<CompletionResponse> call, Response<CompletionResponse> response) {
57+
if (response.isSuccessful()) {
58+
CompletionResponse resp = response.body();
59+
callback.onPart(resp.getTexts().get(0));
60+
callback.onComplete();
61+
62+
} else {
63+
callback.onFailure(newException(response));
64+
}
65+
}
66+
67+
@Override
68+
public void onFailure(Call<CompletionResponse> call, Throwable throwable) {
69+
callback.onFailure(throwable);
70+
}
71+
});
72+
}
73+
74+
/**
75+
* Parse exceptions:
76+
* status code: 429; body: {"message":"You are using a Trial key, which is limited to 5 API calls / minute. You can continue to use the Trial key for free or upgrade to a Production key with higher rate limits at 'https://dashboard.cohere.ai/api-keys'. Contact us on 'https://discord.gg/XW44jPfYJu' or email us at support@cohere.com with any questions"}
77+
*/
78+
private static RuntimeException newException(retrofit2.Response<?> response) {
79+
try {
80+
int code = response.code();
81+
String body = response.errorBody().string();
82+
String errorMessage = String.format("status code: %s; body: %s", code, body);
83+
return new RuntimeException(errorMessage);
84+
} catch (IOException e) {
85+
return new RuntimeException(e);
86+
}
87+
}
88+
89+
90+
public static class Builder {
91+
private CohereApi api;
92+
private CohereConfig config;
93+
94+
public Builder withConfig(CohereConfig config) {
95+
this.config = config;
96+
this.api = new CohereApiFactory().build(config);
97+
return this;
98+
}
99+
100+
public CohereClient build() {
101+
return new CohereClient(this);
102+
}
103+
}
104+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package com.github.llmjava.cohere4j;
2+
3+
import org.apache.commons.configuration2.Configuration;
4+
import org.apache.commons.configuration2.builder.fluent.Configurations;
5+
import org.apache.commons.configuration2.ex.ConfigurationException;
6+
7+
import java.time.Duration;
8+
9+
public class CohereConfig {
10+
private final Configuration config;
11+
12+
public CohereConfig(Configuration config) {
13+
this.config = config;
14+
}
15+
16+
public static final String BASE_URL = "https://api.cohere.ai/";
17+
18+
public static final String API_KEY = "cohere.apiKey";
19+
20+
public static final String TIMEOUT = "timeout";
21+
22+
public static final Long DEFAULT_TIMEOUT_MILLIS = 10*1000l;
23+
24+
public String getApiKey() {
25+
return config.getString(API_KEY);
26+
}
27+
28+
public Duration getTimeout() {
29+
Long timeout = config.getLong(CohereConfig.TIMEOUT, CohereConfig.DEFAULT_TIMEOUT_MILLIS);
30+
return Duration.ofMillis(timeout);
31+
}
32+
33+
public static CohereConfig fromProperties(String path) throws ConfigurationException {
34+
Configuration baseConfig = new Configurations().properties(path);
35+
return new CohereConfig(baseConfig);
36+
}
37+
38+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package com.github.llmjava.cohere4j.callback;
2+
3+
public interface AsyncCallback<T> {
4+
5+
void onSuccess(T response);
6+
void onFailure(Throwable throwable);
7+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.github.llmjava.cohere4j.callback;
2+
3+
public interface StreamingCallback<S> {
4+
5+
void onPart(S response);
6+
void onComplete();
7+
void onFailure(Throwable throwable);
8+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package com.github.llmjava.cohere4j.request;
2+
3+
public class CompletionRequest {
4+
5+
/**
6+
* max_tokens
7+
* integer
8+
* The maximum number of tokens the model will generate as part of the response. Note: Setting a low value may result in incomplete generations.
9+
* Defaults to 20. See BPE Tokens for more details.
10+
*
11+
* Can only be set to 0 if return_likelihoods is set to ALL to get the likelihood of the prompt.
12+
*/
13+
private Integer max_tokens;
14+
15+
/**
16+
* One of NONE|START|END to specify how the API will handle inputs longer than the maximum token length.
17+
*
18+
* Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
19+
*
20+
* If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
21+
*
22+
* Default: END
23+
*/
24+
private String truncate;
25+
26+
/**
27+
* One of GENERATION|ALL|NONE to specify how and if the token likelihoods are returned with the response. Defaults to NONE.
28+
*
29+
* If GENERATION is selected, the token likelihoods will only be provided for generated text.
30+
*
31+
* If ALL is selected, the token likelihoods will be provided both for the prompt and the generated text.
32+
*
33+
* Default: NONE
34+
*/
35+
private String return_likelihoods;
36+
37+
/**
38+
* The input text that serves as the starting point for generating the response.
39+
* Note: The prompt will be pre-processed and modified before reaching the model.
40+
*/
41+
private String prompt;
42+
43+
CompletionRequest(Builder builder) {
44+
this.max_tokens = builder.max_tokens;
45+
this.truncate = builder.truncate;
46+
this.return_likelihoods = builder.return_likelihoods;
47+
this.prompt = builder.prompt;
48+
}
49+
50+
public static class Builder {
51+
private Integer max_tokens;
52+
private String truncate;
53+
private String return_likelihoods;
54+
private String prompt;
55+
56+
public Builder withMaxTokens(Integer maxTokens) {
57+
this.max_tokens = maxTokens;
58+
return this;
59+
}
60+
61+
public Builder withTruncate(String truncate) {
62+
this.truncate = truncate;
63+
return this;
64+
}
65+
66+
public Builder withLikelihoods(String likelihoods) {
67+
this.return_likelihoods = likelihoods;
68+
return this;
69+
}
70+
71+
public Builder withPrompt(String prompt) {
72+
this.prompt = prompt;
73+
return this;
74+
}
75+
76+
public CompletionRequest build() {
77+
return new CompletionRequest(this);
78+
}
79+
}
80+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package com.github.llmjava.cohere4j.response;
2+
3+
import java.util.ArrayList;
4+
import java.util.List;
5+
6+
public class CompletionResponse {
7+
private String id;
8+
private String prompt;
9+
private List<Generation> generations;
10+
11+
public String getPrompt() {
12+
return prompt;
13+
}
14+
15+
public List<String> getTexts() {
16+
List<String> texts = new ArrayList<>();
17+
for(Generation generation: generations) {
18+
texts.add(generation.getText());
19+
}
20+
return texts;
21+
}
22+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package com.github.llmjava.cohere4j.response;
2+
3+
import java.util.List;
4+
5+
public class Generation {
6+
private String id;
7+
private String text;
8+
private Integer index;
9+
private Double likelihood;
10+
private List<Likelihood> token_likelihoods;
11+
12+
public String getText() {
13+
return text;
14+
}
15+
}

0 commit comments

Comments
 (0)