Skip to content

[Vertex AI] Parameterize integration tests for Vertex and Dev API #14540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 7 additions & 37 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,18 @@ import Foundation
public class VertexAI {
// MARK: - Public APIs

/// The default `VertexAI` instance.
///
/// - Parameter location: The region identifier, defaulting to `us-central1`; see [Vertex AI
/// regions
/// ](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions)
/// for a list of supported regions.
/// - Returns: An instance of `VertexAI`, configured with the default `FirebaseApp`.
public static func vertexAI(location: String = "us-central1") -> VertexAI {
guard let app = FirebaseApp.app() else {
fatalError("No instance of the default Firebase app was found.")
}
let vertexInstance = vertexAI(app: app, location: location)
assert(vertexInstance.apiConfig.service == .vertexAI)
assert(vertexInstance.apiConfig.service.endpoint == .firebaseVertexAIProd)
assert(vertexInstance.apiConfig.version == .v1beta)

return vertexInstance
}

/// Creates an instance of `VertexAI` configured with a custom `FirebaseApp`.
/// Creates an instance of `VertexAI`.
///
/// - Parameters:
/// - app: The custom `FirebaseApp` used for initialization.
/// - app: A custom `FirebaseApp` used for initialization; if not specified, uses the default
/// ``FirebaseApp``.
/// - location: The region identifier, defaulting to `us-central1`; see
/// [Vertex AI locations]
/// (https://firebase.google.com/docs/vertex-ai/locations?platform=ios#available-locations)
/// for a list of supported locations.
/// - Returns: A `VertexAI` instance, configured with the custom `FirebaseApp`.
public static func vertexAI(app: FirebaseApp, location: String = "us-central1") -> VertexAI {
public static func vertexAI(app: FirebaseApp? = nil,
location: String = "us-central1") -> VertexAI {
let vertexInstance = vertexAI(app: app, location: location, apiConfig: defaultVertexAIAPIConfig)
assert(vertexInstance.apiConfig.service == .vertexAI)
assert(vertexInstance.apiConfig.service.endpoint == .firebaseVertexAIProd)
Expand Down Expand Up @@ -160,25 +143,12 @@ public class VertexAI {
let location: String?

static let defaultVertexAIAPIConfig = APIConfig(service: .vertexAI, version: .v1beta)
static let defaultDeveloperAPIConfig = APIConfig(
service: .developer(endpoint: .generativeLanguage),
version: .v1beta
)

static func developerAPI(apiConfig: APIConfig = defaultDeveloperAPIConfig) -> VertexAI {
guard let app = FirebaseApp.app() else {
static func vertexAI(app: FirebaseApp?, location: String?, apiConfig: APIConfig) -> VertexAI {
guard let app = app ?? FirebaseApp.app() else {
fatalError("No instance of the default Firebase app was found.")
}

return developerAPI(app: app, apiConfig: apiConfig)
}

static func developerAPI(app: FirebaseApp,
apiConfig: APIConfig = defaultDeveloperAPIConfig) -> VertexAI {
return vertexAI(app: app, location: nil, apiConfig: apiConfig)
}

static func vertexAI(app: FirebaseApp, location: String?, apiConfig: APIConfig) -> VertexAI {
os_unfair_lock_lock(&instancesLock)

// Unlock before the function returns.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import FirebaseAuth
import FirebaseCore
import FirebaseStorage
import FirebaseVertexAI
import Testing
import VertexAITestApp

@testable import struct FirebaseVertexAI.APIConfig

@Suite(.serialized)
struct GenerateContentIntegrationTests {
static let vertexV1Config = APIConfig(service: .vertexAI, version: .v1)
static let vertexV1BetaConfig = APIConfig(service: .vertexAI, version: .v1beta)
static let developerV1BetaConfig = APIConfig(
service: .developer(endpoint: .generativeLanguage),
version: .v1beta
)

// Set temperature, topP and topK to lowest allowed values to make responses more deterministic.
static let generationConfig = GenerationConfig(
temperature: 0.0,
topP: 0.0,
topK: 1,
responseMIMEType: "text/plain"
)
static let systemInstruction = ModelContent(
role: "system",
parts: "You are a friendly and helpful assistant."
)
static let safetySettings = [
SafetySetting(harmCategory: .harassment, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .hateSpeech, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .sexuallyExplicit, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .dangerousContent, threshold: .blockLowAndAbove),
SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
]
// Candidates and total token counts may differ slightly between runs due to whitespace tokens.
let tokenCountAccuracy = 1

let storage: Storage
let userID1: String

init() async throws {
let authResult = try await Auth.auth().signIn(
withEmail: Credentials.emailAddress1,
password: Credentials.emailPassword1
)
userID1 = authResult.user.uid

storage = Storage.storage()
}

@Test(arguments: [vertexV1Config, vertexV1BetaConfig, developerV1BetaConfig])
func generateContent(_ apiConfig: APIConfig) async throws {
let model = GenerateContentIntegrationTests.model(apiConfig: apiConfig)
let prompt = "Where is Google headquarters located? Answer with the city name only."

let response = try await model.generateContent(prompt)

let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
#expect(text == "Mountain View")

let usageMetadata = try #require(response.usageMetadata)
#expect(usageMetadata.promptTokenCount == 21)
#expect(usageMetadata.candidatesTokenCount.isEqual(to: 3, accuracy: tokenCountAccuracy))
#expect(usageMetadata.totalTokenCount.isEqual(to: 24, accuracy: tokenCountAccuracy))
#expect(usageMetadata.promptTokensDetails.count == 1)
let promptTokensDetails = try #require(usageMetadata.promptTokensDetails.first)
#expect(promptTokensDetails.modality == .text)
#expect(promptTokensDetails.tokenCount == usageMetadata.promptTokenCount)
#expect(usageMetadata.candidatesTokensDetails.count == 1)
let candidatesTokensDetails = try #require(usageMetadata.candidatesTokensDetails.first)
#expect(candidatesTokensDetails.modality == .text)
#expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount)
}

static func model(apiConfig: APIConfig) -> GenerativeModel {
return instance(apiConfig: apiConfig).generativeModel(
modelName: "gemini-2.0-flash",
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: [],
toolConfig: .init(functionCallingConfig: .none()),
systemInstruction: systemInstruction
)
}

// TODO(andrewheard): Move this helper to a file in the Utilities folder.
static func instance(apiConfig: APIConfig) -> VertexAI {
switch apiConfig.service {
case .vertexAI:
return VertexAI.vertexAI(app: nil, location: "us-central1", apiConfig: apiConfig)
case .developer:
return VertexAI.vertexAI(app: nil, location: nil, apiConfig: apiConfig)
}
}
}

// TODO(andrewheard): Move this extension to a file in the Utilities folder.
extension Numeric where Self: Strideable, Self.Stride.Magnitude: Comparable {
func isEqual(to other: Self, accuracy: Self.Stride) -> Bool {
return distance(to: other).magnitude < accuracy.magnitude
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,6 @@ final class IntegrationTests: XCTestCase {

// MARK: - Generate Content

func testGenerateContent() async throws {
let prompt = "Where is Google headquarters located? Answer with the city name only."

let response = try await model.generateContent(prompt)

let text = try XCTUnwrap(response.text).trimmingCharacters(in: .whitespacesAndNewlines)
XCTAssertEqual(text, "Mountain View")
let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.promptTokenCount, 21)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 3, accuracy: tokenCountAccuracy)
XCTAssertEqual(usageMetadata.totalTokenCount, 24, accuracy: tokenCountAccuracy)
XCTAssertEqual(usageMetadata.promptTokensDetails.count, 1)
let promptTokensDetails = try XCTUnwrap(usageMetadata.promptTokensDetails.first)
XCTAssertEqual(promptTokensDetails.modality, .text)
XCTAssertEqual(promptTokensDetails.tokenCount, usageMetadata.promptTokenCount)
XCTAssertEqual(usageMetadata.candidatesTokensDetails.count, 1)
let candidatesTokensDetails = try XCTUnwrap(usageMetadata.candidatesTokensDetails.first)
XCTAssertEqual(candidatesTokensDetails.modality, .text)
XCTAssertEqual(candidatesTokensDetails.tokenCount, usageMetadata.candidatesTokenCount)
}

func testGenerateContentStream() async throws {
let expectedText = """
1. Mercury
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
8692F29E2CC9477800539E8F /* FirebaseVertexAI in Frameworks */ = {isa = PBXBuildFile; productRef = 8692F29D2CC9477800539E8F /* FirebaseVertexAI */; };
8698D7462CD3CF3600ABA833 /* FirebaseAppTestUtils.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8698D7452CD3CF2F00ABA833 /* FirebaseAppTestUtils.swift */; };
8698D7482CD4332B00ABA833 /* TestAppCheckProviderFactory.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8698D7472CD4332B00ABA833 /* TestAppCheckProviderFactory.swift */; };
86D77DFC2D7A5340003D155D /* GenerateContentIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */; };
/* End PBXBuildFile section */

/* Begin PBXContainerItemProxy section */
Expand Down Expand Up @@ -49,6 +50,7 @@
868A7C552CCC271300E449DD /* TestApp.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = TestApp.entitlements; sourceTree = "<group>"; };
8698D7452CD3CF2F00ABA833 /* FirebaseAppTestUtils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FirebaseAppTestUtils.swift; sourceTree = "<group>"; };
8698D7472CD4332B00ABA833 /* TestAppCheckProviderFactory.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TestAppCheckProviderFactory.swift; sourceTree = "<group>"; };
86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GenerateContentIntegrationTests.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */

/* Begin PBXFrameworksBuildPhase section */
Expand Down Expand Up @@ -126,6 +128,7 @@
children = (
868A7C4D2CCC1F4700E449DD /* Credentials.swift */,
8661386D2CC943DE00F4B78E /* IntegrationTests.swift */,
86D77DFB2D7A5340003D155D /* GenerateContentIntegrationTests.swift */,
864F8F702D4980D60002EA7E /* ImagenIntegrationTests.swift */,
862218802D04E08D007ED2D4 /* IntegrationTestUtils.swift */,
);
Expand Down Expand Up @@ -273,6 +276,7 @@
868A7C4F2CCC229F00E449DD /* Credentials.swift in Sources */,
864F8F712D4980DD0002EA7E /* ImagenIntegrationTests.swift in Sources */,
862218812D04E098007ED2D4 /* IntegrationTestUtils.swift in Sources */,
86D77DFC2D7A5340003D155D /* GenerateContentIntegrationTests.swift in Sources */,
8661386E2CC943DE00F4B78E /* IntegrationTests.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
Expand Down
41 changes: 34 additions & 7 deletions FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,22 @@ class VertexComponentTests: XCTestCase {
XCTAssertNotNil(NSClassFromString("FIRVertexAIComponent"))
}

/// Tests that a vertex instance can be created properly using the default Firebase pp.
/// Tests that a vertex instance can be created properly using the default Firebase app.
func testVertexInstanceCreation_defaultApp() throws {
let vertex = VertexAI.vertexAI()

XCTAssertNotNil(vertex)
XCTAssertEqual(vertex.firebaseInfo.projectID, VertexComponentTests.projectID)
XCTAssertEqual(vertex.firebaseInfo.apiKey, VertexComponentTests.apiKey)
XCTAssertEqual(vertex.location, "us-central1")
XCTAssertEqual(vertex.apiConfig.service, .vertexAI)
XCTAssertEqual(vertex.apiConfig.service.endpoint, .firebaseVertexAIProd)
XCTAssertEqual(vertex.apiConfig.version, .v1beta)
}

/// Tests that a vertex instance can be created properly using the default Firebase app and custom
/// location.
func testVertexInstanceCreation_defaultApp_customLocation() throws {
let vertex = VertexAI.vertexAI(location: location)

XCTAssertNotNil(vertex)
Expand Down Expand Up @@ -121,8 +135,16 @@ class VertexComponentTests: XCTestCase {
}

func testSameAppAndDifferentAPI_newInstanceCreated() throws {
let vertex1 = VertexAI.vertexAI(app: VertexComponentTests.app)
let vertex2 = VertexAI.developerAPI(app: VertexComponentTests.app)
let vertex1 = VertexAI.vertexAI(
app: VertexComponentTests.app,
location: location,
apiConfig: APIConfig(service: .vertexAI, version: .v1beta)
)
let vertex2 = VertexAI.vertexAI(
app: VertexComponentTests.app,
location: location,
apiConfig: APIConfig(service: .vertexAI, version: .v1)
)

// Ensure they are different instances.
XCTAssert(vertex1 !== vertex2)
Expand Down Expand Up @@ -168,7 +190,8 @@ class VertexComponentTests: XCTestCase {

func testModelResourceName_developerAPI_generativeLanguage() throws {
let app = try XCTUnwrap(VertexComponentTests.app)
let vertex = VertexAI.developerAPI(app: app)
let apiConfig = APIConfig(service: .developer(endpoint: .generativeLanguage), version: .v1beta)
let vertex = VertexAI.vertexAI(app: app, location: nil, apiConfig: apiConfig)
let model = "test-model-name"

let modelResourceName = vertex.modelResourceName(modelName: model)
Expand All @@ -182,7 +205,7 @@ class VertexComponentTests: XCTestCase {
service: .developer(endpoint: .firebaseVertexAIStaging),
version: .v1beta
)
let vertex = VertexAI.developerAPI(app: app, apiConfig: apiConfig)
let vertex = VertexAI.vertexAI(app: app, location: nil, apiConfig: apiConfig)
let model = "test-model-name"
let projectID = vertex.firebaseInfo.projectID

Expand All @@ -208,7 +231,11 @@ class VertexComponentTests: XCTestCase {

func testGenerativeModel_developerAPI() async throws {
let app = try XCTUnwrap(VertexComponentTests.app)
let vertex = VertexAI.developerAPI(app: app)
let apiConfig = APIConfig(
service: .developer(endpoint: .firebaseVertexAIStaging),
version: .v1beta
)
let vertex = VertexAI.vertexAI(app: app, location: nil, apiConfig: apiConfig)
let modelResourceName = vertex.modelResourceName(modelName: modelName)

let generativeModel = vertex.generativeModel(
Expand All @@ -218,6 +245,6 @@ class VertexComponentTests: XCTestCase {

XCTAssertEqual(generativeModel.modelResourceName, modelResourceName)
XCTAssertEqual(generativeModel.systemInstruction, systemInstruction)
XCTAssertEqual(generativeModel.apiConfig, VertexAI.defaultDeveloperAPIConfig)
XCTAssertEqual(generativeModel.apiConfig, apiConfig)
}
}
Loading