Skip to content

Commit

Permalink
R2 (#41)
Browse files Browse the repository at this point in the history
* updated version

* ingore run.sh

* release attempt

* added stacktrace

* tried to fix build

* updated version

* fix build

* improved string escaping

* fixed encode

* patches
  • Loading branch information
RichardHightower authored Jul 29, 2023
1 parent 72e6c0c commit f8418a6
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 126 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ plugins {
}


def jarVersion = "1.0.3"
def jarVersion = "1.0.6"

group = 'com.cloudurable'
archivesBaseName = "jai"
Expand Down
120 changes: 47 additions & 73 deletions src/main/java/com/cloudurable/jai/OpenAIClient.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,24 @@ public static void writeObjectParameter(JsonSerializer jsonBodyBuilder, ObjectPa
}
jsonBodyBuilder.endArray();
}

/**
* {
* "type": "array",
* "items": {
* "type": "number"
* }
* }
*/
if (parameter instanceof ArrayParameter) {
jsonBodyBuilder.startNestedObjectAttribute("items");
ArrayParameter ap =(ArrayParameter) parameter;
jsonBodyBuilder.addAttribute("type", ap.getElementType().toString().toLowerCase());
if (ap.getType() == ParameterType.OBJECT) {
//writeObjectParameter(jsonBodyBuilder, ap.getObjectParam().get());
}
jsonBodyBuilder.endObject();
}
jsonBodyBuilder.endObject();
}

Expand Down
17 changes: 13 additions & 4 deletions src/main/java/com/cloudurable/jai/util/JsonSerializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ public void addAttribute(String name, String value) {
}
}

private static StringBuilder encodeString(String value) {
public static StringBuilder encodeString(String value) {
StringBuilder strBuilder = new StringBuilder(value.length());
char[] charArray = value.toCharArray();

Expand All @@ -226,13 +226,22 @@ private static StringBuilder encodeString(String value) {
strBuilder.append("\\\"");
break;
case '\n':
strBuilder.append("\\\\n");
strBuilder.append("\\n");
break;
case '\r':
strBuilder.append("\\\\r");
strBuilder.append("\\r");
break;
case '\t':
strBuilder.append("\\\t");
strBuilder.append("\\t");
break;
case '\b':
strBuilder.append("\\b");
break;
case '\\':
strBuilder.append("\\\\");
break;
case '/':
strBuilder.append("\\/");
break;
default:
strBuilder.append(ch);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.cloudurable.jai.examples;

import com.cloudurable.jai.OpenAIClient;
import com.cloudurable.jai.model.ClientResponse;
import com.cloudurable.jai.model.text.completion.chat.ChatRequest;
import com.cloudurable.jai.model.text.completion.chat.ChatResponse;
import com.cloudurable.jai.model.text.completion.chat.Message;
import com.cloudurable.jai.model.text.completion.chat.Role;

public class GenerateSequenceDiagram {

public static void main(String... args) throws Exception {
final var client = OpenAIClient.builder().setApiKey(System.getenv("OPENAI_API_KEY")).build();

final var chat = client.chat(ChatRequest.builder().addMessage(Message.builder()
.role(Role.SYSTEM).content("\n" +
"First, think about this and the steps to do a good job. As a senior developer, your task is to create documentation that utilizes Mermaid sequence diagrams to describe the functions of specific code blocks. Your target audience may not necessarily have strong technical skills, so the diagrams should be easily understandable, capturing all the essential points without being overly detailed.\n" +
"\n" +
"\n" +
"Your diagrams will benefit from the inclusion of any pertinent business rules or domain knowledge found in comments or log statements within the code. Incorporate these insights into your diagrams to improve readability and comprehension.\n" +
"\n" +
"Make sure that the diagrams clearly represent key concepts. If the code you're working with pertains to a specific domain, be sure to incorporate relevant language into the diagram. For instance, if the code involves transcribing an audio file using the OpenAI API, mention these specifics instead of using generic language like \"API call\". You can say print instead of System.out.println or println. You can say `Read Audio File` instead of `readBytes` when describing messages. Speak in the business domain language if possible or known. `read file(\"/Users/me/Documents/audio_notes_patents/meeting_notes1.m4a\")` should just be `Read Audio File`. Send `request to API` should be `Send Transcribe request to OpenAI`. Do not do this `User->>OpenAI: User question: \"Who won Main card fights in UFC 290? ...\"` but this `User->>OpenAI: User question`.\n" +
"\n" +
"For each code block, generate two versions of the Mermaid sequence diagram - one capturing the \"happy path\" or the expected sequence of operations, and the other highlighting potential error handling. \n" +
"\n" +
"\n" +
"Keep in mind the basics of the Mermaid grammar for sequence diagrams while creating the diagrams:\n" +
"\n" +
"* sequenceDiagram: participantDeclaration+ interaction+;\n" +
"* participantDeclaration: participant participantName as participantAlias;;\n" +
"* interaction: interactionArrow | note;;\n" +
"* interactionArrow: participantName? interactionArrowType participantName (interactionMessage)?;;\n" +
"* note: 'Note' ('over' participantName)? ':' noteContent;;\n" +
"* interactionMessage: '-->>' | '->>';\n" +
"* interactionArrowType: '--' | '-';;\n" +
"* participantName: [A-Za-z0-9_]+;\n" +
"* participantAlias: [A-Za-z0-9_]+;\n" +
"* noteContent: ~[\\r\\n]+;\n" +
"\n" +
"Don't use participant aliases. Don't use notes. \n" +
"\n" +
"Please ponder on the purpose of the interactionMessage and the \n" +
"different types of interactionMessage and the significance of the `:` in the message \n" +
"and what a good description of a interactionMessage looks like. Identiry participant/actors if we are making an HTTP call or using another class. \n" +
"\n" +
"Use the main method to create the entire sequence diagram if a main method exists. Otherwise do the steps for each method without further prompting.\n" +
"\n" +
"For each class file contents that I give you do the following use the instructions above then generate these five outputs:\n" +
"\n" +
"1. Start by describing what each method does in plain English. \n" +
"2. Describe how you will generate each step into the mermaid sequence diagram.\n" +
"3. Then after generate the mermaid happy path sequence diagram. Use plain English not tech jargon.\n" +
"4. Last, generate the mermaid exceptional path sequence diagram. \n" +
"\n" +
"You will ask me for a class source file. If you understand say, \n" +
"\"I am a mermaid sequence GOD!\". Then repeat back what you will do. \n" +
"Then ask for the first class file. After I give you a class, repeat the four things that you will do then do them. \n")
.build()).build());


}
}
10 changes: 6 additions & 4 deletions src/test/java/com/cloudurable/jai/examples/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class Main {
public static void main(final String... args) {
try {

chatWithFunctions();
// chatWithFunctions();


// listFiles();
Expand Down Expand Up @@ -75,7 +75,7 @@ public static void main(final String... args) {
// editImage();
// callCreateImage();
// callTranslate();
// callTranscribe();
callTranscribe();
// callEmbeddingAsyncExample();
// callEmbeddingExample();
// callEditAsyncExample();
Expand Down Expand Up @@ -388,9 +388,11 @@ private static void callTranslate() throws IOException {

private static void callTranscribe() throws IOException {
// Create the client
final OpenAIClient client = OpenAIClient.builder().setApiKey(System.getenv("OPEN_AI_KEY")).build();

File file = new File("test.m4a");
final var openAiKey = System.getenv("OPENAI_API_KEY");
final OpenAIClient client = OpenAIClient.builder().setApiKey(openAiKey).build();

File file = new File("/Users/richardhightower/Documents/audio_notes_patents/meeting_notes1.m4a");

byte[] bytes = Files.readAllBytes(file.toPath());
// Create the chat request
Expand Down
145 changes: 101 additions & 44 deletions src/test/java/com/cloudurable/jai/examples/WhoWonUFC290.java
Original file line number Diff line number Diff line change
Expand Up @@ -164,47 +164,31 @@ public static void main(String... args) {

// Searching news articles based on the queries
List<ArrayNode> results = queries.stream()
.map(WhoWonUFC290::searchNews).collect(Collectors.toList());
.map(SearchNewsService::searchNews).collect(Collectors.toList());

// Extracting relevant information from the articles
List<ObjectNode> articles = results.stream().map(arrayNode ->
arrayNode.map(on -> on.asCollection().asObject()))
.flatMap(List::stream).collect(Collectors.toList());

// Extracting article content and generating embeddings for each article
List<String> articleContent = articles.stream().map(article ->
String.format("%s %s %s", article.getString("title"),
article.getString("description"), article.getString("content").substring(0, 100)))
.collect(Collectors.toList());
List<float[]> articleEmbeddings = embeddings(articleContent);
List<String> articleContent = ArticleExtractor.extractArticlesFromJsonNodes(articles);
List<float[]> articleEmbeddings = EmbeddingsExtractor.extractEmbeddingsFromArticles(articleContent);

// Calculating cosine similarities between the hypothetical answer embedding and article embeddings
List<Float> cosineSimilarities = articleEmbeddings.stream()
.map(articleEmbedding -> dot(hypotheticalAnswerEmbedding, articleEmbedding))
.collect(Collectors.toList());

// Creating a set of scored articles based on cosine similarities
Set<ScoredArticle> articleSet = IntStream.range(0,
Math.min(cosineSimilarities.size(), articleContent.size()))
.mapToObj(i -> new ScoredArticle(articles.get(i), cosineSimilarities.get(i)))
.collect(Collectors.toSet());
//Add similarity scores to the articles
Set<ScoredArticle> articleSet = ArticleScoreCreator.scoreArticles(hypotheticalAnswerEmbedding, articles, articleContent, articleEmbeddings);

// Sorting the articles based on their scores
List<ScoredArticle> sortedArticles = new ArrayList<>(articleSet);
Collections.sort(sortedArticles, (o1, o2) -> Float.compare(o2.getScore(), o1.getScore()));
List<ScoredArticle> sortedArticles = sortArticles(articleSet);

// Printing the top 5 scored articles
sortedArticles.subList(0, 5).forEach(s -> System.out.println(s));
displayTopFiveArticles(sortedArticles);

// Formatting the top results as JSON strings
String formattedTopResults = String.join(",\n", sortedArticles.stream().map(sa -> sa.getContent())
.map(article -> String.format(Json.niceJson("{'title':'%s', 'url':'%s', 'description':'%s', 'content':'%s'}\n"),
article.getString("title"), article.getString("url"), article.getString("description"),
getArticleContent(article))).collect(Collectors.toList()).subList(0, 10));
String formattedTopResults = formatTopResultsToAddToContext(sortedArticles);

// Generating the final answer with the formatted top results
String finalAnswer = jsonGPT(ANSWER_INPUT.replace("{USER_QUESTION}", USER_QUESTION)
.replace("{formatted_top_results}", formattedTopResults));
String finalAnswer = generateFinalResponse(formattedTopResults);

System.out.println(finalAnswer);
long endTime = System.currentTimeMillis();
Expand All @@ -214,6 +198,30 @@ public static void main(String... args) {
}
}

private static void displayTopFiveArticles(List<ScoredArticle> sortedArticles) {
sortedArticles.subList(0, 5).forEach(s -> System.out.println(s));
}

private static String generateFinalResponse(String formattedTopResults) {
return jsonGPT(ANSWER_INPUT.replace("{USER_QUESTION}", USER_QUESTION)
.replace("{formatted_top_results}", formattedTopResults));
}

private static String formatTopResultsToAddToContext(List<ScoredArticle> sortedArticles) {
String formattedTopResults = String.join(",\n", sortedArticles.stream().map(sa -> sa.getContent())
.map(article -> String.format(Json.niceJson("{'title':'%s', 'url':'%s', 'description':'%s', 'content':'%s'}\n"),
article.getString("title"), article.getString("url"), article.getString("description"),
getArticleContent(article))).collect(Collectors.toList()).subList(0, 10));
return formattedTopResults;
}

public static List<ScoredArticle> sortArticles(Set<ScoredArticle> articleSet) {
List<ScoredArticle> sortedArticles = new ArrayList<>(articleSet);
Collections.sort(sortedArticles, (o1, o2) -> Float.compare(o2.getScore(), o1.getScore()));
return sortedArticles;
}


private static Object getArticleContent(ObjectNode article) {
String content = article.getString("content");
if (content.length() < 250) {
Expand Down Expand Up @@ -269,31 +277,80 @@ public static String dateStr(Instant instant) {
return localDate.format(formatter);
}

public static ArrayNode searchNews(final String query) {
final var end = Instant.now();
final var start = end.minus(java.time.Duration.ofDays(5));
return searchNews(query, start, end, 5);

public static class ArticleExtractor {

public static List<String> extractArticlesFromJsonNodes(List<ObjectNode> articles) {
List<String> articleContent = articles.stream().map(article ->
String.format("%s %s %s", article.getString("title"),
article.getString("description"), article.getString("content").substring(0, 100)))
.collect(Collectors.toList());
return articleContent;
}

}

public static class EmbeddingsExtractor {

public static List<float[]> extractEmbeddingsFromArticles(List<String> articleContent) {
List<float[]> articleEmbeddings = embeddings(articleContent);
return articleEmbeddings;
}

}

public static class ArticleScoreCreator {

public static Set<ScoredArticle> scoreArticles(float[] hypotheticalAnswerEmbedding, List<ObjectNode> articles, List<String> articleContent, List<float[]> articleEmbeddings) {

List<Float> cosineSimilarities = similarityScores(hypotheticalAnswerEmbedding, articleEmbeddings);

// Creating a set of scored articles based on cosine similarities
Set<ScoredArticle> articleSet = IntStream.range(0,
Math.min(cosineSimilarities.size(), articleContent.size()))
.mapToObj(i -> new ScoredArticle(articles.get(i), cosineSimilarities.get(i)))
.collect(Collectors.toSet());
return articleSet;
}

public static List<Float> similarityScores(float[] hypotheticalAnswerEmbedding, List<float[]> articleEmbeddings) {
// Calculating cosine similarities between the hypothetical answer embedding and article embeddings
List<Float> cosineSimilarities = articleEmbeddings.stream()
.map(articleEmbedding -> dot(hypotheticalAnswerEmbedding, articleEmbedding))
.collect(Collectors.toList());
return cosineSimilarities;
}
}

public static ArrayNode searchNews(final String query, final Instant start, final Instant end, final int pageSize) {
System.out.println(query);
try {

String url = "https://newsapi.org/v2/everything?q=" + URLEncoder.encode(query, StandardCharsets.UTF_8)
+ "&apiKey=" + System.getenv("NEWS_API_KEY") + "&language=en" + "&sortBy=relevancy"
+ "&from=" + dateStr(start) + "&to=" + dateStr(end) + "&pageSize=" + pageSize;
public static class SearchNewsService {

HttpClient httpClient = HttpClient.newHttpClient();
HttpResponse<String> response = httpClient.send(HttpRequest.newBuilder().uri(URI.create(url))
.GET().setHeader("Content-Type", "application/json").build(), HttpResponse.BodyHandlers.ofString());
public static ArrayNode searchNews(final String query) {
final var end = Instant.now();
final var start = end.minus(java.time.Duration.ofDays(5));
return searchNews(query, start, end, 5);
}

if (response.statusCode() >= 200 && response.statusCode() < 299) {
return JsonParserBuilder.builder().build().parse(response.body()).atPath("articles").asCollection().asArray();
} else {
throw new IllegalStateException(" status code " + response.statusCode() + " " + response.body());
public static ArrayNode searchNews(final String query, final Instant start, final Instant end, final int pageSize) {
System.out.println(query);
try {

String url = "https://newsapi.org/v2/everything?q=" + URLEncoder.encode(query, StandardCharsets.UTF_8)
+ "&apiKey=" + System.getenv("NEWS_API_KEY") + "&language=en" + "&sortBy=relevancy"
+ "&from=" + dateStr(start) + "&to=" + dateStr(end) + "&pageSize=" + pageSize;

HttpClient httpClient = HttpClient.newHttpClient();
HttpResponse<String> response = httpClient.send(HttpRequest.newBuilder().uri(URI.create(url))
.GET().setHeader("Content-Type", "application/json").build(), HttpResponse.BodyHandlers.ofString());

if (response.statusCode() >= 200 && response.statusCode() < 299) {
return JsonParserBuilder.builder().build().parse(response.body()).atPath("articles").asCollection().asArray();
} else {
throw new IllegalStateException(" status code " + response.statusCode() + " " + response.body());
}
} catch (Exception ex) {
throw new IllegalStateException(ex);
}
} catch (Exception ex) {
throw new IllegalStateException(ex);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,12 @@ void writeNestedList() {

}

@Test
void encode() {
StringBuilder builder = JsonSerializer.encodeString("Hi mom \n how are you \\ \t \b \r \"'good'\" ");
assertEquals("Hi mom \\n how are you \\\\ \\t \\b \\r \\\"'good'\\\" ", builder.toString());
System.out.println(builder);
}


}

0 comments on commit f8418a6

Please sign in to comment.