Skip to content

Commit

Permalink
[WenetDemo] improve safety and maintainability (#2621)
Browse files Browse the repository at this point in the history
- Added logging to enhance debugging and traceability.
- Replaced force unwrapping with safer optional binding techniques to prevent potential runtime crashes.
- Ensured proper memory management by using [weak self] in closures to avoid retain cycles and potential memory leaks.
  • Loading branch information
dangbo authored Aug 30, 2024
1 parent 2adf651 commit f446b8c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 41 deletions.
94 changes: 58 additions & 36 deletions runtime/ios/WenetDemo/WenetDemo/ViewController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ViewController: UIViewController {

var wenetModel: Wenet?
var audioEngine: AVAudioEngine?
var startRecord: Bool?
var startRecord = false
private var workItem: DispatchWorkItem?

override func viewDidLoad() {
Expand All @@ -38,32 +38,43 @@ class ViewController: UIViewController {
}

func initModel() {
let modelPath = Bundle.main.path(forResource: "final", ofType: "zip")
let dictPath = Bundle.main.path(forResource: "units", ofType: "txt")
wenetModel = Wenet(modelPath:modelPath, dictPath:dictPath)!
guard let modelPath = Bundle.main.path(forResource: "final", ofType: "zip"),
let dictPath = Bundle.main.path(forResource: "units", ofType: "txt") else {
print("Error: Model or dictionary file not found.")
return
}

wenetModel = Wenet(modelPath: modelPath, dictPath: dictPath)
wenetModel?.reset()
print("Model initialized successfully.")
}

func initRecorder() {
startRecord = false

audioEngine = AVAudioEngine()
let inputNode = self.audioEngine?.inputNode
guard let inputNode = audioEngine?.inputNode else {
print("Error: Unable to access audio input node.")
return
}

let bus = 0
let inputFormat = inputNode?.outputFormat(forBus: bus)
let outputFormat = AVAudioFormat(commonFormat: .pcmFormatFloat32,
sampleRate: 16000, channels: 1,
interleaved: false)!
let converter = AVAudioConverter(from: inputFormat!, to: outputFormat)!
inputNode!.installTap(onBus: bus,
bufferSize: 1024,
format: inputFormat) {
(buffer: AVAudioPCMBuffer, when: AVAudioTime) in
var newBufferAvailable = true
let inputFormat = inputNode.outputFormat(forBus: bus)
guard let outputFormat = AVAudioFormat(commonFormat: .pcmFormatFloat32,
sampleRate: 16000,
channels: 1,
interleaved: false) else {
print("Error: Unable to create output audio format.")
return
}

let inputCallback: AVAudioConverterInputBlock = {
inNumPackets, outStatus in
guard let converter = AVAudioConverter(from: inputFormat, to: outputFormat) else {
print("Error: Unable to create audio converter.")
return
}

inputNode.installTap(onBus: bus, bufferSize: 1024, format: inputFormat) { [weak self] buffer, _ in
guard let self = self else { return }
var newBufferAvailable = true
let inputCallback: AVAudioConverterInputBlock = { _, outStatus in
if newBufferAvailable {
outStatus.pointee = .haveData
newBufferAvailable = false
Expand All @@ -75,12 +86,11 @@ class ViewController: UIViewController {
}
}

let convertedBuffer = AVAudioPCMBuffer(
pcmFormat: outputFormat,
frameCapacity:
AVAudioFrameCount(outputFormat.sampleRate)
* buffer.frameLength
/ AVAudioFrameCount(buffer.format.sampleRate))!
guard let convertedBuffer = AVAudioPCMBuffer(pcmFormat: outputFormat,
frameCapacity: AVAudioFrameCount(outputFormat.sampleRate) * buffer.frameLength / AVAudioFrameCount(buffer.format.sampleRate)) else {
print("Error: Unable to create converted buffer.")
return
}

var error: NSError?
let status = converter.convert(
Expand All @@ -89,55 +99,67 @@ class ViewController: UIViewController {

// 16000 Hz buffer
let actualSampleCount = Int(convertedBuffer.frameLength)
guard let floatChannelData = convertedBuffer.floatChannelData
else { return }

guard let floatChannelData = convertedBuffer.floatChannelData else {
print("Error: No float channel data available.")
return
}

self.wenetModel?.acceptWaveForm(floatChannelData[0],
Int32(actualSampleCount))
print("Audio data accepted by the model.")
}
print("Audio recorder initialized successfully.")
}

@IBAction func btnClicked(_ sender: Any) {
if(!startRecord!) {
if(!startRecord) {
//Clear result
self.setResult(text: "")

//Reset model
self.wenetModel?.reset()
print("Model reset and result cleared.")

//Start record
do {
try self.audioEngine?.start()
try audioEngine?.start()
print("Audio engine started.")
} catch let error as NSError {
print("Got an error starting audioEngine: \(error.domain), \(error)")
return
}

//Start decode thread
workItem = DispatchWorkItem {
while(!self.workItem!.isCancelled) {
workItem = DispatchWorkItem { [weak self] in
guard let self = self else { return }
while !(self.workItem?.isCancelled ?? true) {
self.wenetModel?.decode()
DispatchQueue.main.sync {
self.setResult(text: (self.wenetModel?.get_result())!)
self.setResult(text: self.wenetModel?.get_result() ?? "")
print("Decoding in progress.")
}
}
}
DispatchQueue.global().async(execute: workItem!)

startRecord = true
button.setTitle("Stop Record", for: UIControl.State.normal)
button.setTitle("Stop Record", for: .normal)
print("Recording started.")
} else {
//Stop record
self.audioEngine?.stop()

//Stop decode thread
workItem!.cancel()

workItem?.cancel()
startRecord = false
button.setTitle("Start Record", for: UIControl.State.normal)
button.setTitle("Start Record", for: .normal)
print("Recording stopped.")
}
}

@objc func setResult(text: String) {
label.text = text
print("Result updated: \(text)")
}
}
11 changes: 6 additions & 5 deletions runtime/ios/WenetDemo/WenetDemo/wenet/wenet.mm
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ - (nullable instancetype)initWithModelPath:
- (void)reset {
decoder->Reset();
state = kEndBatch;
total_result = "";
total_result.clear();
}

- (void)acceptWaveForm: (float*)pcm: (int)size {
Expand Down Expand Up @@ -122,14 +122,15 @@ - (void)decode {
}
}

- (NSString*)get_result {
- (NSString *)get_result {
std::string result;
if (decoder->DecodedSomething()) {
result = decoder->result()[0].sentence;
}
LOG(INFO) << "wenet ui result: " << total_result + result;
NSLog(@"wenet ui result: %s", (total_result + result).c_str());
return [NSString stringWithUTF8String:(total_result + result).c_str()];
std::string final_result = total_result + result;
LOG(INFO) << "wenet ui result: " << final_result;
NSLog(@"wenet ui result: %s", final_result.c_str());
return [NSString stringWithUTF8String:final_result.c_str()];
}

@end

0 comments on commit f446b8c

Please sign in to comment.