| 
 | 1 | +import Foundation  | 
 | 2 | +import llama  | 
 | 3 | + | 
 | 4 | +let arguments = CommandLine.arguments  | 
 | 5 | + | 
 | 6 | +// Check that we have at least one argument (the model path)  | 
 | 7 | +guard arguments.count > 1 else {  | 
 | 8 | +    print("Usage: swift MODEL_PATH [PROMPT] [PARALLEL]")  | 
 | 9 | +    exit(1)  | 
 | 10 | +}  | 
 | 11 | + | 
 | 12 | +let modelPath: String = arguments[1]  | 
 | 13 | +let prompt: String = arguments.count > 2 ? arguments[2] : "Hello my name is"  | 
 | 14 | +let n_parallel: Int = arguments.count > 3 && Int(arguments[3]) != nil ? Int(arguments[3])! : 1  | 
 | 15 | + | 
 | 16 | +// total length of the sequences including the prompt  | 
 | 17 | +let n_len: Int = 32  | 
 | 18 | + | 
 | 19 | +// init LLM  | 
 | 20 | +llama_backend_init(false)  | 
 | 21 | +defer {  | 
 | 22 | +    llama_backend_free()  | 
 | 23 | +}  | 
 | 24 | + | 
 | 25 | +let model_params = llama_model_default_params()  | 
 | 26 | +guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), model_params) else {  | 
 | 27 | +    print("Failed to load model")  | 
 | 28 | +    exit(1)  | 
 | 29 | +}  | 
 | 30 | + | 
 | 31 | +defer {  | 
 | 32 | +    llama_free_model(model)  | 
 | 33 | +}  | 
 | 34 | + | 
 | 35 | +var tokens = tokenize(text: prompt, add_bos: true)  | 
 | 36 | + | 
 | 37 | +let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)  | 
 | 38 | + | 
 | 39 | +var context_params = llama_context_default_params()  | 
 | 40 | +context_params.seed = 1234  | 
 | 41 | +context_params.n_ctx = n_kv_req  | 
 | 42 | +context_params.n_batch = UInt32(max(n_len, n_parallel))  | 
 | 43 | +context_params.n_threads = 8  | 
 | 44 | +context_params.n_threads_batch = 8  | 
 | 45 | + | 
 | 46 | +let context = llama_new_context_with_model(model, context_params)  | 
 | 47 | +guard context != nil else {  | 
 | 48 | +    print("Failed to initialize context")  | 
 | 49 | +    exit(1)  | 
 | 50 | +}  | 
 | 51 | + | 
 | 52 | +defer {  | 
 | 53 | +    llama_free(context)  | 
 | 54 | +}  | 
 | 55 | + | 
 | 56 | +let n_ctx = llama_n_ctx(context)  | 
 | 57 | + | 
 | 58 | +print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")  | 
 | 59 | + | 
 | 60 | +if n_kv_req > n_ctx {  | 
 | 61 | +    print("error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", n_kv_req)  | 
 | 62 | +    exit(1)  | 
 | 63 | +}  | 
 | 64 | + | 
 | 65 | +var buffer: [CChar] = []  | 
 | 66 | +for id: llama_token in tokens {  | 
 | 67 | +    print(token_to_piece(token: id, buffer: &buffer) ?? "", terminator: "")  | 
 | 68 | +}  | 
 | 69 | + | 
 | 70 | +print("\n")  | 
 | 71 | + | 
 | 72 | +var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0)  | 
 | 73 | +defer {  | 
 | 74 | +    llama_batch_free(batch)  | 
 | 75 | +}  | 
 | 76 | + | 
 | 77 | +// evaluate the initial prompt  | 
 | 78 | +batch.n_tokens = Int32(tokens.count)  | 
 | 79 | + | 
 | 80 | +for (i, token) in tokens.enumerated() {  | 
 | 81 | +    batch.token[i] = token  | 
 | 82 | +    batch.pos[i] = Int32(i)  | 
 | 83 | +    batch.seq_id[i] = 0  | 
 | 84 | +    batch.logits[i] = 0  | 
 | 85 | +}  | 
 | 86 | + | 
 | 87 | +// llama_decode will output logits only for the last token of the prompt  | 
 | 88 | +batch.logits[Int(batch.n_tokens) - 1] = 1  | 
 | 89 | + | 
 | 90 | +if llama_decode(context, batch) != 0 {  | 
 | 91 | +    print("llama_decode() failed")  | 
 | 92 | +    exit(1)  | 
 | 93 | +}  | 
 | 94 | + | 
 | 95 | +for i in 1 ..< n_parallel {  | 
 | 96 | +    llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)  | 
 | 97 | +}  | 
 | 98 | + | 
 | 99 | +if n_parallel > 1 {  | 
 | 100 | +    print("generating \(n_parallel) sequences ...\n")  | 
 | 101 | +}  | 
 | 102 | + | 
 | 103 | +var streams: [String] = .init(repeating: "", count: n_parallel)  | 
 | 104 | +var streamBuffers: [[CChar]] = .init(repeating: [], count: n_parallel)  | 
 | 105 | +var i_batch = [Int32](repeating: batch.n_tokens - 1, count: n_parallel)  | 
 | 106 | + | 
 | 107 | +var n_cur = batch.n_tokens  | 
 | 108 | +var n_decode = 0  | 
 | 109 | + | 
 | 110 | +let t_main_start = ggml_time_us()  | 
 | 111 | + | 
 | 112 | +while n_cur <= n_len {  | 
 | 113 | +    // prepare the next batch  | 
 | 114 | +    batch.n_tokens = 0  | 
 | 115 | + | 
 | 116 | +    // sample the next token for each parallel sequence / stream  | 
 | 117 | +    for i in 0 ..< n_parallel {  | 
 | 118 | +        if i_batch[i] < 0 {  | 
 | 119 | +            // the stream has already finished  | 
 | 120 | +            continue  | 
 | 121 | +        }  | 
 | 122 | + | 
 | 123 | +        var n_vocab = llama_n_vocab(model)  | 
 | 124 | +        var logits = llama_get_logits_ith(context, i_batch[i])  | 
 | 125 | + | 
 | 126 | +        var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))  | 
 | 127 | + | 
 | 128 | +        for token_id in 0 ..< n_vocab {  | 
 | 129 | +            candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))  | 
 | 130 | +        }  | 
 | 131 | + | 
 | 132 | +        var candidates_p: llama_token_data_array = .init(  | 
 | 133 | +            data: &candidates,  | 
 | 134 | +            size: candidates.count,  | 
 | 135 | +            sorted: false  | 
 | 136 | +        )  | 
 | 137 | + | 
 | 138 | +        let top_k: Int32 = 40  | 
 | 139 | +        let top_p: Float = 0.9  | 
 | 140 | +        let temp: Float = 0.4  | 
 | 141 | + | 
 | 142 | +        llama_sample_top_k(context, &candidates_p, top_k, 1)  | 
 | 143 | +        llama_sample_top_p(context, &candidates_p, top_p, 1)  | 
 | 144 | +        llama_sample_temp(context, &candidates_p, temp)  | 
 | 145 | + | 
 | 146 | +        let new_token_id = llama_sample_token(context, &candidates_p)  | 
 | 147 | + | 
 | 148 | +        // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);  | 
 | 149 | + | 
 | 150 | +        // is it an end of stream? -> mark the stream as finished  | 
 | 151 | +        if new_token_id == llama_token_eos(context) || n_cur == n_len {  | 
 | 152 | +            i_batch[i] = -1  | 
 | 153 | +            // print("")  | 
 | 154 | +            if n_parallel > 1 {  | 
 | 155 | +                print("stream \(i) finished at n_cur = \(n_cur)")  | 
 | 156 | +            }  | 
 | 157 | + | 
 | 158 | +            continue  | 
 | 159 | +        }  | 
 | 160 | + | 
 | 161 | +        let nextStringPiece = token_to_piece(token: new_token_id, buffer: &streamBuffers[i]) ?? ""  | 
 | 162 | + | 
 | 163 | +        // if there is only one stream, we print immediately to stdout  | 
 | 164 | +        if n_parallel == 1 {  | 
 | 165 | +            print(nextStringPiece, terminator: "")  | 
 | 166 | +        }  | 
 | 167 | +        streams[i] += nextStringPiece  | 
 | 168 | + | 
 | 169 | +        // push this new token for next evaluation  | 
 | 170 | +        batch.token[Int(batch.n_tokens)] = new_token_id  | 
 | 171 | +        batch.pos[Int(batch.n_tokens)] = n_cur  | 
 | 172 | +        batch.seq_id[Int(batch.n_tokens)] = Int32(i)  | 
 | 173 | +        batch.logits[Int(batch.n_tokens)] = 1  | 
 | 174 | + | 
 | 175 | +        i_batch[i] = batch.n_tokens  | 
 | 176 | + | 
 | 177 | +        batch.n_tokens += 1  | 
 | 178 | + | 
 | 179 | +        n_decode += 1  | 
 | 180 | +    }  | 
 | 181 | + | 
 | 182 | +    // all streams are finished  | 
 | 183 | +    if batch.n_tokens == 0 {  | 
 | 184 | +        break  | 
 | 185 | +    }  | 
 | 186 | + | 
 | 187 | +    n_cur += 1  | 
 | 188 | + | 
 | 189 | +    // evaluate the current batch with the transformer model  | 
 | 190 | +    if llama_decode(context, batch) != 0 {  | 
 | 191 | +        print("llama_decode() failed")  | 
 | 192 | +        exit(1)  | 
 | 193 | +    }  | 
 | 194 | +}  | 
 | 195 | + | 
 | 196 | +if n_parallel > 1 {  | 
 | 197 | +    print("\n")  | 
 | 198 | +    for (i, stream) in streams.enumerated() {  | 
 | 199 | +        print("sequence \(i):\n\n\(prompt)\(stream)\n")  | 
 | 200 | +    }  | 
 | 201 | +}  | 
 | 202 | + | 
 | 203 | +let t_main_end = ggml_time_us()  | 
 | 204 | + | 
 | 205 | +print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")  | 
 | 206 | + | 
 | 207 | +llama_print_timings(context)  | 
 | 208 | + | 
 | 209 | +private func tokenize(text: String, add_bos: Bool) -> [llama_token] {  | 
 | 210 | +    let n_tokens = text.count + (add_bos ? 1 : 0)  | 
 | 211 | +    let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)  | 
 | 212 | +    let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos)  | 
 | 213 | +    var swiftTokens: [llama_token] = []  | 
 | 214 | +    for i in 0 ..< tokenCount {  | 
 | 215 | +        swiftTokens.append(tokens[Int(i)])  | 
 | 216 | +    }  | 
 | 217 | +    tokens.deallocate()  | 
 | 218 | +    return swiftTokens  | 
 | 219 | +}  | 
 | 220 | + | 
 | 221 | +private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {  | 
 | 222 | +    var result = [CChar](repeating: 0, count: 8)  | 
 | 223 | +    let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count))  | 
 | 224 | +    if nTokens < 0 {  | 
 | 225 | +        if result.count >= -Int(nTokens) {  | 
 | 226 | +            result.removeLast(-Int(nTokens))  | 
 | 227 | +        } else {  | 
 | 228 | +            result.removeAll()  | 
 | 229 | +        }  | 
 | 230 | +        let check = llama_token_to_piece(  | 
 | 231 | +            model,  | 
 | 232 | +            token,  | 
 | 233 | +            &result,  | 
 | 234 | +            Int32(result.count)  | 
 | 235 | +        )  | 
 | 236 | +        assert(check == nTokens)  | 
 | 237 | +    } else {  | 
 | 238 | +        result.removeLast(result.count - Int(nTokens))  | 
 | 239 | +    }  | 
 | 240 | +    if buffer.isEmpty, let utfString = String(cString: result + [0], encoding: .utf8) {  | 
 | 241 | +        return utfString  | 
 | 242 | +    } else {  | 
 | 243 | +        buffer.append(contentsOf: result)  | 
 | 244 | +        let data = Data(buffer.map { UInt8(bitPattern: $0) })  | 
 | 245 | +        if buffer.count >= 4 { // 4 bytes is the max length of a utf8 character so if we're here we need to reset the buffer  | 
 | 246 | +            buffer = []  | 
 | 247 | +        }  | 
 | 248 | +        guard let bufferString = String(data: data, encoding: .utf8) else {  | 
 | 249 | +            return nil  | 
 | 250 | +        }  | 
 | 251 | +        buffer = []  | 
 | 252 | +        return bufferString  | 
 | 253 | +    }  | 
 | 254 | +    return nil  | 
 | 255 | +}  | 
0 commit comments