-
Notifications
You must be signed in to change notification settings - Fork 812
/
OpenAiClientTest.java
483 lines (422 loc) · 18.3 KB
/
OpenAiClientTest.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
package com.unfbx.chatgpt;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.unfbx.chatgpt.entity.Tts.TextToSpeech;
import com.unfbx.chatgpt.entity.billing.BillingUsage;
import com.unfbx.chatgpt.entity.billing.CreditGrantsResponse;
import com.unfbx.chatgpt.entity.billing.Subscription;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.entity.completions.Completion;
import com.unfbx.chatgpt.entity.completions.CompletionResponse;
import com.unfbx.chatgpt.entity.edits.Edit;
import com.unfbx.chatgpt.entity.edits.EditResponse;
import com.unfbx.chatgpt.entity.embeddings.Embedding;
import com.unfbx.chatgpt.entity.embeddings.EmbeddingResponse;
import com.unfbx.chatgpt.entity.engines.Engine;
import com.unfbx.chatgpt.entity.files.File;
import com.unfbx.chatgpt.entity.common.DeleteResponse;
import com.unfbx.chatgpt.entity.files.UploadFileResponse;
import com.unfbx.chatgpt.entity.fineTune.Event;
import com.unfbx.chatgpt.entity.fineTune.FineTune;
import com.unfbx.chatgpt.entity.fineTune.FineTuneDeleteResponse;
import com.unfbx.chatgpt.entity.fineTune.FineTuneResponse;
import com.unfbx.chatgpt.entity.images.*;
import com.unfbx.chatgpt.entity.models.Model;
import com.unfbx.chatgpt.entity.moderations.Moderation;
import com.unfbx.chatgpt.entity.moderations.ModerationResponse;
import com.unfbx.chatgpt.entity.whisper.Transcriptions;
import com.unfbx.chatgpt.entity.whisper.Translations;
import com.unfbx.chatgpt.entity.whisper.Whisper;
import com.unfbx.chatgpt.entity.whisper.WhisperResponse;
import com.unfbx.chatgpt.interceptor.OpenAILogger;
import com.unfbx.chatgpt.interceptor.OpenAiResponseInterceptor;
import com.unfbx.chatgpt.utils.TikTokensUtil;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import org.junit.*;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.time.LocalDate;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
/**
* 描述: 测试类
*
* @author https:www.unfbx.com
* 2023-02-11
*/
@Slf4j
public class OpenAiClientTest {
private OpenAiClient v2;
@Before
public void before() {
//可以为null
// Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 7890));
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
//!!!!千万别再生产或者测试环境打开BODY级别日志!!!!
//!!!生产或者测试环境建议设置为这三种级别:NONE,BASIC,HEADERS,!!!
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
OkHttpClient okHttpClient = new OkHttpClient
.Builder()
// .proxy(proxy)
.addInterceptor(httpLoggingInterceptor)
.addInterceptor(new OpenAiResponseInterceptor())
.connectTimeout(10, TimeUnit.SECONDS)
.writeTimeout(30, TimeUnit.SECONDS)
.readTimeout(30, TimeUnit.SECONDS)
.build();
v2 = OpenAiClient.builder()
//支持多key传入,请求时候随机选择
.apiKey(Arrays.asList("************************"))
//自定义key的获取策略:默认KeyRandomStrategy
//.keyStrategy(new KeyRandomStrategy())
.keyStrategy(new FirstKeyStrategy())
.okHttpClient(okHttpClient)
//自己做了代理就传代理地址,没有可不不传,(关注公众号回复:openai ,获取免费的测试代理地址)
.apiHost("https://*********/")
.build();
}
@Test
public void subscription() {
Subscription subscription = v2.subscription();
log.info("用户名:{}", subscription.getAccountName());
log.info("用户总余额(美元):{}", subscription.getHardLimitUsd());
log.info("更多信息看Subscription类");
}
@Test
public void billingUsage() {
LocalDate start = LocalDate.of(2023, 3, 7);
BillingUsage billingUsage = v2.billingUsage(start, LocalDate.now());
log.info("总使用金额(美分):{}", billingUsage.getTotalUsage());
log.info("更多信息看BillingUsage类");
}
@Test
public void chatTokensTest() {
//聊天模型:gpt-3.5
List<Message> messages = new ArrayList<>(2);
messages.add(Message.builder().role(Message.Role.USER).content("关注微信公众号:程序员的黑洞。").build());
messages.add(Message.builder().role(Message.Role.USER).content("进入chatgpt-java交流群获取最新版本更新通知。").build());
ChatCompletion chatCompletion = ChatCompletion
.builder()
.messages(messages)
.maxTokens((4096 - TikTokensUtil.tokens(ChatCompletion.Model.GPT_3_5_TURBO.getName(),messages)))
.build();
ChatCompletionResponse chatCompletionResponse = v2.chatCompletion(chatCompletion);
//获取请求的tokens数量
long tokens = chatCompletion.tokens();
//这种方式也可以
// long tokens = TikTokensUtil.tokens(chatCompletion.getModel(),messages);
log.info("Message集合文本:【{}】", messages, tokens);
log.info("本地计算的请求的tokens数{}", tokens);
log.info("本地计算的返回的tokens数{}", TikTokensUtil.tokens(chatCompletion.getModel(),chatCompletionResponse.getChoices().get(0).getMessage().getContent()));
log.info("---------------------------------------------------");
log.info("Open AI 官方计算的总的tokens数{}", chatCompletionResponse.getUsage().getTotalTokens());
log.info("Open AI 官方计算的请求的tokens数{}", chatCompletionResponse.getUsage().getPromptTokens());
log.info("Open AI 官方计算的返回的tokens数{}", chatCompletionResponse.getUsage().getCompletionTokens());
}
@Test
public void testJson() throws JsonProcessingException {
ObjectMapper objectMapper = new ObjectMapper();
objectMapper
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.configure(SerializationFeature.INDENT_OUTPUT, true)
.setSerializationInclusion(JsonInclude.Include.NON_NULL)
.setTimeZone(TimeZone.getTimeZone("GMT+8"))
.setLocale(Locale.CHINA);
Completion completion = Completion.builder().prompt("你好啊").build();
String jsonStr = objectMapper.writeValueAsString(completion);
Completion completion1 = objectMapper.readValue(jsonStr, Completion.class);
}
@Test
public void creditGrants() {
CreditGrantsResponse creditGrantsResponse = v2.creditGrants();
log.info("账户总余额(美元):{}", creditGrantsResponse.getTotalGranted());
log.info("账户总使用金额(美元):{}", creditGrantsResponse.getTotalUsed());
log.info("账户总剩余金额(美元):{}", creditGrantsResponse.getTotalAvailable());
}
@Test
public void chat() {
//聊天模型:gpt-3.5
Message message = Message.builder().role(Message.Role.USER).content("你好啊我的伙伴!").build();
ChatCompletion chatCompletion = ChatCompletion
.builder()
.messages(Collections.singletonList(message))
.model(ChatCompletion.Model.GPT_3_5_TURBO.getName())
.build();
ChatCompletionResponse chatCompletionResponse = v2.chatCompletion(chatCompletion);
chatCompletionResponse.getChoices().forEach(e -> {
System.out.println(e.getMessage());
});
}
@Test
public void speechToTextTranscriptions() {
Transcriptions transcriptions = Transcriptions.builder()
.model(Whisper.Model.WHISPER_1.getName())
.prompt("提示语")
.language("zh")
.temperature(0.2)
.responseFormat(Whisper.ResponseFormat.VTT.getName())
.build();
//语音转文字
WhisperResponse whisperResponse =
v2.speechToTextTranscriptions(new java.io.File("C:\\Users\\grt\\Desktop\\1.m4a") , transcriptions);
System.out.println(whisperResponse.getText());
}
@Test
public void speechToTextTranscriptionsV2() {
//语音转文字
WhisperResponse whisperResponse =
v2.speechToTextTranscriptions(new java.io.File("C:\\Users\\grt\\Desktop\\1.m4a"));
System.out.println(whisperResponse.getText());
}
@Test
public void speechToTextTranslations() {
Translations translations = Translations.builder()
.model(Whisper.Model.WHISPER_1.getName())
// .prompt("提示语")
.temperature(0.2)
.responseFormat(Whisper.ResponseFormat.JSON.getName())
.build();
//语音转文字+翻译
WhisperResponse whisperResponse =
v2.speechToTextTranslations(new java.io.File("C:\\Users\\**\\Desktop\\1.m4a"), translations);
System.out.println(whisperResponse.getText());
}
@Test
public void speechToTextTranslationsV2() {
//语音转文字+翻译
WhisperResponse whisperResponse =
v2.speechToTextTranslations(new java.io.File("C:\\Users\\**\\Desktop\\1.m4a"));
System.out.println(whisperResponse.getText());
}
@Test
public void models() {
List<Model> models = v2.models();
models.forEach(e -> {
System.out.print(e.getOwnedBy() + " ");
System.out.print(e.getId() + " ");
System.out.println(e.getObject() + " ");
});
}
@Test
public void model() {
Model model = v2.model("code-davinci-002");
System.out.println(model.toString());
}
@Test
public void completions() {
// CompletionResponse completions = v2.completions("Java Stream list to map");
// Arrays.stream(completions.getChoices()).forEach(System.out::println);
CompletionResponse completions = v2.completions("我想申请转专业,从计算机专业转到会计学专业,帮我完成一份两百字左右的申请书");
(completions.getChoices()).forEach(System.out::println);
}
//对话测试
@Test
public void completionsV3() {
String question = "Human: 帮我把下面的文本翻译成英文;我爱你中国\n";
Completion q = Completion.builder()
.prompt(question)
.stop(Arrays.asList(" Human:", " Bot:"))
.echo(true)
.build();
CompletionResponse completions = v2.completions(q);
String text = completions.getChoices().get(0).getText();
q.setPrompt(text + "\n" + "再翻译成韩文\n");
completions = v2.completions(q);
text = completions.getChoices().get(0).getText();
q.setPrompt(text + "\n" + "再翻译成日文\n");
completions = v2.completions(q);
text = completions.getChoices().get(0).getText();
System.out.println(text);
}
@Test
public void completionsV2() {
Completion q = Completion.builder()
.prompt("三体人是什么?")
.n(2)
.bestOf(3)
.build();
CompletionResponse completions = v2.completions(q);
System.out.println(completions);
}
@Test
public void editText() {
//文本修改
// Edit edit = Edit.builder().input("我爱你麻麻").instruction("帮我修改错别字").model(Edit.Model.TEXT_DAVINCI_EDIT_001.getName()).build();
//代码修改 NB....
Edit edit = Edit.builder().input("System.out.pri(\"AAAAA\");").instruction("帮我修改这个java代码").model(Edit.Model.CODE_DAVINCI_EDIT_001.getName()).build();
EditResponse editResponse = v2.edit(edit);
System.out.println(editResponse);
}
@Test
public void genImages() {
Image image = Image.builder().prompt("电脑画面").responseFormat(ResponseFormat.B64_JSON.getName()).build();
ImageResponse imageResponse = v2.genImages(image);
System.out.println(imageResponse);
}
@Test
public void genImagesV2() {
ImageResponse imageResponse = v2.genImages("睡着的小朋友");
System.out.println(imageResponse);
}
/**
* Invalid input image - format must be in ['RGBA', 'LA', 'L'], got RGB.
*/
@Test
public void editImageV2() {
ImageEdit imageEdit = ImageEdit.builder().prompt("去除图片中的文字").build();
List<Item> images = v2.editImages(new java.io.File("C:\\Users\\FLJS188\\Desktop\\o.png"),
imageEdit);
System.out.println(images);
}
@Test
public void editImageV3() {
List<Item> images = v2.editImages(new java.io.File("C:\\Users\\***\\Desktop\\1.png"),
"去除图片中的文字");
System.out.println(images);
}
@Test
public void editImage() {
List<Item> images = v2.editImages(new java.io.File("C:\\Users\\***\\Desktop\\1.png"),
"去除图片中的文字");
System.out.println(images);
}
@Test
public void variationsImagesV2() {
ImageVariations imageVariations = ImageVariations.builder().build();
ImageResponse imageResponse = v2.variationsImages(new java.io.File("C:\\Users\\***\\Desktop\\12.png"), imageVariations);
System.out.println(imageResponse);
}
@Test
public void variationsImages() {
ImageResponse imageResponse = v2.variationsImages(new java.io.File("C:\\Users\\***\\Desktop\\12.png"));
System.out.println(imageResponse);
}
@Test
public void embeddingsV2() {
Embedding embedding = Embedding.builder().input(Arrays.asList("我爱你亲爱的姑娘", "i love you")).build();
EmbeddingResponse embeddings = v2.embeddings(embedding);
System.out.println(embeddings);
}
@Test
public void embeddingsV3() {
EmbeddingResponse embeddings = v2.embeddings(Arrays.asList("我爱你亲爱的姑娘", "i love you"));
System.out.println(embeddings);
}
@Test
public void embeddings() {
EmbeddingResponse embeddings = v2.embeddings("我爱你");
System.out.println(embeddings);
}
@Test
public void files() {
List<File> files = v2.files();
System.out.println(files);
}
@Test
public void retrieveFile() {
File files = v2.retrieveFile("file-EHB0Wp3wcZu6tpbwkB6xeiEd");
System.out.println(files);
}
/**
* 不支持免费用户: To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts.
* 暂时没有测试
*/
@Test
@Ignore
public void retrieveFileContent() {
// ResponseBody responseBody = v2.retrieveFileContent("file-EHB0Wp3wcZu6tpbwkB6xeiEd");
// System.out.println(responseBody);
}
@Test
public void uploadFileV1() {
UploadFileResponse uploadFileResponse = v2.uploadFile(new java.io.File("C:\\Users\\***\\Desktop\\2.txt"));
System.out.println(uploadFileResponse);
}
@Test
public void uploadFileV2() {
UploadFileResponse uploadFileResponse = v2.uploadFile("fine-tune", new java.io.File("C:\\Users\\***\\Desktop\\2.txt"));
System.out.println(uploadFileResponse);
}
@Test
public void deleteFile() {
DeleteResponse deleteResponse = v2.deleteFile("file-GreIoKq6lWHvq8PDwDZIGJjm");
System.out.println(deleteResponse);
}
@Test
public void moderations() {
ModerationResponse moderations = v2.moderations("I want to kill them.");
System.out.println(moderations);
}
@Test
public void moderationsv3() {
List<String> list = Collections.singletonList("I want to kill them.");
ModerationResponse moderations = v2.moderations(list);
System.out.println(moderations);
}
@Test
public void moderationsV2() {
Moderation moderation = Moderation.builder().input(Collections.singletonList("I want to kill them.")).build();
ModerationResponse moderations = v2.moderations(moderation);
System.out.println(moderations);
}
@Test
public void engines() {
List<Engine> engines = v2.engines();
System.out.println(engines);
}
@Test
public void engine() {
Engine engines = v2.engine("code-davinci-002");
System.out.println(engines);
}
@Test
public void fineTune() {
FineTuneResponse fineTuneResponse = v2.fineTune("file-EHB0Wp3wcZu6tpbwkB6xeiEd");
System.out.println(fineTuneResponse);
}
@Test
public void fineTuneV2() {
FineTune fineTune = FineTune.builder()
.trainingFile("file-OcQb9zg35cxa4WLBZJ9K2523")
.suffix("grttttttttt")
.model(FineTune.Model.ADA.getName())
.build();
FineTuneResponse fineTuneResponse = v2.fineTune(fineTune);
System.out.println(fineTuneResponse);
}
@Test
public void fineTunes() {
List<FineTuneResponse> fineTuneResponses = v2.fineTunes();
System.out.println(fineTuneResponses);
}
@Test
public void retrieveFineTune() {
FineTuneResponse fineTuneResponses = v2.retrieveFineTune("ft-bU0xJzVfrgOjqoy1e9lC2oDP");
System.out.println(fineTuneResponses);
}
@Test
public void cancelFineTune() {
//status发生变化 pending -> cancelled
FineTuneResponse fineTuneResponses = v2.cancelFineTune("ft-KohbEOCbPyNTyQmt5UV1F1cb");
System.out.println(fineTuneResponses);
}
@Test
public void fineTuneEvents() {
List<Event> events = v2.fineTuneEvents("ft-KohbEOCbPyNTyQmt5UV1F1cb");
System.out.println(events);
}
@Test
public void deleteFineTuneModel() {
FineTuneDeleteResponse deleteResponse = v2.deleteFineTuneModel("ada:ft-winter:grttttttttt-2023-02-17-01-29-27");
System.out.println(deleteResponse);
}
}