Skip to content

Commit 41fb1df

Browse files
authored
Allow top-k and top-p to coexist (#27)
* - added LogitsWrapper - added LogitsProcessor - changed generation to use LogitsProcessor - added tests * review changes
1 parent fa25221 commit 41fb1df

File tree

10 files changed

+282
-169
lines changed

10 files changed

+282
-169
lines changed

Sources/Generation/Generation.swift

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,13 @@ public extension Generation {
5454
// Iterate until we find the eos token or reach the max length
5555
// TODO: additional stopping criteria
5656
var outputTokens = tokens
57+
let logitsProcessor = LogitsProcessor(logitsWarpers: logitsWarpers(config: config))
5758
while outputTokens.count < config.maxLength {
5859
let outputs = model(outputTokens, config)
59-
6060
/// `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case
61-
var logits = (outputs as? MLShapedArraySlice<Float>)?.floats ?? outputs.scalars as! [Float]
62-
63-
let nextToken: Int
64-
if config.temperature > 0 && config.temperature != 1 {
65-
logits = logits.map { $0 / Float(config.temperature) }
66-
}
67-
if config.topK > 0 {
68-
let topK = Math.topK(arr: logits, k: config.topK)
69-
nextToken = Math.sample(indexes: topK.indexes, probs: topK.probs)
70-
} else if config.topP < 1.0 {
71-
let topP = Math.topP(arr: logits, p: Float(config.topP))
72-
nextToken = Math.sample(indexes: topP.indexes, probs: topP.probs)
73-
} else {
74-
fatalError("not implemented yet")
75-
}
76-
61+
let logits = (outputs as? MLShapedArraySlice<Float>)?.floats ?? outputs.scalars as! [Float]
62+
let (indexes, processedLogits) = logitsProcessor(logits)
63+
let nextToken = Math.sample(indexes: indexes, probs: Math.softmax(processedLogits))
7764
if nextToken == config.eosTokenId { break }
7865
outputTokens.append(nextToken)
7966
callback?(outputTokens)
@@ -102,4 +89,18 @@ public extension Generation {
10289

10390
return tokenizer.decode(tokens: output)
10491
}
92+
93+
private func logitsWarpers(config: GenerationConfig) -> [any LogitsWarper] {
94+
var logitsWarpers = [any LogitsWarper]()
95+
if config.temperature > 0 && config.temperature != 1 {
96+
logitsWarpers.append(TemperatureLogitsWarper(temperature: Float(config.temperature)))
97+
}
98+
if config.topK > 0 {
99+
logitsWarpers.append(TopKLogitsWarper(k: config.topK))
100+
}
101+
if config.topP < 1.0 {
102+
logitsWarpers.append(TopPLogitsWarper(p: Float(config.topP)))
103+
}
104+
return logitsWarpers
105+
}
105106
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import Foundation
2+
3+
public struct LogitsProcessor {
4+
public var logitsWarpers: [any LogitsWarper]
5+
6+
public init(logitsWarpers: [any LogitsWarper]) {
7+
self.logitsWarpers = logitsWarpers
8+
}
9+
10+
public func callAsFunction(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
11+
var indexes = Array(arr.indices)
12+
var logits = arr
13+
for warper in logitsWarpers {
14+
(indexes, logits) = warper(logits)
15+
}
16+
return (indexes: indexes, logits: logits)
17+
}
18+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import Foundation
2+
3+
/// Protocol for all logit warpers that can be applied during generation
4+
public protocol LogitsWarper {
5+
func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float])
6+
func callAsFunction(_ arr: [Float]) -> (indexes: [Int], logits: [Float])
7+
}
8+
9+
extension LogitsWarper {
10+
public func callAsFunction(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
11+
warp(arr)
12+
}
13+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import Foundation
2+
3+
public struct TemperatureLogitsWarper: LogitsWarper {
4+
public var temperature: Float
5+
6+
public init(temperature: Float) {
7+
self.temperature = temperature
8+
}
9+
10+
public func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
11+
let logits = arr.map { $0 / temperature }
12+
return (indexes: Array(logits.indices), logits: logits)
13+
}
14+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import Foundation
2+
import Accelerate
3+
4+
/// Top-K.
5+
/// Select the k most-probable element indices from `arr`
6+
/// and return both the indices (from the original array)
7+
/// and their probabilities.
8+
public struct TopKLogitsWarper: LogitsWarper {
9+
public var k: Int
10+
11+
public init(k: Int) {
12+
self.k = k
13+
}
14+
15+
public func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
16+
guard !arr.isEmpty else {
17+
return (indexes: [], logits: [])
18+
}
19+
let k = min(k, arr.count)
20+
let arrDescriptor = BNNSNDArrayDescriptor.allocate(
21+
initializingFrom: arr,
22+
shape: .vector(arr.count)
23+
)
24+
defer {
25+
arrDescriptor.deallocate()
26+
}
27+
let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized(
28+
scalarType: Int32.self,
29+
shape: .vector(k)
30+
)
31+
defer {
32+
bestIndices.deallocate()
33+
}
34+
let bestValues = BNNSNDArrayDescriptor.allocateUninitialized(
35+
scalarType: Float.self,
36+
shape: .vector(k)
37+
)
38+
defer {
39+
bestValues.deallocate()
40+
}
41+
try! Accelerate.BNNS.applyTopK(
42+
k: k,
43+
input: arrDescriptor,
44+
bestValues: bestValues,
45+
bestIndices: bestIndices,
46+
axis: 0,
47+
batchSize: 1,
48+
filterParameters: nil
49+
)
50+
let distances = bestValues.data!.withMemoryRebound(to: Float.self, capacity: k) { ptr in
51+
Array(UnsafeBufferPointer(start: ptr, count: k))
52+
}
53+
let indices = bestIndices.data!.withMemoryRebound(to: Int32.self, capacity: k) { ptr in
54+
Array(UnsafeBufferPointer(start: ptr, count: k))
55+
}
56+
return (indexes: indices.map { Int($0) }, logits: distances)
57+
}
58+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import Foundation
2+
3+
/// Top-P.
4+
/// Select the smallest set of elements whose cumulative probability exceeds the probability `p`.
5+
/// Based on https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
6+
public struct TopPLogitsWarper: LogitsWarper {
7+
public var p: Float
8+
9+
public init(p: Float) {
10+
self.p = p
11+
}
12+
13+
public func warp(_ arr: [Float]) -> (indexes: [Int], logits: [Float]) {
14+
guard !arr.isEmpty else {
15+
return (indexes: [], logits: [])
16+
}
17+
18+
let arrSoftmax = Math.softmax(arr)
19+
var indexLogitProb = [(index: Int, logit: Float, prob: Float)]()
20+
indexLogitProb.reserveCapacity(arr.count)
21+
for (index, data) in zip(arr, arrSoftmax).enumerated() {
22+
indexLogitProb.append((index: index, logit: data.0, prob: data.1))
23+
}
24+
indexLogitProb.sort { $0.prob > $1.prob }
25+
26+
let cumsum = Math.cumsum(indexLogitProb.map(\.prob))
27+
var sliceIndex = cumsum.count - 1
28+
for (index, element) in cumsum.enumerated() where element > p {
29+
sliceIndex = index
30+
break
31+
}
32+
33+
let indexes = indexLogitProb[0 ... sliceIndex].map(\.index)
34+
let logits = indexLogitProb[0 ... sliceIndex].map(\.logit)
35+
return (indexes: indexes, logits: logits)
36+
}
37+
}

Sources/TensorUtils/Math.swift

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -83,83 +83,6 @@ public struct Math {
8383
return result
8484
}
8585

86-
/// Top-P.
87-
/// Select the smallest set of elements whose cumulative probability exceeds the probability `p`.
88-
/// Based on https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
89-
public static func topP(arr: [Float], p: Float) -> (indexes: [Int], probs: [Float]) {
90-
guard !arr.isEmpty else {
91-
return (indexes: [], probs: [])
92-
}
93-
94-
let arrSoftmax = softmax(arr)
95-
var indexLogitProb = [(index: Int, logit: Float, prob: Float)]()
96-
indexLogitProb.reserveCapacity(arr.count)
97-
for (index, data) in zip(arr, arrSoftmax).enumerated() {
98-
indexLogitProb.append((index: index, logit: data.0, prob: data.1))
99-
}
100-
indexLogitProb.sort { $0.prob > $1.prob }
101-
102-
let cumsum = cumsum(indexLogitProb.map(\.prob))
103-
var sliceIndex = cumsum.count - 1
104-
for (index, element) in cumsum.enumerated() where element > p {
105-
sliceIndex = index
106-
break
107-
}
108-
109-
let indexes = indexLogitProb[0 ... sliceIndex].map(\.index)
110-
let probs = softmax(indexLogitProb[0 ... sliceIndex].map(\.logit))
111-
return (indexes: indexes, probs: probs)
112-
}
113-
114-
/// Top-K.
115-
/// Select the k most-probable elements indices from `arr`
116-
/// and return both the indices (from the original array)
117-
/// and their softmaxed probabilities.
118-
///
119-
public static func topK(arr: [Float], k: Int) -> (indexes: [Int], probs: [Float]) {
120-
guard !arr.isEmpty else {
121-
return (indexes: [], probs: [])
122-
}
123-
let k = min(k, arr.count)
124-
let arrDescriptor = BNNSNDArrayDescriptor.allocate(
125-
initializingFrom: arr,
126-
shape: .vector(arr.count)
127-
)
128-
defer {
129-
arrDescriptor.deallocate()
130-
}
131-
let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized(
132-
scalarType: Int32.self,
133-
shape: .vector(k)
134-
)
135-
defer {
136-
bestIndices.deallocate()
137-
}
138-
let bestValues = BNNSNDArrayDescriptor.allocateUninitialized(
139-
scalarType: Float.self,
140-
shape: .vector(k)
141-
)
142-
defer {
143-
bestValues.deallocate()
144-
}
145-
try! Accelerate.BNNS.applyTopK(
146-
k: k,
147-
input: arrDescriptor,
148-
bestValues: bestValues,
149-
bestIndices: bestIndices,
150-
axis: 0,
151-
batchSize: 1,
152-
filterParameters: nil
153-
)
154-
let distances = bestValues.data!.withMemoryRebound(to: Float.self, capacity: k) { ptr in
155-
Array(UnsafeBufferPointer(start: ptr, count: k))
156-
}
157-
let indices = bestIndices.data!.withMemoryRebound(to: Int32.self, capacity: k) { ptr in
158-
Array(UnsafeBufferPointer(start: ptr, count: k))
159-
}
160-
return (indexes: indices.map { Int($0) }, probs: softmax(distances))
161-
}
162-
16386
/// Multinomial sampling from an array of probs. Works well with topK
16487
public static func sample(indexes: [Int], probs: [Float]) -> Int {
16588
let i = randomNumber(probabilities: probs)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//
2+
// LogitsWarperTests.swift
3+
//
4+
// Created by Jan Krukowski on 09/12/2023.
5+
//
6+
7+
import XCTest
8+
import CoreML
9+
@testable import TensorUtils
10+
11+
final class LogitsWarperTests: XCTestCase {
12+
private let accuracy: Float = 0.00001
13+
14+
func testTemperatureLogitsWarper() {
15+
let result1 = TemperatureLogitsWarper(temperature: 0.0)([])
16+
XCTAssertTrue(result1.indexes.isEmpty)
17+
XCTAssertTrue(result1.logits.isEmpty)
18+
19+
let result2 = TemperatureLogitsWarper(temperature: 1.0)([])
20+
XCTAssertTrue(result2.indexes.isEmpty)
21+
XCTAssertTrue(result2.logits.isEmpty)
22+
23+
let result3 = TemperatureLogitsWarper(temperature: 1.0)([2.0, 1.0])
24+
XCTAssertEqual(result3.indexes, [0, 1])
25+
XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy)
26+
27+
let result4 = TemperatureLogitsWarper(temperature: 2.0)([2.0, 1.0])
28+
XCTAssertEqual(result4.indexes, [0, 1])
29+
XCTAssertEqual(result4.logits, [1.0, 0.5], accuracy: accuracy)
30+
31+
let result5 = TemperatureLogitsWarper(temperature: 0.5)([2.0, 1.0])
32+
XCTAssertEqual(result5.indexes, [0, 1])
33+
XCTAssertEqual(result5.logits, [4.0, 2.0], accuracy: accuracy)
34+
}
35+
36+
func testTopKLogitsWarper() {
37+
let result1 = TopKLogitsWarper(k: 0)([])
38+
XCTAssertTrue(result1.indexes.isEmpty)
39+
XCTAssertTrue(result1.logits.isEmpty)
40+
41+
let result2 = TopKLogitsWarper(k: 3)([])
42+
XCTAssertTrue(result2.indexes.isEmpty)
43+
XCTAssertTrue(result2.logits.isEmpty)
44+
45+
let result3 = TopKLogitsWarper(k: 3)([2.0, 1.0])
46+
XCTAssertEqual(result3.indexes, [0, 1])
47+
XCTAssertEqual(result3.logits, [2.0, 1.0], accuracy: accuracy)
48+
49+
let result4 = TopKLogitsWarper(k: 3)([2.0, 1.0, 3.0])
50+
XCTAssertEqual(result4.indexes, [2, 0, 1])
51+
XCTAssertEqual(result4.logits, [3.0, 2.0, 1.0], accuracy: accuracy)
52+
53+
let result5 = TopKLogitsWarper(k: 4)([2.0, 1.0, 3.0, -1.0, 123.0, 0.0])
54+
XCTAssertEqual(result5.indexes, [4, 2, 0, 1])
55+
XCTAssertEqual(result5.logits, [123.0, 3.0, 2.0, 1.0], accuracy: accuracy)
56+
}
57+
58+
func testTopPLogitsWarper() {
59+
let result1 = TopPLogitsWarper(p: 0.99)([])
60+
XCTAssertTrue(result1.indexes.isEmpty)
61+
XCTAssertTrue(result1.logits.isEmpty)
62+
63+
let result2 = TopPLogitsWarper(p: 0.99)((0 ..< 10).map { Float($0) })
64+
XCTAssertEqual(result2.indexes, [9, 8, 7, 6, 5])
65+
XCTAssertEqual(result2.logits, [9.0, 8.0, 7.0, 6.0, 5.0], accuracy: accuracy)
66+
67+
let result3 = TopPLogitsWarper(p: 0.95)((0 ..< 10).map { Float($0) })
68+
XCTAssertEqual(result3.indexes, [9, 8, 7])
69+
XCTAssertEqual(result3.logits, [9.0, 8.0, 7.0], accuracy: accuracy)
70+
71+
let result4 = TopPLogitsWarper(p: 0.6321493)((0 ..< 10).map { Float($0) })
72+
XCTAssertEqual(result4.indexes, [9, 8])
73+
XCTAssertEqual(result4.logits, [9.0, 8.0], accuracy: accuracy)
74+
}
75+
76+
func testLogitsProcessor() {
77+
let processor1 = LogitsProcessor(logitsWarpers: [])
78+
let result1 = processor1([])
79+
XCTAssertTrue(result1.indexes.isEmpty)
80+
XCTAssertTrue(result1.logits.isEmpty)
81+
82+
let processor2 = LogitsProcessor(logitsWarpers: [])
83+
let result2 = processor2([2.0, 1.0])
84+
XCTAssertEqual(result2.indexes, [0, 1])
85+
XCTAssertEqual(result2.logits, [2.0, 1.0], accuracy: accuracy)
86+
87+
let processor3 = LogitsProcessor(
88+
logitsWarpers: [TopKLogitsWarper(k: 3)]
89+
)
90+
let result3 = processor3([2.0, 1.0, 3.0, -5.0])
91+
XCTAssertEqual(result3.indexes, [2, 0, 1])
92+
XCTAssertEqual(result3.logits, [3.0, 2.0, 1.0], accuracy: accuracy)
93+
94+
let processor4 = LogitsProcessor(
95+
logitsWarpers: [TopKLogitsWarper(k: 3), TopPLogitsWarper(p: 0.99)]
96+
)
97+
let result4 = processor4([2.0, 1.0, 3.0, -5.0, -23.0, 12.5])
98+
XCTAssertEqual(result4.indexes, [0])
99+
XCTAssertEqual(result4.logits, [12.5], accuracy: accuracy)
100+
}
101+
}

0 commit comments

Comments
 (0)