Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions Sources/Generation/Generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,13 @@ public extension Generation {
// Iterate until we find the eos token or reach the max length
// TODO: additional stopping criteria
var outputTokens = tokens
let logitsProcessor = LogitsProcessor(logitsWarpers: logitsWarpers(config: config))
while outputTokens.count < config.maxLength {
let outputs = model(outputTokens, config)

/// `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case
var logits = (outputs as? MLShapedArraySlice<Float>)?.floats ?? outputs.scalars as! [Float]

let nextToken: Int
if config.temperature > 0 && config.temperature != 1 {
logits = logits.map { $0 / Float(config.temperature) }
}
if config.topK > 0 {
let topK = Math.topK(arr: logits, k: config.topK)
nextToken = Math.sample(indexes: topK.indexes, probs: topK.probs)
} else if config.topP < 1.0 {
let topP = Math.topP(arr: logits, p: Float(config.topP))
nextToken = Math.sample(indexes: topP.indexes, probs: topP.probs)
} else {
fatalError("not implemented yet")
}

let logits = (outputs as? MLShapedArraySlice<Float>)?.floats ?? outputs.scalars as! [Float]
let (indexes, processedLogits) = logitsProcessor(logits)
let nextToken = Math.sample(indexes: indexes, probs: Math.softmax(processedLogits))
if nextToken == config.eosTokenId { break }
outputTokens.append(nextToken)
callback?(outputTokens)
Expand Down Expand Up @@ -102,4 +89,18 @@ public extension Generation {

return tokenizer.decode(tokens: output)
}

private func logitsWarpers(config: GenerationConfig) -> [any LogitsWarper] {
var logitsWarpers = [any LogitsWarper]()
if config.temperature > 0 && config.temperature != 1 {
logitsWarpers.append(TemperatureLogitsWarper(temperature: Float(config.temperature)))
}
if config.topK > 0 {
logitsWarpers.append(TopKLogitsWarper(k: config.topK))
}
if config.topP < 1.0 {
logitsWarpers.append(TopPLogitsWarper(p: Float(config.topP)))
}
return logitsWarpers
}
}
18 changes: 18 additions & 0 deletions Sources/TensorUtils/LogitsWarper/LogitsProcessor.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import Foundation

public struct LogitsProcessor {
public var logitsWarpers: [any LogitsWarper]

public init(logitsWarpers: [any LogitsWarper]) {
self.logitsWarpers = logitsWarpers
}

public func callAsFunction(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
var indexes = Array(arr.indices)
var logits = arr
for warper in logitsWarpers {
(indexes, logits) = warper(logits)
}
return (indexes: indexes, logits: logits)
}
}
13 changes: 13 additions & 0 deletions Sources/TensorUtils/LogitsWarper/LogitsWarper.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import Foundation

/// Protocol for all logit warpers that can be applied during generation
public protocol LogitsWarper {
func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float])
func callAsFunction(_ arr: [Float]) -> (indexes: [Int], logits: [Float])
}

extension LogitsWarper {
public func callAsFunction(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
warp(arr)
}
}
14 changes: 14 additions & 0 deletions Sources/TensorUtils/LogitsWarper/TemperatureLogitsWarper.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import Foundation

public struct TemperatureLogitsWarper: LogitsWarper {
public var temperature: Float

public init(temperature: Float) {
self.temperature = temperature
}

public func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
let logits = arr.map { $0 / temperature }
return (indexes: Array(logits.indices), logits: logits)
}
}
58 changes: 58 additions & 0 deletions Sources/TensorUtils/LogitsWarper/TopKLogitsWarper.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import Foundation
import Accelerate

/// Top-K.
/// Select the k most-probable element indices from `arr`
/// and return both the indices (from the original array)
/// and their probabilities.
public struct TopKLogitsWarper: LogitsWarper {
public var k: Int

public init(k: Int) {
self.k = k
}

public func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
guard !arr.isEmpty else {
return (indexes: [], logits: [])
}
let k = min(k, arr.count)
let arrDescriptor = BNNSNDArrayDescriptor.allocate(
initializingFrom: arr,
shape: .vector(arr.count)
)
defer {
arrDescriptor.deallocate()
}
let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized(
scalarType: Int32.self,
shape: .vector(k)
)
defer {
bestIndices.deallocate()
}
let bestValues = BNNSNDArrayDescriptor.allocateUninitialized(
scalarType: Float.self,
shape: .vector(k)
)
defer {
bestValues.deallocate()
}
try! Accelerate.BNNS.applyTopK(
k: k,
input: arrDescriptor,
bestValues: bestValues,
bestIndices: bestIndices,
axis: 0,
batchSize: 1,
filterParameters: nil
)
let distances = bestValues.data!.withMemoryRebound(to: Float.self, capacity: k) { ptr in
Array(UnsafeBufferPointer(start: ptr, count: k))
}
let indices = bestIndices.data!.withMemoryRebound(to: Int32.self, capacity: k) { ptr in
Array(UnsafeBufferPointer(start: ptr, count: k))
}
return (indexes: indices.map { Int($0) }, logits: distances)
}
}
37 changes: 37 additions & 0 deletions Sources/TensorUtils/LogitsWarper/TopPLogitsWarper.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import Foundation

/// Top-P.
/// Select the smallest set of elements whose cumulative probability exceeds the probability `p`.
/// Based on https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
public struct TopPLogitsWarper: LogitsWarper {
public var p: Float

public init(p: Float) {
self.p = p
}

public func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
guard !arr.isEmpty else {
return (indexes: [], logits: [])
}

let arrSoftmax = Math.softmax(arr)
var indexLogitProb = [(index: Int, logit: Float, prob: Float)]()
indexLogitProb.reserveCapacity(arr.count)
for (index, data) in zip(arr, arrSoftmax).enumerated() {
indexLogitProb.append((index: index, logit: data.0, prob: data.1))
}
indexLogitProb.sort { $0.prob > $1.prob }

let cumsum = Math.cumsum(indexLogitProb.map(\.prob))
var sliceIndex = cumsum.count - 1
for (index, element) in cumsum.enumerated() where element > p {
sliceIndex = index
break
}

let indexes = indexLogitProb[0 ... sliceIndex].map(\.index)
let logits = indexLogitProb[0 ... sliceIndex].map(\.logit)
return (indexes: indexes, logits: logits)
}
}
77 changes: 0 additions & 77 deletions Sources/TensorUtils/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -83,83 +83,6 @@ public struct Math {
return result
}

/// Top-P.
/// Select the smallest set of elements whose cumulative probability exceeds the probability `p`.
/// Based on https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
public static func topP(arr: [Float], p: Float) -> (indexes: [Int], probs: [Float]) {
guard !arr.isEmpty else {
return (indexes: [], probs: [])
}

let arrSoftmax = softmax(arr)
var indexLogitProb = [(index: Int, logit: Float, prob: Float)]()
indexLogitProb.reserveCapacity(arr.count)
for (index, data) in zip(arr, arrSoftmax).enumerated() {
indexLogitProb.append((index: index, logit: data.0, prob: data.1))
}
indexLogitProb.sort { $0.prob > $1.prob }

let cumsum = cumsum(indexLogitProb.map(\.prob))
var sliceIndex = cumsum.count - 1
for (index, element) in cumsum.enumerated() where element > p {
sliceIndex = index
break
}

let indexes = indexLogitProb[0 ... sliceIndex].map(\.index)
let probs = softmax(indexLogitProb[0 ... sliceIndex].map(\.logit))
return (indexes: indexes, probs: probs)
}

/// Top-K.
/// Select the k most-probable elements indices from `arr`
/// and return both the indices (from the original array)
/// and their softmaxed probabilities.
///
public static func topK(arr: [Float], k: Int) -> (indexes: [Int], probs: [Float]) {
guard !arr.isEmpty else {
return (indexes: [], probs: [])
}
let k = min(k, arr.count)
let arrDescriptor = BNNSNDArrayDescriptor.allocate(
initializingFrom: arr,
shape: .vector(arr.count)
)
defer {
arrDescriptor.deallocate()
}
let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized(
scalarType: Int32.self,
shape: .vector(k)
)
defer {
bestIndices.deallocate()
}
let bestValues = BNNSNDArrayDescriptor.allocateUninitialized(
scalarType: Float.self,
shape: .vector(k)
)
defer {
bestValues.deallocate()
}
try! Accelerate.BNNS.applyTopK(
k: k,
input: arrDescriptor,
bestValues: bestValues,
bestIndices: bestIndices,
axis: 0,
batchSize: 1,
filterParameters: nil
)
let distances = bestValues.data!.withMemoryRebound(to: Float.self, capacity: k) { ptr in
Array(UnsafeBufferPointer(start: ptr, count: k))
}
let indices = bestIndices.data!.withMemoryRebound(to: Int32.self, capacity: k) { ptr in
Array(UnsafeBufferPointer(start: ptr, count: k))
}
return (indexes: indices.map { Int($0) }, probs: softmax(distances))
}

/// Multinomial sampling from an array of probs. Works well with topK
public static func sample(indexes: [Int], probs: [Float]) -> Int {
let i = randomNumber(probabilities: probs)
Expand Down
101 changes: 101 additions & 0 deletions Tests/TensorUtilsTests/LogitsWarperTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//
// LogitsWarperTests.swift
//
// Created by Jan Krukowski on 09/12/2023.
//

import XCTest
import CoreML
@testable import TensorUtils

final class LogitsWarperTests: XCTestCase {
private let accuracy: Float = 0.00001

func testTemperatureLogitsWarper() {
let result1 = TemperatureLogitsWarper(temperature: 0.0)([])
XCTAssertTrue(result1.indexes.isEmpty)
XCTAssertTrue(result1.logits.isEmpty)

let result2 = TemperatureLogitsWarper(temperature: 1.0)([])
XCTAssertTrue(result2.indexes.isEmpty)
XCTAssertTrue(result2.logits.isEmpty)

let result3 = TemperatureLogitsWarper(temperature: 1.0)([2.0, 1.0])
XCTAssertEqual(result3.indexes, [0, 1])
XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy)

let result4 = TemperatureLogitsWarper(temperature: 2.0)([2.0, 1.0])
XCTAssertEqual(result4.indexes, [0, 1])
XCTAssertEqual(result4.logits, [1.0, 0.5], accuracy: accuracy)

let result5 = TemperatureLogitsWarper(temperature: 0.5)([2.0, 1.0])
XCTAssertEqual(result5.indexes, [0, 1])
XCTAssertEqual(result5.logits, [4.0, 2.0], accuracy: accuracy)
}

func testTopKLogitsWarper() {
let result1 = TopKLogitsWarper(k: 0)([])
XCTAssertTrue(result1.indexes.isEmpty)
XCTAssertTrue(result1.logits.isEmpty)

let result2 = TopKLogitsWarper(k: 3)([])
XCTAssertTrue(result2.indexes.isEmpty)
XCTAssertTrue(result2.logits.isEmpty)

let result3 = TopKLogitsWarper(k: 3)([2.0, 1.0])
XCTAssertEqual(result3.indexes, [0, 1])
XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy)

let result4 = TopKLogitsWarper(k: 3)([2.0, 1.0, 3.0])
XCTAssertEqual(result4.indexes, [2, 0, 1])
XCTAssertEqual(result4.logits, [3.0, 2.0, 1.0], accuracy: accuracy)

let result5 = TopKLogitsWarper(k: 4)([2.0, 1.0, 3.0, -1.0, 123.0, 0.0])
XCTAssertEqual(result5.indexes, [4, 2, 0, 1])
XCTAssertEqual(result5.logits, [123.0, 3.0, 2.0, 1.0], accuracy: accuracy)
}

func testTopPLogitsWarper() {
let result1 = TopPLogitsWarper(p: 0.99)([])
XCTAssertTrue(result1.indexes.isEmpty)
XCTAssertTrue(result1.logits.isEmpty)

let result2 = TopPLogitsWarper(p: 0.99)((0 ..< 10).map { Float($0) })
XCTAssertEqual(result2.indexes, [9, 8, 7, 6, 5])
XCTAssertEqual(result2.logits, [9.0, 8.0, 7.0, 6.0, 5.0], accuracy: accuracy)

let result3 = TopPLogitsWarper(p: 0.95)((0 ..< 10).map { Float($0) })
XCTAssertEqual(result3.indexes, [9, 8, 7])
XCTAssertEqual(result3.logits, [9.0, 8.0, 7.0], accuracy: accuracy)

let result4 = TopPLogitsWarper(p: 0.6321493)((0 ..< 10).map { Float($0) })
XCTAssertEqual(result4.indexes, [9, 8])
XCTAssertEqual(result4.logits, [9.0, 8.0], accuracy: accuracy)
}

func testLogitsProcessor() {
let processor1 = LogitsProcessor(logitsWarpers: [])
let result1 = processor1([])
XCTAssertTrue(result1.indexes.isEmpty)
XCTAssertTrue(result1.logits.isEmpty)

let processor2 = LogitsProcessor(logitsWarpers: [])
let result2 = processor2([2.0, 1.0])
XCTAssertEqual(result2.indexes, [0, 1])
XCTAssertEqual(result2.logits, [2.0, 1.0], accuracy: accuracy)

let processor3 = LogitsProcessor(
logitsWarpers: [TopKLogitsWarper(k: 3)]
)
let result3 = processor3([2.0, 1.0, 3.0, -5.0])
XCTAssertEqual(result3.indexes, [2, 0, 1])
XCTAssertEqual(result3.logits, [3.0, 2.0, 1.0], accuracy: accuracy)

let processor4 = LogitsProcessor(
logitsWarpers: [TopKLogitsWarper(k: 3), TopPLogitsWarper(p: 0.99)]
)
let result4 = processor4([2.0, 1.0, 3.0, -5.0, -23.0, 12.5])
XCTAssertEqual(result4.indexes, [0])
XCTAssertEqual(result4.logits, [12.5], accuracy: accuracy)
}
}
Loading