Skip to content

Commit 2132697

Browse files
authored
support streaming generation (#2)
1 parent ea3700b commit 2132697

File tree

12 files changed

+303
-129
lines changed

12 files changed

+303
-129
lines changed

pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@
4949
<artifactId>converter-gson</artifactId>
5050
<version>${retrofit2.version}</version>
5151
</dependency>
52+
<dependency>
53+
<groupId>com.squareup.retrofit2</groupId>
54+
<artifactId>converter-scalars</artifactId>
55+
<version>${retrofit2.version}</version>
56+
</dependency>
5257

5358
<dependency>
5459
<groupId>org.junit.jupiter</groupId>
Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package com.github.llmjava.cohere4j;
22

3-
import com.github.llmjava.cohere4j.request.CompletionRequest;
4-
import com.github.llmjava.cohere4j.response.CompletionResponse;
5-
import com.github.llmjava.cohere4j.response.streaming.StreamingCompletionResponse;
3+
import com.github.llmjava.cohere4j.request.GenerationRequest;
4+
import com.github.llmjava.cohere4j.response.GenerationResponse;
65
import retrofit2.Call;
76
import retrofit2.http.Body;
87
import retrofit2.http.Headers;
@@ -13,13 +12,13 @@ public interface CohereApi {
1312

1413
@POST("/v1/generate")
1514
@Headers({"accept: application/json", "content-type: application/json"})
16-
Call<CompletionResponse>
17-
generate(@Body CompletionRequest request);
15+
Call<GenerationResponse>
16+
generate(@Body GenerationRequest request);
1817

1918
@Streaming
2019
@POST("/v1/generate")
2120
@Headers({"accept: application/stream+json", "content-type: application/json"})
22-
Call<CompletionResponse>
23-
generateStream(@Body CompletionRequest request);
21+
Call<String>
22+
generateStream(@Body GenerationRequest request);
2423

2524
}

src/main/java/com/github/llmjava/cohere4j/CohereApiFactory.java

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,41 @@
66
import okhttp3.OkHttpClient;
77
import retrofit2.Retrofit;
88
import retrofit2.converter.gson.GsonConverterFactory;
9+
import retrofit2.converter.scalars.ScalarsConverterFactory;
910

1011
import java.time.Duration;
1112

1213
public class CohereApiFactory {
1314

14-
public CohereApi build(CohereConfig config) {
15+
Gson gson;
16+
OkHttpClient okHttpClient;
17+
18+
public CohereApiFactory createHttpClient(CohereConfig config) {
1519
String apiKey = config.getApiKey();
1620
Duration timeout = config.getTimeout();
17-
CohereApi api = buildApi(apiKey, timeout);
18-
return api;
19-
}
20-
21-
CohereApi buildApi(String apiKey, Duration timeout) {
22-
OkHttpClient okHttpClient = new OkHttpClient.Builder()
21+
okHttpClient = new OkHttpClient.Builder()
2322
.addInterceptor(new AuthorizationInterceptor(apiKey))
2423
.callTimeout(timeout)
2524
.connectTimeout(timeout)
2625
.readTimeout(timeout)
2726
.writeTimeout(timeout)
2827
.build();
28+
return this;
29+
}
2930

30-
Gson gson = new GsonBuilder()
31+
CohereApiFactory createGson() {
32+
this.gson = new GsonBuilder()
3133
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
34+
.setLenient()
3235
.create();
36+
return this;
37+
}
3338

39+
CohereApi build() {
3440
Retrofit retrofit = new Retrofit.Builder()
3541
.baseUrl(CohereConfig.BASE_URL)
3642
.client(okHttpClient)
43+
.addConverterFactory(ScalarsConverterFactory.create())
3744
.addConverterFactory(GsonConverterFactory.create(gson))
3845
.build();
3946

src/main/java/com/github/llmjava/cohere4j/CohereClient.java

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import com.github.llmjava.cohere4j.callback.AsyncCallback;
44
import com.github.llmjava.cohere4j.callback.StreamingCallback;
5-
import com.github.llmjava.cohere4j.request.CompletionRequest;
6-
import com.github.llmjava.cohere4j.response.CompletionResponse;
7-
import com.github.llmjava.cohere4j.response.streaming.StreamingCompletionResponse;
5+
import com.github.llmjava.cohere4j.request.GenerationRequest;
6+
import com.github.llmjava.cohere4j.response.GenerationResponse;
7+
import com.github.llmjava.cohere4j.response.streaming.StreamingGenerationResponse;
8+
import com.github.llmjava.cohere4j.response.streaming.ResponseConverter;
9+
import com.google.gson.Gson;
810
import retrofit2.Call;
911
import retrofit2.Response;
1012

@@ -13,17 +15,19 @@
1315
public class CohereClient {
1416
private final CohereApi api;
1517
private final CohereConfig config;
18+
private final Gson gson;
1619

1720
CohereClient(Builder builder) {
1821
this.api = builder.api;
1922
this.config = builder.config;
23+
this.gson = builder.gson;
2024
}
2125

22-
public String generate(CompletionRequest request) {
26+
public GenerationResponse generate(GenerationRequest request) {
2327
try {
24-
Response<CompletionResponse> response = api.generate(request).execute();
28+
Response<GenerationResponse> response = api.generate(request).execute();
2529
if (response.isSuccessful()) {
26-
return response.body().getTexts().get(0);
30+
return response.body();
2731
} else {
2832
throw newException(response);
2933
}
@@ -32,40 +36,48 @@ public String generate(CompletionRequest request) {
3236
}
3337
}
3438

35-
public void generateAsync(CompletionRequest request, AsyncCallback<String> callback) {
36-
api.generate(request).enqueue(new retrofit2.Callback<CompletionResponse>() {
39+
public void generateAsync(GenerationRequest request, AsyncCallback<GenerationResponse> callback) {
40+
api.generate(request).enqueue(new retrofit2.Callback<GenerationResponse>() {
3741
@Override
38-
public void onResponse(Call<CompletionResponse> call, Response<CompletionResponse> response) {
42+
public void onResponse(Call<GenerationResponse> call, Response<GenerationResponse> response) {
3943
if (response.isSuccessful()) {
40-
callback.onSuccess(response.body().getTexts().get(0));
44+
callback.onSuccess(response.body());
4145
} else {
4246
callback.onFailure(newException(response));
4347
}
4448
}
4549

4650
@Override
47-
public void onFailure(Call<CompletionResponse> call, Throwable throwable) {
51+
public void onFailure(Call<GenerationResponse> call, Throwable throwable) {
4852
callback.onFailure(throwable);
4953
}
5054
});
5155
}
5256

53-
public void generateStream(CompletionRequest request, StreamingCallback<String> callback) {
54-
api.generateStream(request).enqueue(new retrofit2.Callback<CompletionResponse>() {
57+
public void generateStream(GenerationRequest request, StreamingCallback<StreamingGenerationResponse> callback) {
58+
if(!request.isStreaming()) {
59+
throw new IllegalArgumentException("Expected a streaming request");
60+
}
61+
ResponseConverter converter = new ResponseConverter(gson);
62+
api.generateStream(request).enqueue(new retrofit2.Callback<String>() {
5563
@Override
56-
public void onResponse(Call<CompletionResponse> call, Response<CompletionResponse> response) {
64+
public void onResponse(Call<String> call, Response<String> response) {
5765
if (response.isSuccessful()) {
58-
CompletionResponse resp = response.body();
59-
callback.onPart(resp.getTexts().get(0));
60-
callback.onComplete();
66+
for(StreamingGenerationResponse resp: converter.toStreamingGenerationResponse(response.body())) {
67+
if(resp.isFinished()) {
68+
callback.onComplete(resp);
69+
} else {
70+
callback.onPart(resp);
71+
}
72+
}
6173

6274
} else {
6375
callback.onFailure(newException(response));
6476
}
6577
}
6678

6779
@Override
68-
public void onFailure(Call<CompletionResponse> call, Throwable throwable) {
80+
public void onFailure(Call<String> call, Throwable throwable) {
6981
callback.onFailure(throwable);
7082
}
7183
});
@@ -90,10 +102,13 @@ private static RuntimeException newException(retrofit2.Response<?> response) {
90102
public static class Builder {
91103
private CohereApi api;
92104
private CohereConfig config;
105+
private Gson gson;
93106

94107
public Builder withConfig(CohereConfig config) {
95108
this.config = config;
96-
this.api = new CohereApiFactory().build(config);
109+
CohereApiFactory factory = new CohereApiFactory();
110+
this.api = factory.createGson().createHttpClient(config).build();
111+
this.gson = factory.gson;
97112
return this;
98113
}
99114

src/main/java/com/github/llmjava/cohere4j/callback/StreamingCallback.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
public interface StreamingCallback<S> {
44

55
void onPart(S response);
6-
void onComplete();
6+
void onComplete(S response);
77
void onFailure(Throwable throwable);
88
}

src/main/java/com/github/llmjava/cohere4j/request/CompletionRequest.java

Lines changed: 0 additions & 80 deletions
This file was deleted.

0 commit comments

Comments
 (0)