Skip to content

Commit

Permalink
Properly implement serialize / deserialize
Browse files Browse the repository at this point in the history
  • Loading branch information
ksylvest committed Aug 16, 2024
1 parent f848580 commit 5e385ba
Show file tree
Hide file tree
Showing 30 changed files with 679 additions and 182 deletions.
9 changes: 9 additions & 0 deletions .rubocop.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ RSpec/SpecFilePathFormat:
CustomTransform:
OmniAI: omniai

RSpec/MultipleMemoizedHelpers:
Enabled: false

Style/EmptyCaseCondition:
Enabled: false

Style/TrailingCommaInArrayLiteral:
EnforcedStyleForMultiline: consistent_comma

Expand All @@ -27,3 +33,6 @@ Metrics/ParameterLists:

Metrics/MethodLength:
Enabled: false

Metrics/AbcSize:
Enabled: false
26 changes: 13 additions & 13 deletions Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
omniai-google (1.6.3)
omniai-google (1.8.0)
event_stream_parser
omniai
zeitwerk
Expand All @@ -18,7 +18,7 @@ GEM
bigdecimal
rexml
diff-lcs (1.5.1)
docile (1.4.0)
docile (1.4.1)
domain_name (0.6.20240107)
event_stream_parser (1.0.0)
ffi (1.17.0)
Expand All @@ -35,7 +35,7 @@ GEM
ffi-compiler (1.3.2)
ffi (>= 1.15.5)
rake
hashdiff (1.1.0)
hashdiff (1.1.1)
http (5.2.0)
addressable (~> 2.8)
base64 (~> 0.1)
Expand All @@ -50,20 +50,20 @@ GEM
llhttp-ffi (0.5.0)
ffi-compiler (~> 1.0)
rake (~> 13.0)
omniai (1.6.2)
omniai (1.8.0)
event_stream_parser
http
zeitwerk
parallel (1.25.1)
parser (3.3.4.0)
parallel (1.26.2)
parser (3.3.4.2)
ast (~> 2.4.1)
racc
public_suffix (6.0.0)
racc (1.8.0)
public_suffix (6.0.1)
racc (1.8.1)
rainbow (3.1.1)
rake (13.2.1)
regexp_parser (2.9.2)
rexml (3.3.3)
rexml (3.3.5)
strscan
rspec (3.13.0)
rspec-core (~> 3.13.0)
Expand All @@ -80,7 +80,7 @@ GEM
rspec-support (3.13.1)
rspec_junit_formatter (0.6.0)
rspec-core (>= 2, < 4, != 2.12.0)
rubocop (1.65.0)
rubocop (1.65.1)
json (~> 2.3)
language_server-protocol (>= 3.17.0)
parallel (~> 1.10)
Expand All @@ -91,11 +91,11 @@ GEM
rubocop-ast (>= 1.31.1, < 2.0)
ruby-progressbar (~> 1.7)
unicode-display_width (>= 2.4.0, < 3.0)
rubocop-ast (1.31.3)
rubocop-ast (1.32.0)
parser (>= 3.3.1.0)
rubocop-rake (0.6.0)
rubocop (~> 1.0)
rubocop-rspec (3.0.3)
rubocop-rspec (3.0.4)
rubocop (~> 1.61)
ruby-progressbar (1.13.0)
simplecov (0.22.0)
Expand All @@ -111,7 +111,7 @@ GEM
crack (>= 0.3.2)
hashdiff (>= 0.4.0, < 2.0.0)
yard (0.9.36)
zeitwerk (2.6.16)
zeitwerk (2.6.17)

PLATFORMS
aarch64-linux-gnu
Expand Down
33 changes: 12 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,27 @@ Global configuration is supported for the following options:
OmniAI::Google.configure do |config|
config.api_key = 'sk-...' # default: ENV['GOOGLE_API_KEY']
config.host = '...' # default: 'https://generativelanguage.googleapis.com'
config.version = 'v1beta' # default: 'v1'
config.version = OmniAI::Google::Config::Version::BETA # either 'v1' or 'v1beta'
end
```

### Chat

A chat completion is generated by passing in prompts using any a variety of formats:
A chat completion is generated by passing in a simple text prompt:

```ruby
completion = client.chat('Tell me a joke!')
completion.choice.message.content # 'Why did the chicken cross the road? To get to the other side.'
completion.text # 'Why did the chicken cross the road? To get to the other side.'
```

```ruby
completion = client.chat({
role: OmniAI::Chat::Role::USER,
content: 'Is it wise to jump off a bridge?'
})
completion.choice.message.content # 'No.'
```
A chat completion may also be generated by using the prompt builder:

```ruby
completion = client.chat([
{
role: OmniAI::Chat::Role::USER,
content: 'You are a helpful assistant.'
},
'What is the capital of Canada?',
])
completion.choice.message.content # 'The capital of Canada is Ottawa.'
completion = client.chat do |prompt|
prompt.system('Your are an expert in geography.')
prompt.user('What is the capital of Canada?')
end
completion.text # 'The capital of Canada is Ottawa.'
```

#### Model
Expand All @@ -72,7 +63,7 @@ completion.choice.message.content # 'The capital of Canada is Ottawa.'

```ruby
completion = client.chat('How fast is a cheetah?', model: OmniAI::Google::Chat::Model::GEMINI_FLASH)
completion.choice.message.content # 'A cheetah can reach speeds over 100 km/h.'
completion.text # 'A cheetah can reach speeds over 100 km/h.'
```

[Google API Reference `model`](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions)
Expand All @@ -83,7 +74,7 @@ completion.choice.message.content # 'A cheetah can reach speeds over 100 km/h.'

```ruby
completion = client.chat('Pick a number between 1 and 5', temperature: 2.0)
completion.choice.message.content # '3'
completion.text # '3'
```

[Google API Reference `temperature`](https://ai.google.dev/api/rest/v1/GenerationConfig)
Expand All @@ -94,7 +85,7 @@ completion.choice.message.content # '3'

```ruby
stream = proc do |chunk|
print(chunk.choice.delta.content) # 'Better', 'three', 'hours', ...
print(chunk.text) # 'Better', 'three', 'hours', ...
end
client.chat('Be poetic.', stream:)
```
Expand Down
84 changes: 37 additions & 47 deletions lib/omniai/google/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -25,49 +25,47 @@ module Model

DEFAULT_MODEL = Model::GEMINI_PRO

TEXT_SERIALIZER = lambda do |content, *|
{ text: content.text }
end
# @return [Context]
CONTEXT = Context.build do |context|
context.serializers[:text] = TextSerializer.method(:serialize)
context.deserializers[:text] = TextSerializer.method(:deserialize)

# @param [Message]
# @return [Hash]
# @example
# message = Message.new(...)
# MESSAGE_SERIALIZER.call(message)
MESSAGE_SERIALIZER = lambda do |message, context:|
parts = message.content.is_a?(String) ? [Text.new(message.content)] : message.content
role = message.system? ? Role::USER : message.role

{
role:,
parts: parts.map { |part| part.serialize(context:) },
}
end
context.serializers[:file] = MediaSerializer.method(:serialize)
context.serializers[:url] = MediaSerializer.method(:serialize)

# @param [Media]
# @return [Hash]
# @example
# media = Media.new(...)
# MEDIA_SERIALIZER.call(media)
MEDIA_SERIALIZER = lambda do |media, *|
{
inlineData: {
mimeType: media.type,
data: media.data,
},
}
end
context.serializers[:tool_call] = ToolCallSerializer.method(:serialize)
context.deserializers[:tool_call] = ToolCallSerializer.method(:deserialize)

# @return [Context]
CONTEXT = Context.build do |context|
context.serializers[:message] = MESSAGE_SERIALIZER
context.serializers[:text] = TEXT_SERIALIZER
context.serializers[:file] = MEDIA_SERIALIZER
context.serializers[:url] = MEDIA_SERIALIZER
context.serializers[:tool_call_result] = ToolCallResultSerializer.method(:serialize)
context.deserializers[:tool_call_result] = ToolCallResultSerializer.method(:deserialize)

context.serializers[:function] = FunctionSerializer.method(:serialize)
context.deserializers[:function] = FunctionSerializer.method(:deserialize)

context.serializers[:usage] = UsageSerializer.method(:serialize)
context.deserializers[:usage] = UsageSerializer.method(:deserialize)

context.serializers[:payload] = PayloadSerializer.method(:serialize)
context.deserializers[:payload] = PayloadSerializer.method(:deserialize)

context.serializers[:choice] = ChoiceSerializer.method(:serialize)
context.deserializers[:choice] = ChoiceSerializer.method(:deserialize)

context.serializers[:message] = MessageSerializer.method(:serialize)
context.deserializers[:message] = MessageSerializer.method(:deserialize)

context.deserializers[:content] = ContentSerializer.method(:deserialize)

context.serializers[:tool] = ToolSerializer.method(:serialize)
end

protected

# @return [Context]
def context
CONTEXT
end

# @return [HTTP::Response]
def request!
@client
Expand All @@ -82,7 +80,8 @@ def request!
# @return [Hash]
def payload
OmniAI::Google.config.chat_options.merge({
contents:,
system_instruction: @prompt.messages.find(&:system?)&.serialize(context:),
contents: @prompt.messages.reject(&:system?).map { |message| message.serialize(context:) },
tools:,
generationConfig: generation_config,
}).compact
Expand All @@ -93,7 +92,7 @@ def tools
return unless @tools

[
function_declarations: @tools&.map(&:prepare),
function_declarations: @tools.map { |tool| tool.serialize(context:) },
]
end

Expand All @@ -104,15 +103,6 @@ def generation_config
{ temperature: @temperature }.compact
end

# Example:
#
# [{ role: 'user', parts: [{ text: '...' }] }]
#
# @return [Array<Hash>]
def contents
@prompt.serialize(context: CONTEXT)
end

# @return [String]
def path
"/#{@client.version}/models/#{@model}:#{operation}"
Expand Down
26 changes: 26 additions & 0 deletions lib/omniai/google/chat/choice_serializer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# frozen_string_literal: true

module OmniAI
module Google
class Chat
# Overrides choice serialize / deserialize.
module ChoiceSerializer
# @param choice [OmniAI::Chat::Choice]
# @param context [Context]
# @return [Hash]
def self.serialize(choice, context:)
content = choice.message.serialize(context:)
{ content: }
end

# @param data [Hash]
# @param context [Context]
# @return [OmniAI::Chat::Choice]
def self.deserialize(data, context:)
message = OmniAI::Chat::Message.deserialize(data['content'], context:)
OmniAI::Chat::Choice.new(message:)
end
end
end
end
end
20 changes: 20 additions & 0 deletions lib/omniai/google/chat/content_serializer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# frozen_string_literal: true

module OmniAI
module Google
class Chat
# Overrides content serialize / deserialize.
module ContentSerializer
# @param data [Hash]
# @param context [Context]
# @return [OmniAI::Chat::Text, OmniAI::Chat::ToolCall]
def self.deserialize(data, context:)
case
when data['text'] then data['text']
when data['functionCall'] then OmniAI::Chat::ToolCall.deserialize(data, context:)
end
end
end
end
end
end
27 changes: 27 additions & 0 deletions lib/omniai/google/chat/function_serializer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# frozen_string_literal: true

module OmniAI
module Google
class Chat
# Overrides function serialize / deserialize.
module FunctionSerializer
# @param function [OmniAI::Chat::Function]
# @return [Hash]
def self.serialize(function, *)
{
name: function.name,
args: function.arguments,
}
end

# @param data [Hash]
# @return [OmniAI::Chat::Function]
def self.deserialize(data, *)
name = data['name']
arguments = data['arguments']
OmniAI::Chat::Function.new(name:, arguments:)
end
end
end
end
end
21 changes: 21 additions & 0 deletions lib/omniai/google/chat/media_serializer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# frozen_string_literal: true

module OmniAI
module Google
class Chat
# Overrides media serialize / deserialize.
module MediaSerializer
# @param media [OmniAI::Chat::Media]
# @return [Hash]
def self.serialize(media, *)
{
inlineData: {
mimeType: media.type,
data: media.data,
},
}
end
end
end
end
end
Loading

0 comments on commit 5e385ba

Please sign in to comment.