22
33import com .github .llmjava .cohere4j .callback .AsyncCallback ;
44import com .github .llmjava .cohere4j .callback .StreamingCallback ;
5- import com .github .llmjava .cohere4j .request .CompletionRequest ;
6- import com .github .llmjava .cohere4j .response .CompletionResponse ;
7- import com .github .llmjava .cohere4j .response .streaming .StreamingCompletionResponse ;
5+ import com .github .llmjava .cohere4j .request .GenerationRequest ;
6+ import com .github .llmjava .cohere4j .response .GenerationResponse ;
7+ import com .github .llmjava .cohere4j .response .streaming .StreamingGenerationResponse ;
8+ import com .github .llmjava .cohere4j .response .streaming .ResponseConverter ;
9+ import com .google .gson .Gson ;
810import retrofit2 .Call ;
911import retrofit2 .Response ;
1012
1315public class CohereClient {
1416 private final CohereApi api ;
1517 private final CohereConfig config ;
18+ private final Gson gson ;
1619
1720 CohereClient (Builder builder ) {
1821 this .api = builder .api ;
1922 this .config = builder .config ;
23+ this .gson = builder .gson ;
2024 }
2125
22- public String generate (CompletionRequest request ) {
26+ public GenerationResponse generate (GenerationRequest request ) {
2327 try {
24- Response <CompletionResponse > response = api .generate (request ).execute ();
28+ Response <GenerationResponse > response = api .generate (request ).execute ();
2529 if (response .isSuccessful ()) {
26- return response .body (). getTexts (). get ( 0 ) ;
30+ return response .body ();
2731 } else {
2832 throw newException (response );
2933 }
@@ -32,40 +36,48 @@ public String generate(CompletionRequest request) {
3236 }
3337 }
3438
35- public void generateAsync (CompletionRequest request , AsyncCallback <String > callback ) {
36- api .generate (request ).enqueue (new retrofit2 .Callback <CompletionResponse >() {
39+ public void generateAsync (GenerationRequest request , AsyncCallback <GenerationResponse > callback ) {
40+ api .generate (request ).enqueue (new retrofit2 .Callback <GenerationResponse >() {
3741 @ Override
38- public void onResponse (Call <CompletionResponse > call , Response <CompletionResponse > response ) {
42+ public void onResponse (Call <GenerationResponse > call , Response <GenerationResponse > response ) {
3943 if (response .isSuccessful ()) {
40- callback .onSuccess (response .body (). getTexts (). get ( 0 ) );
44+ callback .onSuccess (response .body ());
4145 } else {
4246 callback .onFailure (newException (response ));
4347 }
4448 }
4549
4650 @ Override
47- public void onFailure (Call <CompletionResponse > call , Throwable throwable ) {
51+ public void onFailure (Call <GenerationResponse > call , Throwable throwable ) {
4852 callback .onFailure (throwable );
4953 }
5054 });
5155 }
5256
53- public void generateStream (CompletionRequest request , StreamingCallback <String > callback ) {
54- api .generateStream (request ).enqueue (new retrofit2 .Callback <CompletionResponse >() {
57+ public void generateStream (GenerationRequest request , StreamingCallback <StreamingGenerationResponse > callback ) {
58+ if (!request .isStreaming ()) {
59+ throw new IllegalArgumentException ("Expected a streaming request" );
60+ }
61+ ResponseConverter converter = new ResponseConverter (gson );
62+ api .generateStream (request ).enqueue (new retrofit2 .Callback <String >() {
5563 @ Override
56- public void onResponse (Call <CompletionResponse > call , Response <CompletionResponse > response ) {
64+ public void onResponse (Call <String > call , Response <String > response ) {
5765 if (response .isSuccessful ()) {
58- CompletionResponse resp = response .body ();
59- callback .onPart (resp .getTexts ().get (0 ));
60- callback .onComplete ();
66+ for (StreamingGenerationResponse resp : converter .toStreamingGenerationResponse (response .body ())) {
67+ if (resp .isFinished ()) {
68+ callback .onComplete (resp );
69+ } else {
70+ callback .onPart (resp );
71+ }
72+ }
6173
6274 } else {
6375 callback .onFailure (newException (response ));
6476 }
6577 }
6678
6779 @ Override
68- public void onFailure (Call <CompletionResponse > call , Throwable throwable ) {
80+ public void onFailure (Call <String > call , Throwable throwable ) {
6981 callback .onFailure (throwable );
7082 }
7183 });
@@ -90,10 +102,13 @@ private static RuntimeException newException(retrofit2.Response<?> response) {
90102 public static class Builder {
91103 private CohereApi api ;
92104 private CohereConfig config ;
105+ private Gson gson ;
93106
94107 public Builder withConfig (CohereConfig config ) {
95108 this .config = config ;
96- this .api = new CohereApiFactory ().build (config );
109+ CohereApiFactory factory = new CohereApiFactory ();
110+ this .api = factory .createGson ().createHttpClient (config ).build ();
111+ this .gson = factory .gson ;
97112 return this ;
98113 }
99114
0 commit comments