Skip to content

Commit 460358e

Browse files
authored
[LogitWrappers] update min-p (#289)
* update minp * Apply suggestion from @kashif
1 parent 6fdfa9e commit 460358e

File tree

2 files changed

+41
-55
lines changed

2 files changed

+41
-55
lines changed

Examples/transformers-cli/Sources/transformers-cli/Transformers.swift

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import CoreML
33
import Foundation
44
import Generation
55
import Models
6+
import Tokenizers
67

78
@available(macOS 15.0, iOS 18.0, *)
89
@main
@@ -108,12 +109,13 @@ struct TransformersCLI: AsyncParsableCommand {
108109
let compiledURL = try compile(at: url)
109110
print("Loading model \(compiledURL)")
110111
let model: LanguageModel
111-
if let tokenizerFolder {
112-
let tokenizerURL = URL(filePath: tokenizerFolder, directoryHint: .isDirectory)
112+
if let tokenizerPath {
113+
let tokenizerURL = URL(filePath: tokenizerPath, directoryHint: .isDirectory)
114+
let tokenizer = try await AutoTokenizer.from(modelFolder: tokenizerURL)
113115
model = try LanguageModel.loadCompiled(
114116
url: compiledURL,
115-
tokenizerFolder: tokenizerURL,
116-
computeUnits: computeUnits.asMLComputeUnits
117+
computeUnits: computeUnits.asMLComputeUnits,
118+
tokenizer: tokenizer
117119
)
118120
} else {
119121
model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits)

Sources/Generation/LogitsWarper/MinPLogitsWarper.swift

Lines changed: 35 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -44,71 +44,55 @@ public struct MinPLogitsWarper: LogitsProcessor {
4444
public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor {
4545
// Algorithm (following transformers implementation):
4646
// 1. Compute probabilities from logits
47-
// 2. Find max probability per batch
48-
// 3. Create threshold = minP * maxProb
49-
// 4. Sort logits and mask tokens where prob < threshold
50-
// 5. Keep at least minTokensToKeep
51-
// 6. Scatter back to original order
47+
// 2. Find max probability per batch (with keepdim)
48+
// 3. Calculate threshold = minP * maxProb
49+
// 4. Create mask for tokens where prob < threshold
50+
// 5. Use topK to get min_tokens_to_keep and unmask them
51+
// 6. Apply mask to scores
5252

5353
let vocabSize = scores.shape[scores.rank - 1]
5454

55-
// Compute probabilities
55+
// Convert logits to probabilities
5656
let probs = scores.softmax(alongAxis: -1)
5757

58-
// Sort probabilities descending to get max (first element)
59-
let sortedProbIndices = probs.argsort(alongAxis: -1, descendingOrder: true)
60-
let sortedProbs = probs.gathering(atIndices: sortedProbIndices, alongAxis: -1)
58+
// Get the probability of the top token for each sequence in the batch
59+
// Using max with keepRank=true to maintain dimensions for broadcasting
60+
let topProbs = probs.max(alongAxes: [-1], keepRank: true)
6161

62-
// Extract max prob per batch: first element of each sorted sequence
63-
// Do this on CPU to avoid complex broadcasting issues
64-
let sortedProbsArray = await sortedProbs.shapedArray(of: Float.self)
65-
let batchSize = scores.shape[0]
66-
var thresholdScalars = [Float]()
67-
thresholdScalars.reserveCapacity(batchSize * vocabSize)
68-
for batchIdx in 0..<batchSize {
69-
let maxProb = sortedProbsArray.scalars[batchIdx * vocabSize] // First element
70-
let thresholdVal = minP * maxProb
71-
for _ in 0..<vocabSize {
72-
thresholdScalars.append(thresholdVal)
73-
}
74-
}
75-
let threshold = MLTensor(shape: probs.shape, scalars: thresholdScalars, scalarType: Float.self)
62+
// Calculate the actual min_p threshold by scaling min_p with the top token's probability
63+
let scaledMinP = topProbs * minP
7664

77-
// Create mask: tokensToRemove where prob < threshold
78-
let tokensToRemove = probs .< threshold
65+
// Create a mask for tokens that have a probability less than the scaled min_p
66+
let tokensToRemove = probs .< scaledMinP
7967

80-
// Sort scores descending
81-
let sortedScoreIndices = scores.argsort(alongAxis: -1, descendingOrder: true)
82-
let inversePermutation = sortedScoreIndices.argsort(alongAxis: -1)
68+
// Keep at least min_tokens_to_keep tokens (clip k to vocab size if needed)
69+
let k = min(minTokensToKeep, vocabSize)
8370

84-
// Gather mask in sorted order
85-
let sortedTokensToRemove = tokensToRemove.gathering(atIndices: sortedScoreIndices, alongAxis: -1)
71+
// Get indices of top-k probabilities
72+
let topKResult = probs.topK(k)
73+
let topKIndices = topKResult.indices
8674

87-
// Create position tensor for minTokensToKeep check
88-
let posBaseShape = Array(repeating: 1, count: scores.rank - 1) + [vocabSize]
89-
var posMultiples = scores.shape
90-
posMultiples[posMultiples.count - 1] = 1
75+
// Create a mask to keep the top-k tokens
76+
// Since MLTensor doesn't have a scatter operation that works like PyTorch's scatter_,
77+
// we use replacing(atIndices:with:alongAxis:) which replaces values at specified indices.
78+
// For our case, we want to unmask (set to False/0) the top-k token positions.
9179

92-
let positions = MLTensor(
93-
rangeFrom: Int32(0),
94-
to: Int32(vocabSize),
95-
by: 1,
96-
scalarType: Int32.self
97-
)
98-
.reshaped(to: posBaseShape)
99-
.tiled(multiples: posMultiples)
80+
// Convert boolean mask to Int32 (1 = remove, 0 = keep)
81+
let zerosInt = MLTensor(repeating: Int32(0), shape: tokensToRemove.shape, scalarType: Int32.self)
82+
let onesInt = MLTensor(repeating: Int32(1), shape: tokensToRemove.shape, scalarType: Int32.self)
83+
let tokensToRemoveAsInt = zerosInt.replacing(with: onesInt, where: tokensToRemove)
10084

101-
// Mask: remove if (position >= minTokensToKeep AND shouldRemove)
102-
let beyondMinimum = positions .>= Int32(minTokensToKeep)
103-
let finalRemoveMask = sortedTokensToRemove .& beyondMinimum
85+
// Try using replacing(atIndices:with:alongAxis:) which takes a scalar value
86+
// This replaces slices at the specified indices with the scalar value
87+
let finalTokensToRemoveInt = tokensToRemoveAsInt.replacing(atIndices: topKIndices, with: Int32(0), alongAxis: -1)
10488

105-
// Apply filter in sorted space
106-
let sortedScores = scores.gathering(atIndices: sortedScoreIndices, alongAxis: -1)
107-
let filterTensor = MLTensor(repeating: filterValue, shape: sortedScores.shape, scalarType: Float.self)
108-
let filteredSorted = sortedScores.replacing(with: filterTensor, where: finalRemoveMask)
89+
// Convert back to boolean mask
90+
let zerosComparison = MLTensor(repeating: Int32(0), shape: tokensToRemove.shape, scalarType: Int32.self)
91+
let finalTokensToRemove = finalTokensToRemoveInt .!= zerosComparison
10992

110-
// Scatter back to original order
111-
return filteredSorted.gathering(atIndices: inversePermutation, alongAxis: -1)
93+
// Apply mask to scores
94+
let filterTensor = MLTensor(repeating: filterValue, shape: scores.shape, scalarType: Float.self)
95+
return scores.replacing(with: filterTensor, where: finalTokensToRemove)
11296
}
11397
}
11498
#endif

0 commit comments

Comments
 (0)