From f8418a62f0729235064642a81fad8683eb25a3e5 Mon Sep 17 00:00:00 2001 From: Richard Hightower Date: Sat, 29 Jul 2023 16:02:31 -0500 Subject: [PATCH] R2 (#41) * updated version * ingore run.sh * release attempt * added stacktrace * tried to fix build * updated version * fix build * improved string escaping * fixed encode * patches --- build.gradle | 2 +- .../com/cloudurable/jai/OpenAIClient.java | 120 ++++++--------- .../chat/ChatRequestSerializer.java | 18 +++ .../cloudurable/jai/util/JsonSerializer.java | 17 +- .../jai/examples/GenerateSequenceDiagram.java | 62 ++++++++ .../com/cloudurable/jai/examples/Main.java | 10 +- .../jai/examples/WhoWonUFC290.java | 145 ++++++++++++------ .../jai/util/JsonSerializerTest.java | 7 + 8 files changed, 255 insertions(+), 126 deletions(-) create mode 100644 src/test/java/com/cloudurable/jai/examples/GenerateSequenceDiagram.java diff --git a/build.gradle b/build.gradle index 50571c1..49eafc9 100644 --- a/build.gradle +++ b/build.gradle @@ -9,7 +9,7 @@ plugins { } -def jarVersion = "1.0.3" +def jarVersion = "1.0.6" group = 'com.cloudurable' archivesBaseName = "jai" diff --git a/src/main/java/com/cloudurable/jai/OpenAIClient.java b/src/main/java/com/cloudurable/jai/OpenAIClient.java index 7e116a8..14cb487 100644 --- a/src/main/java/com/cloudurable/jai/OpenAIClient.java +++ b/src/main/java/com/cloudurable/jai/OpenAIClient.java @@ -33,6 +33,7 @@ import com.cloudurable.jai.model.text.embedding.EmbeddingResponse; import com.cloudurable.jai.util.MultipartEntityBuilder; import com.cloudurable.jai.util.RequestResponseUtils; +import io.nats.jparse.parser.JsonParserBuilder; import java.net.URI; import java.net.http.HttpClient; @@ -54,18 +55,21 @@ public class OpenAIClient implements Client, ClientAsync { private final SecretHolder apiKey; private final String apiEndpoint; private final HttpClient httpClient; + private final boolean validateJson; /** * Constructs an OpenAIClient object. * - * @param apiKey The API key for authentication with the OpenAI API. - * @param apiEndpoint The API endpoint URL for the OpenAI API. - * @param httpClient The HTTP client used for making API requests. + * @param apiKey The API key for authentication with the OpenAI API. + * @param apiEndpoint The API endpoint URL for the OpenAI API. + * @param httpClient The HTTP client used for making API requests. + * @param validateJson */ - public OpenAIClient(SecretHolder apiKey, String apiEndpoint, HttpClient httpClient) { + public OpenAIClient(SecretHolder apiKey, String apiEndpoint, HttpClient httpClient, boolean validateJson) { this.apiKey = apiKey; this.apiEndpoint = apiEndpoint; this.httpClient = httpClient; + this.validateJson = validateJson; } /** @@ -87,10 +91,7 @@ public static Builder builder() { @Override public CompletableFuture> chatAsync(final ChatRequest chatRequest) { - final String jsonRequest = ChatRequestSerializer.serialize(chatRequest); - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/chat/completions") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(ChatRequestSerializer.serialize(chatRequest), "/chat/completions"); return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()) .thenApply((Function, ClientResponse>) response -> @@ -99,6 +100,19 @@ public CompletableFuture> chatAsync(fi } + private HttpRequest buildGptRequest(String jsonRequestBody, String path) { + if (validateJson) { + try { + JsonParserBuilder.builder().build().parse(jsonRequestBody); + } catch (Exception ex) { + throw new IllegalArgumentException(jsonRequestBody, ex); + } + } + final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody(path) + .POST(HttpRequest.BodyPublishers.ofString(jsonRequestBody)); + return requestBuilder.build(); + } + /** * Sends a completion request to the OpenAI API and returns the client response. * @@ -108,10 +122,7 @@ public CompletableFuture> chatAsync(fi @Override public CompletableFuture> completionAsync( final CompletionRequest completionRequest) { - final String jsonRequest = CompletionRequestSerializer.serialize(completionRequest); - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/completions") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(CompletionRequestSerializer.serialize(completionRequest), "/completions"); return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()) .thenApply((Function, ClientResponse>) response -> @@ -122,11 +133,7 @@ public CompletableFuture> @Override public CompletableFuture> moderateAsync(CreateModerationRequest moderationRequest) { - final String jsonRequest = CreateModerationRequestSerializer.serialize(moderationRequest); - - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/moderations") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(CreateModerationRequestSerializer.serialize(moderationRequest), "/moderations"); try { return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()) .thenApply(response -> getCreateModerationResponse(moderationRequest, response)); @@ -140,11 +147,7 @@ public CompletableFuture moderate(CreateModerationRequest moderationRequest) { - final String jsonRequest = CreateModerationRequestSerializer.serialize(moderationRequest); - - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/moderations") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(CreateModerationRequestSerializer.serialize(moderationRequest), "/moderations"); try { final HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); return getCreateModerationResponse(moderationRequest, response); @@ -467,11 +470,7 @@ public FineTuneData cancelFineTune(String id) { @Override public CompletableFuture> createFineTuneAsync(CreateFineTuneRequest createFineTuneRequest) { - final String jsonRequest = CreateFineTuneRequestSerializer.serialize(createFineTuneRequest); - // Build and send the HTTP request - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/fine-tunes") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(CreateFineTuneRequestSerializer.serialize(createFineTuneRequest), "/fine-tunes"); try { return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()) .thenApply(response -> getCreateFineTuneResponse(createFineTuneRequest, response)); @@ -483,11 +482,7 @@ public CompletableFuture> cr @Override public ClientResponse createFineTune(CreateFineTuneRequest createFineTuneRequest) { - final String jsonRequest = CreateFineTuneRequestSerializer.serialize(createFineTuneRequest); - // Build and send the HTTP request - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/fine-tunes") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(CreateFineTuneRequestSerializer.serialize(createFineTuneRequest), "/fine-tunes"); try { final HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); return getCreateFineTuneResponse(createFineTuneRequest, response); @@ -551,11 +546,7 @@ private HttpRequest.Builder createRequestBuilder(String path) { @Override public ClientResponse chat(final ChatRequest chatRequest) { - final String jsonRequest = ChatRequestSerializer.serialize(chatRequest); - - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/chat/completions") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(ChatRequestSerializer.serialize(chatRequest), "/chat/completions"); try { final HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); return getChatResponse(chatRequest, response); @@ -591,11 +582,7 @@ private HttpRequest.Builder createRequestBuilderWithBody(final String path) { */ @Override public ClientResponse completion(final CompletionRequest completionRequest) { - final String jsonRequest = CompletionRequestSerializer.serialize(completionRequest); - // Build and send the HTTP request - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/completions") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(CompletionRequestSerializer.serialize(completionRequest), "/completions"); try { final HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); return getCompletionResponse(completionRequest, response); @@ -606,11 +593,7 @@ public ClientResponse completion(final Co @Override public ClientResponse edit(final EditRequest editRequest) { - final String jsonRequest = EditRequestSerializer.serialize(editRequest); - // Build and send the HTTP request - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/edits") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(EditRequestSerializer.serialize(editRequest), "/edits"); try { final HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); return getEditResponse(editRequest, response); @@ -621,11 +604,7 @@ public ClientResponse edit(final EditRequest editRequ @Override public ClientResponse embedding(EmbeddingRequest embeddingRequest) { - final String jsonRequest = EmbeddingRequestSerializer.serialize(embeddingRequest); - // Build and send the HTTP request - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/embeddings") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(EmbeddingRequestSerializer.serialize(embeddingRequest), "/embeddings"); try { final HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); return getEmbeddingResponse(embeddingRequest, response); @@ -721,11 +700,7 @@ public ClientResponse translate(TranslateReques @Override public ClientResponse createImage(CreateImageRequest imageRequest) { - final String jsonRequest = ImageRequestSerializer.buildJson(imageRequest); - - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/images/generations") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(ImageRequestSerializer.buildJson(imageRequest), "/images/generations"); try { final HttpResponse response = httpClient.send(request, HttpResponse.BodyHandlers.ofString()); return getCreateImageResponse(imageRequest, response); @@ -769,11 +744,7 @@ public ClientResponse createImageVar @Override public CompletableFuture> createImageAsync(CreateImageRequest imageRequest) { - final String jsonRequest = ImageRequestSerializer.buildJson(imageRequest); - - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/images/generations") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(ImageRequestSerializer.buildJson(imageRequest), "/images/generations"); return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()) .thenApply(response -> @@ -815,11 +786,7 @@ public CompletableFuture> embeddingAsync(final EmbeddingRequest embeddingRequest) { - final String jsonRequest = EmbeddingRequestSerializer.serialize(embeddingRequest); - // Build and send the HTTP request - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/embeddings") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(EmbeddingRequestSerializer.serialize(embeddingRequest), "/embeddings"); return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()) .thenApply((Function, ClientResponse>) response -> getEmbeddingResponse(embeddingRequest, response)).exceptionally(e -> @@ -829,11 +796,7 @@ public CompletableFuture> em @Override public CompletableFuture> editAsync(EditRequest editRequest) { - final String jsonRequest = EditRequestSerializer.serialize(editRequest); - // Build and send the HTTP request - final HttpRequest.Builder requestBuilder = createRequestBuilderWithJsonBody("/edits") - .POST(HttpRequest.BodyPublishers.ofString(jsonRequest)); - final HttpRequest request = requestBuilder.build(); + final HttpRequest request = buildGptRequest(EditRequestSerializer.serialize(editRequest), "/edits"); return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()) .thenApply((Function, ClientResponse>) response -> getEditResponse(editRequest, response)).exceptionally(e -> @@ -851,6 +814,8 @@ public static class Builder { private HttpClient.Builder httpClientBuilder; + private boolean validateJson; + private Builder() { } @@ -922,6 +887,11 @@ public Builder setApiEndpoint(String apiEndpoint) { return this; } + public Builder validateJson(boolean validateJson) { + this.validateJson = validateJson; + return this; + } + /** * Builds the OpenAIClient object. * @@ -930,9 +900,11 @@ public Builder setApiEndpoint(String apiEndpoint) { */ public OpenAIClient build() { validateParameters(); - return new OpenAIClient(apiKey, apiEndpoint, getHttpClient()); + return new OpenAIClient(apiKey, apiEndpoint, getHttpClient(), validateJson); } + + private void validateParameters() { if (apiKey == null || apiKey.isEmpty()) { throw new IllegalArgumentException("API key is required"); @@ -942,6 +914,8 @@ private void validateParameters() { } } + + } } diff --git a/src/main/java/com/cloudurable/jai/model/text/completion/chat/ChatRequestSerializer.java b/src/main/java/com/cloudurable/jai/model/text/completion/chat/ChatRequestSerializer.java index af8baaf..7e655f4 100644 --- a/src/main/java/com/cloudurable/jai/model/text/completion/chat/ChatRequestSerializer.java +++ b/src/main/java/com/cloudurable/jai/model/text/completion/chat/ChatRequestSerializer.java @@ -127,6 +127,24 @@ public static void writeObjectParameter(JsonSerializer jsonBodyBuilder, ObjectPa } jsonBodyBuilder.endArray(); } + + /** + * { + * "type": "array", + * "items": { + * "type": "number" + * } + * } + */ + if (parameter instanceof ArrayParameter) { + jsonBodyBuilder.startNestedObjectAttribute("items"); + ArrayParameter ap =(ArrayParameter) parameter; + jsonBodyBuilder.addAttribute("type", ap.getElementType().toString().toLowerCase()); + if (ap.getType() == ParameterType.OBJECT) { + //writeObjectParameter(jsonBodyBuilder, ap.getObjectParam().get()); + } + jsonBodyBuilder.endObject(); + } jsonBodyBuilder.endObject(); } diff --git a/src/main/java/com/cloudurable/jai/util/JsonSerializer.java b/src/main/java/com/cloudurable/jai/util/JsonSerializer.java index 155197d..fa5d407 100644 --- a/src/main/java/com/cloudurable/jai/util/JsonSerializer.java +++ b/src/main/java/com/cloudurable/jai/util/JsonSerializer.java @@ -214,7 +214,7 @@ public void addAttribute(String name, String value) { } } - private static StringBuilder encodeString(String value) { + public static StringBuilder encodeString(String value) { StringBuilder strBuilder = new StringBuilder(value.length()); char[] charArray = value.toCharArray(); @@ -226,13 +226,22 @@ private static StringBuilder encodeString(String value) { strBuilder.append("\\\""); break; case '\n': - strBuilder.append("\\\\n"); + strBuilder.append("\\n"); break; case '\r': - strBuilder.append("\\\\r"); + strBuilder.append("\\r"); break; case '\t': - strBuilder.append("\\\t"); + strBuilder.append("\\t"); + break; + case '\b': + strBuilder.append("\\b"); + break; + case '\\': + strBuilder.append("\\\\"); + break; + case '/': + strBuilder.append("\\/"); break; default: strBuilder.append(ch); diff --git a/src/test/java/com/cloudurable/jai/examples/GenerateSequenceDiagram.java b/src/test/java/com/cloudurable/jai/examples/GenerateSequenceDiagram.java new file mode 100644 index 0000000..84724c0 --- /dev/null +++ b/src/test/java/com/cloudurable/jai/examples/GenerateSequenceDiagram.java @@ -0,0 +1,62 @@ +package com.cloudurable.jai.examples; + +import com.cloudurable.jai.OpenAIClient; +import com.cloudurable.jai.model.ClientResponse; +import com.cloudurable.jai.model.text.completion.chat.ChatRequest; +import com.cloudurable.jai.model.text.completion.chat.ChatResponse; +import com.cloudurable.jai.model.text.completion.chat.Message; +import com.cloudurable.jai.model.text.completion.chat.Role; + +public class GenerateSequenceDiagram { + + public static void main(String... args) throws Exception { + final var client = OpenAIClient.builder().setApiKey(System.getenv("OPENAI_API_KEY")).build(); + + final var chat = client.chat(ChatRequest.builder().addMessage(Message.builder() + .role(Role.SYSTEM).content("\n" + + "First, think about this and the steps to do a good job. As a senior developer, your task is to create documentation that utilizes Mermaid sequence diagrams to describe the functions of specific code blocks. Your target audience may not necessarily have strong technical skills, so the diagrams should be easily understandable, capturing all the essential points without being overly detailed.\n" + + "\n" + + "\n" + + "Your diagrams will benefit from the inclusion of any pertinent business rules or domain knowledge found in comments or log statements within the code. Incorporate these insights into your diagrams to improve readability and comprehension.\n" + + "\n" + + "Make sure that the diagrams clearly represent key concepts. If the code you're working with pertains to a specific domain, be sure to incorporate relevant language into the diagram. For instance, if the code involves transcribing an audio file using the OpenAI API, mention these specifics instead of using generic language like \"API call\". You can say print instead of System.out.println or println. You can say `Read Audio File` instead of `readBytes` when describing messages. Speak in the business domain language if possible or known. `read file(\"/Users/me/Documents/audio_notes_patents/meeting_notes1.m4a\")` should just be `Read Audio File`. Send `request to API` should be `Send Transcribe request to OpenAI`. Do not do this `User->>OpenAI: User question: \"Who won Main card fights in UFC 290? ...\"` but this `User->>OpenAI: User question`.\n" + + "\n" + + "For each code block, generate two versions of the Mermaid sequence diagram - one capturing the \"happy path\" or the expected sequence of operations, and the other highlighting potential error handling. \n" + + "\n" + + "\n" + + "Keep in mind the basics of the Mermaid grammar for sequence diagrams while creating the diagrams:\n" + + "\n" + + "* sequenceDiagram: participantDeclaration+ interaction+;\n" + + "* participantDeclaration: participant participantName as participantAlias;;\n" + + "* interaction: interactionArrow | note;;\n" + + "* interactionArrow: participantName? interactionArrowType participantName (interactionMessage)?;;\n" + + "* note: 'Note' ('over' participantName)? ':' noteContent;;\n" + + "* interactionMessage: '-->>' | '->>';\n" + + "* interactionArrowType: '--' | '-';;\n" + + "* participantName: [A-Za-z0-9_]+;\n" + + "* participantAlias: [A-Za-z0-9_]+;\n" + + "* noteContent: ~[\\r\\n]+;\n" + + "\n" + + "Don't use participant aliases. Don't use notes. \n" + + "\n" + + "Please ponder on the purpose of the interactionMessage and the \n" + + "different types of interactionMessage and the significance of the `:` in the message \n" + + "and what a good description of a interactionMessage looks like. Identiry participant/actors if we are making an HTTP call or using another class. \n" + + "\n" + + "Use the main method to create the entire sequence diagram if a main method exists. Otherwise do the steps for each method without further prompting.\n" + + "\n" + + "For each class file contents that I give you do the following use the instructions above then generate these five outputs:\n" + + "\n" + + "1. Start by describing what each method does in plain English. \n" + + "2. Describe how you will generate each step into the mermaid sequence diagram.\n" + + "3. Then after generate the mermaid happy path sequence diagram. Use plain English not tech jargon.\n" + + "4. Last, generate the mermaid exceptional path sequence diagram. \n" + + "\n" + + "You will ask me for a class source file. If you understand say, \n" + + "\"I am a mermaid sequence GOD!\". Then repeat back what you will do. \n" + + "Then ask for the first class file. After I give you a class, repeat the four things that you will do then do them. \n") + .build()).build()); + + + } +} diff --git a/src/test/java/com/cloudurable/jai/examples/Main.java b/src/test/java/com/cloudurable/jai/examples/Main.java index 70c600d..78661d0 100644 --- a/src/test/java/com/cloudurable/jai/examples/Main.java +++ b/src/test/java/com/cloudurable/jai/examples/Main.java @@ -36,7 +36,7 @@ public class Main { public static void main(final String... args) { try { - chatWithFunctions(); +// chatWithFunctions(); // listFiles(); @@ -75,7 +75,7 @@ public static void main(final String... args) { // editImage(); // callCreateImage(); // callTranslate(); -// callTranscribe(); + callTranscribe(); // callEmbeddingAsyncExample(); // callEmbeddingExample(); // callEditAsyncExample(); @@ -388,9 +388,11 @@ private static void callTranslate() throws IOException { private static void callTranscribe() throws IOException { // Create the client - final OpenAIClient client = OpenAIClient.builder().setApiKey(System.getenv("OPEN_AI_KEY")).build(); - File file = new File("test.m4a"); + final var openAiKey = System.getenv("OPENAI_API_KEY"); + final OpenAIClient client = OpenAIClient.builder().setApiKey(openAiKey).build(); + + File file = new File("/Users/richardhightower/Documents/audio_notes_patents/meeting_notes1.m4a"); byte[] bytes = Files.readAllBytes(file.toPath()); // Create the chat request diff --git a/src/test/java/com/cloudurable/jai/examples/WhoWonUFC290.java b/src/test/java/com/cloudurable/jai/examples/WhoWonUFC290.java index 5fa139d..dcb4dfd 100644 --- a/src/test/java/com/cloudurable/jai/examples/WhoWonUFC290.java +++ b/src/test/java/com/cloudurable/jai/examples/WhoWonUFC290.java @@ -164,7 +164,7 @@ public static void main(String... args) { // Searching news articles based on the queries List results = queries.stream() - .map(WhoWonUFC290::searchNews).collect(Collectors.toList()); + .map(SearchNewsService::searchNews).collect(Collectors.toList()); // Extracting relevant information from the articles List articles = results.stream().map(arrayNode -> @@ -172,39 +172,23 @@ public static void main(String... args) { .flatMap(List::stream).collect(Collectors.toList()); // Extracting article content and generating embeddings for each article - List articleContent = articles.stream().map(article -> - String.format("%s %s %s", article.getString("title"), - article.getString("description"), article.getString("content").substring(0, 100))) - .collect(Collectors.toList()); - List articleEmbeddings = embeddings(articleContent); + List articleContent = ArticleExtractor.extractArticlesFromJsonNodes(articles); + List articleEmbeddings = EmbeddingsExtractor.extractEmbeddingsFromArticles(articleContent); - // Calculating cosine similarities between the hypothetical answer embedding and article embeddings - List cosineSimilarities = articleEmbeddings.stream() - .map(articleEmbedding -> dot(hypotheticalAnswerEmbedding, articleEmbedding)) - .collect(Collectors.toList()); - - // Creating a set of scored articles based on cosine similarities - Set articleSet = IntStream.range(0, - Math.min(cosineSimilarities.size(), articleContent.size())) - .mapToObj(i -> new ScoredArticle(articles.get(i), cosineSimilarities.get(i))) - .collect(Collectors.toSet()); + //Add similarity scores to the articles + Set articleSet = ArticleScoreCreator.scoreArticles(hypotheticalAnswerEmbedding, articles, articleContent, articleEmbeddings); // Sorting the articles based on their scores - List sortedArticles = new ArrayList<>(articleSet); - Collections.sort(sortedArticles, (o1, o2) -> Float.compare(o2.getScore(), o1.getScore())); + List sortedArticles = sortArticles(articleSet); // Printing the top 5 scored articles - sortedArticles.subList(0, 5).forEach(s -> System.out.println(s)); + displayTopFiveArticles(sortedArticles); // Formatting the top results as JSON strings - String formattedTopResults = String.join(",\n", sortedArticles.stream().map(sa -> sa.getContent()) - .map(article -> String.format(Json.niceJson("{'title':'%s', 'url':'%s', 'description':'%s', 'content':'%s'}\n"), - article.getString("title"), article.getString("url"), article.getString("description"), - getArticleContent(article))).collect(Collectors.toList()).subList(0, 10)); + String formattedTopResults = formatTopResultsToAddToContext(sortedArticles); // Generating the final answer with the formatted top results - String finalAnswer = jsonGPT(ANSWER_INPUT.replace("{USER_QUESTION}", USER_QUESTION) - .replace("{formatted_top_results}", formattedTopResults)); + String finalAnswer = generateFinalResponse(formattedTopResults); System.out.println(finalAnswer); long endTime = System.currentTimeMillis(); @@ -214,6 +198,30 @@ public static void main(String... args) { } } + private static void displayTopFiveArticles(List sortedArticles) { + sortedArticles.subList(0, 5).forEach(s -> System.out.println(s)); + } + + private static String generateFinalResponse(String formattedTopResults) { + return jsonGPT(ANSWER_INPUT.replace("{USER_QUESTION}", USER_QUESTION) + .replace("{formatted_top_results}", formattedTopResults)); + } + + private static String formatTopResultsToAddToContext(List sortedArticles) { + String formattedTopResults = String.join(",\n", sortedArticles.stream().map(sa -> sa.getContent()) + .map(article -> String.format(Json.niceJson("{'title':'%s', 'url':'%s', 'description':'%s', 'content':'%s'}\n"), + article.getString("title"), article.getString("url"), article.getString("description"), + getArticleContent(article))).collect(Collectors.toList()).subList(0, 10)); + return formattedTopResults; + } + + public static List sortArticles(Set articleSet) { + List sortedArticles = new ArrayList<>(articleSet); + Collections.sort(sortedArticles, (o1, o2) -> Float.compare(o2.getScore(), o1.getScore())); + return sortedArticles; + } + + private static Object getArticleContent(ObjectNode article) { String content = article.getString("content"); if (content.length() < 250) { @@ -269,31 +277,80 @@ public static String dateStr(Instant instant) { return localDate.format(formatter); } - public static ArrayNode searchNews(final String query) { - final var end = Instant.now(); - final var start = end.minus(java.time.Duration.ofDays(5)); - return searchNews(query, start, end, 5); + + public static class ArticleExtractor { + + public static List extractArticlesFromJsonNodes(List articles) { + List articleContent = articles.stream().map(article -> + String.format("%s %s %s", article.getString("title"), + article.getString("description"), article.getString("content").substring(0, 100))) + .collect(Collectors.toList()); + return articleContent; + } + + } + + public static class EmbeddingsExtractor { + + public static List extractEmbeddingsFromArticles(List articleContent) { + List articleEmbeddings = embeddings(articleContent); + return articleEmbeddings; + } + + } + + public static class ArticleScoreCreator { + + public static Set scoreArticles(float[] hypotheticalAnswerEmbedding, List articles, List articleContent, List articleEmbeddings) { + + List cosineSimilarities = similarityScores(hypotheticalAnswerEmbedding, articleEmbeddings); + + // Creating a set of scored articles based on cosine similarities + Set articleSet = IntStream.range(0, + Math.min(cosineSimilarities.size(), articleContent.size())) + .mapToObj(i -> new ScoredArticle(articles.get(i), cosineSimilarities.get(i))) + .collect(Collectors.toSet()); + return articleSet; + } + + public static List similarityScores(float[] hypotheticalAnswerEmbedding, List articleEmbeddings) { + // Calculating cosine similarities between the hypothetical answer embedding and article embeddings + List cosineSimilarities = articleEmbeddings.stream() + .map(articleEmbedding -> dot(hypotheticalAnswerEmbedding, articleEmbedding)) + .collect(Collectors.toList()); + return cosineSimilarities; + } } - public static ArrayNode searchNews(final String query, final Instant start, final Instant end, final int pageSize) { - System.out.println(query); - try { - String url = "https://newsapi.org/v2/everything?q=" + URLEncoder.encode(query, StandardCharsets.UTF_8) - + "&apiKey=" + System.getenv("NEWS_API_KEY") + "&language=en" + "&sortBy=relevancy" - + "&from=" + dateStr(start) + "&to=" + dateStr(end) + "&pageSize=" + pageSize; + public static class SearchNewsService { - HttpClient httpClient = HttpClient.newHttpClient(); - HttpResponse response = httpClient.send(HttpRequest.newBuilder().uri(URI.create(url)) - .GET().setHeader("Content-Type", "application/json").build(), HttpResponse.BodyHandlers.ofString()); + public static ArrayNode searchNews(final String query) { + final var end = Instant.now(); + final var start = end.minus(java.time.Duration.ofDays(5)); + return searchNews(query, start, end, 5); + } - if (response.statusCode() >= 200 && response.statusCode() < 299) { - return JsonParserBuilder.builder().build().parse(response.body()).atPath("articles").asCollection().asArray(); - } else { - throw new IllegalStateException(" status code " + response.statusCode() + " " + response.body()); + public static ArrayNode searchNews(final String query, final Instant start, final Instant end, final int pageSize) { + System.out.println(query); + try { + + String url = "https://newsapi.org/v2/everything?q=" + URLEncoder.encode(query, StandardCharsets.UTF_8) + + "&apiKey=" + System.getenv("NEWS_API_KEY") + "&language=en" + "&sortBy=relevancy" + + "&from=" + dateStr(start) + "&to=" + dateStr(end) + "&pageSize=" + pageSize; + + HttpClient httpClient = HttpClient.newHttpClient(); + HttpResponse response = httpClient.send(HttpRequest.newBuilder().uri(URI.create(url)) + .GET().setHeader("Content-Type", "application/json").build(), HttpResponse.BodyHandlers.ofString()); + + if (response.statusCode() >= 200 && response.statusCode() < 299) { + return JsonParserBuilder.builder().build().parse(response.body()).atPath("articles").asCollection().asArray(); + } else { + throw new IllegalStateException(" status code " + response.statusCode() + " " + response.body()); + } + } catch (Exception ex) { + throw new IllegalStateException(ex); } - } catch (Exception ex) { - throw new IllegalStateException(ex); } } diff --git a/src/test/java/com/cloudurable/jai/util/JsonSerializerTest.java b/src/test/java/com/cloudurable/jai/util/JsonSerializerTest.java index 71526d1..bc16656 100644 --- a/src/test/java/com/cloudurable/jai/util/JsonSerializerTest.java +++ b/src/test/java/com/cloudurable/jai/util/JsonSerializerTest.java @@ -113,5 +113,12 @@ void writeNestedList() { } + @Test + void encode() { + StringBuilder builder = JsonSerializer.encodeString("Hi mom \n how are you \\ \t \b \r \"'good'\" "); + assertEquals("Hi mom \\n how are you \\\\ \\t \\b \\r \\\"'good'\\\" ", builder.toString()); + System.out.println(builder); + } + }