Skip to content

Commit

Permalink
Support for client.transcribe
Browse files Browse the repository at this point in the history
  • Loading branch information
ksylvest committed Jun 17, 2024
1 parent 12b09ea commit 7916a40
Show file tree
Hide file tree
Showing 10 changed files with 280 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
omniai (1.0.9)
omniai (1.1.0)
event_stream_parser
http
zeitwerk
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,11 @@ stream = proc do |chunk|
end
client.chat('Tell me a joke.', stream:)
```

### Transcribe

Clients that support chat (e.g. OpenAI w/ "Whisper", etc) convert recordings to text via the following calls:

```ruby
client.transcribe(file.path)
```
4 changes: 4 additions & 0 deletions lib/omniai/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ module Role
SYSTEM = 'system'
end

module Format
JSON = :json
end

def self.process!(...)
new(...).process!
end
Expand Down
14 changes: 14 additions & 0 deletions lib/omniai/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,19 @@ def connection
def chat(messages, model:, temperature: nil, format: nil, stream: nil)
raise NotImplementedError, "#{self.class.name}#chat undefined"
end

# @raise [OmniAI::Error]
#
# @param file [IO]
# @param model [String]
# @param language [String, nil] optional
# @param prompt [String, nil] optional
# @param temperature [Float, nil] optional
# @param format [Symbol] :text, :srt, :vtt, or :json (default)
#
# @return text [OmniAI::Transcribe::Transcription]
def transcribe(file, model:, language: nil, prompt: nil, temperature: nil, format: Transcription::Format::JSON)
raise NotImplementedError, "#{self.class.name}#speak undefined"
end
end
end
152 changes: 152 additions & 0 deletions lib/omniai/transcribe.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# frozen_string_literal: true

module OmniAI
# An abstract class that provides a consistent interface for processing transcribe requests.
#
# Usage:
#
# class OmniAI::OpenAI::Transcribe < OmniAI::Transcribe
# module Model
# WHISPER_1 = "whisper-1"
# end
#
# protected
#
# # @return [Hash]
# def payload
# raise NotImplementedError, "#{self.class.name}#payload undefined"
# end
#
# # @return [String]
# def path
# raise NotImplementedError, "#{self.class.name}#path undefined"
# end
# end
#
# client.transcribe(File.open("..."), model: "...", format: :json)
class Transcribe
module Language
AFRIKAANS = 'af'
ARABIC = 'ar'
ARMENIAN = 'hy'
AZERBAIJANI = 'az'
BELARUSIAN = 'be'
BOSNIAN = 'bs'
BULGARIAN = 'bg'
CATALAN = 'ca'
CHINESE = 'zh'
CROATIAN = 'hr'
CZECH = 'cs'
DANISH = 'da'
DUTCH = 'nl'
ENGLISH = 'en'
ESTONIAN = 'et'
FINNISH = 'fi'
FRENCH = 'fr'
GALICIAN = 'gl'
GERMAN = 'de'
GREEK = 'el'
HEBREW = 'he'
HINDI = 'hi'
HUNGARIAN = 'hu'
ICELANDIC = 'is'
INDONESIAN = 'id'
ITALIAN = 'it'
JAPANESE = 'ja'
KANNADA = 'kn'
KAZAKH = 'kk'
KOREAN = 'ko'
LATVIAN = 'lv'
LITHUANIAN = 'lt'
MACEDONIAN = 'mk'
MALAY = 'ms'
MARATHI = 'mr'
MAORI = 'mi'
NEPALI = 'ne'
NORWEGIAN = 'no'
PERSIAN = 'fa'
POLISH = 'pl'
PORTUGUESE = 'pt'
ROMANIAN = 'ro'
RUSSIAN = 'ru'
SERBIAN = 'sr'
SLOVAK = 'sk'
SLOVENIAN = 'sl'
SPANISH = 'es'
SWAHILI = 'sw'
SWEDISH = 'sv'
TAGALOG = 'tl'
TAMIL = 'ta'
THAI = 'th'
TURKISH = 'tr'
UKRAINIAN = 'uk'
URDU = 'ur'
VIETNAMESE = 'vi'
WELSH = 'cy'
end

module Format
JSON = 'json'
TEXT = 'text'
VTT = 'vtt'
SRT = 'srt'
end

def self.process!(...)
new(...).process!
end

# @param path [String] required
# @param client [OmniAI::Client] the client
# @param model [String] required
# @param language [String, nil] optional
# @param prompt [String, nil] optional
# @param temperature [Float, nil] optional
# @param format [String, nil] optional
def initialize(path, client:, model:, language: nil, prompt: nil, temperature: nil, format: Format::JSON)
@path = path
@model = model
@language = language
@prompt = prompt
@temperature = temperature
@format = format
@client = client
end

# @return [OmniAI::Transcribe::Transcription]
# @raise [ExecutionError]
def process!
response = request!
raise HTTPError, response.flush unless response.status.ok?

data = @format.eql?(Format::JSON) ? response.parse : { text: String(response.body) }
Transcription.new(format: @format, data:)
end

protected

# @return [Hash]
def payload
{
file: HTTP::FormData::File.new(@path),
model: @model,
language: @language,
prompt: @prompt,
temperature: @temperature,
}.compact
end

# @return [String]
def path
raise NotImplementedError, "#{self.class.name}#path undefined"
end

# @return [HTTP::Response]
def request!
@client
.connection
.accept(@format.eql?(Format::JSON) ? :json : :text)
.post(path, form: payload)
end
end
end
26 changes: 26 additions & 0 deletions lib/omniai/transcribe/transcription.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# frozen_string_literal: true

module OmniAI
class Transcribe
# A transcription returned by the API.
class Transcription
attr_accessor :data, :format

# @param data [Hash]
def initialize(data:, format:)
@data = data
@format = format
end

# @return [String]
def text
@data['text']
end

# @return [String]
def inspect
"#<#{self.class} text=#{text.inspect} format=#{format.inspect}>"
end
end
end
end
2 changes: 1 addition & 1 deletion lib/omniai/version.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# frozen_string_literal: true

module OmniAI
VERSION = '1.0.9'
VERSION = '1.1.0'
end
Binary file added spec/fixtures/file.ogg
Binary file not shown.
17 changes: 17 additions & 0 deletions spec/omniai/transcribe/transcription_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# frozen_string_literal: true

RSpec.describe OmniAI::Transcribe::Transcription do
subject(:transcription) { described_class.new(format: OmniAI::Transcribe::Format::JSON, data: { 'text' => 'Hi!' }) }

describe '#format' do
it { expect(transcription.format).to eq('json') }
end

describe '#text' do
it { expect(transcription.text).to eq('Hi!') }
end

describe '#inspect' do
it { expect(transcription.inspect).to eq('#<OmniAI::Transcribe::Transcription text="Hi!" format="json">') }
end
end
57 changes: 57 additions & 0 deletions spec/omniai/transcribe_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# frozen_string_literal: true

class FakeClient < OmniAI::Client
def connection
HTTP.persistent('http://localhost:8080')
end
end

class FakeTranscribe < OmniAI::Transcribe
module Model
FAKE = 'fake'
end

def path
'/transcribe'
end
end

RSpec.describe OmniAI::Transcribe do
subject(:transcribe) { described_class.new(path, model:, client:) }

let(:model) { '...' }
let(:client) { OmniAI::Client.new(api_key: '...') }
let(:path) { Pathname.new(File.dirname(__FILE__)).join('..', 'fixtures', 'file.ogg') }

describe '#path' do
it { expect { transcribe.send(:path) }.to raise_error(NotImplementedError) }
end

describe '.process!' do
subject(:process!) { FakeTranscribe.process!(path, client:, model:) }

let(:client) { FakeClient.new(api_key: '...') }
let(:model) { FakeTranscribe::Model::FAKE }

context 'when OK' do
before do
stub_request(:post, 'http://localhost:8080/transcribe')
.to_return_json(status: 200, body: {
text: 'Hi!',
})
end

it { expect(process!).to be_a(OmniAI::Transcribe::Transcription) }
it { expect(process!.text).to eq('Hi!') }
end

context 'when UNPROCESSABLE' do
before do
stub_request(:post, 'http://localhost:8080/transcribe')
.to_return(status: 422, body: 'An unknown error occurred.')
end

it { expect { process! }.to raise_error(OmniAI::HTTPError) }
end
end
end

0 comments on commit 7916a40

Please sign in to comment.