@@ -5,25 +5,35 @@ import Foundation
55
66public enum ConversationError : Error {
77 case sessionNotFound
8+ case invalidEphemeralKey
89 case converterInitializationFailed
910}
1011
1112@MainActor @Observable
1213public final class Conversation : @unchecked Sendable {
1314 public typealias SessionUpdateCallback = ( inout Session ) -> Void
1415
15- public var debug : Bool
1616 private let client : WebRTCConnector
1717 private var task : Task < Void , Error > !
1818 private let sessionUpdateCallback : SessionUpdateCallback ?
1919 private let errorStream : AsyncStream < ServerError > . Continuation
2020
21- /// A stream of errors that occur during the conversation.
22- public let errors : AsyncStream < ServerError >
21+ /// Whether to print debug information to the console.
22+ public var debug : Bool
23+
24+ /// Whether to mute the user's microphone.
25+ public var muted : Bool = false {
26+ didSet {
27+ client. audioTrack. isEnabled = !muted
28+ }
29+ }
2330
2431 /// The unique ID of the conversation.
2532 public private( set) var id : String ?
2633
34+ /// A stream of errors that occur during the conversation.
35+ public let errors : AsyncStream < ServerError >
36+
2737 /// The current session for this conversation.
2838 public private( set) var session : Session ?
2939
@@ -38,6 +48,9 @@ public final class Conversation: @unchecked Sendable {
3848 /// This only works when using the server's voice detection.
3949 public private( set) var isUserSpeaking : Bool = false
4050
51+ /// Whether the model is currently speaking.
52+ public private( set) var isModelSpeaking : Bool = false
53+
4154 /// A list of messages in the conversation.
4255 /// Note that this doesn't include function call events. To get a complete list, use `entries`.
4356 public var messages : [ Item . Message ] {
@@ -47,9 +60,9 @@ public final class Conversation: @unchecked Sendable {
4760 } }
4861 }
4962
50- public required init ( debug: Bool = false , configuring sessionUpdateCallback: SessionUpdateCallback ? = nil ) throws {
63+ public required init ( debug: Bool = false , configuring sessionUpdateCallback: SessionUpdateCallback ? = nil ) {
5164 self . debug = debug
52- client = try WebRTCConnector . create ( )
65+ client = try ! WebRTCConnector . create ( )
5366 self . sessionUpdateCallback = sessionUpdateCallback
5467 ( errors, errorStream) = AsyncStream . makeStream ( of: ServerError . self)
5568
@@ -69,21 +82,23 @@ public final class Conversation: @unchecked Sendable {
6982 }
7083 }
7184
85+ deinit {
86+ client. disconnect ( )
87+ errorStream. finish ( )
88+ }
89+
7290 public func connect( using request: URLRequest ) async throws {
7391 await AVAudioApplication . requestRecordPermission ( )
7492
7593 try await client. connect ( using: request)
7694 }
7795
7896 public func connect( ephemeralKey: String , model: Model = . gptRealtime) async throws {
79- try await connect ( using: . webRTCConnectionRequest( ephemeralKey: ephemeralKey, model: model) )
80- }
81-
82- deinit {
83- errorStream. finish ( )
84-
85- Task { @MainActor [ weak self] in
86- self ? . task? . cancel ( )
97+ do {
98+ try await connect ( using: . webRTCConnectionRequest( ephemeralKey: ephemeralKey, model: model) )
99+ } catch let error as WebRTCConnector . WebRTCError {
100+ guard case . invalidEphemeralKey = error else { throw error }
101+ throw ConversationError . invalidEphemeralKey
87102 }
88103 }
89104
@@ -137,7 +152,7 @@ public final class Conversation: @unchecked Sendable {
137152 /// Send a text message and wait for a response.
138153 /// Optionally, you can provide a response configuration to customize the model's behavior.
139154 public func send( from role: Item . Message . Role , text: String , response: Response . Config ? = nil ) throws {
140- try send ( event: . createConversationItem( . message( Item . Message ( id: String ( randomLength: 32 ) , role: role, content: [ . input_text ( text) ] ) ) ) )
155+ try send ( event: . createConversationItem( . message( Item . Message ( id: String ( randomLength: 32 ) , role: role, content: [ . inputText ( text) ] ) ) ) )
141156 try send ( event: . createResponse( using: response) )
142157 }
143158
@@ -167,9 +182,9 @@ private extension Conversation {
167182 entries. removeAll { $0. id == itemId }
168183 case let . conversationItemInputAudioTranscriptionCompleted( _, itemId, contentIndex, transcript, _, _) :
169184 updateEvent ( id: itemId) { message in
170- guard case let . input_audio ( audio) = message. content [ contentIndex] else { return }
185+ guard case let . inputAudio ( audio) = message. content [ contentIndex] else { return }
171186
172- message. content [ contentIndex] = . input_audio ( . init( audio: audio. audio, transcript: transcript) )
187+ message. content [ contentIndex] = . inputAudio ( . init( audio: audio. audio, transcript: transcript) )
173188 }
174189 case let . conversationItemInputAudioTranscriptionFailed( _, _, _, error) :
175190 errorStream. yield ( error)
@@ -211,7 +226,7 @@ private extension Conversation {
211226 case let . responseOutputAudioDelta( _, _, itemId, _, contentIndex, delta) :
212227 updateEvent ( id: itemId) { message in
213228 guard case let . audio( audio) = message. content [ contentIndex] else { return }
214- message. content [ contentIndex] = . audio( . init( audio: audio. audio. data + delta. data, transcript: audio. transcript) )
229+ message. content [ contentIndex] = . audio( . init( audio: ( audio. audio? . data ?? Data ( ) ) + delta. data, transcript: audio. transcript) )
215230 }
216231 case let . responseFunctionCallArgumentsDelta( _, _, itemId, _, _, delta) :
217232 updateEvent ( id: itemId) { functionCall in
@@ -225,6 +240,10 @@ private extension Conversation {
225240 isUserSpeaking = true
226241 case . inputAudioBufferSpeechStopped:
227242 isUserSpeaking = false
243+ case . outputAudioBufferStarted:
244+ isModelSpeaking = true
245+ case . outputAudioBufferStopped:
246+ isModelSpeaking = false
228247 case let . responseOutputItemDone( _, _, _, item) :
229248 updateEvent ( id: item. id) { message in
230249 guard case let . message( newMessage) = item else { return }
0 commit comments