Skip to content

Commit fb2f54e

Browse files
committed
enable caching for Image content
1 parent 2d15ea9 commit fb2f54e

File tree

9 files changed

+177
-50
lines changed

9 files changed

+177
-50
lines changed

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/JsonFormats.scala

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package io.cequence.openaiscala.anthropic
22

3+
import io.cequence.openaiscala.anthropic.domain.CacheControl.Ephemeral
34
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.{ImageBlock, TextBlock}
45
import io.cequence.openaiscala.anthropic.domain.Content.{
56
ContentBlock,
7+
ContentBlockBase,
68
ContentBlocks,
79
SingleString
810
}
@@ -43,6 +45,8 @@ trait JsonFormats {
4345

4446
implicit lazy val textBlockFormat: Format[TextBlock] = Json.format[TextBlock]
4547

48+
// implicit lazy val contentBlockBaseFormat: Format[ContentBlockBase] =
49+
// Json.format[ContentBlockBase]
4650
implicit lazy val contentBlocksFormat: Format[ContentBlocks] = Json.format[ContentBlocks]
4751

4852
// implicit lazy val textBlockWrites: Writes[TextBlock] = Json.writes[TextBlock]
@@ -55,8 +59,25 @@ trait JsonFormats {
5559
implicit val config: JsonConfiguration = JsonConfiguration(SnakeCase)
5660
Json.writes[TextBlock]
5761
}
58-
implicit lazy val imageBlockWrites: Writes[ImageBlock] =
59-
(block: ImageBlock) =>
62+
// implicit lazy val imageBlockWrites: Writes[ImageBlock] =
63+
// (block: ImageBlock) =>
64+
// Json.obj(
65+
// "type" -> "image",
66+
// "source" -> Json.obj(
67+
// "type" -> block.`type`,
68+
// "media_type" -> block.mediaType,
69+
// "data" -> block.data
70+
// )
71+
// )
72+
73+
implicit lazy val contentBlockWrites: Writes[ContentBlockBase] = {
74+
case ContentBlockBase(tb: TextBlock, None) =>
75+
Json.obj("type" -> "text") ++ Json.toJson(tb)(textBlockWrites).as[JsObject]
76+
case ContentBlockBase(tb: TextBlock, Some(Ephemeral)) =>
77+
Json.obj("type" -> "text", "cache_control" -> "ephemeral") ++ Json
78+
.toJson(tb)(textBlockWrites)
79+
.as[JsObject]
80+
case ContentBlockBase(block: ImageBlock, None) =>
6081
Json.obj(
6182
"type" -> "image",
6283
"source" -> Json.obj(
@@ -65,21 +86,27 @@ trait JsonFormats {
6586
"data" -> block.data
6687
)
6788
)
68-
69-
implicit lazy val contentBlockWrites: Writes[ContentBlock] = {
70-
case tb: TextBlock =>
71-
Json.obj("type" -> "text") ++ Json.toJson(tb)(textBlockWrites).as[JsObject]
72-
case ib: ImageBlock => Json.toJson(ib)(imageBlockWrites)
89+
case ContentBlockBase(block: ImageBlock, Some(Ephemeral)) =>
90+
Json.obj(
91+
"type" -> "image",
92+
"cache_control" -> "ephemeral",
93+
"source" -> Json.obj(
94+
"type" -> block.`type`,
95+
"media_type" -> block.mediaType,
96+
"data" -> block.data
97+
)
98+
)
7399
}
74100

75-
implicit lazy val contentBlockReads: Reads[ContentBlock] =
101+
implicit lazy val contentBlockReads: Reads[ContentBlockBase] =
76102
(json: JsValue) => {
77103
(json \ "type").validate[String].flatMap {
78104
case "text" =>
79105
((json \ "text").validate[String] and
80106
(json \ "cache_control").validateOpt[CacheControl]).tupled.flatMap {
81-
case (text, cacheControl) => JsSuccess(TextBlock(text, cacheControl))
82-
case _ => JsError("Invalid text block")
107+
case (text, cacheControl) =>
108+
JsSuccess(ContentBlockBase(TextBlock(text), cacheControl))
109+
case _ => JsError("Invalid text block")
83110
}
84111

85112
case "image" =>
@@ -88,7 +115,8 @@ trait JsonFormats {
88115
`type` <- (source \ "type").validate[String]
89116
mediaType <- (source \ "media_type").validate[String]
90117
data <- (source \ "data").validate[String]
91-
} yield ImageBlock(`type`, mediaType, data)
118+
cacheControl <- (json \ "cache_control").validateOpt[CacheControl]
119+
} yield ContentBlockBase(ImageBlock(`type`, mediaType, data), cacheControl)
92120
case _ => JsError("Unsupported or invalid content block")
93121
}
94122
}
@@ -107,7 +135,7 @@ trait JsonFormats {
107135
implicit lazy val contentReads: Reads[Content] = new Reads[Content] {
108136
def reads(json: JsValue): JsResult[Content] = json match {
109137
case JsString(str) => JsSuccess(SingleString(str))
110-
case JsArray(_) => Json.fromJson[Seq[ContentBlock]](json).map(ContentBlocks(_))
138+
case JsArray(_) => Json.fromJson[Seq[ContentBlockBase]](json).map(ContentBlocks(_))
111139
case _ => JsError("Invalid content format")
112140
}
113141
}
@@ -144,7 +172,7 @@ trait JsonFormats {
144172
).tupled.flatMap {
145173
case ("user", JsString(str), cacheControl) => Reads.pure(UserMessage(str, cacheControl))
146174
case ("user", json @ JsArray(_), _) => {
147-
Json.fromJson[Seq[ContentBlock]](json) match {
175+
Json.fromJson[Seq[ContentBlockBase]](json) match {
148176
case JsSuccess(contentBlocks, _) =>
149177
Reads.pure(UserMessageContent(contentBlocks))
150178
case JsError(errors) =>
@@ -155,7 +183,7 @@ trait JsonFormats {
155183
Reads.pure(AssistantMessage(str, cacheControl))
156184

157185
case ("assistant", json @ JsArray(_), _) => {
158-
Json.fromJson[Seq[ContentBlock]](json) match {
186+
Json.fromJson[Seq[ContentBlockBase]](json) match {
159187
case JsSuccess(contentBlocks, _) =>
160188
Reads.pure(AssistantMessageContent(contentBlocks))
161189
case JsError(errors) =>
@@ -168,7 +196,7 @@ trait JsonFormats {
168196
implicit lazy val createMessageResponseReads: Reads[CreateMessageResponse] = (
169197
(__ \ "id").read[String] and
170198
(__ \ "role").read[ChatRole] and
171-
(__ \ "content").read[Seq[ContentBlock]].map(ContentBlocks(_)) and
199+
(__ \ "content").read[Seq[ContentBlockBase]].map(ContentBlocks(_)) and
172200
(__ \ "model").read[String] and
173201
(__ \ "stop_reason").readNullable[String] and
174202
(__ \ "stop_sequence").readNullable[String] and

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/domain/Content.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,24 @@ trait Cacheable {
1212
}
1313

1414
object Content {
15-
case class SingleString(text: String, override val cacheControl: Option[CacheControl] = None) extends Content
15+
case class SingleString(
16+
text: String,
17+
override val cacheControl: Option[CacheControl] = None
18+
) extends Content
1619
with Cacheable
1720

18-
case class ContentBlocks(blocks: Seq[ContentBlock]) extends Content
21+
case class ContentBlocks(blocks: Seq[ContentBlockBase]) extends Content
22+
23+
case class ContentBlockBase(
24+
content: ContentBlock,
25+
override val cacheControl: Option[CacheControl] = None
26+
) extends Content
27+
with Cacheable
1928

2029
sealed trait ContentBlock
2130

2231
object ContentBlock {
23-
case class TextBlock(text: String, override val cacheControl: Option[CacheControl] = None)
24-
extends ContentBlock
25-
with Cacheable
32+
case class TextBlock(text: String) extends ContentBlock
2633

2734
case class ImageBlock(
2835
`type`: String,

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/domain/Message.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package io.cequence.openaiscala.anthropic.domain
22

33
import io.cequence.openaiscala.anthropic.domain.Content.{
44
ContentBlock,
5+
ContentBlockBase,
56
ContentBlocks,
67
SingleString
78
}
@@ -18,14 +19,14 @@ object Message {
1819
cacheControl: Option[CacheControl] = None
1920
) extends Message(ChatRole.User, SingleString(contentString, cacheControl))
2021

21-
case class UserMessageContent(contentBlocks: Seq[ContentBlock])
22+
case class UserMessageContent(contentBlocks: Seq[ContentBlockBase])
2223
extends Message(ChatRole.User, ContentBlocks(contentBlocks))
2324

2425
case class AssistantMessage(
2526
contentString: String,
2627
cacheControl: Option[CacheControl] = None
2728
) extends Message(ChatRole.Assistant, SingleString(contentString, cacheControl))
2829

29-
case class AssistantMessageContent(contentBlocks: Seq[ContentBlock])
30+
case class AssistantMessageContent(contentBlocks: Seq[ContentBlockBase])
3031
extends Message(ChatRole.Assistant, ContentBlocks(contentBlocks))
3132
}

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/service/impl/OpenAIAnthropicChatCompletionService.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ private[service] class OpenAIAnthropicChatCompletionService(
4040
): Future[ChatCompletionResponse] = {
4141
underlying
4242
.createMessage(
43-
toAnthropic(messages),
43+
toAnthropicMessages(messages, settings),
4444
toAnthropic(settings, messages)
4545
)
4646
.map(toOpenAI)
@@ -64,7 +64,7 @@ private[service] class OpenAIAnthropicChatCompletionService(
6464
): Source[ChatCompletionChunkResponse, NotUsed] =
6565
underlying
6666
.createMessageStreamed(
67-
toAnthropic(messages),
67+
toAnthropicMessages(messages, settings),
6868
toAnthropic(settings, messages)
6969
)
7070
.map(toOpenAI)

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/service/impl/package.scala

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
package io.cequence.openaiscala.anthropic.service
22

3+
import io.cequence.openaiscala.anthropic.domain.CacheControl.Ephemeral
34
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
4-
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlocks
5+
import io.cequence.openaiscala.anthropic.domain.Content.{ContentBlockBase, ContentBlocks}
56
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse.UsageInfo
67
import io.cequence.openaiscala.anthropic.domain.response.{
78
ContentBlockDelta,
89
CreateMessageResponse
910
}
1011
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
11-
import io.cequence.openaiscala.anthropic.domain.{Content, Message}
12+
import io.cequence.openaiscala.anthropic.domain.{CacheControl, Content, Message}
1213
import io.cequence.openaiscala.domain.response.{
1314
ChatCompletionChoiceChunkInfo,
1415
ChatCompletionChoiceInfo,
@@ -35,29 +36,79 @@ import java.{util => ju}
3536

3637
package object impl extends AnthropicServiceConsts {
3738

38-
def toAnthropic(messages: Seq[OpenAIBaseMessage])
39-
: Seq[Message] = // send settings, cache_system, cache_user, // cache_tools_definition
39+
val AnthropicCacheControl = "cache_control"
40+
41+
def toAnthropicMessages(
42+
messages: Seq[OpenAIBaseMessage],
43+
settings: CreateChatCompletionSettings
44+
): Seq[Message] = {
45+
// send settings, cache_system, cache_user, // cache_tools_definition
4046
// TODO: handle other message types (e.g. assistant)
41-
messages.collect {
42-
case OpenAIUserMessage(content, _) => Message.UserMessage(content)
43-
case OpenAIUserSeqMessage(contents, _) =>
44-
Message.UserMessageContent(contents.map(toAnthropic))
45-
// legacy message type
46-
case MessageSpec(role, content, _) if role == ChatRole.User =>
47-
Message.UserMessage(content)
48-
}
47+
val useSystemCache: Option[CacheControl] =
48+
if (settings.useAnthropicSystemMessagesCache) Some(Ephemeral) else None
49+
val countUserMessagesToCache = settings.anthropicCachedUserMessagesCount
50+
51+
def onlyOnceCacheControl(cacheUsed: Boolean): Option[CacheControl] =
52+
if (cacheUsed) None else useSystemCache
53+
54+
// cacheSystemMessages
55+
// cacheUserMessages - number of user messages to cache (1-4) (1-3). 1
56+
57+
// (system, user) => 1x system, 3x user
58+
// (_, user) => 4x user
59+
// (system, _) => 1x system
60+
61+
// construct Anthropic messages out of OpenAI messages
62+
// the first N user messages are marked as cached, where N is equal to countUserMessagesToCache
63+
// if useSystemCache is true, the last system message is marked as cached
4964

50-
def toAnthropic(content: OpenAIContent): Content.ContentBlock = {
65+
// so I need to keep track, while foldLefting, of the number of user messages we are still able to cache
66+
67+
messages
68+
.foldLeft((List.empty[Message], countUserMessagesToCache): (List[Message], Int)) {
69+
case ((acc, userMessagesToCache), message) =>
70+
message match {
71+
case OpenAIUserMessage(content, _) =>
72+
val cacheControl = if (userMessagesToCache > 0) Some(Ephemeral) else None
73+
(
74+
acc :+ Message.UserMessage(content, cacheControl),
75+
userMessagesToCache - cacheControl.map(_ => 1).getOrElse(0)
76+
)
77+
case OpenAIUserSeqMessage(contents, _) => {
78+
val (contentBlocks, remainingCache) =
79+
contents.foldLeft((Seq.empty[ContentBlockBase], userMessagesToCache)) {
80+
case ((acc, cacheLeft), content) =>
81+
val (block, newCacheLeft) = toAnthropic(cacheLeft)(content)
82+
(acc :+ block, newCacheLeft)
83+
}
84+
(acc :+ Message.UserMessageContent(contentBlocks), remainingCache)
85+
}
86+
87+
}
88+
}
89+
._1
90+
91+
}
92+
93+
def toAnthropic(userMessagesToCache: Int)(content: OpenAIContent)
94+
: (Content.ContentBlockBase, Int) = {
95+
val cacheControl = if (userMessagesToCache > 0) Some(Ephemeral) else None
96+
val newCacheControlCount = userMessagesToCache - cacheControl.map(_ => 1).getOrElse(0)
5197
content match {
52-
case OpenAITextContent(text) => TextBlock(text)
98+
case OpenAITextContent(text) =>
99+
(ContentBlockBase(TextBlock(text), cacheControl), newCacheControlCount)
100+
53101
case OpenAIImageContent(url) =>
54102
if (url.startsWith("data:")) {
55103
val mediaTypeEncodingAndData = url.drop(5)
56104
val mediaType = mediaTypeEncodingAndData.takeWhile(_ != ';')
57105
val encodingAndData = mediaTypeEncodingAndData.drop(mediaType.length + 1)
58106
val encoding = mediaType.takeWhile(_ != ',')
59107
val data = encodingAndData.drop(encoding.length + 1)
60-
Content.ContentBlock.ImageBlock(encoding, mediaType, data)
108+
ContentBlockBase(
109+
Content.ContentBlock.ImageBlock(encoding, mediaType, data),
110+
cacheControl
111+
) -> newCacheControlCount
61112
} else {
62113
throw new IllegalArgumentException(
63114
"Image content only supported by providing image data directly."
@@ -123,7 +174,9 @@ package object impl extends AnthropicServiceConsts {
123174
)
124175

125176
def toOpenAIAssistantMessage(content: ContentBlocks): AssistantMessage = {
126-
val textContents = content.blocks.collect { case TextBlock(text, None) => text } // TODO
177+
val textContents = content.blocks.collect { case ContentBlockBase(TextBlock(text), _) =>
178+
text
179+
} // TODO
127180
// TODO: log if there is more than one text content
128181
if (textContents.isEmpty) {
129182
throw new IllegalArgumentException("No text content found in the response")

anthropic-client/src/test/scala/io/cequence/openaiscala/anthropic/JsonFormatsSpec.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package io.cequence.openaiscala.anthropic
33
import io.cequence.openaiscala.anthropic.JsonFormatsSpec.JsonPrintMode
44
import io.cequence.openaiscala.anthropic.JsonFormatsSpec.JsonPrintMode.{Compact, Pretty}
55
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.{ImageBlock, TextBlock}
6+
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlockBase
67
import io.cequence.openaiscala.anthropic.domain.Message
78
import io.cequence.openaiscala.anthropic.domain.Message.{
89
AssistantMessage,
@@ -33,7 +34,12 @@ class JsonFormatsSpec extends AnyWordSpecLike with Matchers with JsonFormats {
3334

3435
"serialize and deserialize a user message with text content blocks" in {
3536
val userMessage =
36-
UserMessageContent(Seq(TextBlock("Hello, world!"), TextBlock("How are you?")))
37+
UserMessageContent(
38+
Seq(
39+
ContentBlockBase(TextBlock("Hello, world!")),
40+
ContentBlockBase(TextBlock("How are you?"))
41+
)
42+
)
3743
val json =
3844
"""{"role":"user","content":[{"type":"text","text":"Hello, world!"},{"type":"text","text":"How are you?"}]}"""
3945
testCodec[Message](userMessage, json)
@@ -47,7 +53,12 @@ class JsonFormatsSpec extends AnyWordSpecLike with Matchers with JsonFormats {
4753

4854
"serialize and deserialize an assistant message with text content blocks" in {
4955
val assistantMessage =
50-
AssistantMessageContent(Seq(TextBlock("Hello, world!"), TextBlock("How are you?")))
56+
AssistantMessageContent(
57+
Seq(
58+
ContentBlockBase(TextBlock("Hello, world!")),
59+
ContentBlockBase(TextBlock("How are you?"))
60+
)
61+
)
5162
val json =
5263
"""{"role":"assistant","content":[{"type":"text","text":"Hello, world!"},{"type":"text","text":"How are you?"}]}"""
5364
testCodec[Message](assistantMessage, json)
@@ -68,7 +79,9 @@ class JsonFormatsSpec extends AnyWordSpecLike with Matchers with JsonFormats {
6879

6980
"serialize and deserialize a message with an image content" in {
7081
val userMessage =
71-
UserMessageContent(Seq(ImageBlock("base64", "image/jpeg", "/9j/4AAQSkZJRg...")))
82+
UserMessageContent(
83+
Seq(ContentBlockBase(ImageBlock("base64", "image/jpeg", "/9j/4AAQSkZJRg...")))
84+
)
7285
testCodec[Message](userMessage, expectedImageContentJson, Pretty)
7386
}
7487

0 commit comments

Comments
 (0)