Skip to content

Commit

Permalink
Fix timestamp rules filter
Browse files Browse the repository at this point in the history
- Also adds back missing language property from merge
  • Loading branch information
ZachNagengast committed Mar 28, 2024
1 parent d01bca4 commit 9e215f0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 28 deletions.
26 changes: 15 additions & 11 deletions Sources/WhisperKit/Core/LogitsFilter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ open class TimestampRulesFilter: LogitsFiltering {
}

public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
guard let sampleBegin = sampleBegin(for: tokens) else {
guard let sampleBegin = sampleBegin(for: tokens),
sampleBegin > tokens.count else {
return logits
}

// suppress <|notimestamps|> which is handled by `withoutTimestamps`
logits.fill(indexes: [[0, 0, specialTokens.noTimestampsToken as NSNumber]], with: -FloatType.infinity)

Expand Down Expand Up @@ -109,15 +111,17 @@ open class TimestampRulesFilter: LogitsFiltering {
}
}

if tokens.count == sampleBegin {
// suppress generating non-timestamp tokens at the beginning
logits.fillLastDimension(indexes: 0..<specialTokens.timeTokenBegin, with: -FloatType.infinity)
if let maxInitialTimestampIndex {
// apply the `maxInitialTimestamp` option
let lastAllowed = specialTokens.timeTokenBegin + maxInitialTimestampIndex + 1
logits.fillLastDimension(indexes: lastAllowed..<logits.count, with: -FloatType.infinity)
}
}
// TODO: Allow model to predict initial timestamp
// Currently initial timestamp is forced to <|0.00|> every time
// if tokens.count == sampleBegin {
// // suppress generating non-timestamp tokens at the beginning
// logits.fillLastDimension(indexes: 0..<specialTokens.timeTokenBegin, with: -FloatType.infinity)
// if let maxInitialTimestampIndex {
// // apply the `maxInitialTimestamp` option
// let lastAllowed = specialTokens.timeTokenBegin + maxInitialTimestampIndex + 1
// logits.fillLastDimension(indexes: lastAllowed..<logits.count, with: -FloatType.infinity)
// }
// }

// if sum of probability over timestamps is above any other token, sample timestamp
if sumOfProbabilityOverTimestampsIsAboveAnyOtherToken(logits: logits, timeTokenBegin: specialTokens.timeTokenBegin) {
Expand Down Expand Up @@ -242,7 +246,7 @@ open class TimestampRulesFilter: LogitsFiltering {


@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public class LanguageLogitsFilter: LogitsFiltering {
open class LanguageLogitsFilter: LogitsFiltering {
let allLanguageTokens: Set<Int>
let logitsDim: Int
let sampleBegin: Int
Expand Down
30 changes: 16 additions & 14 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ public protocol WhisperTokenizer: Tokenizer {
struct WhisperTokenizerWrapper: WhisperTokenizer {
let tokenizer: any Tokenizer
let specialTokens: SpecialTokens

init(tokenizer: any Tokenizer) {
self.tokenizer = tokenizer
self.specialTokens = SpecialTokens(
Expand All @@ -904,37 +904,37 @@ struct WhisperTokenizerWrapper: WhisperTokenizer {
private func splitTokensOnUnicode(tokens: [Int]) -> (words: [String], wordTokens: [[Int]]) {
let decodedFull = tokenizer.decode(tokens: tokens)
let replacementString = "\u{fffd}"

var words: [String] = []
var wordTokens: [[Int]] = []
var currentTokens: [Int] = []
var unicodeOffset = 0

for token in tokens {
currentTokens.append(token)
let decoded = tokenizer.decode(tokens: currentTokens)

var hasUnicodeInFullString = false
if let range = decoded.range(of: replacementString) {
hasUnicodeInFullString = decodedFull[range] == replacementString
}

if !decoded.contains(replacementString) || hasUnicodeInFullString {
words.append(decoded)
wordTokens.append(currentTokens)
currentTokens = []
unicodeOffset += decoded.count
}
}

return (words, wordTokens)
}

private func splitTokensOnSpaces(tokens: [Int]) -> (words: [String], wordTokens: [[Int]]) {
let (subwords, subwordTokensList) = splitTokensOnUnicode(tokens: tokens)
var words: [String] = []
var wordTokens: [[Int]] = []

for (subword, subwordTokens) in zip(subwords, subwordTokensList) {
let special = subwordTokens.first! >= specialTokens.specialTokenBegin
let withSpace = subword.hasPrefix(" ")
Expand All @@ -950,10 +950,10 @@ struct WhisperTokenizerWrapper: WhisperTokenizer {
wordTokens[words.count - 1].append(contentsOf: subwordTokens)
}
}

return (words, wordTokens)
}

private func isPunctuation(_ text: String, tokenRange: Range<String.Index>, tag: NLTag?) -> Bool {
let punctuationCharacters = CharacterSet.punctuationCharacters
let token = String(text[tokenRange])
Expand All @@ -965,25 +965,27 @@ struct WhisperTokenizerWrapper: WhisperTokenizer {
}
return false
}

/// Decodes token ids into individual words and per-word subtokens
/// - Parameter tokenIds: Array of tokens to decode and then split
/// - Parameter tokenIds: Array of tokens to decode and then split
/// - Returns: Tuple containing and array of the split words and all tokens for each word
func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]]) {
let decodedWords = tokenizer.decode(tokens: tokenIds.filter { $0 < specialTokens.specialTokenBegin })

// Detect language of input text
let recognizer = NLLanguageRecognizer()
recognizer.processString(decodedWords)
let languageCode = recognizer.dominantLanguage?.rawValue

if ["zh", "ja", "th", "lo", "my", "yue"].contains(languageCode) {
return splitTokensOnUnicode(tokens: tokenIds)
} else {
return splitTokensOnSpaces(tokens: tokenIds)
}
}
}

public extension WhisperTokenizer {
var languages: [String: String] {
[
"english": "en",
Expand Down
6 changes: 3 additions & 3 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
// Single loop variables
var timings = TranscriptionTimings()
let prefilledIndex = decoderInputs.cacheLength[0].intValue
let intialPromptIndex = decoderInputs.initialPrompt.count - 1
let intialPromptIndex = decoderInputs.initialPrompt.count
var currentTokens: [Int] = decoderInputs.initialPrompt
var nextToken: Int = decoderInputs.initialPrompt.last!
var logProbs: [Float] = Array(repeating: 0, count: prefilledIndex + 1)
Expand Down Expand Up @@ -477,8 +477,8 @@ open class TextDecoder: TextDecoding, WhisperMLModel {

// Check if current index is part of the initial prompt
isPrefill = false
if tokenIndex <= intialPromptIndex {
isPrefill = tokenIndex < intialPromptIndex // Prefill stops at the last token of the initial prompt
if tokenIndex < intialPromptIndex {
isPrefill = tokenIndex < intialPromptIndex - 1 // Prefill stops at the last token of the initial prompt
let prefillToken = currentTokens[tokenIndex]
nextToken = prefillToken
Logging.debug("Forcing token \(nextToken) at index \(tokenIndex) from initial prompt")
Expand Down

0 comments on commit 9e215f0

Please sign in to comment.