Skip to content

Commit

Permalink
Properly implement serialize / deserialize
Browse files Browse the repository at this point in the history
  • Loading branch information
ksylvest committed Aug 15, 2024
1 parent 4e39526 commit 3cf8c02
Show file tree
Hide file tree
Showing 20 changed files with 459 additions and 93 deletions.
3 changes: 3 additions & 0 deletions .rubocop.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ RSpec/SpecFilePathFormat:
CustomTransform:
OmniAI: omniai

RSpec/MultipleMemoizedHelpers:
Enabled: false

Layout/FirstHashElementIndentation:
EnforcedStyle: consistent

Expand Down
26 changes: 13 additions & 13 deletions Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
omniai-anthropic (1.6.3)
omniai-anthropic (1.8.0)
event_stream_parser
omniai
zeitwerk
Expand All @@ -18,7 +18,7 @@ GEM
bigdecimal
rexml
diff-lcs (1.5.1)
docile (1.4.0)
docile (1.4.1)
domain_name (0.6.20240107)
event_stream_parser (1.0.0)
ffi (1.17.0)
Expand All @@ -35,7 +35,7 @@ GEM
ffi-compiler (1.3.2)
ffi (>= 1.15.5)
rake
hashdiff (1.1.0)
hashdiff (1.1.1)
http (5.2.0)
addressable (~> 2.8)
base64 (~> 0.1)
Expand All @@ -50,20 +50,20 @@ GEM
llhttp-ffi (0.5.0)
ffi-compiler (~> 1.0)
rake (~> 13.0)
omniai (1.6.3)
omniai (1.8.0)
event_stream_parser
http
zeitwerk
parallel (1.25.1)
parser (3.3.4.0)
parallel (1.26.2)
parser (3.3.4.2)
ast (~> 2.4.1)
racc
public_suffix (6.0.0)
racc (1.8.0)
public_suffix (6.0.1)
racc (1.8.1)
rainbow (3.1.1)
rake (13.2.1)
regexp_parser (2.9.2)
rexml (3.3.3)
rexml (3.3.5)
strscan
rspec (3.13.0)
rspec-core (~> 3.13.0)
Expand All @@ -80,7 +80,7 @@ GEM
rspec-support (3.13.1)
rspec_junit_formatter (0.6.0)
rspec-core (>= 2, < 4, != 2.12.0)
rubocop (1.65.0)
rubocop (1.65.1)
json (~> 2.3)
language_server-protocol (>= 3.17.0)
parallel (~> 1.10)
Expand All @@ -91,11 +91,11 @@ GEM
rubocop-ast (>= 1.31.1, < 2.0)
ruby-progressbar (~> 1.7)
unicode-display_width (>= 2.4.0, < 3.0)
rubocop-ast (1.31.3)
rubocop-ast (1.32.0)
parser (>= 3.3.1.0)
rubocop-rake (0.6.0)
rubocop (~> 1.0)
rubocop-rspec (3.0.3)
rubocop-rspec (3.0.4)
rubocop (~> 1.61)
ruby-progressbar (1.13.0)
simplecov (0.22.0)
Expand All @@ -111,7 +111,7 @@ GEM
crack (>= 0.3.2)
hashdiff (>= 0.4.0, < 2.0.0)
yard (0.9.36)
zeitwerk (2.6.16)
zeitwerk (2.6.17)

PLATFORMS
aarch64-linux-gnu
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ A chat completion is generated by passing in prompts using any a variety of form

```ruby
completion = client.chat('Tell me a joke!')
completion.choice.message.content # 'Why did the chicken cross the road? To get to the other side.'
completion.text # 'Why did the chicken cross the road? To get to the other side.'
```

```ruby
completion = client.chat do |prompt|
prompt.system('You are a helpful assistant.')
prompt.user('What is the capital of Canada?')
end
completion.choice.message.content # 'The capital of Canada is Ottawa.'
completion.text # 'The capital of Canada is Ottawa.'
```

#### Model
Expand All @@ -57,7 +57,7 @@ completion.choice.message.content # 'The capital of Canada is Ottawa.'

```ruby
completion = client.chat('Provide code for fibonacci', model: OmniAI::Anthropic::Chat::Model::CLAUDE_SONNET)
completion.choice.message.content # 'def fibonacci(n)...end'
completion.text # 'def fibonacci(n)...end'
```

[Anthropic API Reference `model`](https://docs.anthropic.com/en/api/messages)
Expand All @@ -68,7 +68,7 @@ completion.choice.message.content # 'def fibonacci(n)...end'

```ruby
completion = client.chat('Pick a number between 1 and 5', temperature: 1.0)
completion.choice.message.content # '3'
completion.text # '3'
```

[Anthropic API Reference `temperature`](https://docs.anthropic.com/en/api/messages)
Expand All @@ -79,7 +79,7 @@ completion.choice.message.content # '3'

```ruby
stream = proc do |chunk|
print(chunk.choice.delta.content) # 'Better', 'three', 'hours', ...
print(chunk.text) # 'Better', 'three', 'hours', ...
end
client.chat('Be poetic.', stream:)
```
Expand All @@ -94,7 +94,7 @@ client.chat('Be poetic.', stream:)
completion = client.chat(format: :json) do |prompt|
prompt.system(OmniAI::Chat::JSON_PROMPT)
prompt.user('What is the name of the drummer for the Beatles?')
JSON.parse(completion.choice.message.content) # { "name": "Ringo" }
JSON.parse(completion.text) # { "name": "Ringo" }
```

[Anthropic API Reference `control-output-format`](https://docs.anthropic.com/en/docs/control-output-format)
49 changes: 27 additions & 22 deletions lib/omniai/anthropic/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module Anthropic
# prompt.system('You are an expert in the field of AI.')
# prompt.user('What are the biggest risks of AI?')
# end
# completion.choice.message.content # '...'
# completion.text # '...'
class Chat < OmniAI::Chat
module Model
CLAUDE_INSTANT_1_0 = 'claude-instant-1.2'
Expand All @@ -27,26 +27,22 @@ module Model

DEFAULT_MODEL = Model::CLAUDE_SONNET

# @param [Media]
# @return [Hash]
# @example
# media = Media.new(...)
# MEDIA_SERIALIZER.call(media)
MEDIA_SERIALIZER = lambda do |media, *|
{
type: media.kind, # i.e. 'image' / 'video' / 'audio' / ...
source: {
type: 'base64',
media_type: media.type, # i.e. 'image/jpeg' / 'video/ogg' / 'audio/mpeg' / ...
data: media.data,
},
}
end

# @return [Context]
CONTEXT = Context.build do |context|
context.serializers[:file] = MEDIA_SERIALIZER
context.serializers[:url] = MEDIA_SERIALIZER
context.serializers[:file] = MediaSerializer.method(:serialize)
context.serializers[:url] = MediaSerializer.method(:serialize)

context.serializers[:choice] = ChoiceSerializer.method(:serialize)
context.deserializers[:choice] = ChoiceSerializer.method(:deserialize)

context.serializers[:tool_call] = ToolCallSerializer.method(:serialize)
context.deserializers[:tool_call] = ToolCallSerializer.method(:deserialize)

context.serializers[:function] = FunctionSerializer.method(:serialize)
context.deserializers[:function] = FunctionSerializer.method(:deserialize)

context.deserializers[:content] = ContentSerializer.method(:deserialize)
context.deserializers[:payload] = PayloadSerializer.method(:deserialize)
end

# @return [Hash]
Expand All @@ -63,21 +59,30 @@ def payload

# @return [Array<Hash>]
def messages
messages = @prompt.messages.filter(&:user?)
messages.map { |message| message.serialize(context: CONTEXT) }
messages = @prompt.messages.reject(&:system?)
messages.map { |message| message.serialize(context:) }
end

# @return [String, nil]
def system
messages = @prompt.messages.filter(&:system?)
messages.map(&:content).join("\n\n") if messages.any?
return if messages.empty?

messages.filter(&:text?).map(&:text).join("\n\n")
end

# @return [String]
def path
"/#{Client::VERSION}/messages"
end

protected

# @return [Context]
def context
CONTEXT
end

private

# @return [Array<Hash>, nil]
Expand Down
25 changes: 25 additions & 0 deletions lib/omniai/anthropic/chat/choice_serializer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# frozen_string_literal: true

module OmniAI
module Anthropic
class Chat
# Overrides choice serialize / deserialize.
module ChoiceSerializer
# @param choice [OmniAI::Chat::Choice]
# @param context [Context]
# @return [Hash]
def self.serialize(choice, context:)
choice.message.serialize(context:)
end

# @param data [Hash]
# @param context [Context]
# @return [OmniAI::Chat::Choice]
def self.deserialize(data, context:)
message = OmniAI::Chat::Message.deserialize(data, context:)
OmniAI::Chat::Choice.new(message:)
end
end
end
end
end
20 changes: 20 additions & 0 deletions lib/omniai/anthropic/chat/content_serializer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# frozen_string_literal: true

module OmniAI
module Anthropic
class Chat
# Overrides content serialize / deserialize.
module ContentSerializer
# @param data [Hash]
# @param context [Context]
# @return [OmniAI::Chat::Text, OmniAI::Chat::ToolCall]
def self.deserialize(data, context:)
case data['type']
when 'text' then OmniAI::Chat::Text.deserialize(data, context:)
when 'tool_use' then OmniAI::Chat::ToolCall.deserialize(data, context:)
end
end
end
end
end
end
27 changes: 27 additions & 0 deletions lib/omniai/anthropic/chat/function_serializer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# frozen_string_literal: true

module OmniAI
module Anthropic
class Chat
# Overrides function serialize / deserialize.
module FunctionSerializer
# @param function [OmniAI::Chat::Function]
# @return [Hash]
def self.serialize(function, *)
{
name: function.name,
input: function.arguments,
}
end

# @param data [Hash]
# @return [OmniAI::Chat::Function]
def self.deserialize(data, *)
name = data['name']
arguments = data['input']
OmniAI::Chat::Function.new(name:, arguments:)
end
end
end
end
end
23 changes: 23 additions & 0 deletions lib/omniai/anthropic/chat/media_serializer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# frozen_string_literal: true

module OmniAI
module Anthropic
class Chat
# Overrides media serialize / deserialize.
module MediaSerializer
# @param payload [OmniAI::Chat::Media]
# @return [Hash]
def self.serialize(media, *)
{
type: media.kind, # i.e. 'image' / 'video' / 'audio' / ...
source: {
type: 'base64',
media_type: media.type, # i.e. 'image/jpeg' / 'video/ogg' / 'audio/mpeg' / ...
data: media.data,
},
}
end
end
end
end
end
30 changes: 30 additions & 0 deletions lib/omniai/anthropic/chat/payload_serializer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# frozen_string_literal: true

module OmniAI
module Anthropic
class Chat
# Overrides payload serialize / deserialize.
module PayloadSerializer
# @param payload [OmniAI::Chat::Payload]
# @param context [OmniAI::Context]
# @return [Hash]
def self.serialize(payload, context:)
usage = payload.usage.serialize(context:)
choice = payload.choice.serialize(context:)

choice.merge({ usage: })
end

# @param data [Hash]
# @param context [OmniAI::Context]
# @return [OmniAI::Chat::Payload]
def self.deserialize(data, context:)
usage = OmniAI::Chat::Usage.deserialize(data['usage'], context:) if data['usage']
choice = OmniAI::Chat::Choice.deserialize(data, context:)

OmniAI::Chat::Payload.new(choices: [choice], usage:)
end
end
end
end
end
29 changes: 0 additions & 29 deletions lib/omniai/anthropic/chat/response/completion.rb

This file was deleted.

Loading

0 comments on commit 3cf8c02

Please sign in to comment.