@@ -44,71 +44,55 @@ public struct MinPLogitsWarper: LogitsProcessor {
4444 public func callAsFunction( _ inputIds: MLTensor , _ scores: MLTensor ) async -> MLTensor {
4545 // Algorithm (following transformers implementation):
4646 // 1. Compute probabilities from logits
47- // 2. Find max probability per batch
48- // 3. Create threshold = minP * maxProb
49- // 4. Sort logits and mask tokens where prob < threshold
50- // 5. Keep at least minTokensToKeep
51- // 6. Scatter back to original order
47+ // 2. Find max probability per batch (with keepdim)
48+ // 3. Calculate threshold = minP * maxProb
49+ // 4. Create mask for tokens where prob < threshold
50+ // 5. Use topK to get min_tokens_to_keep and unmask them
51+ // 6. Apply mask to scores
5252
5353 let vocabSize = scores. shape [ scores. rank - 1 ]
5454
55- // Compute probabilities
55+ // Convert logits to probabilities
5656 let probs = scores. softmax ( alongAxis: - 1 )
5757
58- // Sort probabilities descending to get max (first element)
59- let sortedProbIndices = probs . argsort ( alongAxis : - 1 , descendingOrder : true )
60- let sortedProbs = probs. gathering ( atIndices : sortedProbIndices , alongAxis : - 1 )
58+ // Get the probability of the top token for each sequence in the batch
59+ // Using max with keepRank=true to maintain dimensions for broadcasting
60+ let topProbs = probs. max ( alongAxes : [ - 1 ] , keepRank : true )
6161
62- // Extract max prob per batch: first element of each sorted sequence
63- // Do this on CPU to avoid complex broadcasting issues
64- let sortedProbsArray = await sortedProbs. shapedArray ( of: Float . self)
65- let batchSize = scores. shape [ 0 ]
66- var thresholdScalars = [ Float] ( )
67- thresholdScalars. reserveCapacity ( batchSize * vocabSize)
68- for batchIdx in 0 ..< batchSize {
69- let maxProb = sortedProbsArray. scalars [ batchIdx * vocabSize] // First element
70- let thresholdVal = minP * maxProb
71- for _ in 0 ..< vocabSize {
72- thresholdScalars. append ( thresholdVal)
73- }
74- }
75- let threshold = MLTensor ( shape: probs. shape, scalars: thresholdScalars, scalarType: Float . self)
62+ // Calculate the actual min_p threshold by scaling min_p with the top token's probability
63+ let scaledMinP = topProbs * minP
7664
77- // Create mask: tokensToRemove where prob < threshold
78- let tokensToRemove = probs .< threshold
65+ // Create a mask for tokens that have a probability less than the scaled min_p
66+ let tokensToRemove = probs .< scaledMinP
7967
80- // Sort scores descending
81- let sortedScoreIndices = scores. argsort ( alongAxis: - 1 , descendingOrder: true )
82- let inversePermutation = sortedScoreIndices. argsort ( alongAxis: - 1 )
68+ // Keep at least min_tokens_to_keep tokens (clip k to vocab size if needed)
69+ let k = min ( minTokensToKeep, vocabSize)
8370
84- // Gather mask in sorted order
85- let sortedTokensToRemove = tokensToRemove. gathering ( atIndices: sortedScoreIndices, alongAxis: - 1 )
71+ // Get indices of top-k probabilities
72+ let topKResult = probs. topK ( k)
73+ let topKIndices = topKResult. indices
8674
87- // Create position tensor for minTokensToKeep check
88- let posBaseShape = Array ( repeating : 1 , count : scores . rank - 1 ) + [ vocabSize ]
89- var posMultiples = scores . shape
90- posMultiples [ posMultiples . count - 1 ] = 1
75+ // Create a mask to keep the top-k tokens
76+ // Since MLTensor doesn't have a scatter operation that works like PyTorch's scatter_,
77+ // we use replacing(atIndices:with:alongAxis:) which replaces values at specified indices.
78+ // For our case, we want to unmask (set to False/0) the top-k token positions.
9179
92- let positions = MLTensor (
93- rangeFrom: Int32 ( 0 ) ,
94- to: Int32 ( vocabSize) ,
95- by: 1 ,
96- scalarType: Int32 . self
97- )
98- . reshaped ( to: posBaseShape)
99- . tiled ( multiples: posMultiples)
80+ // Convert boolean mask to Int32 (1 = remove, 0 = keep)
81+ let zerosInt = MLTensor ( repeating: Int32 ( 0 ) , shape: tokensToRemove. shape, scalarType: Int32 . self)
82+ let onesInt = MLTensor ( repeating: Int32 ( 1 ) , shape: tokensToRemove. shape, scalarType: Int32 . self)
83+ let tokensToRemoveAsInt = zerosInt. replacing ( with: onesInt, where: tokensToRemove)
10084
101- // Mask: remove if (position >= minTokensToKeep AND shouldRemove)
102- let beyondMinimum = positions .>= Int32 ( minTokensToKeep )
103- let finalRemoveMask = sortedTokensToRemove .& beyondMinimum
85+ // Try using replacing(atIndices:with:alongAxis:) which takes a scalar value
86+ // This replaces slices at the specified indices with the scalar value
87+ let finalTokensToRemoveInt = tokensToRemoveAsInt . replacing ( atIndices : topKIndices , with : Int32 ( 0 ) , alongAxis : - 1 )
10488
105- // Apply filter in sorted space
106- let sortedScores = scores. gathering ( atIndices: sortedScoreIndices, alongAxis: - 1 )
107- let filterTensor = MLTensor ( repeating: filterValue, shape: sortedScores. shape, scalarType: Float . self)
108- let filteredSorted = sortedScores. replacing ( with: filterTensor, where: finalRemoveMask)
89+ // Convert back to boolean mask
90+ let zerosComparison = MLTensor ( repeating: Int32 ( 0 ) , shape: tokensToRemove. shape, scalarType: Int32 . self)
91+ let finalTokensToRemove = finalTokensToRemoveInt .!= zerosComparison
10992
110- // Scatter back to original order
111- return filteredSorted. gathering ( atIndices: inversePermutation, alongAxis: - 1 )
93+ // Apply mask to scores
94+ let filterTensor = MLTensor ( repeating: filterValue, shape: scores. shape, scalarType: Float . self)
95+ return scores. replacing ( with: filterTensor, where: finalTokensToRemove)
11296 }
11397}
11498#endif
0 commit comments