Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix token concatenation implementation in Swift example project #4325

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ actor LlamaContext {
private var context: OpaquePointer
private var batch: llama_batch
private var tokens_list: [llama_token]
/// This variable is used to store temporarily invalid cchars
private var temporary_invalid_cchars: [CChar]

var n_len: Int32 = 512
var n_cur: Int32 = 0
Expand All @@ -21,6 +23,7 @@ actor LlamaContext {
self.context = context
self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = []
}

deinit {
Expand Down Expand Up @@ -61,6 +64,7 @@ actor LlamaContext {
print("attempting to complete \"\(text)\"")

tokens_list = tokenize(text: text, add_bos: true)
temporary_invalid_cchars = []

let n_ctx = llama_n_ctx(context)
let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count)
Expand All @@ -72,7 +76,7 @@ actor LlamaContext {
}

for id in tokens_list {
print(token_to_piece(token: id))
print(String(cString: token_to_piece(token: id) + [0]))
}

// batch = llama_batch_init(512, 0) // done in init()
Expand Down Expand Up @@ -115,10 +119,25 @@ actor LlamaContext {

if new_token_id == llama_token_eos(context) || n_cur == n_len {
print("\n")
return ""
let new_token_str = String(cString: temporary_invalid_cchars + [0])
temporary_invalid_cchars.removeAll()
return new_token_str
}

let new_token_str = token_to_piece(token: new_token_id)
let new_token_cchars = token_to_piece(token: new_token_id)
temporary_invalid_cchars.append(contentsOf: new_token_cchars)
let new_token_str: String
if let string = String(validatingUTF8: temporary_invalid_cchars + [0]) {
temporary_invalid_cchars.removeAll()
new_token_str = string
} else if (0 ..< temporary_invalid_cchars.count).contains(where: {$0 != 0 && String(validatingUTF8: Array(temporary_invalid_cchars.suffix($0)) + [0]) != nil}) {
// in this case, at least the suffix of the temporary_invalid_cchars can be interpreted as UTF8 string
let string = String(cString: temporary_invalid_cchars + [0])
temporary_invalid_cchars.removeAll()
new_token_str = string
} else {
new_token_str = ""
}
print(new_token_str)
// tokens_list.append(new_token_id)

Expand All @@ -144,6 +163,7 @@ actor LlamaContext {

func clear() {
tokens_list.removeAll()
temporary_invalid_cchars.removeAll()
}

private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
Expand All @@ -162,7 +182,8 @@ actor LlamaContext {
return swiftTokens
}

private func token_to_piece(token: llama_token) -> String {
/// - note: The result does not contain null-terminator
private func token_to_piece(token: llama_token) -> [CChar] {
let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8)
result.initialize(repeating: Int8(0), count: 8)
defer {
Expand All @@ -176,10 +197,12 @@ actor LlamaContext {
defer {
newResult.deallocate()
}
_ = llama_token_to_piece(model, token, newResult, -nTokens)
return String(cString: newResult)
let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens)
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
return Array(bufferPointer)
} else {
return String(cString: result)
let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nTokens))
return Array(bufferPointer)
}
}
}
Loading