Skip to content

Commit

Permalink
Properly implement serialize / deserialize
Browse files Browse the repository at this point in the history
  • Loading branch information
ksylvest committed Aug 15, 2024
1 parent a9e3e09 commit a6333d9
Show file tree
Hide file tree
Showing 86 changed files with 1,635 additions and 928 deletions.
10 changes: 10 additions & 0 deletions .rubocop.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
require:
- rubocop-factory_bot
- rubocop-rake
- rubocop-rspec

Expand All @@ -15,6 +16,15 @@ Layout/FirstHashElementIndentation:
Layout/FirstArrayElementIndentation:
EnforcedStyle: consistent

Metrics/AbcSize:
Enabled: false

Metrics/CyclomaticComplexity:
Enabled: false

Metrics/PerceivedComplexity:
Enabled: false

Metrics/ClassLength:
Enabled: false

Expand Down
2 changes: 2 additions & 0 deletions Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ source 'https://rubygems.org'

gemspec

gem 'factory_bot'
gem 'logger'
gem 'rake'
gem 'rspec'
gem 'rspec_junit_formatter'
gem 'rubocop'
gem 'rubocop-factory_bot'
gem 'rubocop-rake'
gem 'rubocop-rspec'
gem 'simplecov'
Expand Down
41 changes: 33 additions & 8 deletions Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
PATH
remote: .
specs:
omniai (1.7.0)
omniai (1.8.0)
event_stream_parser
http
zeitwerk

GEM
remote: https://rubygems.org/
specs:
activesupport (7.1.3.4)
base64
bigdecimal
concurrent-ruby (~> 1.0, >= 1.0.2)
connection_pool (>= 2.2.5)
drb
i18n (>= 1.6, < 2)
minitest (>= 5.1)
mutex_m
tzinfo (~> 2.0)
addressable (2.8.7)
public_suffix (>= 2.0.2, < 7.0)
ast (2.4.2)
base64 (0.2.0)
bigdecimal (3.1.8)
concurrent-ruby (1.3.3)
connection_pool (2.4.1)
crack (1.0.0)
bigdecimal
rexml
diff-lcs (1.5.1)
docile (1.4.0)
docile (1.4.1)
domain_name (0.6.20240107)
drb (2.2.1)
event_stream_parser (1.0.0)
factory_bot (6.4.6)
activesupport (>= 5.0.0)
ffi (1.17.0)
ffi (1.17.0-aarch64-linux-gnu)
ffi (1.17.0-aarch64-linux-musl)
Expand All @@ -35,7 +50,7 @@ GEM
ffi-compiler (1.3.2)
ffi (>= 1.15.5)
rake
hashdiff (1.1.0)
hashdiff (1.1.1)
http (5.2.0)
addressable (~> 2.8)
base64 (~> 0.1)
Expand All @@ -45,17 +60,21 @@ GEM
http-cookie (1.0.6)
domain_name (~> 0.5)
http-form_data (2.3.0)
i18n (1.14.5)
concurrent-ruby (~> 1.0)
json (2.7.2)
language_server-protocol (3.17.0.3)
llhttp-ffi (0.5.0)
ffi-compiler (~> 1.0)
rake (~> 13.0)
logger (1.6.0)
parallel (1.25.1)
parser (3.3.4.0)
minitest (5.24.1)
mutex_m (0.2.0)
parallel (1.26.1)
parser (3.3.4.2)
ast (~> 2.4.1)
racc
public_suffix (6.0.0)
public_suffix (6.0.1)
racc (1.8.1)
rainbow (3.1.1)
rake (13.2.1)
Expand Down Expand Up @@ -88,11 +107,13 @@ GEM
rubocop-ast (>= 1.31.1, < 2.0)
ruby-progressbar (~> 1.7)
unicode-display_width (>= 2.4.0, < 3.0)
rubocop-ast (1.31.3)
rubocop-ast (1.32.0)
parser (>= 3.3.1.0)
rubocop-factory_bot (2.26.1)
rubocop (~> 1.61)
rubocop-rake (0.6.0)
rubocop (~> 1.0)
rubocop-rspec (3.0.3)
rubocop-rspec (3.0.4)
rubocop (~> 1.61)
ruby-progressbar (1.13.0)
simplecov (0.22.0)
Expand All @@ -102,6 +123,8 @@ GEM
simplecov-html (0.12.3)
simplecov_json_formatter (0.1.4)
strscan (3.1.0)
tzinfo (2.0.6)
concurrent-ruby (~> 1.0)
unicode-display_width (2.5.0)
webmock (3.23.1)
addressable (>= 2.8.0)
Expand All @@ -124,12 +147,14 @@ PLATFORMS
x86_64-linux-musl

DEPENDENCIES
factory_bot
logger
omniai!
rake
rspec
rspec_junit_formatter
rubocop
rubocop-factory_bot
rubocop-rake
rubocop-rspec
simplecov
Expand Down
13 changes: 3 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Generating a completion is as simple as sending in the text:

```ruby
completion = client.chat('Tell me a joke.')
completion.choice.message.content # 'Why don't scientists trust atoms? They make up everything!'
completion.text # 'Why don't scientists trust atoms? They make up everything!'
```

#### Completions using a Complex Prompt
Expand All @@ -145,7 +145,7 @@ completion = client.chat do |prompt|
message.file('./hamster.jpeg', "image/jpeg")
end
end
completion.choice.message.content # 'They are photos of a cat, a cat, and a hamster.'
completion.text # 'They are photos of a cat, a cat, and a hamster.'
```

#### Completions using Streaming via Proc
Expand All @@ -154,7 +154,7 @@ A real-time stream of messages can be generated by passing in a proc:

```ruby
stream = proc do |chunk|
print(chunk.choice.delta.content) # '...'
print(chunk.text) # '...'
end
client.chat('Tell me a joke.', stream:)
```
Expand Down Expand Up @@ -315,10 +315,3 @@ Type 'exit' or 'quit' to abort.
0.0
...
```

0.0
...

```
```
95 changes: 61 additions & 34 deletions lib/omniai/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,23 @@ module OmniAI
class Chat
JSON_PROMPT = 'Respond with valid JSON. Do not include any non-JSON in the response.'

# An error raised when a chat makes a tool-call for a tool that cannot be found.
class ToolCallLookupError < Error
def initialize(tool_call)
super("missing tool for tool_call=#{tool_call.inspect}")
# An error raised for tool-call issues.
class ToolCallError < Error
# @param tool_call [OmniAI::Chat::ToolCall]
# @param message [String]
def initialize(tool_call:, message:)
super(message)
@tool_call = tool_call
end
end

# An error raised when a tool-call is missing.
class ToolCallMissingError < ToolCallError
def initialize(tool_call:)
super(tool_call:, message: "missing tool for tool_call=#{tool_call.inspect}")
end
end

module Role
ASSISTANT = 'assistant'
USER = 'user'
Expand Down Expand Up @@ -87,10 +96,22 @@ def process!

protected

# Override to provide an context for serializers / deserializes for a provider.
#
# @return [Context, nil]
def context
nil
end

# @return [Logger, nil]
def logger
@client.logger
end

# Used to spawn another chat with the same configuration using different messages.
#
# @param prompt [OmniAI::Chat::Prompt]
# @return [OmniAI::Chat::Prompt]
# @return [OmniAI::Chat]
def spawn!(prompt)
self.class.new(
prompt,
Expand All @@ -100,7 +121,7 @@ def spawn!(prompt)
stream: @stream,
tools: @tools,
format: @format
).process!
)
end

# @return [Hash]
Expand All @@ -114,7 +135,7 @@ def path
end

# @param response [HTTP::Response]
# @return [OmniAI::Chat::Response::Completion]
# @return [OmniAI::Chat::Response]
def parse!(response:)
if @stream
stream!(response:)
Expand All @@ -124,31 +145,32 @@ def parse!(response:)
end

# @param response [HTTP::Response]
# @return [OmniAI::Chat::Response::Completion]
# @return [OmniAI::Chat::Response]
def complete!(response:)
completion = self.class::Response::Completion.new(data: response.parse)

if @tools && completion.tool_call_list.any?
spawn!([
*@prompt.serialize,
*completion.choices.map(&:message).map(&:data),
*(completion.tool_call_list.map { |tool_call| execute_tool_call(tool_call) }),
])
completion = self.class::Response.new(data: response.parse, context:)

if @tools && completion.tool_call_list?
spawn!(
@prompt.dup.tap do |prompt|
prompt.messages += completion.messages
prompt.messages += build_tool_call_messages(completion.tool_call_list)
end
).process!
else
completion
end
end

# @param response [HTTP::Response]
# @return [OmniAI::Chat::Response::Stream]
# @return [OmniAI::Chat::Stream]
def stream!(response:)
raise Error, "#{self.class.name}#stream! unstreamable" unless @stream

self.class::Response::Stream.new(response:).stream! do |chunk|
self.class::Stream.new(body: response.body, logger:, context:).stream! do |chunk|
case @stream
when IO, StringIO
if chunk.content?
@stream << chunk.content
if chunk.text
@stream << chunk.text
@stream.flush
end
else @stream.call(chunk)
Expand All @@ -160,31 +182,36 @@ def stream!(response:)

# @return [HTTP::Response]
def request!
logger&.debug("Chat#request! payload=#{payload.inspect}")

@client
.connection
.accept(:json)
.post(path, json: payload)
end

# @param tool_call_list [Array<OmniAI::Chat::ToolCall>]
# @return [Array<Message>]
def build_tool_call_messages(tool_call_list)
tool_call_list.map do |tool_call|
content = execute_tool_call(tool_call)
ToolCallMessage.new(content:, tool_call_id: tool_call.id)
end
end

# @raise [ToolCallError]
# @param tool_call [OmniAI::Chat::ToolCall]
# @return [ToolCallResult]
def execute_tool_call(tool_call)
function = tool_call.function
logger&.debug("Chat#execute_tool_call tool_call=#{tool_call.inspect}")

tool = @tools.find { |entry| function.name == entry.name } || raise(ToolCallLookupError, tool_call)
result = tool.call(function.arguments)
function = tool_call.function
tool = @tools.find { |entry| function.name == entry.name } || raise(ToolCallMissingError, tool_call)
content = tool.call(function.arguments)

prepare_tool_call_message(tool_call:, content: result)
end
logger&.debug("Chat#execute_tool_call content=#{content.inspect}")

# @param tool_call [OmniAI::Chat::ToolCall]
# @param content [String]
def prepare_tool_call_message(tool_call:, content:)
{
role: Role::TOOL,
name: tool_call.function.name,
tool_call_id: tool_call.id,
content:,
}
content
end
end
end
Loading

0 comments on commit a6333d9

Please sign in to comment.