Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/audio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ mod device_monitor;
mod errors;
mod mic;
mod norm;
mod resampler;
mod speaker;

pub use device_monitor::*;
pub use errors::*;
pub use mic::*;
pub use norm::*;
pub use resampler::*;
pub use speaker::*;

pub use cpal;
Expand Down
206 changes: 206 additions & 0 deletions crates/audio/src/resampler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
use dasp::interpolate::Interpolator;
use futures_util::Stream;
use kalosm_sound::AsyncSource;

pub struct ResampledAsyncSource<S: AsyncSource> {
source: S,
target_sample_rate: u32,
sample_position: f64,
resampler: dasp::interpolate::linear::Linear<f32>,
last_source_rate: u32,
}

impl<S: AsyncSource> ResampledAsyncSource<S> {
pub fn new(source: S, target_sample_rate: u32) -> Self {
let initial_rate = source.sample_rate();
Self {
source,
target_sample_rate,
sample_position: initial_rate as f64 / target_sample_rate as f64,
resampler: dasp::interpolate::linear::Linear::new(0.0, 0.0),
last_source_rate: initial_rate,
}
}
}

impl<S: AsyncSource + Unpin> Stream for ResampledAsyncSource<S> {
type Item = f32;

fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let myself = self.get_mut();

let current_source_rate = myself.source.sample_rate();
if current_source_rate != myself.last_source_rate {
myself.last_source_rate = current_source_rate;
}

let source_output_sample_ratio =
current_source_rate as f64 / myself.target_sample_rate as f64;

let source = myself.source.as_stream();
let mut source = std::pin::pin!(source);

while myself.sample_position >= 1.0 {
match source.as_mut().poll_next(cx) {
std::task::Poll::Ready(Some(frame)) => {
myself.sample_position -= 1.0;
myself.resampler.next_source_frame(frame);
}
std::task::Poll::Ready(None) => return std::task::Poll::Ready(None),
std::task::Poll::Pending => return std::task::Poll::Pending,
}
}

let interpolated = myself.resampler.interpolate(myself.sample_position);
myself.sample_position += source_output_sample_ratio;

std::task::Poll::Ready(Some(interpolated))
}
}

impl<S: AsyncSource + Unpin> AsyncSource for ResampledAsyncSource<S> {
fn as_stream(&mut self) -> impl Stream<Item = f32> + '_ {
self
}

fn sample_rate(&self) -> u32 {
self.target_sample_rate
}
}

#[cfg(test)]
mod tests {
use futures_util::{Stream, StreamExt};
use kalosm_sound::AsyncSource;
use rodio::Source;
use std::pin::Pin;
use std::task::{Context, Poll};

use crate::ResampledAsyncSource;

fn get_samples_with_rate(path: impl AsRef<std::path::Path>) -> (Vec<f32>, u32) {
let source =
rodio::Decoder::new(std::io::BufReader::new(std::fs::File::open(path).unwrap()))
.unwrap();

let sample_rate = AsyncSource::sample_rate(&source);
let samples = source.convert_samples::<f32>().collect();
(samples, sample_rate)
}

struct DynamicRateSource {
segments: Vec<(Vec<f32>, u32)>,
current_segment: usize,
current_position: usize,
}

impl DynamicRateSource {
fn new(segments: Vec<(Vec<f32>, u32)>) -> Self {
Self {
segments,
current_segment: 0,
current_position: 0,
}
}
}

impl AsyncSource for DynamicRateSource {
fn as_stream(&mut self) -> impl Stream<Item = f32> + '_ {
DynamicRateStream { source: self }
}

fn sample_rate(&self) -> u32 {
if self.current_segment < self.segments.len() {
self.segments[self.current_segment].1
} else {
unreachable!()
}
}
}

struct DynamicRateStream<'a> {
source: &'a mut DynamicRateSource,
}

impl<'a> Stream for DynamicRateStream<'a> {
type Item = f32;

fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let source = &mut self.source;

while source.current_segment < source.segments.len() {
let (samples, _rate) = &source.segments[source.current_segment];

if source.current_position < samples.len() {
let sample = samples[source.current_position];
source.current_position += 1;
return Poll::Ready(Some(sample));
}

source.current_segment += 1;
source.current_position = 0;
}

Poll::Ready(None)
}
}

#[tokio::test]
async fn test_existing_resampler() {
let source = DynamicRateSource::new(vec![
get_samples_with_rate(hypr_data::english_1::AUDIO_PART1_8000HZ_PATH),
get_samples_with_rate(hypr_data::english_1::AUDIO_PART2_16000HZ_PATH),
get_samples_with_rate(hypr_data::english_1::AUDIO_PART3_22050HZ_PATH),
get_samples_with_rate(hypr_data::english_1::AUDIO_PART4_32000HZ_PATH),
get_samples_with_rate(hypr_data::english_1::AUDIO_PART5_44100HZ_PATH),
get_samples_with_rate(hypr_data::english_1::AUDIO_PART6_48000HZ_PATH),
]);

let mut out_wav = hound::WavWriter::create(
"./out_1.wav",
hound::WavSpec {
channels: 1,
sample_rate: 16000,
bits_per_sample: 32,
sample_format: hound::SampleFormat::Float,
},
)
.unwrap();

let mut resampled = source.resample(16000);
while let Some(sample) = resampled.next().await {
out_wav.write_sample(sample).unwrap();
}
}

#[tokio::test]
async fn test_new_resampler() {
let source = DynamicRateSource::new(vec![
get_samples_with_rate(hypr_data::english_1::AUDIO_PART1_8000HZ_PATH),
get_samples_with_rate(hypr_data::english_1::AUDIO_PART2_16000HZ_PATH),
get_samples_with_rate(hypr_data::english_1::AUDIO_PART3_22050HZ_PATH),
get_samples_with_rate(hypr_data::english_1::AUDIO_PART4_32000HZ_PATH),
get_samples_with_rate(hypr_data::english_1::AUDIO_PART5_44100HZ_PATH),
get_samples_with_rate(hypr_data::english_1::AUDIO_PART6_48000HZ_PATH),
]);

let mut out_wav = hound::WavWriter::create(
"./out_2.wav",
hound::WavSpec {
channels: 1,
sample_rate: 16000,
bits_per_sample: 32,
sample_format: hound::SampleFormat::Float,
},
)
.unwrap();

let mut resampled = ResampledAsyncSource::new(source, 16000);
while let Some(sample) = resampled.next().await {
out_wav.write_sample(sample).unwrap();
}
}
}
20 changes: 16 additions & 4 deletions crates/audio/src/speaker/macos.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Poll, Waker};

Expand Down Expand Up @@ -25,23 +26,24 @@ struct WakerState {

pub struct SpeakerStream {
consumer: HeapCons<f32>,
stream_desc: cat::AudioStreamBasicDesc,
_device: ca::hardware::StartedDevice<ca::AggregateDevice>,
_ctx: Box<Ctx>,
_tap: ca::TapGuard,
waker_state: Arc<Mutex<WakerState>>,
current_sample_rate: Arc<AtomicU32>,
}

impl SpeakerStream {
pub fn sample_rate(&self) -> u32 {
self.stream_desc.sample_rate as u32
self.current_sample_rate.load(Ordering::Relaxed)
}
}

struct Ctx {
format: arc::R<av::AudioFormat>,
producer: HeapProd<f32>,
waker_state: Arc<Mutex<WakerState>>,
current_sample_rate: Arc<AtomicU32>,
}

impl SpeakerInput {
Expand Down Expand Up @@ -100,7 +102,7 @@ impl SpeakerInput {
ctx: &mut Box<Ctx>,
) -> Result<ca::hardware::StartedDevice<ca::AggregateDevice>> {
extern "C" fn proc(
_device: ca::Device,
device: ca::Device,
_now: &cat::AudioTimeStamp,
input_data: &cat::AudioBufList<1>,
_input_time: &cat::AudioTimeStamp,
Expand All @@ -110,6 +112,13 @@ impl SpeakerInput {
) -> os::Status {
let ctx = ctx.unwrap();

ctx.current_sample_rate.store(
device
.actual_sample_rate()
.unwrap_or(ctx.format.absd().sample_rate) as u32,
Ordering::Relaxed,
);

assert_eq!(ctx.format.common_format(), av::audio::CommonFormat::PcmF32);

if let Some(view) =
Expand Down Expand Up @@ -157,21 +166,24 @@ impl SpeakerInput {
has_data: false,
}));

let current_sample_rate = Arc::new(AtomicU32::new(asbd.sample_rate as u32));

let mut ctx = Box::new(Ctx {
format,
producer,
waker_state: waker_state.clone(),
current_sample_rate: current_sample_rate.clone(),
});

let device = self.start_device(&mut ctx).unwrap();

SpeakerStream {
consumer,
stream_desc: asbd,
_device: device,
_ctx: ctx,
_tap: self.tap,
waker_state,
current_sample_rate,
}
}
}
Expand Down
34 changes: 34 additions & 0 deletions crates/data/scripts/resamples.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/bin/bash

# Check if input file is provided
if [ $# -eq 0 ]; then
echo "Usage: $0 <input_audio_file>"
exit 1
fi

INPUT_FILE="$1"
DIR=$(dirname "$INPUT_FILE")
BASENAME=$(basename "$INPUT_FILE" .wav)

# Array of common sample rates for testing
SAMPLE_RATES=(8000 16000 22050 32000 44100 48000)

# Get duration of input file in seconds
DURATION=$(ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 "$INPUT_FILE")

# Calculate part duration
NUM_PARTS=${#SAMPLE_RATES[@]}
PART_DURATION=$(echo "$DURATION / $NUM_PARTS" | bc -l)

# Generate parts with different sample rates
for i in "${!SAMPLE_RATES[@]}"; do
RATE=${SAMPLE_RATES[$i]}
PART_NUM=$((i + 1))
START=$(echo "$i * $PART_DURATION" | bc -l)
OUTPUT_FILE="${DIR}/${BASENAME}_part${PART_NUM}_${RATE}hz.wav"

echo "Creating part ${PART_NUM}: ${START}s-$(echo "$START + $PART_DURATION" | bc -l)s at ${RATE}Hz"
ffmpeg -i "$INPUT_FILE" -ss ${START} -t ${PART_DURATION} -ar ${RATE} "$OUTPUT_FILE" -y -loglevel error
done

echo "Done! Created ${NUM_PARTS} parts with different sample rates in ${DIR}/"
Binary file added crates/data/src/english_1/audio_part1_8000hz.wav
Binary file not shown.
Binary file not shown.
Binary file added crates/data/src/english_1/audio_part3_22050hz.wav
Binary file not shown.
Binary file added crates/data/src/english_1/audio_part4_32000hz.wav
Binary file not shown.
Binary file added crates/data/src/english_1/audio_part5_44100hz.wav
Binary file not shown.
Binary file added crates/data/src/english_1/audio_part6_48000hz.wav
Binary file not shown.
25 changes: 25 additions & 0 deletions crates/data/src/english_1/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,31 @@
pub const AUDIO: &[u8] = include_wav!("./audio.wav");
pub const AUDIO_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/src/english_1/audio.wav");

pub const AUDIO_PART1_8000HZ_PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/english_1/audio_part1_8000hz.wav"
);
pub const AUDIO_PART2_16000HZ_PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/english_1/audio_part2_16000hz.wav"
);
pub const AUDIO_PART3_22050HZ_PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/english_1/audio_part3_22050hz.wav"
);
pub const AUDIO_PART4_32000HZ_PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/english_1/audio_part4_32000hz.wav"
);
pub const AUDIO_PART5_44100HZ_PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/english_1/audio_part5_44100hz.wav"
);
pub const AUDIO_PART6_48000HZ_PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/english_1/audio_part6_48000hz.wav"
);

pub const TRANSCRIPTION_JSON: &str = include_str!("./transcription.json");

pub const TRANSCRIPTION_PATH: &str = concat!(
Expand Down
Loading
Loading