Skip to content

Commit

Permalink
Add config for sense voice models
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jul 18, 2024
1 parent 3bae5c3 commit fe31189
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 4 deletions.
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ set(sources
offline-recognizer-impl.cc
offline-recognizer.cc
offline-rnn-lm.cc
offline-sense-voice-model-config.cc
offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// sherpa-onnx/csrc/offline-ct-transformer-model-meta_data.h
// sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_
Expand Down
7 changes: 5 additions & 2 deletions sherpa-onnx/csrc/offline-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
tdnn.Register(po);
zipformer_ctc.Register(po);
wenet_ctc.Register(po);
sense_voice.Register(po);

po->Register("telespeech-ctc", &telespeech_ctc,
"Path to model.onnx for telespeech ctc");
Expand Down Expand Up @@ -94,12 +95,14 @@ bool OfflineModelConfig::Validate() const {
return wenet_ctc.Validate();
}

if (!sense_voice.model.empty()) {
return sense_voice.Validate();
}

if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) {
SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist",
telespeech_ctc.c_str());
return false;
} else {
return true;
}

return transducer.Validate();
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
Expand All @@ -24,6 +25,7 @@ struct OfflineModelConfig {
OfflineTdnnModelConfig tdnn;
OfflineZipformerCtcModelConfig zipformer_ctc;
OfflineWenetCtcModelConfig wenet_ctc;
OfflineSenseVoiceModelConfig sense_voice;
std::string telespeech_ctc;

std::string tokens;
Expand Down Expand Up @@ -53,6 +55,7 @@ struct OfflineModelConfig {
const OfflineTdnnModelConfig &tdnn,
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
const OfflineSenseVoiceModelConfig &sense_voice,
const std::string &telespeech_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type,
Expand All @@ -65,6 +68,7 @@ struct OfflineModelConfig {
tdnn(tdnn),
zipformer_ctc(zipformer_ctc),
wenet_ctc(wenet_ctc),
sense_voice(sense_voice),
telespeech_ctc(telespeech_ctc),
tokens(tokens),
num_threads(num_threads),
Expand Down
54 changes: 54 additions & 0 deletions sherpa-onnx/csrc/offline-sense-voice-model-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// sherpa-onnx/csrc/offline-sense-voice-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

void OfflineSenseVoiceModelConfig::Register(ParseOptions *po) {
po->Register("sense-voice", &model, "Path to model.onnx of SenseVoice.");
po->Register(
"sense-voice-language", &language,
"Valid values: auto, zh, en, ja, ko, yue. If left empty, auto is used");
po->Register(
"sense-voice-use-itn", &use_itn,
"True to enable inverse text normalization. False to disable it.");
}

bool OfflineSenseVoiceModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("SenseVoice model '%s' does not exist", model.c_str());
return false;
}

if (!language.empty()) {
if (language != "auto" && language != "zh" && language != "en" &&
language != "ja" && language != "ko" && language != "yue") {
SHERPA_ONNX_LOGE(
"Invalid sense-voice-language: '%s'. Valid values are: auto, zh, en, "
"ja, ko, yue. Or you can leave it empty to use 'auto'",
language.c_str());

return false;
}
}

return true;
}

std::string OfflineSenseVoiceModelConfig::ToString() const {
std::ostringstream os;

os << "OfflineSenseVoiceModelConfig(";
os << "model=\"" << model << "\", ";
os << "language=\"" << language << "\", ";
os << "use_itn=" << (use_itn ? "True" : "False") << ")";

return os.str();
}

} // namespace sherpa_onnx
39 changes: 39 additions & 0 deletions sherpa-onnx/csrc/offline-sense-voice-model-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// sherpa-onnx/csrc/offline-sense-voice-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_

#include <string>

#include "sherpa-onnx/csrc/parse-options.h"

namespace sherpa_onnx {

struct OfflineSenseVoiceModelConfig {
std::string model;

// "" or "auto" to let the model recognize the language
// valid values:
// zh, en, ja, ko, yue, auto
std::string language = "auto";

// true to use inverse text normalization
// false to not use inverse text normalization
bool use_itn = false;

OfflineSenseVoiceModelConfig() = default;
explicit OfflineSenseVoiceModelConfig(const std::string &model,
const std::string &language,
bool use_itn)
: model(model), language(language), use_itn(use_itn) {}

void Register(ParseOptions *po);
bool Validate() const;

std::string ToString() const;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
1 change: 1 addition & 0 deletions sherpa-onnx/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(srcs
offline-paraformer-model-config.cc
offline-punctuation.cc
offline-recognizer.cc
offline-sense-voice-model-config.cc
offline-stream.cc
offline-tdnn-model-config.cc
offline-transducer-model-config.cc
Expand Down
7 changes: 6 additions & 1 deletion sherpa-onnx/python/csrc/offline-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
Expand All @@ -26,6 +27,7 @@ void PybindOfflineModelConfig(py::module *m) {
PybindOfflineTdnnModelConfig(m);
PybindOfflineZipformerCtcModelConfig(m);
PybindOfflineWenetCtcModelConfig(m);
PybindOfflineSenseVoiceModelConfig(m);

using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
Expand All @@ -36,7 +38,8 @@ void PybindOfflineModelConfig(py::module *m) {
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &, const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &, const std::string &,
const OfflineWenetCtcModelConfig &,
const OfflineSenseVoiceModelConfig &, const std::string &,
const std::string &, int32_t, bool, const std::string &,
const std::string &, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
Expand All @@ -46,6 +49,7 @@ void PybindOfflineModelConfig(py::module *m) {
py::arg("tdnn") = OfflineTdnnModelConfig(),
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("sense_voice") = OfflineSenseVoiceModelConfig(),
py::arg("telespeech_ctc") = "", py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "",
Expand All @@ -57,6 +61,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("tdnn", &PyClass::tdnn)
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("sense_voice", &PyClass::sense_voice)
.def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/python/csrc/offline-paraformer-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace sherpa_onnx {
void PybindOfflineParaformerModelConfig(py::module *m) {
using PyClass = OfflineParaformerModelConfig;
py::class_<PyClass>(*m, "OfflineParaformerModelConfig")
.def(py::init<>())
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
Expand Down
26 changes: 26 additions & 0 deletions sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// sherpa-onnx/python/csrc/offline-sense-voice-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-sense-voice-model-config.h"

#include <string>
#include <vector>

#include "sherpa-onnx/python/csrc/offline-sense-voice-model-config.h"

namespace sherpa_onnx {

void PybindOfflineSenseVoiceModelConfig(py::module *m) {
using PyClass = OfflineSenseVoiceModelConfig;
py::class_<PyClass>(*m, "OfflineSenseVoiceModelConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &, bool>(),
py::arg("model"), py::arg("language"), py::arg("use_itn"))
.def_readwrite("model", &PyClass::model)
.def_readwrite("language", &PyClass::language)
.def_readwrite("use_itn", &PyClass::use_itn)
.def("__str__", &PyClass::ToString);
}

} // namespace sherpa_onnx
16 changes: 16 additions & 0 deletions sherpa-onnx/python/csrc/offline-sense-voice-model-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/offline-sense-voice-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_

#include "sherpa-onnx/python/csrc/sherpa-onnx.h"

namespace sherpa_onnx {

void PybindOfflineSenseVoiceModelConfig(py::module *m);

}

#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_
1 change: 1 addition & 0 deletions sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
OfflineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
OfflineSenseVoiceModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import (
Expand Down

0 comments on commit fe31189

Please sign in to comment.