Skip to content

Commit c13d297

Browse files
committed
Remove uses of unstructured concurrency
1 parent 6adb06f commit c13d297

File tree

5 files changed

+114
-146
lines changed

5 files changed

+114
-146
lines changed

FirebaseVertexAI/Sources/Chat.swift

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -85,68 +85,62 @@ public actor Chat {
8585
/// - Parameter parts: The new content to send as a single chat message.
8686
/// - Returns: A stream containing the model's response or an error if an error occurred.
8787
@available(macOS 12.0, *)
88-
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...)
88+
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...) async throws
8989
-> AsyncThrowingStream<GenerateContentResponse, Error> {
90-
return try sendMessageStream([ModelContent(parts: parts)])
90+
return try await sendMessageStream([ModelContent(parts: parts)])
9191
}
9292

9393
/// Sends a message using the existing history of this chat as context. If successful, the message
9494
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
9595
/// - Parameter content: The new content to send as a single chat message.
9696
/// - Returns: A stream containing the model's response or an error if an error occurred.
9797
@available(macOS 12.0, *)
98-
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent])
98+
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent]) async throws
9999
-> AsyncThrowingStream<GenerateContentResponse, Error> {
100100
let resolvedContent: [ModelContent]
101101
do {
102102
resolvedContent = try content()
103103
} catch let underlying {
104-
return AsyncThrowingStream { continuation in
105-
let error: Error
104+
// TODO: Consider throwing this before the stream.
105+
return AsyncThrowingStream {
106106
if let contentError = underlying as? ImageConversionError {
107-
error = GenerateContentError.promptImageContentError(underlying: contentError)
107+
throw GenerateContentError.promptImageContentError(underlying: contentError)
108108
} else {
109-
error = GenerateContentError.internalError(underlying: underlying)
109+
throw GenerateContentError.internalError(underlying: underlying)
110110
}
111-
continuation.finish(throwing: error)
112111
}
113112
}
114113

115-
return AsyncThrowingStream { continuation in
116-
Task {
117-
var aggregatedContent: [ModelContent] = []
118-
119-
// Ensure that the new content has the role set.
120-
let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:))
121-
122-
// Send the history alongside the new message as context.
123-
let request = history + newContent
124-
let stream = await model.generateContentStream(request)
125-
do {
126-
for try await chunk in stream {
127-
// Capture any content that's streaming. This should be populated if there's no error.
128-
if let chunkContent = chunk.candidates.first?.content {
129-
aggregatedContent.append(chunkContent)
130-
}
131-
132-
// Pass along the chunk.
133-
continuation.yield(chunk)
134-
}
135-
} catch {
136-
// Rethrow the error that the underlying stream threw. Don't add anything to history.
137-
continuation.finish(throwing: error)
138-
return
139-
}
114+
var aggregatedContent: [ModelContent] = []
115+
116+
// Ensure that the new content has the role set.
117+
let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:))
140118

141-
// Save the request.
142-
history.append(contentsOf: newContent)
119+
// Send the history alongside the new message as context.
120+
let request = history + newContent
121+
let stream = try await model.generateContentStream(request)
143122

144-
// Aggregate the content to add it to the history before we finish.
145-
let aggregated = aggregatedChunks(aggregatedContent)
146-
history.append(aggregated)
123+
var streamIterator = stream.makeAsyncIterator()
147124

148-
continuation.finish()
125+
return AsyncThrowingStream {
126+
while let chunk = try await streamIterator.next() {
127+
// Capture any content that's streaming. This should be populated if there's no error.
128+
if let chunkContent = chunk.candidates.first?.content {
129+
aggregatedContent.append(chunkContent)
130+
}
131+
132+
// Pass along the chunk.
133+
return chunk
149134
}
135+
136+
// Save the request.
137+
self.history.append(contentsOf: newContent)
138+
139+
// Aggregate the content to add it to the history before we finish.
140+
let aggregated = self.aggregatedChunks(aggregatedContent)
141+
self.history.append(aggregated)
142+
143+
return nil
150144
}
151145
}
152146

FirebaseVertexAI/Sources/GenerativeAIService.swift

Lines changed: 45 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -72,94 +72,62 @@ struct GenerativeAIService {
7272
}
7373

7474
@available(macOS 12.0, *)
75-
func loadRequestStream<T: GenerativeAIRequest>(request: T)
75+
func loadRequestStream<T: GenerativeAIRequest>(request: T) async throws
7676
-> AsyncThrowingStream<T.Response, Error> {
77-
return AsyncThrowingStream { continuation in
78-
Task {
79-
let urlRequest: URLRequest
80-
do {
81-
urlRequest = try await self.urlRequest(request: request)
82-
} catch {
83-
continuation.finish(throwing: error)
84-
return
85-
}
77+
let urlRequest = try await self.urlRequest(request: request)
8678

87-
#if DEBUG
88-
printCURLCommand(from: urlRequest)
89-
#endif
90-
91-
let stream: URLSession.AsyncBytes
92-
let rawResponse: URLResponse
93-
do {
94-
(stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
95-
} catch {
96-
continuation.finish(throwing: error)
97-
return
98-
}
79+
#if DEBUG
80+
printCURLCommand(from: urlRequest)
81+
#endif
9982

100-
// Verify the status code is 200
101-
let response: HTTPURLResponse
102-
do {
103-
response = try httpResponse(urlResponse: rawResponse)
104-
} catch {
105-
continuation.finish(throwing: error)
106-
return
107-
}
83+
let stream: URLSession.AsyncBytes
84+
let rawResponse: URLResponse
85+
(stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
86+
87+
let response = try httpResponse(urlResponse: rawResponse)
88+
89+
// Verify the status code is 200
90+
guard response.statusCode == 200 else {
91+
Logging.network
92+
.error("[FirebaseVertexAI] The server responded with an error: \(response)")
93+
var responseBody = ""
94+
for try await line in stream.lines {
95+
responseBody += line + "\n"
96+
}
10897

109-
// Verify the status code is 200
110-
guard response.statusCode == 200 else {
111-
Logging.network
112-
.error("[FirebaseVertexAI] The server responded with an error: \(response)")
113-
var responseBody = ""
114-
for try await line in stream.lines {
115-
responseBody += line + "\n"
116-
}
98+
Logging.default.error("[FirebaseVertexAI] Response payload: \(responseBody)")
99+
throw parseError(responseBody: responseBody)
100+
}
117101

118-
Logging.default.error("[FirebaseVertexAI] Response payload: \(responseBody)")
119-
continuation.finish(throwing: parseError(responseBody: responseBody))
102+
// Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
103+
var extraLines = ""
120104

121-
return
122-
}
105+
let decoder = JSONDecoder()
106+
decoder.keyDecodingStrategy = .convertFromSnakeCase
123107

124-
// Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
125-
var extraLines = ""
126-
127-
let decoder = JSONDecoder()
128-
decoder.keyDecodingStrategy = .convertFromSnakeCase
129-
for try await line in stream.lines {
130-
Logging.network.debug("[FirebaseVertexAI] Stream response: \(line)")
131-
132-
if line.hasPrefix("data:") {
133-
// We can assume 5 characters since it's utf-8 encoded, removing `data:`.
134-
let jsonText = String(line.dropFirst(5))
135-
let data: Data
136-
do {
137-
data = try jsonData(jsonText: jsonText)
138-
} catch {
139-
continuation.finish(throwing: error)
140-
return
141-
}
142-
143-
// Handle the content.
144-
do {
145-
let content = try parseResponse(T.Response.self, from: data)
146-
continuation.yield(content)
147-
} catch {
148-
continuation.finish(throwing: error)
149-
return
150-
}
151-
} else {
152-
extraLines += line
153-
}
154-
}
108+
var linesStream = stream.lines.makeAsyncIterator()
109+
110+
return AsyncThrowingStream<T.Response, Error> {
111+
while let line = try await linesStream.next() {
112+
Logging.network.debug("[FirebaseVertexAI] Stream response: \(line)")
113+
114+
if line.hasPrefix("data:") {
115+
// We can assume 5 characters since it's utf-8 encoded, removing `data:`.
116+
let jsonText = String(line.dropFirst(5))
117+
let data = try jsonData(jsonText: jsonText)
155118

156-
if extraLines.count > 0 {
157-
continuation.finish(throwing: parseError(responseBody: extraLines))
158-
return
119+
// Handle the content.
120+
return try parseResponse(T.Response.self, from: data)
121+
} else {
122+
extraLines += line
159123
}
124+
}
160125

161-
continuation.finish(throwing: nil)
126+
if extraLines.count > 0 {
127+
throw parseError(responseBody: extraLines)
162128
}
129+
130+
return nil
163131
}
164132
}
165133

FirebaseVertexAI/Sources/GenerativeModel.swift

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,9 @@ public final actor GenerativeModel {
179179
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
180180
/// error if an error occurred.
181181
@available(macOS 12.0, *)
182-
public func generateContentStream(_ parts: any ThrowingPartsRepresentable...)
182+
public func generateContentStream(_ parts: any ThrowingPartsRepresentable...) async throws
183183
-> AsyncThrowingStream<GenerateContentResponse, Error> {
184-
return try generateContentStream([ModelContent(parts: parts)])
184+
return try await generateContentStream([ModelContent(parts: parts)])
185185
}
186186

187187
/// Generates new content from input content given to the model as a prompt.
@@ -190,20 +190,19 @@ public final actor GenerativeModel {
190190
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
191191
/// error if an error occurred.
192192
@available(macOS 12.0, *)
193-
public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent])
193+
public func generateContentStream(_ content: @autoclosure () throws
194+
-> [ModelContent]) async throws
194195
-> AsyncThrowingStream<GenerateContentResponse, Error> {
195196
let evaluatedContent: [ModelContent]
196197
do {
197198
evaluatedContent = try content()
198199
} catch let underlying {
199-
return AsyncThrowingStream { continuation in
200-
let error: Error
200+
// TODO: Consider throwing this before the stream.
201+
return AsyncThrowingStream {
201202
if let contentError = underlying as? ImageConversionError {
202-
error = GenerateContentError.promptImageContentError(underlying: contentError)
203-
} else {
204-
error = GenerateContentError.internalError(underlying: underlying)
203+
throw GenerateContentError.promptImageContentError(underlying: contentError)
205204
}
206-
continuation.finish(throwing: error)
205+
throw GenerateContentError.internalError(underlying: underlying)
207206
}
208207
}
209208

@@ -217,8 +216,15 @@ public final actor GenerativeModel {
217216
isStreaming: true,
218217
options: requestOptions)
219218

220-
var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
221-
.makeAsyncIterator()
219+
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error>
220+
do {
221+
responseStream = try await generativeAIService
222+
.loadRequestStream(request: generateContentRequest)
223+
} catch {
224+
throw GenerativeModel.generateContentError(from: error)
225+
}
226+
227+
var responseIterator = responseStream.makeAsyncIterator()
222228
return AsyncThrowingStream {
223229
let response: GenerateContentResponse?
224230
do {

FirebaseVertexAI/Tests/Unit/ChatTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ final class ChatTests: XCTestCase {
6464
)
6565
let chat = Chat(model: model, history: [])
6666
let input = "Test input"
67-
let stream = await chat.sendMessageStream(input)
67+
let stream = try await chat.sendMessageStream(input)
6868

6969
// Ensure the values are parsed correctly
7070
for try await value in stream {

0 commit comments

Comments
 (0)