-
Notifications
You must be signed in to change notification settings - Fork 241
Open
Description
I tried applied the tract model with the help of some ai, and this is what I got (at least it builds partly - see errors below).
use hound;
use ndarray::Array;
use tract_onnx::prelude::*;
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, Read};
use std::path::Path;
use std::process::Command;
const SUPPORTED_LANGUAGES: [&str; 7] = [
"en-us", // English
"en-gb", // English (British)
"es", // Spanish
"fr-fr", // French
"ja", // Japanese
"ko", // Korean
"cmn", // Mandarin Chinese
];
const MAX_PHONEME_LENGTH: usize = 510;
const SAMPLE_RATE: usize = 24000;
#[derive(Debug)]
pub struct EspeakConfig {
pub lib_path: Option<String>,
pub data_path: Option<String>,
}
#[derive(Debug)]
pub struct KoKoroConfig {
pub model_path: String,
pub voices_path: String,
pub espeak_config: Option<EspeakConfig>,
}
impl KoKoroConfig {
pub fn new(model_path: &str, voices_path: &str, espeak_config: Option<EspeakConfig>) -> Self {
KoKoroConfig {
model_path: model_path.to_string(),
voices_path: voices_path.to_string(),
espeak_config,
}
}
pub fn validate(&self) -> Result<(), String> {
if !Path::new(&self.voices_path).exists() {
let error_msg = format!(
"Voices file not found at {}. You can download the voices file using the following command:\nwget https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/voices.bin",
self.voices_path
);
return Err(error_msg);
}
if !Path::new(&self.model_path).exists() {
let error_msg = format!(
"Model file not found at {}. You can download the model file using the following command:\nwget https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/kokoro-v0_19.onnx",
self.model_path
);
return Err(error_msg);
}
Ok(())
}
}
pub fn get_vocab() -> HashMap<char, usize> {
let pad = '$';
let punctuation = ";:,.!?¡¿—…\"«»“ ” ";
let letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
let letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ";
let mut symbols: Vec<char> = vec![pad];
symbols.extend(punctuation.chars());
symbols.extend(letters.chars());
symbols.extend(letters_ipa.chars());
let mut vocab_map = HashMap::new();
for (index, symbol) in symbols.iter().enumerate() {
vocab_map.insert(*symbol, index);
}
vocab_map
}
pub struct Tokenizer {
espeak_config: EspeakConfig,
}
impl Tokenizer {
pub fn new(espeak_config: EspeakConfig) -> Self {
Tokenizer { espeak_config }
}
fn phonemize(&self, text: &str, lang: &str) -> Result<String, Box<dyn std::error::Error>> {
// Run eSpeak and capture the output
let output = Command::new("espeak-ng")
.args(&["-q", "-x", text, "--lang", lang]) // Arguments to eSpeak
.output()?;
// Check for errors
if !output.status.success() {
return Err(Box::new(io::Error::new(io::ErrorKind::Other, "eSpeak failed to execute")));
}
// Convert the output to a String
let phonemes = String::from_utf8_lossy(&output.stdout).to_string();
Ok(phonemes.trim().to_string())
}
pub fn tokenize(&self, text: &str, lang: &str) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let phonemes = self.phonemize(text, lang)?;
if phonemes.len() > MAX_PHONEME_LENGTH {
return Err(Box::new(io::Error::new(io::ErrorKind::InvalidInput, format!(
"Text is too long, must be less than {} phonemes",
MAX_PHONEME_LENGTH
))));
}
// Return a list of phonemes
Ok(phonemes.split_whitespace().map(|s| s.to_string()).collect())
}
}
pub struct Kokoro {
model: SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
voices: HashMap<String, Vec<f32>>, // Store voices as vectors (or more complex types as needed)
vocab: HashMap<char, usize>, // Runtime vocab
tokenizer: Tokenizer,
}
impl Kokoro {
pub fn new(model_path: &str, voices_path: &str, espeak_config: EspeakConfig) -> Result<Self, Box<dyn std::error::Error>> {
let model = tract_onnx::onnx()
.model_for_path(model_path)?
.into_optimized()?
.into_runnable()?;
let voices = read_voices(voices_path)?;
let vocab = get_vocab(); // Get the vocabulary at runtime
let tokenizer = Tokenizer::new(espeak_config);
Ok(Kokoro { model, voices, vocab, tokenizer })
}
fn get_voice_style(&self, name: &str) -> &[f32] {
self.voices.get(name).expect("Voice not found")
}
fn create_audio(&self, tokens: &[String], voice: &[f32], speed: f32) -> Result<Tensor, Box<dyn std::error::Error>> {
assert!(tokens.len() <= MAX_PHONEME_LENGTH, "Too many phonemes!");
// Prepare token input for ONNX model
let mut token_input = Array::<u8, _>::zeros((1, tokens.len() + 2)); // Use u8 for input
for (i, token) in tokens.iter().enumerate() {
token_input[[0, i + 1]] = *self.vocab.get(&token.chars().next().unwrap()).unwrap_or(&0) as u8; // Ensure input is u8
}
// Convert the input array to a Tensor
let input_tensor: Tensor = token_input.into_tensor();
// Create the input tensor vector needed for ONNX model
let input_tensors: TVec<TValue> = tvec!(input_tensor.into());
let result = self.model.run(input_tensors)?;
// Convert the result back to a tensor
let tensor: Tensor = result[0].to_owned().into_tensor();
Ok(tensor)
}
pub fn create(&self, text: &str, voice: &str, speed: f32, lang: &str) -> Result<(Vec<f32>, usize), Box<dyn std::error::Error>> {
// Language validation
if !SUPPORTED_LANGUAGES.contains(&lang) {
return Err(Box::new(io::Error::new(io::ErrorKind::InvalidInput, format!(
"Language must be one of: {:?}. Got: {}", SUPPORTED_LANGUAGES, lang
))));
}
let tokens = self.tokenizer.tokenize(text, lang)?; // Use the Tokenizer for phonemization
let voice_style = self.get_voice_style(voice);
let audio_tensor = self.create_audio(&tokens, voice_style, speed)?;
let audio_samples: Vec<f32> = audio_tensor.to_array_view::<f32>()?.iter().copied().collect();
Ok((audio_samples, SAMPLE_RATE))
}
}
// Function to read voices from a binary file
fn read_voices(file_path: &str) -> io::Result<HashMap<String, Vec<f32>>> {
let mut file = File::open(file_path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
// Example parsing logic (you will need a proper way to interpret your binary file)
let mut voices = HashMap::new();
// Implement voice parsing logic to fill the voices HashMap here (example placeholder)
voices.insert("af_heart".to_string(), vec![0.0; 256]); // Placeholder data
Ok(voices)
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let model_path = "kokoro-v1.0.fp16.onnx"; // Update path as necessary
let voices_path = "voices.bin"; // Update path as necessary
let espeak_config = EspeakConfig {
lib_path: None, // Optionally set paths if needed
data_path: None,
};
// Create the KoKoro synthesizer
let kokoro = Kokoro::new(model_path, voices_path, espeak_config)?;
// Input Text Configuration
let text = "This is an English phrase for synthesis."; // Example English input
let lang = "en-us";
// Generate audio sample
let (audio_samples, sample_rate) = kokoro.create(text, "af_heart", 1.0, lang)?;
// Write the generated audio to a WAV file
let mut writer = hound::WavWriter::create("audio.wav", hound::WavSpec {
channels: 1, // Mono
sample_rate: sample_rate as u32,
bits_per_sample: 16, // Change as necessary
sample_format: hound::SampleFormat::Int,
})?;
for sample in audio_samples {
let sample_i16 = (sample * i16::MAX as f32).round() as i16; // Convert f32 to i16
writer.write_sample(sample_i16)?;
}
writer.finalize()?;
Ok(())
}
I get the following error:
Error: Failed analyse for node #1034 "/encoder/text_encoder/cnn.0/cnn.0.0/Conv_quant" ConvHir
Caused by:
0: Infering facts
1: Applying rule inputs[0].datum_type == inputs[3].datum_type
2: Impossible to unify U8 with I8.
Any help in the right direction would be appreciated. There is a also a regular fp32 version of the onnx model.
spazziale and andreytkachenko
Metadata
Metadata
Assignees
Labels
No labels