Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Sources/Generation/Generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ public extension Generation {
if config.topK > 0 {
let topK = Math.topK(arr: logits, k: config.topK)
nextToken = Math.sample(indexes: topK.indexes, probs: topK.probs)
} else if config.topP < 1.0 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding of this is correct, top-k can coexist with top-p: https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L805-L808

However, it could make sense to merge this PR now and making them coexist in a future one. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say let's merge it now, seems logical to create a separate PR with a common interface to these two

let topP = Math.topP(arr: logits, p: Float(config.topP))
nextToken = Math.sample(indexes: topP.indexes, probs: topP.probs)
} else {
fatalError("topP not implemented yet")
fatalError("not implemented yet")
Comment on lines -71 to +74
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make top-k compatible with top-p, we'd do a single sample call on the selected tokens and remove this fatalError.

}

if nextToken == config.eosTokenId { break }
Expand Down
42 changes: 42 additions & 0 deletions Sources/TensorUtils/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,48 @@ public struct Math {
let ptr = UnsafeMutablePointer<Float32>(OpaquePointer(multiArray.dataPointer))
return Math.argmax32(ptr, count: multiArray.count)
}

/// Returns the cumulative sum of the array.
public static func cumsum(_ arr: [Float]) -> [Float] {
guard !arr.isEmpty else {
return []
}
let arrCount = vDSP_Length(arr.count)
var weight: Float = 1.0
var result: [Float] = Array(repeating: 0.0, count: arr.count)
var firstItem = arr[0]
vDSP_vrsum(arr, 1, &weight, &result, 1, arrCount)
vDSP_vsadd(result, 1, &firstItem, &result, 1, arrCount)
return result
}

/// Top-P.
/// Select the smallest set of elements whose cumulative probability exceeds the probability `p`.
/// Based on https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
public static func topP(arr: [Float], p: Float) -> (indexes: [Int], probs: [Float]) {
guard !arr.isEmpty else {
return (indexes: [], probs: [])
}

let arrSoftmax = softmax(arr)
var indexLogitProb = [(index: Int, logit: Float, prob: Float)]()
indexLogitProb.reserveCapacity(arr.count)
for (index, data) in zip(arr, arrSoftmax).enumerated() {
indexLogitProb.append((index: index, logit: data.0, prob: data.1))
}
indexLogitProb.sort { $0.prob > $1.prob }

let cumsum = cumsum(indexLogitProb.map(\.prob))
var sliceIndex = cumsum.count - 1
for (index, element) in cumsum.enumerated() where element > p {
sliceIndex = index
break
}

let indexes = indexLogitProb[0 ... sliceIndex].map(\.index)
let probs = softmax(indexLogitProb[0 ... sliceIndex].map(\.logit))
return (indexes: indexes, probs: probs)
}

/// Top-K.
/// Select the k most-probable elements indices from `arr`
Expand Down
27 changes: 27 additions & 0 deletions Tests/TensorUtilsTests/TensorUtilsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,33 @@ class TensorUtilsTests: XCTestCase {
XCTAssertEqual(result5.probs.reduce(0, +), 1.0, accuracy: accuracy)
}

func testTopP() {
let result1 = Math.topP(arr: [], p: 0.99)
XCTAssertTrue(result1.indexes.isEmpty)
XCTAssertTrue(result1.probs.isEmpty)

let result2 = Math.topP(arr: (0 ..< 10).map { Float($0) }, p: 0.99)
XCTAssertEqual(result2.indexes, [9, 8, 7, 6, 5])
XCTAssertEqual(result2.probs, [0.63640857, 0.23412164, 0.08612853, 0.031684916, 0.011656229], accuracy: accuracy)
XCTAssertEqual(result2.probs.reduce(0, +), 1.0, accuracy: accuracy)

let result3 = Math.topP(arr: (0 ..< 10).map { Float($0) }, p: 0.95)
XCTAssertEqual(result3.indexes, [9, 8, 7])
XCTAssertEqual(result3.probs, [0.6652409, 0.24472845, 0.090030566], accuracy: accuracy)
XCTAssertEqual(result3.probs.reduce(0, +), 1.0, accuracy: accuracy)

let result4 = Math.topP(arr: (0 ..< 10).map { Float($0) }, p: 0.6321493)
XCTAssertEqual(result4.indexes, [9, 8])
XCTAssertEqual(result4.probs, [0.7310586, 0.26894143], accuracy: accuracy)
XCTAssertEqual(result4.probs.reduce(0, +), 1.0, accuracy: accuracy)
}

func testCumsum() {
XCTAssertTrue(Math.cumsum([]).isEmpty)
XCTAssertEqual(Math.cumsum([1]), [1])
XCTAssertEqual(Math.cumsum([1, 2, 3, 4]), [1, 3, 6, 10])
}

func testArgMax() throws {
let result1 = Math.argmax([3.0, 4.0, 1.0, 2.0] as [Float], count: 4)

Expand Down