Skip to content

Commit 5de49e3

Browse files
authored
Merge pull request cequence-io#88 from cequence-io/feature/3654-prompt-caching
Feature/3654 prompt caching
2 parents 8b3a581 + 60d5507 commit 5de49e3

26 files changed

+704
-128
lines changed
Lines changed: 149 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package io.cequence.openaiscala.anthropic
22

3-
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.{ImageBlock, TextBlock}
3+
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.{MediaBlock, TextBlock}
44
import io.cequence.openaiscala.anthropic.domain.Content.{
5-
ContentBlock,
5+
ContentBlockBase,
66
ContentBlocks,
77
SingleString
88
}
@@ -19,9 +19,10 @@ import io.cequence.openaiscala.anthropic.domain.response.{
1919
CreateMessageResponse,
2020
DeltaText
2121
}
22-
import io.cequence.openaiscala.anthropic.domain.{ChatRole, Content, Message}
22+
import io.cequence.openaiscala.anthropic.domain.{CacheControl, ChatRole, Content, Message}
2323
import io.cequence.wsclient.JsonUtil
2424
import play.api.libs.functional.syntax._
25+
import play.api.libs.json.JsonNaming.SnakeCase
2526
import play.api.libs.json._
2627

2728
object JsonFormats extends JsonFormats
@@ -32,6 +33,84 @@ trait JsonFormats {
3233
JsonUtil.enumFormat[ChatRole](ChatRole.allValues: _*)
3334
implicit lazy val usageInfoFormat: Format[UsageInfo] = Json.format[UsageInfo]
3435

36+
def writeJsObject(cacheControl: CacheControl): JsObject = cacheControl match {
37+
case CacheControl.Ephemeral =>
38+
Json.obj("cache_control" -> Json.obj("type" -> "ephemeral"))
39+
}
40+
41+
implicit lazy val cacheControlFormat: Format[CacheControl] = new Format[CacheControl] {
42+
def reads(json: JsValue): JsResult[CacheControl] = json match {
43+
case JsObject(map) =>
44+
if (map == Map("type" -> JsString("ephemeral"))) JsSuccess(CacheControl.Ephemeral)
45+
else JsError(s"Invalid cache control $map")
46+
case x => {
47+
JsError(s"Invalid cache control ${x}")
48+
}
49+
}
50+
51+
def writes(cacheControl: CacheControl): JsValue = writeJsObject(cacheControl)
52+
}
53+
54+
implicit lazy val cacheControlOptionFormat: Format[Option[CacheControl]] =
55+
new Format[Option[CacheControl]] {
56+
def reads(json: JsValue): JsResult[Option[CacheControl]] = json match {
57+
case JsNull => JsSuccess(None)
58+
case _ => cacheControlFormat.reads(json).map(Some(_))
59+
}
60+
61+
def writes(option: Option[CacheControl]): JsValue = option match {
62+
case None => JsNull
63+
case Some(cacheControl) => cacheControlFormat.writes(cacheControl)
64+
}
65+
}
66+
67+
implicit lazy val contentBlockBaseWrites: Writes[ContentBlockBase] = {
68+
case ContentBlockBase(textBlock @ TextBlock(_), cacheControl) =>
69+
Json.obj("type" -> "text") ++
70+
Json.toJson(textBlock)(textBlockWrites).as[JsObject] ++
71+
cacheControlToJsObject(cacheControl)
72+
case ContentBlockBase(media @ MediaBlock(_, _, _, _), maybeCacheControl) =>
73+
Json.toJson(media)(mediaBlockWrites).as[JsObject] ++
74+
cacheControlToJsObject(maybeCacheControl)
75+
76+
}
77+
78+
implicit lazy val contentBlockBaseReads: Reads[ContentBlockBase] =
79+
(json: JsValue) => {
80+
(json \ "type").validate[String].flatMap {
81+
case "text" =>
82+
((json \ "text").validate[String] and
83+
(json \ "cache_control").validateOpt[CacheControl]).tupled.flatMap {
84+
case (text, cacheControl) =>
85+
JsSuccess(ContentBlockBase(TextBlock(text), cacheControl))
86+
case _ => JsError("Invalid text block")
87+
}
88+
89+
case imageOrDocument @ ("image" | "document") =>
90+
for {
91+
source <- (json \ "source").validate[JsObject]
92+
`type` <- (source \ "type").validate[String]
93+
mediaType <- (source \ "media_type").validate[String]
94+
data <- (source \ "data").validate[String]
95+
cacheControl <- (json \ "cache_control").validateOpt[CacheControl]
96+
} yield ContentBlockBase(
97+
MediaBlock(imageOrDocument, `type`, mediaType, data),
98+
cacheControl
99+
)
100+
101+
case _ => JsError("Unsupported or invalid content block")
102+
}
103+
}
104+
105+
implicit lazy val contentBlockBaseFormat: Format[ContentBlockBase] = Format(
106+
contentBlockBaseReads,
107+
contentBlockBaseWrites
108+
)
109+
implicit lazy val contentBlockBaseSeqFormat: Format[Seq[ContentBlockBase]] = Format(
110+
Reads.seq(contentBlockBaseReads),
111+
Writes.seq(contentBlockBaseWrites)
112+
)
113+
35114
implicit lazy val userMessageFormat: Format[UserMessage] = Json.format[UserMessage]
36115
implicit lazy val userMessageContentFormat: Format[UserMessageContent] =
37116
Json.format[UserMessageContent]
@@ -44,92 +123,114 @@ trait JsonFormats {
44123

45124
implicit lazy val contentBlocksFormat: Format[ContentBlocks] = Json.format[ContentBlocks]
46125

47-
// implicit val textBlockWrites: Writes[TextBlock] = Json.writes[TextBlock]
48-
implicit val textBlockReads: Reads[TextBlock] = Json.reads[TextBlock]
126+
implicit lazy val textBlockReads: Reads[TextBlock] = {
127+
implicit val config: JsonConfiguration = JsonConfiguration(SnakeCase)
128+
Json.reads[TextBlock]
129+
}
130+
131+
implicit lazy val textBlockWrites: Writes[TextBlock] = {
132+
implicit val config: JsonConfiguration = JsonConfiguration(SnakeCase)
133+
Json.writes[TextBlock]
134+
}
49135

50-
implicit val textBlockWrites: Writes[TextBlock] = Json.writes[TextBlock]
51-
implicit val imageBlockWrites: Writes[ImageBlock] =
52-
(block: ImageBlock) =>
136+
implicit lazy val mediaBlockWrites: Writes[MediaBlock] =
137+
(block: MediaBlock) =>
53138
Json.obj(
54-
"type" -> "image",
139+
"type" -> block.`type`,
55140
"source" -> Json.obj(
56-
"type" -> block.`type`,
141+
"type" -> block.encoding,
57142
"media_type" -> block.mediaType,
58143
"data" -> block.data
59144
)
60145
)
61146

62-
implicit val contentBlockWrites: Writes[ContentBlock] = {
63-
case tb: TextBlock =>
64-
Json.obj("type" -> "text") ++ Json.toJson(tb)(textBlockWrites).as[JsObject]
65-
case ib: ImageBlock => Json.toJson(ib)(imageBlockWrites)
66-
}
67-
68-
implicit val contentBlockReads: Reads[ContentBlock] =
69-
(json: JsValue) => {
70-
(json \ "type").validate[String].flatMap {
71-
case "text" => (json \ "text").validate[String].map(TextBlock.apply)
72-
case "image" =>
73-
for {
74-
source <- (json \ "source").validate[JsObject]
75-
`type` <- (source \ "type").validate[String]
76-
mediaType <- (source \ "media_type").validate[String]
77-
data <- (source \ "data").validate[String]
78-
} yield ImageBlock(`type`, mediaType, data)
79-
case _ => JsError("Unsupported or invalid content block")
80-
}
81-
}
147+
private def cacheControlToJsObject(maybeCacheControl: Option[CacheControl]): JsObject =
148+
maybeCacheControl.fold(Json.obj())(cc => writeJsObject(cc))
82149

83-
implicit val contentReads: Reads[Content] = new Reads[Content] {
150+
implicit lazy val contentReads: Reads[Content] = new Reads[Content] {
84151
def reads(json: JsValue): JsResult[Content] = json match {
85152
case JsString(str) => JsSuccess(SingleString(str))
86-
case JsArray(_) => Json.fromJson[Seq[ContentBlock]](json).map(ContentBlocks(_))
153+
case JsArray(_) => Json.fromJson[Seq[ContentBlockBase]](json).map(ContentBlocks(_))
87154
case _ => JsError("Invalid content format")
88155
}
89156
}
90157

91-
implicit val baseMessageWrites: Writes[Message] = new Writes[Message] {
158+
implicit lazy val contentWrites: Writes[Content] = new Writes[Content] {
159+
def writes(content: Content): JsValue = content match {
160+
case SingleString(text, cacheControl) =>
161+
Json.obj("content" -> text) ++ cacheControlToJsObject(cacheControl)
162+
case ContentBlocks(blocks) =>
163+
Json.obj("content" -> Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites)))
164+
}
165+
}
166+
167+
implicit lazy val baseMessageWrites: Writes[Message] = new Writes[Message] {
92168
def writes(message: Message): JsValue = message match {
93-
case UserMessage(content) => Json.obj("role" -> "user", "content" -> content)
169+
case UserMessage(content, cacheControl) =>
170+
val baseObj = Json.obj("role" -> "user", "content" -> content)
171+
baseObj ++ cacheControlToJsObject(cacheControl)
172+
94173
case UserMessageContent(content) =>
95174
Json.obj(
96175
"role" -> "user",
97-
"content" -> content.map(Json.toJson(_)(contentBlockWrites))
176+
"content" -> content.map(Json.toJson(_)(contentBlockBaseWrites))
98177
)
99-
case AssistantMessage(content) => Json.obj("role" -> "assistant", "content" -> content)
178+
179+
case AssistantMessage(content, cacheControl) =>
180+
val baseObj = Json.obj("role" -> "assistant", "content" -> content)
181+
baseObj ++ cacheControlToJsObject(cacheControl)
182+
100183
case AssistantMessageContent(content) =>
101184
Json.obj(
102185
"role" -> "assistant",
103-
"content" -> content.map(Json.toJson(_)(contentBlockWrites))
186+
"content" -> content.map(Json.toJson(_)(contentBlockBaseWrites))
104187
)
105188
// Add cases for other subclasses if necessary
106189
}
107190
}
108191

109-
implicit val baseMessageReads: Reads[Message] = (
192+
implicit lazy val baseMessageReads: Reads[Message] = (
110193
(__ \ "role").read[String] and
111-
(__ \ "content").lazyRead(contentReads)
194+
(__ \ "content").read[JsValue] and
195+
(__ \ "cache_control").readNullable[CacheControl]
112196
).tupled.flatMap {
113-
case ("user", SingleString(text)) => Reads.pure(UserMessage(text))
114-
case ("user", ContentBlocks(blocks)) => Reads.pure(UserMessageContent(blocks))
115-
case ("assistant", SingleString(text)) => Reads.pure(AssistantMessage(text))
116-
case ("assistant", ContentBlocks(blocks)) => Reads.pure(AssistantMessageContent(blocks))
197+
case ("user", JsString(str), cacheControl) => Reads.pure(UserMessage(str, cacheControl))
198+
case ("user", json @ JsArray(_), _) => {
199+
Json.fromJson[Seq[ContentBlockBase]](json) match {
200+
case JsSuccess(contentBlocks, _) =>
201+
Reads.pure(UserMessageContent(contentBlocks))
202+
case JsError(errors) =>
203+
Reads(_ => JsError(errors))
204+
}
205+
}
206+
case ("assistant", JsString(str), cacheControl) =>
207+
Reads.pure(AssistantMessage(str, cacheControl))
208+
209+
case ("assistant", json @ JsArray(_), _) => {
210+
Json.fromJson[Seq[ContentBlockBase]](json) match {
211+
case JsSuccess(contentBlocks, _) =>
212+
Reads.pure(AssistantMessageContent(contentBlocks))
213+
case JsError(errors) =>
214+
Reads(_ => JsError(errors))
215+
}
216+
}
117217
case _ => Reads(_ => JsError("Unsupported role or content type"))
118218
}
119219

120-
implicit val createMessageResponseReads: Reads[CreateMessageResponse] = (
220+
implicit lazy val createMessageResponseReads: Reads[CreateMessageResponse] = (
121221
(__ \ "id").read[String] and
122222
(__ \ "role").read[ChatRole] and
123-
(__ \ "content").read[Seq[ContentBlock]].map(ContentBlocks(_)) and
223+
(__ \ "content").read[Seq[ContentBlockBase]].map(ContentBlocks(_)) and
124224
(__ \ "model").read[String] and
125225
(__ \ "stop_reason").readNullable[String] and
126226
(__ \ "stop_sequence").readNullable[String] and
127227
(__ \ "usage").read[UsageInfo]
128228
)(CreateMessageResponse.apply _)
129229

130-
implicit val createMessageChunkResponseReads: Reads[CreateMessageChunkResponse] =
230+
implicit lazy val createMessageChunkResponseReads: Reads[CreateMessageChunkResponse] =
131231
Json.reads[CreateMessageChunkResponse]
132232

133-
implicit val deltaTextReads: Reads[DeltaText] = Json.reads[DeltaText]
134-
implicit val contentBlockDeltaReads: Reads[ContentBlockDelta] = Json.reads[ContentBlockDelta]
233+
implicit lazy val deltaTextReads: Reads[DeltaText] = Json.reads[DeltaText]
234+
implicit lazy val contentBlockDeltaReads: Reads[ContentBlockDelta] =
235+
Json.reads[ContentBlockDelta]
135236
}

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/domain/Content.scala

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,80 @@ package io.cequence.openaiscala.anthropic.domain
22

33
sealed trait Content
44

5+
sealed trait CacheControl
6+
object CacheControl {
7+
case object Ephemeral extends CacheControl
8+
}
9+
10+
trait Cacheable {
11+
def cacheControl: Option[CacheControl]
12+
}
13+
514
object Content {
6-
case class SingleString(text: String) extends Content
15+
case class SingleString(
16+
text: String,
17+
override val cacheControl: Option[CacheControl] = None
18+
) extends Content
19+
with Cacheable
720

8-
case class ContentBlocks(blocks: Seq[ContentBlock]) extends Content
21+
case class ContentBlocks(blocks: Seq[ContentBlockBase]) extends Content
22+
23+
case class ContentBlockBase(
24+
content: ContentBlock,
25+
override val cacheControl: Option[CacheControl] = None
26+
) extends Content
27+
with Cacheable
928

1029
sealed trait ContentBlock
1130

1231
object ContentBlock {
1332
case class TextBlock(text: String) extends ContentBlock
14-
case class ImageBlock(
33+
34+
case class MediaBlock(
1535
`type`: String,
36+
encoding: String,
1637
mediaType: String,
1738
data: String
1839
) extends ContentBlock
40+
41+
object MediaBlock {
42+
def pdf(
43+
data: String,
44+
cacheControl: Option[CacheControl] = None
45+
): ContentBlockBase =
46+
ContentBlockBase(
47+
MediaBlock("document", "base64", "application/pdf", data),
48+
cacheControl
49+
)
50+
51+
def image(
52+
mediaType: String
53+
)(
54+
data: String,
55+
cacheControl: Option[CacheControl] = None
56+
): ContentBlockBase =
57+
ContentBlockBase(MediaBlock("image", "base64", mediaType, data), cacheControl)
58+
59+
def jpeg(
60+
data: String,
61+
cacheControl: Option[CacheControl] = None
62+
): ContentBlockBase = image("image/jpeg")(data, cacheControl)
63+
64+
def png(
65+
data: String,
66+
cacheControl: Option[CacheControl] = None
67+
): ContentBlockBase = image("image/png")(data, cacheControl)
68+
69+
def gif(
70+
data: String,
71+
cacheControl: Option[CacheControl] = None
72+
): ContentBlockBase = image("image/gif")(data, cacheControl)
73+
74+
def webp(
75+
data: String,
76+
cacheControl: Option[CacheControl] = None
77+
): ContentBlockBase = image("image/webp")(data, cacheControl)
78+
}
79+
1980
}
2081
}
Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package io.cequence.openaiscala.anthropic.domain
22

33
import io.cequence.openaiscala.anthropic.domain.Content.{
4-
ContentBlock,
4+
ContentBlockBase,
55
ContentBlocks,
66
SingleString
77
}
@@ -13,12 +13,19 @@ sealed abstract class Message private (
1313

1414
object Message {
1515

16-
case class UserMessage(contentString: String)
17-
extends Message(ChatRole.User, SingleString(contentString))
18-
case class UserMessageContent(contentBlocks: Seq[ContentBlock])
16+
case class UserMessage(
17+
contentString: String,
18+
cacheControl: Option[CacheControl] = None
19+
) extends Message(ChatRole.User, SingleString(contentString, cacheControl))
20+
21+
case class UserMessageContent(contentBlocks: Seq[ContentBlockBase])
1922
extends Message(ChatRole.User, ContentBlocks(contentBlocks))
20-
case class AssistantMessage(contentString: String)
21-
extends Message(ChatRole.Assistant, SingleString(contentString))
22-
case class AssistantMessageContent(contentBlocks: Seq[ContentBlock])
23+
24+
case class AssistantMessage(
25+
contentString: String,
26+
cacheControl: Option[CacheControl] = None
27+
) extends Message(ChatRole.Assistant, SingleString(contentString, cacheControl))
28+
29+
case class AssistantMessageContent(contentBlocks: Seq[ContentBlockBase])
2330
extends Message(ChatRole.Assistant, ContentBlocks(contentBlocks))
2431
}

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/domain/settings/AnthropicCreateMessageSettings.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ final case class AnthropicCreateMessageSettings(
55
// See [[models|https://docs.anthropic.com/claude/docs/models-overview]] for additional details and options.
66
model: String,
77

8-
// System prompt.
9-
// A system prompt is a way of providing context and instructions to Claude, such as specifying a particular goal or role. See our [[guide to system prompts|https://docs.anthropic.com/claude/docs/system-prompts]].
10-
system: Option[String] = None,
8+
// // System prompt.
9+
// // A system prompt is a way of providing context and instructions to Claude, such as specifying a particular goal or role. See our [[guide to system prompts|https://docs.anthropic.com/claude/docs/system-prompts]].
10+
// system: Option[String] = None,
1111

1212
// The maximum number of tokens to generate before stopping.
1313
// Note that our models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate.

0 commit comments

Comments
 (0)