Skip to content

Commit a8dfef8

Browse files
committed
Anthropic - system messages fix + anthropic examples fixed
1 parent 980cc5c commit a8dfef8

File tree

5 files changed

+42
-18
lines changed

5 files changed

+42
-18
lines changed

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/service/impl/package.scala

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ package object impl extends AnthropicServiceConsts {
4242
messages: Seq[OpenAIBaseMessage],
4343
settings: CreateChatCompletionSettings
4444
): Seq[Message] = {
45+
assert(
46+
messages.forall(_.isSystem),
47+
"All messages must be system messages"
48+
)
49+
4550
val useSystemCache: Option[CacheControl] =
4651
if (settings.useAnthropicSystemMessagesCache) Some(Ephemeral) else None
4752

@@ -52,12 +57,15 @@ package object impl extends AnthropicServiceConsts {
5257
if (index == messages.size - 1)
5358
ContentBlockBase(TextBlock(content), Some(cacheControl))
5459
else ContentBlockBase(TextBlock(content), None)
60+
5561
case None => ContentBlockBase(TextBlock(content))
5662
}
5763
}
5864

59-
if (messageStrings.isEmpty) Seq.empty
60-
else Seq(SystemMessageContent(messageStrings))
65+
if (messageStrings.isEmpty)
66+
Seq.empty
67+
else
68+
Seq(SystemMessageContent(messageStrings))
6169
}
6270

6371
def toAnthropicMessages(
@@ -67,8 +75,10 @@ package object impl extends AnthropicServiceConsts {
6775

6876
val anthropicMessages: Seq[Message] = messages.collect {
6977
case OpenAIUserMessage(content, _) => Message.UserMessage(content)
78+
7079
case OpenAIUserSeqMessage(contents, _) =>
7180
Message.UserMessageContent(contents.map(toAnthropic))
81+
7282
case OpenAIAssistantMessage(content, _) => Message.AssistantMessage(content)
7383

7484
// legacy message type
@@ -82,27 +92,30 @@ package object impl extends AnthropicServiceConsts {
8292

8393
val anthropicMessagesWithCache: Seq[Message] = anthropicMessages
8494
.foldLeft((List.empty[Message], countUserMessagesToCache)) {
85-
case ((acc, userMessagesToCache), message) =>
95+
case ((acc, userMessagesToCacheCount), message) =>
8696
message match {
8797
case Message.UserMessage(contentString, _) =>
88-
val newCacheControl = if (userMessagesToCache > 0) Some(Ephemeral) else None
98+
val newCacheControl = if (userMessagesToCacheCount > 0) Some(Ephemeral) else None
8999
(
90100
acc :+ Message.UserMessage(contentString, newCacheControl),
91-
userMessagesToCache - newCacheControl.map(_ => 1).getOrElse(0)
101+
userMessagesToCacheCount - newCacheControl.map(_ => 1).getOrElse(0)
92102
)
103+
93104
case Message.UserMessageContent(contentBlocks) =>
94105
val (newContentBlocks, remainingCache) =
95-
contentBlocks.foldLeft((Seq.empty[ContentBlockBase], userMessagesToCache)) {
106+
contentBlocks.foldLeft((Seq.empty[ContentBlockBase], userMessagesToCacheCount)) {
96107
case ((acc, cacheLeft), content) =>
97108
val (block, newCacheLeft) =
98109
toAnthropic(cacheLeft)(content.asInstanceOf[OpenAIContent])
99110
(acc :+ block, newCacheLeft)
100111
}
101112
(acc :+ Message.UserMessageContent(newContentBlocks), remainingCache)
113+
102114
case assistant: Message.AssistantMessage =>
103-
(acc :+ assistant, userMessagesToCache)
115+
(acc :+ assistant, userMessagesToCacheCount)
116+
104117
case assistants: Message.AssistantMessageContent =>
105-
(acc :+ assistants, userMessagesToCache)
118+
(acc :+ assistants, userMessagesToCacheCount)
106119
}
107120
}
108121
._1

openai-core/src/main/scala/io/cequence/openaiscala/domain/BaseMessage.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package io.cequence.openaiscala.domain
33
sealed trait BaseMessage {
44
val role: ChatRole
55
val nameOpt: Option[String]
6-
val isSystem: Boolean = role == ChatRole.System
6+
def isSystem: Boolean = role == ChatRole.System
77
}
88

99
final case class SystemMessage(

openai-examples/src/main/scala/io/cequence/openaiscala/examples/nonopenai/AnthropicCreateChatCompletionCachedWithOpenAIAdapter.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
44
import io.cequence.openaiscala.domain.{NonOpenAIModelId, SystemMessage, UserMessage}
55
import io.cequence.openaiscala.examples.ExampleBase
66
import io.cequence.openaiscala.service.OpenAIChatCompletionService
7+
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettingsOps._
78

89
import scala.concurrent.Future
910

@@ -15,15 +16,17 @@ object AnthropicCreateChatCompletionCachedWithOpenAIAdapter
1516
ChatCompletionProvider.anthropic(withCache = true)
1617

1718
private val messages = Seq(
18-
SystemMessage("You are a helpful assistant."),
19+
SystemMessage("You are a helpful assistant who knows elfs personally."),
1920
UserMessage("What is the weather like in Norway?")
2021
)
2122

2223
override protected def run: Future[_] =
2324
service
2425
.createChatCompletion(
2526
messages = messages,
26-
settings = CreateChatCompletionSettings(NonOpenAIModelId.claude_3_5_sonnet_20241022)
27+
settings = CreateChatCompletionSettings(
28+
NonOpenAIModelId.claude_3_5_sonnet_20241022
29+
).setUseAnthropicSystemMessagesCache(true), // this is how we pass it through the adapter
2730
)
2831
.map { content =>
2932
println(content.choices.headOption.map(_.message.content).getOrElse("N/A"))

openai-examples/src/main/scala/io/cequence/openaiscala/examples/nonopenai/AnthropicCreateMessage.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package io.cequence.openaiscala.examples.nonopenai
33
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
44
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlockBase
55
import io.cequence.openaiscala.anthropic.domain.Message
6-
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
6+
import io.cequence.openaiscala.anthropic.domain.Message.{SystemMessage, UserMessage}
77
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse
88
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
99
import io.cequence.openaiscala.anthropic.service.{AnthropicService, AnthropicServiceFactory}
@@ -15,16 +15,19 @@ import scala.concurrent.Future
1515
// requires `openai-scala-anthropic-client` as a dependency and `ANTHROPIC_API_KEY` environment variable to be set
1616
object AnthropicCreateMessage extends ExampleBase[AnthropicService] {
1717

18-
override protected val service: AnthropicService = AnthropicServiceFactory(withCache = true)
18+
override protected val service: AnthropicService = AnthropicServiceFactory()
1919

20-
val messages: Seq[Message] = Seq(UserMessage("What is the weather like in Norway?"))
20+
val messages: Seq[Message] = Seq(
21+
SystemMessage("You are a helpful assistant who knows elfs personally."),
22+
UserMessage("What is the weather like in Norway?")
23+
)
2124

2225
override protected def run: Future[_] =
2326
service
2427
.createMessage(
2528
messages,
2629
settings = AnthropicCreateMessageSettings(
27-
model = NonOpenAIModelId.claude_3_haiku_20240307,
30+
model = NonOpenAIModelId.claude_3_5_haiku_20241022,
2831
max_tokens = 4096
2932
)
3033
)

openai-examples/src/main/scala/io/cequence/openaiscala/examples/nonopenai/AnthropicCreateMessageStreamed.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package io.cequence.openaiscala.examples.nonopenai
22

33
import akka.stream.scaladsl.Sink
44
import io.cequence.openaiscala.anthropic.domain.Message
5-
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
5+
import io.cequence.openaiscala.anthropic.domain.Message.{SystemMessage, UserMessage}
66
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
77
import io.cequence.openaiscala.anthropic.service.{AnthropicService, AnthropicServiceFactory}
88
import io.cequence.openaiscala.domain.NonOpenAIModelId
@@ -15,14 +15,19 @@ object AnthropicCreateMessageStreamed extends ExampleBase[AnthropicService] {
1515

1616
override protected val service: AnthropicService = AnthropicServiceFactory()
1717

18-
val messages: Seq[Message] = Seq(UserMessage("What is the weather like in Norway?"))
18+
val messages: Seq[Message] = Seq(
19+
SystemMessage("You are a helpful assistant who knows elfs personally."),
20+
UserMessage("What is the weather like in Norway?")
21+
)
22+
23+
private val modelId = NonOpenAIModelId.claude_3_5_haiku_20241022
1924

2025
override protected def run: Future[_] =
2126
service
2227
.createMessageStreamed(
2328
messages,
2429
settings = AnthropicCreateMessageSettings(
25-
model = NonOpenAIModelId.claude_3_haiku_20240307,
30+
model = modelId,
2631
max_tokens = 4096
2732
)
2833
)

0 commit comments

Comments
 (0)