Skip to content

Commit ea05175

Browse files
committed
Make GenerativeModel and Chat into Swift actors
1 parent 047856b commit ea05175

File tree

6 files changed

+67
-55
lines changed

6 files changed

+67
-55
lines changed

FirebaseVertexAI/Sources/Chat.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import Foundation
1717
/// An object that represents a back-and-forth chat with a model, capturing the history and saving
1818
/// the context in memory between each message sent.
1919
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
20-
public class Chat {
20+
public actor Chat {
2121
private let model: GenerativeModel
2222

2323
/// Initializes a new chat representing a 1:1 conversation between model and user.
@@ -121,7 +121,7 @@ public class Chat {
121121

122122
// Send the history alongside the new message as context.
123123
let request = history + newContent
124-
let stream = model.generateContentStream(request)
124+
let stream = await model.generateContentStream(request)
125125
do {
126126
for try await chunk in stream {
127127
// Capture any content that's streaming. This should be populated if there's no error.

FirebaseVertexAI/Sources/GenerativeModel.swift

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import Foundation
1919
/// A type that represents a remote multimodal model (like Gemini), with the ability to generate
2020
/// content based on various input types.
2121
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
22-
public final class GenerativeModel {
22+
public final actor GenerativeModel {
2323
/// The resource name of the model in the backend; has the format "models/model-name".
2424
let modelResourceName: String
2525

@@ -219,31 +219,37 @@ public final class GenerativeModel {
219219

220220
var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
221221
.makeAsyncIterator()
222-
return AsyncThrowingStream {
223-
let response: GenerateContentResponse?
224-
do {
225-
response = try await responseIterator.next()
226-
} catch {
227-
throw GenerativeModel.generateContentError(from: error)
228-
}
222+
return AsyncThrowingStream { contination in
223+
Task {
224+
do {
225+
// The responseIterator will return `nil` when it's done.
226+
while let response = try await responseIterator.next() {
227+
// Check the prompt feedback to see if the prompt was blocked.
228+
if response.promptFeedback?.blockReason != nil {
229+
contination.finish(throwing: GenerateContentError.promptBlocked(response: response))
230+
return
231+
}
229232

230-
// The responseIterator will return `nil` when it's done.
231-
guard let response = response else {
232-
// This is the end of the stream! Signal it by sending `nil`.
233-
return nil
234-
}
233+
// If the stream ended early unexpectedly, throw an error.
234+
if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
235+
contination.finish(throwing: GenerateContentError.responseStoppedEarly(
236+
reason: finishReason,
237+
response: response
238+
))
239+
return
240+
}
235241

236-
// Check the prompt feedback to see if the prompt was blocked.
237-
if response.promptFeedback?.blockReason != nil {
238-
throw GenerateContentError.promptBlocked(response: response)
239-
}
242+
// Response was valid content, pass it along and continue.
243+
contination.yield(response)
244+
}
240245

241-
// If the stream ended early unexpectedly, throw an error.
242-
if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
243-
throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response)
244-
} else {
245-
// Response was valid content, pass it along and continue.
246-
return response
246+
// This is the end of the stream! Signal it by calling `finish`.
247+
contination.finish()
248+
return
249+
} catch {
250+
contination.finish(throwing: GenerativeModel.generateContentError(from: error))
251+
return
252+
}
247253
}
248254
}
249255
}

FirebaseVertexAI/Tests/Unit/ChatTests.swift

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,20 @@ final class ChatTests: XCTestCase {
6464
)
6565
let chat = Chat(model: model, history: [])
6666
let input = "Test input"
67-
let stream = chat.sendMessageStream(input)
67+
let stream = await chat.sendMessageStream(input)
6868

6969
// Ensure the values are parsed correctly
7070
for try await value in stream {
7171
XCTAssertNotNil(value.text)
7272
}
7373

74-
XCTAssertEqual(chat.history.count, 2)
75-
XCTAssertEqual(chat.history[0].parts[0].text, input)
74+
let history = await chat.history
75+
XCTAssertEqual(history.count, 2)
76+
XCTAssertEqual(history[0].parts[0].text, input)
7677

7778
let finalText = "1 2 3 4 5 6 7 8"
7879
let assembledExpectation = ModelContent(role: "model", parts: finalText)
79-
XCTAssertEqual(chat.history[0].parts[0].text, input)
80-
XCTAssertEqual(chat.history[1], assembledExpectation)
80+
XCTAssertEqual(history[0].parts[0].text, input)
81+
XCTAssertEqual(history[1], assembledExpectation)
8182
}
8283
}

FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ final class GenerativeModelTests: XCTestCase {
760760
)
761761

762762
do {
763-
let stream = model.generateContentStream("Hi")
763+
let stream = await model.generateContentStream("Hi")
764764
for try await _ in stream {
765765
XCTFail("No content is there, this shouldn't happen.")
766766
}
@@ -784,7 +784,7 @@ final class GenerativeModelTests: XCTestCase {
784784
)
785785

786786
do {
787-
let stream = model.generateContentStream(testPrompt)
787+
let stream = await model.generateContentStream(testPrompt)
788788
for try await _ in stream {
789789
XCTFail("No content is there, this shouldn't happen.")
790790
}
@@ -807,7 +807,7 @@ final class GenerativeModelTests: XCTestCase {
807807
)
808808

809809
do {
810-
let stream = model.generateContentStream("Hi")
810+
let stream = await model.generateContentStream("Hi")
811811
for try await _ in stream {
812812
XCTFail("No content is there, this shouldn't happen.")
813813
}
@@ -827,7 +827,7 @@ final class GenerativeModelTests: XCTestCase {
827827
)
828828

829829
do {
830-
let stream = model.generateContentStream("Hi")
830+
let stream = await model.generateContentStream("Hi")
831831
for try await _ in stream {
832832
XCTFail("Content shouldn't be shown, this shouldn't happen.")
833833
}
@@ -847,7 +847,7 @@ final class GenerativeModelTests: XCTestCase {
847847
)
848848

849849
do {
850-
let stream = model.generateContentStream("Hi")
850+
let stream = await model.generateContentStream("Hi")
851851
for try await _ in stream {
852852
XCTFail("Content shouldn't be shown, this shouldn't happen.")
853853
}
@@ -866,7 +866,7 @@ final class GenerativeModelTests: XCTestCase {
866866
withExtension: "txt"
867867
)
868868

869-
let stream = model.generateContentStream("Hi")
869+
let stream = await model.generateContentStream("Hi")
870870
do {
871871
for try await content in stream {
872872
XCTAssertNotNil(content.text)
@@ -887,7 +887,7 @@ final class GenerativeModelTests: XCTestCase {
887887
)
888888

889889
var responses = 0
890-
let stream = model.generateContentStream("Hi")
890+
let stream = await model.generateContentStream("Hi")
891891
for try await content in stream {
892892
XCTAssertNotNil(content.text)
893893
responses += 1
@@ -904,7 +904,7 @@ final class GenerativeModelTests: XCTestCase {
904904
)
905905

906906
var responses = 0
907-
let stream = model.generateContentStream("Hi")
907+
let stream = await model.generateContentStream("Hi")
908908
for try await content in stream {
909909
XCTAssertNotNil(content.text)
910910
responses += 1
@@ -921,7 +921,7 @@ final class GenerativeModelTests: XCTestCase {
921921
)
922922

923923
var hadUnknown = false
924-
let stream = model.generateContentStream("Hi")
924+
let stream = await model.generateContentStream("Hi")
925925
for try await content in stream {
926926
XCTAssertNotNil(content.text)
927927
if let ratings = content.candidates.first?.safetyRatings,
@@ -940,7 +940,7 @@ final class GenerativeModelTests: XCTestCase {
940940
withExtension: "txt"
941941
)
942942

943-
let stream = model.generateContentStream("Hi")
943+
let stream = await model.generateContentStream("Hi")
944944
var citations = [Citation]()
945945
var responses = [GenerateContentResponse]()
946946
for try await content in stream {
@@ -996,7 +996,7 @@ final class GenerativeModelTests: XCTestCase {
996996
appCheckToken: appCheckToken
997997
)
998998

999-
let stream = model.generateContentStream(testPrompt)
999+
let stream = await model.generateContentStream(testPrompt)
10001000
for try await _ in stream {}
10011001
}
10021002

@@ -1018,7 +1018,7 @@ final class GenerativeModelTests: XCTestCase {
10181018
appCheckToken: AppCheckInteropFake.placeholderTokenValue
10191019
)
10201020

1021-
let stream = model.generateContentStream(testPrompt)
1021+
let stream = await model.generateContentStream(testPrompt)
10221022
for try await _ in stream {}
10231023
}
10241024

@@ -1030,7 +1030,7 @@ final class GenerativeModelTests: XCTestCase {
10301030
)
10311031
var responses = [GenerateContentResponse]()
10321032

1033-
let stream = model.generateContentStream(testPrompt)
1033+
let stream = await model.generateContentStream(testPrompt)
10341034
for try await response in stream {
10351035
responses.append(response)
10361036
}
@@ -1056,7 +1056,7 @@ final class GenerativeModelTests: XCTestCase {
10561056

10571057
var responseCount = 0
10581058
do {
1059-
let stream = model.generateContentStream("Hi")
1059+
let stream = await model.generateContentStream("Hi")
10601060
for try await content in stream {
10611061
XCTAssertNotNil(content.text)
10621062
responseCount += 1
@@ -1076,7 +1076,7 @@ final class GenerativeModelTests: XCTestCase {
10761076
func testGenerateContentStream_nonHTTPResponse() async throws {
10771077
MockURLProtocol.requestHandler = try nonHTTPRequestHandler()
10781078

1079-
let stream = model.generateContentStream("Hi")
1079+
let stream = await model.generateContentStream("Hi")
10801080
do {
10811081
for try await content in stream {
10821082
XCTFail("Unexpected content in stream: \(content)")
@@ -1096,7 +1096,7 @@ final class GenerativeModelTests: XCTestCase {
10961096
withExtension: "txt"
10971097
)
10981098

1099-
let stream = model.generateContentStream(testPrompt)
1099+
let stream = await model.generateContentStream(testPrompt)
11001100
do {
11011101
for try await content in stream {
11021102
XCTFail("Unexpected content in stream: \(content)")
@@ -1120,7 +1120,7 @@ final class GenerativeModelTests: XCTestCase {
11201120
withExtension: "txt"
11211121
)
11221122

1123-
let stream = model.generateContentStream(testPrompt)
1123+
let stream = await model.generateContentStream(testPrompt)
11241124
do {
11251125
for try await content in stream {
11261126
XCTFail("Unexpected content in stream: \(content)")
@@ -1159,7 +1159,7 @@ final class GenerativeModelTests: XCTestCase {
11591159
)
11601160

11611161
var responses = 0
1162-
let stream = model.generateContentStream(testPrompt)
1162+
let stream = await model.generateContentStream(testPrompt)
11631163
for try await content in stream {
11641164
XCTAssertNotNil(content.text)
11651165
responses += 1

FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ final class VertexAIAPITests: XCTestCase {
170170
#endif
171171

172172
// Chat
173-
_ = genAI.startChat()
174-
_ = genAI.startChat(history: [ModelContent(parts: "abc")])
173+
_ = await genAI.startChat()
174+
_ = await genAI.startChat(history: [ModelContent(parts: "abc")])
175175
}
176176

177177
// Public API tests for GenerateContentResponse.

FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,20 @@ class VertexComponentTests: XCTestCase {
106106
let app = try XCTUnwrap(VertexComponentTests.app)
107107
let vertex = VertexAI.vertexAI(app: app, location: location)
108108
let modelName = "test-model-name"
109-
let modelResourceName = vertex.modelResourceName(modelName: modelName)
110-
let systemInstruction = ModelContent(role: "system", parts: "test-system-instruction-prompt")
109+
let expectedModelResourceName = vertex.modelResourceName(modelName: modelName)
110+
let expectedSystemInstruction = ModelContent(
111+
role: "system",
112+
parts: "test-system-instruction-prompt"
113+
)
111114

112115
let generativeModel = vertex.generativeModel(
113116
modelName: modelName,
114-
systemInstruction: systemInstruction
117+
systemInstruction: expectedSystemInstruction
115118
)
116119

117-
XCTAssertEqual(generativeModel.modelResourceName, modelResourceName)
118-
XCTAssertEqual(generativeModel.systemInstruction, systemInstruction)
120+
let modelResourceName = await generativeModel.modelResourceName
121+
let systemInstruction = await generativeModel.systemInstruction
122+
XCTAssertEqual(modelResourceName, expectedModelResourceName)
123+
XCTAssertEqual(systemInstruction, expectedSystemInstruction)
119124
}
120125
}

0 commit comments

Comments
 (0)