Skip to content

Commit

Permalink
Improvements to Swift LLM app
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidt-sebastian committed May 11, 2024
1 parent 08f7d95 commit 58d9638
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
objects = {

/* Begin PBXBuildFile section */
8D60EEED2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin in Resources */ = {isa = PBXBuildFile; fileRef = 8D60EEEC2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin */; };
0687D6A32BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin in Resources */ = {isa = PBXBuildFile; fileRef = 0687D6A22BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin */; };
0B6641881CDD8AB5E5F6C3FE /* Pods_InferenceExample.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = F02369998C59A0B7AAC3A1D5 /* Pods_InferenceExample.framework */; };
8DCF4C452B99289E00427D77 /* InferenceExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8DCF4C442B99289E00427D77 /* InferenceExampleApp.swift */; };
8DCF4C472B99289E00427D77 /* ConversationScreen.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8DCF4C462B99289E00427D77 /* ConversationScreen.swift */; };
8DCF4C492B99289E00427D77 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 8DCF4C482B99289E00427D77 /* Assets.xcassets */; };
Expand All @@ -17,7 +18,8 @@
/* End PBXBuildFile section */

/* Begin PBXFileReference section */
8D60EEEC2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; path = "gemma-2b-it-cpu-int4.bin"; sourceTree = "<group>"; };
0687D6A22BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; name = "gemma-1.1-2b-it-gpu-int4.bin"; path = "../../../../../../Downloads/gemma-1.1-2b-it-gpu-int4.bin"; sourceTree = "<group>"; };
84EC0509CA2D0A791FE6A50D /* Pods-InferenceExample.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-InferenceExample.release.xcconfig"; path = "Target Support Files/Pods-InferenceExample/Pods-InferenceExample.release.xcconfig"; sourceTree = "<group>"; };
8DCF4C412B99289E00427D77 /* InferenceExample.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = InferenceExample.app; sourceTree = BUILT_PRODUCTS_DIR; };
8DCF4C442B99289E00427D77 /* InferenceExampleApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InferenceExampleApp.swift; sourceTree = "<group>"; };
8DCF4C462B99289E00427D77 /* ConversationScreen.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationScreen.swift; sourceTree = "<group>"; };
Expand All @@ -26,33 +28,47 @@
8DCF4C4C2B99289E00427D77 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
8DCF4C572B992B9C00427D77 /* ConversationViewModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ConversationViewModel.swift; sourceTree = "<group>"; };
8DCF4C5B2B9939D700427D77 /* OnDeviceModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OnDeviceModel.swift; sourceTree = "<group>"; };
E0D7FB7ACC6233403D36FDF8 /* Pods-InferenceExample.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-InferenceExample.debug.xcconfig"; path = "Target Support Files/Pods-InferenceExample/Pods-InferenceExample.debug.xcconfig"; sourceTree = "<group>"; };
F02369998C59A0B7AAC3A1D5 /* Pods_InferenceExample.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_InferenceExample.framework; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */

/* Begin PBXFrameworksBuildPhase section */
8DCF4C3E2B99289E00427D77 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
0B6641881CDD8AB5E5F6C3FE /* Pods_InferenceExample.framework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */

/* Begin PBXGroup section */
32B29128DD33338DDF219030 /* Frameworks */ = {
isa = PBXGroup;
children = (
F02369998C59A0B7AAC3A1D5 /* Pods_InferenceExample.framework */,
);
name = Frameworks;
sourceTree = "<group>";
};
540B7F154909C4C9EF376B57 /* Pods */ = {
isa = PBXGroup;
children = (
E0D7FB7ACC6233403D36FDF8 /* Pods-InferenceExample.debug.xcconfig */,
84EC0509CA2D0A791FE6A50D /* Pods-InferenceExample.release.xcconfig */,
);
path = Pods;
sourceTree = "<group>";
};
8DCF4C382B99289D00427D77 = {
isa = PBXGroup;
children = (
0687D6A22BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin */,
8DCF4C432B99289E00427D77 /* InferenceExample */,
8DCF4C422B99289E00427D77 /* Products */,
540B7F154909C4C9EF376B57 /* Pods */,
8D60EEEC2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin */,
32B29128DD33338DDF219030 /* Frameworks */,
);
sourceTree = "<group>";
};
Expand Down Expand Up @@ -93,6 +109,7 @@
isa = PBXNativeTarget;
buildConfigurationList = 8DCF4C502B99289E00427D77 /* Build configuration list for PBXNativeTarget "InferenceExample" */;
buildPhases = (
3C9FFF9E17DFA2F74C306852 /* [CP] Check Pods Manifest.lock */,
8DCF4C3D2B99289E00427D77 /* Sources */,
8DCF4C3E2B99289E00427D77 /* Frameworks */,
8DCF4C3F2B99289E00427D77 /* Resources */,
Expand Down Expand Up @@ -146,12 +163,37 @@
files = (
8DCF4C4D2B99289E00427D77 /* Preview Assets.xcassets in Resources */,
8DCF4C492B99289E00427D77 /* Assets.xcassets in Resources */,
8D60EEED2B9A8A180019075E /* gemma-2b-it-cpu-int4.bin in Resources */,
0687D6A32BEFC65D00167F37 /* gemma-1.1-2b-it-gpu-int4.bin in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */

/* Begin PBXShellScriptBuildPhase section */
3C9FFF9E17DFA2F74C306852 /* [CP] Check Pods Manifest.lock */ = {
isa = PBXShellScriptBuildPhase;
buildActionMask = 2147483647;
files = (
);
inputFileListPaths = (
);
inputPaths = (
"${PODS_PODFILE_DIR_PATH}/Podfile.lock",
"${PODS_ROOT}/Manifest.lock",
);
name = "[CP] Check Pods Manifest.lock";
outputFileListPaths = (
);
outputPaths = (
"$(DERIVED_FILE_DIR)/Pods-InferenceExample-checkManifestLockResult.txt",
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n";
showEnvVarsInLog = 0;
};
/* End PBXShellScriptBuildPhase section */

/* Begin PBXSourcesBuildPhase section */
8DCF4C3D2B99289E00427D77 /* Sources */ = {
isa = PBXSourcesBuildPhase;
Expand Down Expand Up @@ -283,13 +325,15 @@
};
8DCF4C512B99289E00427D77 /* Debug */ = {
isa = XCBuildConfiguration;
baseConfigurationReference = E0D7FB7ACC6233403D36FDF8 /* Pods-InferenceExample.debug.xcconfig */;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_ENTITLEMENTS = InferenceExample/InferenceExample.entitlements;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_ASSET_PATHS = "\"InferenceExample/Preview Content\"";
DEVELOPMENT_TEAM = M3D535GFVK;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
Expand All @@ -307,7 +351,7 @@
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.2;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample;
PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample.foo;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
Expand All @@ -319,13 +363,15 @@
};
8DCF4C522B99289E00427D77 /* Release */ = {
isa = XCBuildConfiguration;
baseConfigurationReference = 84EC0509CA2D0A791FE6A50D /* Pods-InferenceExample.release.xcconfig */;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_ENTITLEMENTS = InferenceExample/InferenceExample.entitlements;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_ASSET_PATHS = "\"InferenceExample/Preview Content\"";
DEVELOPMENT_TEAM = M3D535GFVK;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
Expand All @@ -343,7 +389,7 @@
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.2;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample;
PRODUCT_BUNDLE_IDENTIFIER = com.mediapipe.InferenceExample.foo;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ class ConversationViewModel: ObservableObject {
messages.append(systemMessage)

do {
let response = try await chat.sendMessage(text)
let response = try await chat.sendMessage(text, progress : { [weak self] partialResult in
guard let self = self else { return }
DispatchQueue.main.async {
self.messages[self.messages.count - 1].message = partialResult
}
})

// replace pending message with model response
messages[messages.count - 1].message = response
Expand Down
44 changes: 33 additions & 11 deletions examples/llm_inference/ios/InferenceExample/OnDeviceModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,27 @@ import MediaPipeTasksGenAI

final class OnDeviceModel {

private var inference: LlmInference! = {
let path = Bundle.main.path(forResource: "gemma-2b-it-cpu-int4", ofType: "bin")!
let llmOptions = LlmInference.Options(modelPath: path)
return LlmInference(options: llmOptions)
}()

func generateResponse(prompt: String) async throws -> String {
private var cachedInference: LlmInference?

private var inference: LlmInference
{
get throws {
if let cached = cachedInference {
return cached
} else {
let path = Bundle.main.path(forResource: "gemma-1.1-2b-it-gpu-int4", ofType: "bin")!
let llmOptions = LlmInference.Options(modelPath: path)
cachedInference = try LlmInference(options: llmOptions)
return cachedInference!
}
}
}

func generateResponse(prompt: String, progress: @escaping (String) -> Void) async throws -> String {
var partialResult = ""

let inference = try inference
return try await withCheckedThrowingContinuation { continuation in
do {
try inference.generateResponseAsync(inputText: prompt) { partialResponse, error in
Expand All @@ -34,6 +47,7 @@ final class OnDeviceModel {
}
if let partial = partialResponse {
partialResult += partial
progress(partialResult.trimmingCharacters(in: .whitespacesAndNewlines))
}
} completion: {
let aggregate = partialResult.trimmingCharacters(in: .whitespacesAndNewlines)
Expand Down Expand Up @@ -62,16 +76,24 @@ final class Chat {
self.model = model
}

private func composeUserTurn(_ newMessage: String) -> String {
return "<start_of_turn>user\n\(newMessage)<end_of_turn>\n"
}

private func composeModelTurn(_ newMessage: String) -> String {
return "<start_of_turn>model\n\(newMessage)<end_of_turn>\n"
}

private func compositePrompt(newMessage: String) -> String {
return history.joined(separator: "\n") + "\n" + newMessage
}

func sendMessage(_ text: String) async throws -> String {
let prompt = compositePrompt(newMessage: text)
let reply = try await model.generateResponse(prompt: prompt)

history.append(text)
history.append(reply)
func sendMessage(_ text: String, progress: @escaping (String) -> Void) async throws -> String {
let prompt = compositePrompt(newMessage: composeUserTurn(text))
let reply = try await model.generateResponse(prompt: prompt, progress: progress)

history = [prompt, composeModelTurn(reply)]

print("Prompt: \(prompt)")
print("Reply: \(reply)")
Expand Down
10 changes: 5 additions & 5 deletions examples/llm_inference/ios/Podfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PODS:
- MediaPipeTasksGenAI (0.10.11):
- MediaPipeTasksGenAIC (= 0.10.11)
- MediaPipeTasksGenAIC (0.10.11)
- MediaPipeTasksGenAI (0.10.14):
- MediaPipeTasksGenAIC (= 0.10.14)
- MediaPipeTasksGenAIC (0.10.14)

DEPENDENCIES:
- MediaPipeTasksGenAI
Expand All @@ -12,8 +12,8 @@ SPEC REPOS:
- MediaPipeTasksGenAIC

SPEC CHECKSUMS:
MediaPipeTasksGenAI: 9fb3fc0e9f9329d0b3f89d741dbdb4cb4429b87a
MediaPipeTasksGenAIC: 9bb1f037b742d7d642c8b8fbec8b4626f73c18c5
MediaPipeTasksGenAI: 8cd77fa32ea21f7a6319b025aa28cfc3e20ab73b
MediaPipeTasksGenAIC: 270ec81f85e96fac283945702e34112ebbfd5e77

PODFILE CHECKSUM: b561fe84c5e19b81e1111ba0f8f21564f7006b85

Expand Down

0 comments on commit 58d9638

Please sign in to comment.