Skip to content

Commit 3b7ddc7

Browse files
committed
wip
1 parent 061fec5 commit 3b7ddc7

File tree

4 files changed

+88
-38
lines changed

4 files changed

+88
-38
lines changed

Sources/Core/Models/Item.swift

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@ import MetaCodable
88

99
public struct Audio: Equatable, Hashable, Codable, Sendable {
1010
/// Audio bytes
11-
public var audio: AudioData
11+
public var audio: AudioData?
1212

1313
/// The transcript of the audio
1414
public var transcript: String?
1515

16-
public init(audio: AudioData, transcript: String? = nil) {
16+
public init(audio: AudioData? = nil, transcript: String? = nil) {
1717
self.audio = audio
1818
self.transcript = transcript
1919
}
2020

21-
public init(audio: Data = Data(), transcript: String? = nil) {
22-
self.init(audio: AudioData(data: audio), transcript: transcript)
21+
public init(audio: Data? = nil, transcript: String? = nil) {
22+
self.init(audio: audio.map { AudioData(data: $0) }, transcript: transcript)
2323
}
2424
}
2525

@@ -36,15 +36,15 @@ import MetaCodable
3636
public enum Content: Equatable, Hashable, Sendable {
3737
case text(String)
3838
case audio(Audio)
39-
case input_text(String)
40-
case input_audio(Audio)
39+
case inputText(String)
40+
case inputAudio(Audio)
4141

4242
public var text: String? {
4343
switch self {
4444
case let .text(text): text
45-
case let .input_text(text): text
45+
case let .inputText(text): text
4646
case let .audio(audio): audio.transcript
47-
case let .input_audio(audio): audio.transcript
47+
case let .inputAudio(audio): audio.transcript
4848
}
4949
}
5050
}
@@ -439,11 +439,11 @@ extension Item.Message.Content: Codable {
439439
self = try .text(container.decode(String.self, forKey: .text))
440440
case "input_text":
441441
let container = try decoder.container(keyedBy: Text.CodingKeys.self)
442-
self = try .input_text(container.decode(String.self, forKey: .text))
443-
case "audio":
442+
self = try .inputText(container.decode(String.self, forKey: .text))
443+
case "output_audio":
444444
self = try .audio(Item.Audio(from: decoder))
445445
case "input_audio":
446-
self = try .input_audio(Item.Audio(from: decoder))
446+
self = try .inputAudio(Item.Audio(from: decoder))
447447
default:
448448
throw DecodingError.dataCorruptedError(forKey: .type, in: container, debugDescription: "Unknown content type: \(type)")
449449
}
@@ -456,14 +456,14 @@ extension Item.Message.Content: Codable {
456456
case let .text(text):
457457
try container.encode(text, forKey: .text)
458458
try container.encode("text", forKey: .type)
459-
case let .input_text(text):
459+
case let .inputText(text):
460460
try container.encode(text, forKey: .text)
461461
try container.encode("input_text", forKey: .type)
462462
case let .audio(audio):
463-
try container.encode("audio", forKey: .type)
463+
try container.encode("output_audio", forKey: .type)
464464
try container.encode(audio.audio, forKey: .audio)
465465
try container.encode(audio.transcript, forKey: .transcript)
466-
case let .input_audio(audio):
466+
case let .inputAudio(audio):
467467
try container.encode(audio.audio, forKey: .audio)
468468
try container.encode("input_audio", forKey: .type)
469469
try container.encode(audio.transcript, forKey: .transcript)

Sources/Core/Models/ServerEvent.swift

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ import MetaCodable
4646
@CodedAs("conversation.item.created")
4747
case conversationItemCreated(eventId: String, item: Item, previousItemId: String?)
4848

49+
/// Returned when a conversation item is added.
50+
///
51+
/// - Parameter eventId: The unique ID of the server event.
52+
/// - Parameter item: A single item within a Realtime conversation.
53+
/// - Parameter previousItemId: The ID of the item that precedes this one, if any.
54+
@CodedAs("conversation.item.added")
55+
case conversationItemAdded(eventId: String, item: Item, previousItemId: String?)
56+
4957
/// Returned when a conversation item is finalized.
5058
/// - Parameter eventId: The unique ID of the server event.
5159
/// - Parameter item: A single item within a Realtime conversation.
@@ -202,6 +210,20 @@ import MetaCodable
202210
@CodedAs("input_audio_buffer.timeout_triggered")
203211
case inputAudioBufferTimeoutTriggered(eventId: String, itemId: String, audioStartMs: Int, audioEndMs: Int)
204212

213+
/// Returned when the output audio buffer starts playing audio.
214+
///
215+
/// - Parameter eventId: The unique ID of the server event.
216+
/// - Parameter responseId: The ID of the Response to which the output audio belongs.
217+
@CodedAs("output_audio_buffer.started")
218+
case outputAudioBufferStarted(eventId: String, responseId: String)
219+
220+
/// Returned when the output audio buffer stops playing audio.
221+
///
222+
/// - Parameter eventId: The unique ID of the server event.
223+
/// - Parameter responseId: The ID of the Response to which the output audio belongs.
224+
@CodedAs("output_audio_buffer.stopped")
225+
case outputAudioBufferStopped(eventId: String, responseId: String)
226+
205227
/// Returned when a new Response is created.
206228
///
207229
/// The first event of response creation, where the response is in an initial state of `inProgress`.
@@ -318,7 +340,7 @@ import MetaCodable
318340
/// - Parameter outputIndex: The index of the output item in the Response.
319341
/// - Parameter contentIndex: The index of the content part in the item's content array.
320342
/// - Parameter delta: The transcript delta.
321-
@CodedAs("response.audio_transcript.delta")
343+
@CodedAs("response.output_audio_transcript.delta")
322344
case responseAudioTranscriptDelta(
323345
eventId: String,
324346
responseId: String,
@@ -336,7 +358,7 @@ import MetaCodable
336358
/// - Parameter outputIndex: The index of the output item in the Response.
337359
/// - Parameter contentIndex: The index of the content part in the item's content array.
338360
/// - Parameter transcript: The final transcript of the audio.
339-
@CodedAs("response.audio_transcript.done")
361+
@CodedAs("response.output_audio_transcript.done")
340362
case responseAudioTranscriptDone(
341363
eventId: String,
342364
responseId: String,
@@ -509,6 +531,7 @@ extension ServerEvent: Identifiable {
509531
case let .error(id, _): id
510532
case let .sessionCreated(id, _): id
511533
case let .sessionUpdated(id, _): id
534+
case let .conversationItemAdded(id, _, _): id
512535
case let .conversationItemCreated(id, _, _): id
513536
case let .conversationItemDone(id, _, _): id
514537
case let .conversationItemRetrieved(id, _): id
@@ -523,6 +546,8 @@ extension ServerEvent: Identifiable {
523546
case let .inputAudioBufferSpeechStarted(id, _, _): id
524547
case let .inputAudioBufferSpeechStopped(id, _, _): id
525548
case let .inputAudioBufferTimeoutTriggered(id, _, _, _): id
549+
case let .outputAudioBufferStarted(id, _): id
550+
case let .outputAudioBufferStopped(id, _): id
526551
case let .responseCreated(id, _): id
527552
case let .responseDone(id, _): id
528553
case let .responseOutputItemAdded(id, _, _, _): id

Sources/UI/Conversation.swift

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,35 @@ import Foundation
55

66
public enum ConversationError: Error {
77
case sessionNotFound
8+
case invalidEphemeralKey
89
case converterInitializationFailed
910
}
1011

1112
@MainActor @Observable
1213
public final class Conversation: @unchecked Sendable {
1314
public typealias SessionUpdateCallback = (inout Session) -> Void
1415

15-
public var debug: Bool
1616
private let client: WebRTCConnector
1717
private var task: Task<Void, Error>!
1818
private let sessionUpdateCallback: SessionUpdateCallback?
1919
private let errorStream: AsyncStream<ServerError>.Continuation
2020

21-
/// A stream of errors that occur during the conversation.
22-
public let errors: AsyncStream<ServerError>
21+
/// Whether to print debug information to the console.
22+
public var debug: Bool
23+
24+
/// Whether to mute the user's microphone.
25+
public var muted: Bool = false {
26+
didSet {
27+
client.audioTrack.isEnabled = !muted
28+
}
29+
}
2330

2431
/// The unique ID of the conversation.
2532
public private(set) var id: String?
2633

34+
/// A stream of errors that occur during the conversation.
35+
public let errors: AsyncStream<ServerError>
36+
2737
/// The current session for this conversation.
2838
public private(set) var session: Session?
2939

@@ -38,6 +48,9 @@ public final class Conversation: @unchecked Sendable {
3848
/// This only works when using the server's voice detection.
3949
public private(set) var isUserSpeaking: Bool = false
4050

51+
/// Whether the model is currently speaking.
52+
public private(set) var isModelSpeaking: Bool = false
53+
4154
/// A list of messages in the conversation.
4255
/// Note that this doesn't include function call events. To get a complete list, use `entries`.
4356
public var messages: [Item.Message] {
@@ -47,9 +60,9 @@ public final class Conversation: @unchecked Sendable {
4760
} }
4861
}
4962

50-
public required init(debug: Bool = false, configuring sessionUpdateCallback: SessionUpdateCallback? = nil) throws {
63+
public required init(debug: Bool = false, configuring sessionUpdateCallback: SessionUpdateCallback? = nil) {
5164
self.debug = debug
52-
client = try WebRTCConnector.create()
65+
client = try! WebRTCConnector.create()
5366
self.sessionUpdateCallback = sessionUpdateCallback
5467
(errors, errorStream) = AsyncStream.makeStream(of: ServerError.self)
5568

@@ -69,21 +82,23 @@ public final class Conversation: @unchecked Sendable {
6982
}
7083
}
7184

85+
deinit {
86+
client.disconnect()
87+
errorStream.finish()
88+
}
89+
7290
public func connect(using request: URLRequest) async throws {
7391
await AVAudioApplication.requestRecordPermission()
7492

7593
try await client.connect(using: request)
7694
}
7795

7896
public func connect(ephemeralKey: String, model: Model = .gptRealtime) async throws {
79-
try await connect(using: .webRTCConnectionRequest(ephemeralKey: ephemeralKey, model: model))
80-
}
81-
82-
deinit {
83-
errorStream.finish()
84-
85-
Task { @MainActor [weak self] in
86-
self?.task?.cancel()
97+
do {
98+
try await connect(using: .webRTCConnectionRequest(ephemeralKey: ephemeralKey, model: model))
99+
} catch let error as WebRTCConnector.WebRTCError {
100+
guard case .invalidEphemeralKey = error else { throw error }
101+
throw ConversationError.invalidEphemeralKey
87102
}
88103
}
89104

@@ -137,7 +152,7 @@ public final class Conversation: @unchecked Sendable {
137152
/// Send a text message and wait for a response.
138153
/// Optionally, you can provide a response configuration to customize the model's behavior.
139154
public func send(from role: Item.Message.Role, text: String, response: Response.Config? = nil) throws {
140-
try send(event: .createConversationItem(.message(Item.Message(id: String(randomLength: 32), role: role, content: [.input_text(text)]))))
155+
try send(event: .createConversationItem(.message(Item.Message(id: String(randomLength: 32), role: role, content: [.inputText(text)]))))
141156
try send(event: .createResponse(using: response))
142157
}
143158

@@ -167,9 +182,9 @@ private extension Conversation {
167182
entries.removeAll { $0.id == itemId }
168183
case let .conversationItemInputAudioTranscriptionCompleted(_, itemId, contentIndex, transcript, _, _):
169184
updateEvent(id: itemId) { message in
170-
guard case let .input_audio(audio) = message.content[contentIndex] else { return }
185+
guard case let .inputAudio(audio) = message.content[contentIndex] else { return }
171186

172-
message.content[contentIndex] = .input_audio(.init(audio: audio.audio, transcript: transcript))
187+
message.content[contentIndex] = .inputAudio(.init(audio: audio.audio, transcript: transcript))
173188
}
174189
case let .conversationItemInputAudioTranscriptionFailed(_, _, _, error):
175190
errorStream.yield(error)
@@ -211,7 +226,7 @@ private extension Conversation {
211226
case let .responseOutputAudioDelta(_, _, itemId, _, contentIndex, delta):
212227
updateEvent(id: itemId) { message in
213228
guard case let .audio(audio) = message.content[contentIndex] else { return }
214-
message.content[contentIndex] = .audio(.init(audio: audio.audio.data + delta.data, transcript: audio.transcript))
229+
message.content[contentIndex] = .audio(.init(audio: (audio.audio?.data ?? Data()) + delta.data, transcript: audio.transcript))
215230
}
216231
case let .responseFunctionCallArgumentsDelta(_, _, itemId, _, _, delta):
217232
updateEvent(id: itemId) { functionCall in
@@ -225,6 +240,10 @@ private extension Conversation {
225240
isUserSpeaking = true
226241
case .inputAudioBufferSpeechStopped:
227242
isUserSpeaking = false
243+
case .outputAudioBufferStarted:
244+
isModelSpeaking = true
245+
case .outputAudioBufferStopped:
246+
isModelSpeaking = false
228247
case let .responseOutputItemDone(_, _, _, item):
229248
updateEvent(id: item.id) { message in
230249
guard case let .message(newMessage) = item else { return }

Sources/WebRTC/WebRTCConnector.swift

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ import FoundationNetworking
77
#endif
88

99
@Observable public final class WebRTCConnector: NSObject, Connector, Sendable {
10-
enum WebRTCError: Error {
10+
public enum WebRTCError: Error {
11+
case invalidEphemeralKey
1112
case missingAudioPermission
1213
case failedToCreateDataChannel
1314
case failedToCreatePeerConnection
@@ -24,7 +25,7 @@ import FoundationNetworking
2425
!audioTrack.isEnabled
2526
}
2627

27-
private let audioTrack: LKRTCAudioTrack
28+
package let audioTrack: LKRTCAudioTrack
2829
private let dataChannel: LKRTCDataChannel
2930
private let connection: LKRTCPeerConnection
3031

@@ -82,7 +83,6 @@ import FoundationNetworking
8283
public func disconnect() {
8384
connection.close()
8485
stream.finish()
85-
Task { @MainActor in status = .disconnected }
8686
}
8787

8888
public func toggleMute() {
@@ -166,7 +166,9 @@ private extension WebRTCConnector {
166166
request.setValue("application/sdp", forHTTPHeaderField: "Content-Type")
167167

168168
let (data, response) = try await URLSession.shared.data(for: request)
169-
guard let httpResponse = response as? HTTPURLResponse, (200...299).contains(httpResponse.statusCode), let remoteSdp = String(data: data, encoding: .utf8) else {
169+
170+
guard let response = response as? HTTPURLResponse, response.statusCode == 201, let remoteSdp = String(data: data, encoding: .utf8) else {
171+
if (response as? HTTPURLResponse)?.statusCode == 401 { throw WebRTCError.invalidEphemeralKey }
170172
throw WebRTCError.badServerResponse(response)
171173
}
172174

@@ -191,7 +193,11 @@ extension WebRTCConnector: LKRTCPeerConnectionDelegate {
191193

192194
extension WebRTCConnector: LKRTCDataChannelDelegate {
193195
public func dataChannel(_: LKRTCDataChannel, didReceiveMessageWith buffer: LKRTCDataBuffer) {
194-
stream.yield(with: Result { try self.decoder.decode(ServerEvent.self, from: buffer.data) })
196+
do { try stream.yield(decoder.decode(ServerEvent.self, from: buffer.data)) }
197+
catch {
198+
print("Failed to decode server event: \(String(data: buffer.data, encoding: .utf8) ?? "<invalid utf8>")")
199+
stream.finish(throwing: error)
200+
}
195201
}
196202

197203
public func dataChannelDidChangeState(_ dataChannel: LKRTCDataChannel) {

0 commit comments

Comments
 (0)