Skip to content

Commit

Permalink
[NOID] Fixes #4133: switch default model for cypher/schema interactio…
Browse files Browse the repository at this point in the history
…ns to gpt-4o (#4143)

Co-authored-by: gmarcostam <92850018+gmarcostam@users.noreply.github.com>
  • Loading branch information
vga91 and gmarcostam committed Dec 3, 2024
1 parent 9302bf4 commit 7c953b4
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 10 deletions.
12 changes: 6 additions & 6 deletions docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ This procedure `apoc.ml.openai.chat` takes a list of maps of chat exchanges betw

It uses the `/chat/create` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^].

Additional configuration is passed to the API, the default model used is `gpt-3.5-turbo`.
Additional configuration is passed to the API, the default model used is `gpt-4o`.

.Chat Completion Call
[source,cypher]
Expand All @@ -162,7 +162,7 @@ CALL apoc.ml.openai.chat([
.Chat Completion Response
----
{created=1684248203, id="chatcmpl-7GqBXZr94avd4fluYDi2fWEz7DIHL",
object="chat.completion", model="gpt-3.5-turbo-0301",
object="chat.completion", model="gpt-4o-0301",
usage={completion_tokens=2, prompt_tokens=26, total_tokens=28},
choices=[{finish_reason="stop", index=0, message={role="assistant", content="Earth."}}]}
----
Expand Down Expand Up @@ -269,7 +269,7 @@ RETURN m.title
| name | description | mandatory
| retries | The number of retries in case of API call failures | no, default `3`
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| model | The Open AI model | no, default `gpt-3.5-turbo`
| model | The Open AI model | no, default `gpt-4o`
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
|===

Expand Down Expand Up @@ -318,7 +318,7 @@ RETURN *
|===
| name | description | mandatory
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| model | The Open AI model | no, default `gpt-3.5-turbo`
| model | The Open AI model | no, default `gpt-4o`
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
|===

Expand Down Expand Up @@ -383,7 +383,7 @@ RETURN DISTINCT a.name
| name | description | mandatory
| count | The number of queries to retrieve | no, default `1`
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| model | The Open AI model | no, default `gpt-3.5-turbo`
| model | The Open AI model | no, default `gpt-4o`
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
|===

Expand Down Expand Up @@ -460,7 +460,7 @@ RETURN *
|===
| name | description | mandatory
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| model | The Open AI model | no, default `gpt-3.5-turbo`
| model | The Open AI model | no, default `gpt-4o`
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
|===

Expand Down
3 changes: 2 additions & 1 deletion full/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,12 @@ public Stream<MapResult> chatCompletion(
@Name("api_key") String apiKey,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration)
throws Exception {
String model = configuration.putIfAbsent("model", "gpt-4o");
return executeRequest(
apiKey,
configuration,
"chat/completions",
"gpt-3.5-turbo",
model,
"messages",
messages,
"$",
Expand Down
2 changes: 1 addition & 1 deletion full/src/main/java/apoc/ml/Prompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ private String prompt(
if (assistantPrompt != null && !assistantPrompt.isBlank())
prompt.add(Map.of("role", "assistant", "content", assistantPrompt));
String apiKey = (String) conf.get("apiKey");
String model = (String) conf.getOrDefault("model", "gpt-3.5-turbo");
String model = (String) conf.getOrDefault("model", "gpt-4o");
String result = OpenAI.executeRequest(
apiKey, Map.of(), "chat/completions", model, "messages", prompt, "$", apocConfig)
.map(v -> (Map<String, Object>) v)
Expand Down
4 changes: 2 additions & 2 deletions full/src/test/java/apoc/ml/OpenAIAzureIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public void completion() {
db,
"CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, $conf)",
getParams(OPENAI_CHAT_URL),
(row) -> assertCompletion(row, "gpt-35-turbo"));
(row) -> assertCompletion(row, "gpt-4o"));
}

@Test
Expand All @@ -75,7 +75,7 @@ public void chatCompletion() {
+ "{role:\"user\", content:\"What planet do humans live on?\"}\n"
+ "], $apiKey, $conf)",
getParams(OPENAI_COMPLETION_URL),
(row) -> assertChatCompletion(row, "gpt-35-turbo"));
(row) -> assertChatCompletion(row, "gpt-4o"));
}

private static Map<String, Object> getParams(String url) {
Expand Down
22 changes: 22 additions & 0 deletions full/src/test/java/apoc/ml/PromptIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,26 @@ public void testSchemaFromEmptyQueries() {
s -> assertThat(s).contains("doesn't have"));
});
}

@Test
public void testQueryGpt35Turbo() {
testResult(db, """
CALL apoc.ml.query($query, {model: 'gpt-3.5-turbo', retries: $retries, apiKey: $apiKey})
""",
Map.of(
"query", "What movies has Tom Hanks acted in?",
"retries", 2L,
"apiKey", OPENAI_KEY
),
(r) -> {
List<Map<String, Object>> list = r.stream().toList();
Assertions.assertThat(list).hasSize(12);
Assertions.assertThat(list.stream()
.map(m -> m.get("query"))
.filter(Objects::nonNull)
.map(Object::toString)
.map(String::trim))
.isNotEmpty();
});
}
}

0 comments on commit 7c953b4

Please sign in to comment.